1
0
mirror of https://github.com/paboyle/Grid.git synced 2025-04-09 21:50:45 +01:00

Faster GPU basis rotation

May need to later include Regensburg optimised CPU variant
This commit is contained in:
Peter Boyle 2020-04-10 11:06:04 -04:00
parent 8a5c13d5fb
commit dc50190b8f

View File

@ -59,16 +59,15 @@ void basisRotate(std::vector<Field> &basis,Eigen::MatrixXd& Qt,int j0, int j1, i
{ {
typedef decltype(basis[0].View()) View; typedef decltype(basis[0].View()) View;
auto tmp_v = basis[0].View(); auto tmp_v = basis[0].View();
std::vector<View> basis_v(basis.size(),tmp_v); Vector<View> basis_v(basis.size(),tmp_v);
typedef typename Field::vector_object vobj; typedef typename Field::vector_object vobj;
GridBase* grid = basis[0].Grid(); GridBase* grid = basis[0].Grid();
for(int k=0;k<basis.size();k++){ for(int k=0;k<basis.size();k++){
basis_v[k] = basis[k].View(); basis_v[k] = basis[k].View();
} }
#if 0
std::vector < vobj , commAllocator<vobj> > Bt(thread_max() * Nm); // Thread private std::vector < vobj , commAllocator<vobj> > Bt(thread_max() * Nm); // Thread private
thread_region thread_region
{ {
vobj* B = Bt.data() + Nm * thread_num(); vobj* B = Bt.data() + Nm * thread_num();
@ -86,24 +85,89 @@ void basisRotate(std::vector<Field> &basis,Eigen::MatrixXd& Qt,int j0, int j1, i
} }
}); });
} }
#else
int nrot = j1-j0;
uint64_t oSites =grid->oSites();
uint64_t siteBlock=(grid->oSites()+nrot-1)/nrot; // Maximum 1 additional vector overhead
// printf("BasisRotate %d %d nrot %d siteBlock %d\n",j0,j1,nrot,siteBlock);
Vector <vobj> Bt(siteBlock * nrot);
auto Bp=&Bt[0];
// GPU readable copy of Eigen matrix
Vector<double> Qt_jv(Nm*Nm);
double *Qt_p = & Qt_jv[0];
for(int k=0;k<Nm;++k){
for(int j=0;j<Nm;++j){
Qt_p[j*Nm+k]=Qt(j,k);
}
}
// Block the loop to keep storage footprint down
vobj zz=Zero();
for(uint64_t s=0;s<oSites;s+=siteBlock){
// remaining work in this block
int ssites=MIN(siteBlock,oSites-s);
// zero out the accumulators
accelerator_for(ss,siteBlock*nrot,vobj::Nsimd(),{
auto z=coalescedRead(zz);
coalescedWrite(Bp[ss],z);
});
accelerator_for(sj,ssites*nrot,vobj::Nsimd(),{
int j =sj%nrot;
int jj =j0+j;
int ss =sj/nrot;
int sss=ss+s;
for(int k=k0; k<k1; ++k){
auto tmp = coalescedRead(Bp[ss*nrot+j]);
coalescedWrite(Bp[ss*nrot+j],tmp+ Qt_p[jj*Nm+k] * coalescedRead(basis_v[k][sss]));
}
});
accelerator_for(sj,ssites*nrot,vobj::Nsimd(),{
int j =sj%nrot;
int jj =j0+j;
int ss =sj/nrot;
int sss=ss+s;
coalescedWrite(basis_v[jj][sss],coalescedRead(Bp[ss*nrot+j]));
});
}
#endif
} }
// Extract a single rotated vector // Extract a single rotated vector
template<class Field> template<class Field>
void basisRotateJ(Field &result,std::vector<Field> &basis,Eigen::MatrixXd& Qt,int j, int k0,int k1,int Nm) void basisRotateJ(Field &result,std::vector<Field> &basis,Eigen::MatrixXd& Qt,int j, int k0,int k1,int Nm)
{ {
typedef decltype(basis[0].View()) View;
typedef typename Field::vector_object vobj; typedef typename Field::vector_object vobj;
GridBase* grid = basis[0].Grid(); GridBase* grid = basis[0].Grid();
result.Checkerboard() = basis[0].Checkerboard(); result.Checkerboard() = basis[0].Checkerboard();
auto result_v=result.View(); auto result_v=result.View();
thread_for(ss, grid->oSites(),{ Vector<View> basis_v(basis.size(),result_v);
vobj B = Zero(); for(int k=0;k<basis.size();k++){
basis_v[k] = basis[k].View();
}
vobj zz=Zero();
Vector<double> Qt_jv(Nm);
double * Qt_j = & Qt_jv[0];
for(int k=0;k<Nm;++k) Qt_j[k]=Qt(j,k);
accelerator_for(ss, grid->oSites(),vobj::Nsimd(),{
auto B=coalescedRead(zz);
for(int k=k0; k<k1; ++k){ for(int k=k0; k<k1; ++k){
auto basis_k = basis[k].View(); B +=Qt_j[k] * coalescedRead(basis_v[k][ss]);
B +=Qt(j,k) * basis_k[ss];
} }
result_v[ss] = B; coalescedWrite(result_v[ss], B);
}); });
} }
@ -303,7 +367,7 @@ public:
RealD _eresid, // resid in lmdue deficit RealD _eresid, // resid in lmdue deficit
int _MaxIter, // Max iterations int _MaxIter, // Max iterations
RealD _betastp=0.0, // if beta(k) < betastp: converged RealD _betastp=0.0, // if beta(k) < betastp: converged
int _MinRestart=1, int _orth_period = 1, int _MinRestart=0, int _orth_period = 1,
IRLdiagonalisation _diagonalisation= IRLdiagonaliseWithEigen) : IRLdiagonalisation _diagonalisation= IRLdiagonaliseWithEigen) :
SimpleTester(HermOp), _PolyOp(PolyOp), _HermOp(HermOp), _Tester(SimpleTester), SimpleTester(HermOp), _PolyOp(PolyOp), _HermOp(HermOp), _Tester(SimpleTester),
Nstop(_Nstop) , Nk(_Nk), Nm(_Nm), Nstop(_Nstop) , Nk(_Nk), Nm(_Nm),