mirror of
https://github.com/paboyle/Grid.git
synced 2024-11-09 23:45:36 +00:00
BlockCG linalg acceleratoin with BLAS
This commit is contained in:
parent
77944437ce
commit
d4dc5e0f43
@ -38,16 +38,25 @@ public:
|
|||||||
|
|
||||||
typedef typename Field::scalar_type scalar;
|
typedef typename Field::scalar_type scalar;
|
||||||
typedef typename Field::scalar_object scalar_object;
|
typedef typename Field::scalar_object scalar_object;
|
||||||
|
typedef typename Field::vector_object vector_object;
|
||||||
|
|
||||||
deviceVector<scalar> BLAS_X; // nrhs x vol -- the sources
|
deviceVector<scalar> BLAS_X; // nrhs x vol -- the sources
|
||||||
deviceVector<scalar> BLAS_Y; // nrhs x vol -- the result
|
deviceVector<scalar> BLAS_Y; // nrhs x vol -- the result
|
||||||
deviceVector<scalar> BLAS_C; // nrhs x nrhs -- the coefficients
|
deviceVector<scalar> BLAS_C; // nrhs x nrhs -- the coefficients
|
||||||
|
deviceVector<scalar> BLAS_Cred; // nrhs x nrhs x oSites -- reduction buffer
|
||||||
|
deviceVector<scalar *> Xdip;
|
||||||
|
deviceVector<scalar *> Ydip;
|
||||||
|
deviceVector<scalar *> Cdip;
|
||||||
|
|
||||||
MultiRHSBlockCGLinalg() {};
|
MultiRHSBlockCGLinalg() {};
|
||||||
~MultiRHSBlockCGLinalg(){ Deallocate(); };
|
~MultiRHSBlockCGLinalg(){ Deallocate(); };
|
||||||
|
|
||||||
void Deallocate(void)
|
void Deallocate(void)
|
||||||
{
|
{
|
||||||
|
Xdip.resize(0);
|
||||||
|
Ydip.resize(0);
|
||||||
|
Cdip.resize(0);
|
||||||
|
BLAS_Cred.resize(0);
|
||||||
BLAS_C.resize(0);
|
BLAS_C.resize(0);
|
||||||
BLAS_X.resize(0);
|
BLAS_X.resize(0);
|
||||||
BLAS_Y.resize(0);
|
BLAS_Y.resize(0);
|
||||||
@ -120,10 +129,10 @@ public:
|
|||||||
/////////////////////////////////////////
|
/////////////////////////////////////////
|
||||||
BLAS.gemmBatched(GridBLAS_OP_N,GridBLAS_OP_N,
|
BLAS.gemmBatched(GridBLAS_OP_N,GridBLAS_OP_N,
|
||||||
vw,nrhs,nrhs,
|
vw,nrhs,nrhs,
|
||||||
ComplexD(1.0),
|
scalar(1.0),
|
||||||
Xd,
|
Xd,
|
||||||
Cd,
|
Cd,
|
||||||
ComplexD(0.0), // wipe out Y
|
scalar(0.0), // wipe out Y
|
||||||
Yd);
|
Yd);
|
||||||
BLAS.synchronise();
|
BLAS.synchronise();
|
||||||
RealD t3 = usecond();
|
RealD t3 = usecond();
|
||||||
@ -144,6 +153,7 @@ public:
|
|||||||
|
|
||||||
void InnerProductMatrix(Eigen::MatrixXcd &m , const std::vector<Field> &X, const std::vector<Field> &Y)
|
void InnerProductMatrix(Eigen::MatrixXcd &m , const std::vector<Field> &X, const std::vector<Field> &Y)
|
||||||
{
|
{
|
||||||
|
#if 0
|
||||||
int nrhs;
|
int nrhs;
|
||||||
GridBase *grid;
|
GridBase *grid;
|
||||||
uint64_t vol;
|
uint64_t vol;
|
||||||
@ -242,7 +252,124 @@ public:
|
|||||||
std::cout << "InnerProductMatrix gsum t5 "<< t5-t4<<" us"<<std::endl;
|
std::cout << "InnerProductMatrix gsum t5 "<< t5-t4<<" us"<<std::endl;
|
||||||
std::cout << "InnerProductMatrix cp t6 "<< t6-t5<<" us"<<std::endl;
|
std::cout << "InnerProductMatrix cp t6 "<< t6-t5<<" us"<<std::endl;
|
||||||
std::cout << "InnerProductMatrix took "<< t6-t0<<" us"<<std::endl;
|
std::cout << "InnerProductMatrix took "<< t6-t0<<" us"<<std::endl;
|
||||||
|
#else
|
||||||
|
int nrhs;
|
||||||
|
GridBase *grid;
|
||||||
|
uint64_t vol;
|
||||||
|
uint64_t words;
|
||||||
|
|
||||||
|
nrhs = X.size();
|
||||||
|
assert(X.size()==Y.size());
|
||||||
|
conformable(X[0],Y[0]);
|
||||||
|
|
||||||
|
grid = X[0].Grid();
|
||||||
|
int rd0 = grid->_rdimensions[0] * grid->_rdimensions[1];
|
||||||
|
vol = grid->oSites()/rd0;
|
||||||
|
words = rd0*sizeof(vector_object)/sizeof(scalar);
|
||||||
|
int64_t vw = vol * words;
|
||||||
|
assert(vw == grid->lSites()*sizeof(scalar_object)/sizeof(scalar));
|
||||||
|
|
||||||
|
RealD t0 = usecond();
|
||||||
|
BLAS_X.resize(nrhs * vw); // cost free if size doesn't change
|
||||||
|
BLAS_Y.resize(nrhs * vw); // cost free if size doesn't change
|
||||||
|
BLAS_Cred.resize(nrhs * nrhs * vol);// cost free if size doesn't change
|
||||||
|
RealD t1 = usecond();
|
||||||
|
|
||||||
|
/////////////////////////////////////////////
|
||||||
|
// Copy in the multi-rhs sources -- layout batched BLAS ready
|
||||||
|
/////////////////////////////////////////////
|
||||||
|
for(int r=0;r<nrhs;r++){
|
||||||
|
autoView(x_v,X[r],AcceleratorRead);
|
||||||
|
autoView(y_v,Y[r],AcceleratorRead);
|
||||||
|
scalar *from_x=(scalar *)&x_v[0];
|
||||||
|
scalar *from_y=(scalar *)&y_v[0];
|
||||||
|
scalar *BX = &BLAS_X[0];
|
||||||
|
scalar *BY = &BLAS_Y[0];
|
||||||
|
accelerator_for(ssw,vw,1,{
|
||||||
|
uint64_t ss=ssw/words;
|
||||||
|
uint64_t w=ssw%words;
|
||||||
|
uint64_t offset = w+r*words+ss*nrhs*words; // [ss][rhs][words]
|
||||||
|
BX[offset] = from_x[ssw];
|
||||||
|
BY[offset] = from_y[ssw];
|
||||||
|
});
|
||||||
|
}
|
||||||
|
RealD t2 = usecond();
|
||||||
|
|
||||||
|
/*
|
||||||
|
* in Fortran column major notation (cuBlas order)
|
||||||
|
*
|
||||||
|
* Xxr = [X1(x)][..][Xn(x)]
|
||||||
|
*
|
||||||
|
* Yxr = [Y1(x)][..][Ym(x)]
|
||||||
|
*
|
||||||
|
* C_rs = X^dag Y
|
||||||
|
*/
|
||||||
|
Xdip.resize(vol);
|
||||||
|
Ydip.resize(vol);
|
||||||
|
Cdip.resize(vol);
|
||||||
|
std::vector<scalar *> Xh(vol);
|
||||||
|
std::vector<scalar *> Yh(vol);
|
||||||
|
std::vector<scalar *> Ch(vol);
|
||||||
|
for(uint64_t ss=0;ss<vol;ss++){
|
||||||
|
|
||||||
|
Xh[ss] = & BLAS_X[ss*nrhs*words];
|
||||||
|
Yh[ss] = & BLAS_Y[ss*nrhs*words];
|
||||||
|
Ch[ss] = & BLAS_Cred[ss*nrhs*nrhs];
|
||||||
|
|
||||||
|
}
|
||||||
|
acceleratorCopyToDevice(&Xh[0],&Xdip[0],vol*sizeof(scalar *));
|
||||||
|
acceleratorCopyToDevice(&Yh[0],&Ydip[0],vol*sizeof(scalar *));
|
||||||
|
acceleratorCopyToDevice(&Ch[0],&Cdip[0],vol*sizeof(scalar *));
|
||||||
|
|
||||||
|
GridBLAS BLAS;
|
||||||
|
|
||||||
|
RealD t3 = usecond();
|
||||||
|
/////////////////////////////////////////
|
||||||
|
// C_rs = X^dag Y
|
||||||
|
/////////////////////////////////////////
|
||||||
|
BLAS.gemmBatched(GridBLAS_OP_C,GridBLAS_OP_N,
|
||||||
|
nrhs,nrhs,words,
|
||||||
|
ComplexD(1.0),
|
||||||
|
Xdip,
|
||||||
|
Ydip,
|
||||||
|
ComplexD(0.0), // wipe out C
|
||||||
|
Cdip);
|
||||||
|
BLAS.synchronise();
|
||||||
|
RealD t4 = usecond();
|
||||||
|
|
||||||
|
std::vector<scalar> HOST_C(BLAS_Cred.size()); // nrhs . nrhs -- the coefficients
|
||||||
|
acceleratorCopyFromDevice(&BLAS_Cred[0],&HOST_C[0],BLAS_Cred.size()*sizeof(scalar));
|
||||||
|
|
||||||
|
RealD t5 = usecond();
|
||||||
|
m = Eigen::MatrixXcd::Zero(nrhs,nrhs);
|
||||||
|
for(int ss=0;ss<vol;ss++){
|
||||||
|
Eigen::Map<Eigen::MatrixXcd> eC((std::complex<double> *)&HOST_C[ss*nrhs*nrhs],nrhs,nrhs);
|
||||||
|
m = m + eC;
|
||||||
|
}
|
||||||
|
RealD t6l = usecond();
|
||||||
|
grid->GlobalSumVector((scalar *) &m(0,0),nrhs*nrhs);
|
||||||
|
RealD t6 = usecond();
|
||||||
|
uint64_t M=nrhs;
|
||||||
|
uint64_t N=nrhs;
|
||||||
|
uint64_t K=vw;
|
||||||
|
RealD xybytes = grid->lSites()*sizeof(scalar_object);
|
||||||
|
RealD bytes = 1.0*sizeof(ComplexD)*(M*N*2+N*K+M*K);
|
||||||
|
RealD flops = 8.0*M*N*K;
|
||||||
|
flops = flops/(t4-t3)/1.e3;
|
||||||
|
bytes = bytes/(t4-t3)/1.e3;
|
||||||
|
xybytes = 4*xybytes/(t2-t1)/1.e3;
|
||||||
|
std::cout << "InnerProductMatrix m,n,k "<< M<<","<<N<<","<<K<<std::endl;
|
||||||
|
std::cout << "InnerProductMatrix alloc t1 "<< t1-t0<<" us"<<std::endl;
|
||||||
|
std::cout << "InnerProductMatrix cp t2 "<< t2-t1<<" us "<<xybytes<<" GB/s"<<std::endl;
|
||||||
|
std::cout << "InnerProductMatrix setup t3 "<< t3-t2<<" us"<<std::endl;
|
||||||
|
std::cout << "InnerProductMatrix blas t4 "<< t4-t3<<" us"<<std::endl;
|
||||||
|
std::cout << "InnerProductMatrix blas "<< flops<<" GF/s"<<std::endl;
|
||||||
|
std::cout << "InnerProductMatrix blas "<< bytes<<" GB/s"<<std::endl;
|
||||||
|
std::cout << "InnerProductMatrix cp t5 "<< t5-t4<<" us"<<std::endl;
|
||||||
|
std::cout << "InnerProductMatrix lsum t6l "<< t6l-t5<<" us"<<std::endl;
|
||||||
|
std::cout << "InnerProductMatrix gsum t6 "<< t6-t6l<<" us"<<std::endl;
|
||||||
|
std::cout << "InnerProductMatrix took "<< t6-t0<<" us"<<std::endl;
|
||||||
|
#endif
|
||||||
}
|
}
|
||||||
};
|
};
|
||||||
|
|
||||||
|
@ -447,10 +447,10 @@ public:
|
|||||||
/////////////////////////////////////////
|
/////////////////////////////////////////
|
||||||
BLAS.gemmBatched(GridBLAS_OP_C,GridBLAS_OP_N,
|
BLAS.gemmBatched(GridBLAS_OP_C,GridBLAS_OP_N,
|
||||||
nbasis,nrhs,vw,
|
nbasis,nrhs,vw,
|
||||||
ComplexD(1.0),
|
scalar(1.0),
|
||||||
Vd,
|
Vd,
|
||||||
Fd,
|
Fd,
|
||||||
ComplexD(0.0), // wipe out C
|
scalar(0.0), // wipe out C
|
||||||
Cd);
|
Cd);
|
||||||
BLAS.synchronise();
|
BLAS.synchronise();
|
||||||
// std::cout << "BlockProject done"<<std::endl;
|
// std::cout << "BlockProject done"<<std::endl;
|
||||||
@ -497,10 +497,10 @@ public:
|
|||||||
int64_t vw = block_vol * words;
|
int64_t vw = block_vol * words;
|
||||||
BLAS.gemmBatched(GridBLAS_OP_N,GridBLAS_OP_N,
|
BLAS.gemmBatched(GridBLAS_OP_N,GridBLAS_OP_N,
|
||||||
vw,nrhs,nbasis,
|
vw,nrhs,nbasis,
|
||||||
ComplexD(1.0),
|
scalar(1.0),
|
||||||
Vd,
|
Vd,
|
||||||
Cd,
|
Cd,
|
||||||
ComplexD(0.0), // wipe out C
|
scalar(0.0), // wipe out C
|
||||||
Fd);
|
Fd);
|
||||||
BLAS.synchronise();
|
BLAS.synchronise();
|
||||||
// std::cout << " blas call done"<<std::endl;
|
// std::cout << " blas call done"<<std::endl;
|
||||||
|
@ -182,10 +182,10 @@ public:
|
|||||||
/////////////////////////////////////////
|
/////////////////////////////////////////
|
||||||
BLAS.gemmBatched(GridBLAS_OP_C,GridBLAS_OP_N,
|
BLAS.gemmBatched(GridBLAS_OP_C,GridBLAS_OP_N,
|
||||||
nev,nrhs,vw,
|
nev,nrhs,vw,
|
||||||
ComplexD(1.0),
|
scalar(1.0),
|
||||||
Ed,
|
Ed,
|
||||||
Rd,
|
Rd,
|
||||||
ComplexD(0.0), // wipe out C
|
scalar(0.0), // wipe out C
|
||||||
Cd);
|
Cd);
|
||||||
BLAS.synchronise();
|
BLAS.synchronise();
|
||||||
|
|
||||||
@ -210,10 +210,10 @@ public:
|
|||||||
/////////////////////////////////////////
|
/////////////////////////////////////////
|
||||||
BLAS.gemmBatched(GridBLAS_OP_N,GridBLAS_OP_N,
|
BLAS.gemmBatched(GridBLAS_OP_N,GridBLAS_OP_N,
|
||||||
vw,nrhs,nev,
|
vw,nrhs,nev,
|
||||||
ComplexD(1.0),
|
scalar(1.0),
|
||||||
Ed, // x . nev
|
Ed, // x . nev
|
||||||
Cd, // nev . nrhs
|
Cd, // nev . nrhs
|
||||||
ComplexD(0.0),
|
scalar(0.0),
|
||||||
Gd);
|
Gd);
|
||||||
BLAS.synchronise();
|
BLAS.synchronise();
|
||||||
|
|
||||||
|
Loading…
Reference in New Issue
Block a user