1
0
mirror of https://github.com/paboyle/Grid.git synced 2024-11-09 23:45:36 +00:00

Working and faster version

This commit is contained in:
Peter Boyle 2024-02-21 14:46:43 -05:00
parent e1d0a7cec3
commit 9c2565f64e

View File

@ -32,21 +32,6 @@ Author: Peter Boyle <pboyle@bnl.gov>
NAMESPACE_BEGIN(Grid); NAMESPACE_BEGIN(Grid);
// Move this to accelerator.h
// Also give a copy device.
// Rename acceleratorPut
// Rename acceleratorGet
template<class T> void deviceSet(T& dev,T&host)
{
acceleratorCopyToDevice(&host,&dev,sizeof(T));
}
template<class T> T deviceGet(T& dev)
{
T host;
acceleratorCopyFromDevice(&dev,&host,sizeof(T));
return host;
}
// Fine Object == (per site) type of fine field // Fine Object == (per site) type of fine field
// nbasis == number of deflation vectors // nbasis == number of deflation vectors
template<class Fobj,class CComplex,int nbasis> template<class Fobj,class CComplex,int nbasis>
@ -133,14 +118,12 @@ public:
for(int p=0;p<geom.npoint;p++){ for(int p=0;p<geom.npoint;p++){
for(int ss=0;ss<unpadded_sites;ss++){ for(int ss=0;ss<unpadded_sites;ss++){
ComplexD *ptr = (ComplexD *)&BLAS_A[p][ss]; ComplexD *ptr = (ComplexD *)&BLAS_A[p][ss];
//ComplexD *ptr = (ComplexD *)&BLAS_A[p][0]; std::cout << " A ptr "<<std::hex<<ptr<<std::dec<<" "<<ss<<"/"<<BLAS_A[p].size()<<std::endl; acceleratorPut(BLAS_AP[p][ss],ptr);
deviceSet(BLAS_AP[p][ss],ptr);
} }
} }
for(int ss=0;ss<unpadded_sites;ss++){ for(int ss=0;ss<unpadded_sites;ss++){
ComplexD *ptr = (ComplexD *)&BLAS_C[ss*nrhs]; ComplexD *ptr = (ComplexD *)&BLAS_C[ss*nrhs];
//ComplexD *ptr = (ComplexD *)&BLAS_C[0]; std::cout << " C ptr "<<std::hex<<ptr<<std::dec<<" "<<ss<<"/"<<BLAS_C.size()<<std::endl; acceleratorPut(BLAS_CP[ss],ptr);
deviceSet(BLAS_CP[ss],ptr);
} }
///////////////////////////////////////////////// /////////////////////////////////////////////////
@ -155,19 +138,14 @@ public:
ghost_zone=1; // If general stencil wrapped in any direction, wrap=1 ghost_zone=1; // If general stencil wrapped in any direction, wrap=1
} }
} }
// GeneralStencilEntryReordered tmp;
if( ghost_zone==0) { if( ghost_zone==0) {
for(int32_t point = 0 ; point < geom.npoint; point++){ for(int32_t point = 0 ; point < geom.npoint; point++){
int i=s*orhs*geom.npoint+point; int i=s*orhs*geom.npoint+point;
int32_t nbr = Stencil._entries[i]._offset*CComplex::Nsimd(); // oSite -> lSite int32_t nbr = Stencil._entries[i]._offset*CComplex::Nsimd(); // oSite -> lSite
// std::cout << " B ptr "<< nbr<<"/"<<BLAS_B.size()<<std::endl;
assert(nbr<BLAS_B.size()); assert(nbr<BLAS_B.size());
ComplexD * ptr = (ComplexD *)&BLAS_B[nbr]; ComplexD * ptr = (ComplexD *)&BLAS_B[nbr];
// ComplexD * ptr = (ComplexD *)&BLAS_B[0]; acceleratorPut(BLAS_BP[point][j],ptr); // neighbour indexing in ghost zone volume
// std::cout << " B ptr unpadded "<<std::hex<<ptr<<std::dec<<" "<<s<<"/"<<padded_sites<<std::endl;
// std::cout << " B ptr padded "<<std::hex<<ptr<<std::dec<<" "<<j<<"/"<<unpadded_sites<<std::endl;
deviceSet(BLAS_BP[point][j],ptr); // neighbour indexing in ghost zone volume
// auto tmp = deviceGet(*BLAS_BP[point][j]); // debug trigger SEGV if bad ptr
} }
j++; j++;
} }
@ -236,7 +214,6 @@ public:
#if 0 #if 0
std::vector<typename vobj::scalar_object> tmp; std::vector<typename vobj::scalar_object> tmp;
tmp.resize(in.size()); tmp.resize(in.size());
// std::cout << "BLAStoGrid volume " <<tmp.size()<<" "<< grid.Grid()->lSites()<<std::endl;
assert(in.size()==grid.Grid()->lSites()); assert(in.size()==grid.Grid()->lSites());
acceleratorCopyFromDevice(&in[0],&tmp[0],sizeof(typename vobj::scalar_object)*in.size()); acceleratorCopyFromDevice(&in[0],&tmp[0],sizeof(typename vobj::scalar_object)*in.size());
vectorizeFromLexOrdArray(tmp,grid); vectorizeFromLexOrdArray(tmp,grid);
@ -289,19 +266,10 @@ public:
} }
void CopyMatrix (void) void CopyMatrix (void)
{ {
// Clone "A" to be lexicographic in the physics coords
// Use unvectorisetolexordarray
// Copy to device
for(int p=0;p<geom.npoint;p++){ for(int p=0;p<geom.npoint;p++){
//Unpadded //Unpadded
auto Aup = _Op.Cell.Extract(_Op._A[p]); auto Aup = _Op.Cell.Extract(_Op._A[p]);
// Coordinate coor({0,0,0,0,0});
// auto sval = peekSite(Aup,coor);
// std::cout << "CopyMatrix: p "<<p<<" Aup[0] :"<<sval<<std::endl;
// sval = peekSite(_Op._A[p],coor);
// std::cout << "CopyMatrix: p "<<p<<" _Op._Ap[0] :"<<sval<<std::endl;
GridtoBLAS(Aup,BLAS_A[p]); GridtoBLAS(Aup,BLAS_A[p]);
// std::cout << "Copy Matrix p "<<p<<" "<< deviceGet(BLAS_A[p][0])<<std::endl;
} }
} }
void Mdag(const CoarseVector &in, CoarseVector &out) void Mdag(const CoarseVector &in, CoarseVector &out)
@ -346,11 +314,8 @@ public:
int64_t nrhs =pin.Grid()->GlobalDimensions()[0]; int64_t nrhs =pin.Grid()->GlobalDimensions()[0];
assert(nrhs>=1); assert(nrhs>=1);
// std::cout << GridLogMessage << "New Mrhs GridtoBLAS in sizes "<<in.Grid()->lSites()<<" "<<pin.Grid()->lSites()<<std::endl;
t_GtoB=-usecond(); t_GtoB=-usecond();
GridtoBLAS(pin,BLAS_B); GridtoBLAS(pin,BLAS_B);
// out = Zero();
// GridtoBLAS(out,BLAS_C);
t_GtoB+=usecond(); t_GtoB+=usecond();
GridBLAS BLAS; GridBLAS BLAS;
@ -360,7 +325,7 @@ public:
RealD c = 1.0; RealD c = 1.0;
if (p==0) c = 0.0; if (p==0) c = 0.0;
ComplexD beta(c); ComplexD beta(c);
// std::cout << GridLogMessage << "New Mrhs coarse gemmBatched "<<p<<std::endl;
BLAS.gemmBatched(nbasis,nrhs,nbasis, BLAS.gemmBatched(nbasis,nrhs,nbasis,
ComplexD(1.0), ComplexD(1.0),
BLAS_AP[p], BLAS_AP[p],
@ -370,16 +335,12 @@ public:
} }
BLAS.synchronise(); BLAS.synchronise();
t_mult+=usecond(); t_mult+=usecond();
// std::cout << GridLogMessage << "New Mrhs coarse BLAStoGrid "<<std::endl;
t_BtoG=-usecond(); t_BtoG=-usecond();
BLAStoGrid(out,BLAS_C); BLAStoGrid(out,BLAS_C);
t_BtoG+=usecond(); t_BtoG+=usecond();
t_tot+=usecond(); t_tot+=usecond();
// auto check =deviceGet(BLAS_C[0]);
// std::cout << "C[0] "<<check<<std::endl;
// Coordinate coor({0,0,0,0,0,0});
// peekLocalSite(check,out,coor);
// std::cout << "C[0] "<< check<<std::endl;
std::cout << GridLogMessage << "New Mrhs coarse DONE "<<std::endl; std::cout << GridLogMessage << "New Mrhs coarse DONE "<<std::endl;
std::cout << GridLogMessage<<"Coarse Mult exch "<<t_exch<<" us"<<std::endl; std::cout << GridLogMessage<<"Coarse Mult exch "<<t_exch<<" us"<<std::endl;
std::cout << GridLogMessage<<"Coarse Mult mult "<<t_mult<<" us"<<std::endl; std::cout << GridLogMessage<<"Coarse Mult mult "<<t_mult<<" us"<<std::endl;