mirror of
https://github.com/paboyle/Grid.git
synced 2024-11-09 23:45:36 +00:00
Working and faster version
This commit is contained in:
parent
e1d0a7cec3
commit
9c2565f64e
@ -32,21 +32,6 @@ Author: Peter Boyle <pboyle@bnl.gov>
|
||||
NAMESPACE_BEGIN(Grid);
|
||||
|
||||
|
||||
// Move this to accelerator.h
|
||||
// Also give a copy device.
|
||||
// Rename acceleratorPut
|
||||
// Rename acceleratorGet
|
||||
template<class T> void deviceSet(T& dev,T&host)
|
||||
{
|
||||
acceleratorCopyToDevice(&host,&dev,sizeof(T));
|
||||
}
|
||||
template<class T> T deviceGet(T& dev)
|
||||
{
|
||||
T host;
|
||||
acceleratorCopyFromDevice(&dev,&host,sizeof(T));
|
||||
return host;
|
||||
}
|
||||
|
||||
// Fine Object == (per site) type of fine field
|
||||
// nbasis == number of deflation vectors
|
||||
template<class Fobj,class CComplex,int nbasis>
|
||||
@ -133,14 +118,12 @@ public:
|
||||
for(int p=0;p<geom.npoint;p++){
|
||||
for(int ss=0;ss<unpadded_sites;ss++){
|
||||
ComplexD *ptr = (ComplexD *)&BLAS_A[p][ss];
|
||||
//ComplexD *ptr = (ComplexD *)&BLAS_A[p][0]; std::cout << " A ptr "<<std::hex<<ptr<<std::dec<<" "<<ss<<"/"<<BLAS_A[p].size()<<std::endl;
|
||||
deviceSet(BLAS_AP[p][ss],ptr);
|
||||
acceleratorPut(BLAS_AP[p][ss],ptr);
|
||||
}
|
||||
}
|
||||
for(int ss=0;ss<unpadded_sites;ss++){
|
||||
ComplexD *ptr = (ComplexD *)&BLAS_C[ss*nrhs];
|
||||
//ComplexD *ptr = (ComplexD *)&BLAS_C[0]; std::cout << " C ptr "<<std::hex<<ptr<<std::dec<<" "<<ss<<"/"<<BLAS_C.size()<<std::endl;
|
||||
deviceSet(BLAS_CP[ss],ptr);
|
||||
acceleratorPut(BLAS_CP[ss],ptr);
|
||||
}
|
||||
|
||||
/////////////////////////////////////////////////
|
||||
@ -155,19 +138,14 @@ public:
|
||||
ghost_zone=1; // If general stencil wrapped in any direction, wrap=1
|
||||
}
|
||||
}
|
||||
// GeneralStencilEntryReordered tmp;
|
||||
|
||||
if( ghost_zone==0) {
|
||||
for(int32_t point = 0 ; point < geom.npoint; point++){
|
||||
int i=s*orhs*geom.npoint+point;
|
||||
int32_t nbr = Stencil._entries[i]._offset*CComplex::Nsimd(); // oSite -> lSite
|
||||
// std::cout << " B ptr "<< nbr<<"/"<<BLAS_B.size()<<std::endl;
|
||||
assert(nbr<BLAS_B.size());
|
||||
ComplexD * ptr = (ComplexD *)&BLAS_B[nbr];
|
||||
// ComplexD * ptr = (ComplexD *)&BLAS_B[0];
|
||||
// std::cout << " B ptr unpadded "<<std::hex<<ptr<<std::dec<<" "<<s<<"/"<<padded_sites<<std::endl;
|
||||
// std::cout << " B ptr padded "<<std::hex<<ptr<<std::dec<<" "<<j<<"/"<<unpadded_sites<<std::endl;
|
||||
deviceSet(BLAS_BP[point][j],ptr); // neighbour indexing in ghost zone volume
|
||||
// auto tmp = deviceGet(*BLAS_BP[point][j]); // debug trigger SEGV if bad ptr
|
||||
acceleratorPut(BLAS_BP[point][j],ptr); // neighbour indexing in ghost zone volume
|
||||
}
|
||||
j++;
|
||||
}
|
||||
@ -236,7 +214,6 @@ public:
|
||||
#if 0
|
||||
std::vector<typename vobj::scalar_object> tmp;
|
||||
tmp.resize(in.size());
|
||||
// std::cout << "BLAStoGrid volume " <<tmp.size()<<" "<< grid.Grid()->lSites()<<std::endl;
|
||||
assert(in.size()==grid.Grid()->lSites());
|
||||
acceleratorCopyFromDevice(&in[0],&tmp[0],sizeof(typename vobj::scalar_object)*in.size());
|
||||
vectorizeFromLexOrdArray(tmp,grid);
|
||||
@ -289,19 +266,10 @@ public:
|
||||
}
|
||||
void CopyMatrix (void)
|
||||
{
|
||||
// Clone "A" to be lexicographic in the physics coords
|
||||
// Use unvectorisetolexordarray
|
||||
// Copy to device
|
||||
for(int p=0;p<geom.npoint;p++){
|
||||
//Unpadded
|
||||
auto Aup = _Op.Cell.Extract(_Op._A[p]);
|
||||
// Coordinate coor({0,0,0,0,0});
|
||||
// auto sval = peekSite(Aup,coor);
|
||||
// std::cout << "CopyMatrix: p "<<p<<" Aup[0] :"<<sval<<std::endl;
|
||||
// sval = peekSite(_Op._A[p],coor);
|
||||
// std::cout << "CopyMatrix: p "<<p<<" _Op._Ap[0] :"<<sval<<std::endl;
|
||||
GridtoBLAS(Aup,BLAS_A[p]);
|
||||
// std::cout << "Copy Matrix p "<<p<<" "<< deviceGet(BLAS_A[p][0])<<std::endl;
|
||||
}
|
||||
}
|
||||
void Mdag(const CoarseVector &in, CoarseVector &out)
|
||||
@ -346,11 +314,8 @@ public:
|
||||
int64_t nrhs =pin.Grid()->GlobalDimensions()[0];
|
||||
assert(nrhs>=1);
|
||||
|
||||
// std::cout << GridLogMessage << "New Mrhs GridtoBLAS in sizes "<<in.Grid()->lSites()<<" "<<pin.Grid()->lSites()<<std::endl;
|
||||
t_GtoB=-usecond();
|
||||
GridtoBLAS(pin,BLAS_B);
|
||||
// out = Zero();
|
||||
// GridtoBLAS(out,BLAS_C);
|
||||
t_GtoB+=usecond();
|
||||
|
||||
GridBLAS BLAS;
|
||||
@ -360,7 +325,7 @@ public:
|
||||
RealD c = 1.0;
|
||||
if (p==0) c = 0.0;
|
||||
ComplexD beta(c);
|
||||
// std::cout << GridLogMessage << "New Mrhs coarse gemmBatched "<<p<<std::endl;
|
||||
|
||||
BLAS.gemmBatched(nbasis,nrhs,nbasis,
|
||||
ComplexD(1.0),
|
||||
BLAS_AP[p],
|
||||
@ -370,16 +335,12 @@ public:
|
||||
}
|
||||
BLAS.synchronise();
|
||||
t_mult+=usecond();
|
||||
// std::cout << GridLogMessage << "New Mrhs coarse BLAStoGrid "<<std::endl;
|
||||
|
||||
t_BtoG=-usecond();
|
||||
BLAStoGrid(out,BLAS_C);
|
||||
t_BtoG+=usecond();
|
||||
t_tot+=usecond();
|
||||
// auto check =deviceGet(BLAS_C[0]);
|
||||
// std::cout << "C[0] "<<check<<std::endl;
|
||||
// Coordinate coor({0,0,0,0,0,0});
|
||||
// peekLocalSite(check,out,coor);
|
||||
// std::cout << "C[0] "<< check<<std::endl;
|
||||
|
||||
std::cout << GridLogMessage << "New Mrhs coarse DONE "<<std::endl;
|
||||
std::cout << GridLogMessage<<"Coarse Mult exch "<<t_exch<<" us"<<std::endl;
|
||||
std::cout << GridLogMessage<<"Coarse Mult mult "<<t_mult<<" us"<<std::endl;
|
||||
|
Loading…
Reference in New Issue
Block a user