1
0
mirror of https://github.com/paboyle/Grid.git synced 2025-06-17 15:27:06 +01:00

Dirichlet filters running on AMD and now integrated in Fermion op

This commit is contained in:
Peter Boyle
2022-02-23 19:29:28 -05:00
parent 70988e43d2
commit 0f1c5b08a1
9 changed files with 125 additions and 61 deletions

View File

@ -36,41 +36,6 @@ using namespace Grid;
/// Move to domains ////
////////////////////////
template<typename MomentaField>
struct DirichletFilter: public MomentumFilterBase<MomentaField>
{
typedef typename MomentaField::vector_type vector_type; //SIMD-vectorized complex type
typedef typename MomentaField::scalar_type scalar_type; //scalar complex type
typedef iScalar<iScalar<iScalar<vector_type> > > ScalarType; //complex phase for each site
Coordinate Block;
DirichletFilter(const Coordinate &_Block): Block(_Block){}
void applyFilter(MomentaField &P) const override
{
GridBase *grid = P.Grid();
typedef decltype(PeekIndex<LorentzIndex>(P, 0)) LatCM;
////////////////////////////////////////////////////
// Zero strictly links crossing between domains
////////////////////////////////////////////////////
LatticeInteger coor(grid);
LatCM zz(grid); zz = Zero();
for(int mu=0;mu<Nd;mu++) {
if ( Block[mu] ) {
// If costly could provide Grid earlier and precompute masks
LatticeCoordinate(coor,mu);
auto P_mu = PeekIndex<LorentzIndex>(P, mu);
P_mu = where(mod(coor,Block[mu])==Integer(Block[mu]-1),zz,P_mu);
PokeIndex<LorentzIndex>(P, P_mu, mu);
}
}
}
};
Gamma::Algebra Gmu [] = {
Gamma::Algebra::GammaX,
Gamma::Algebra::GammaY,
@ -78,7 +43,7 @@ Gamma::Algebra Gmu [] = {
Gamma::Algebra::GammaT
};
void Benchmark(int Ls, std::vector<int> Dirichlet);
void Benchmark(int Ls, Coordinate Dirichlet);
int main (int argc, char ** argv)
{
@ -97,8 +62,9 @@ int main (int argc, char ** argv)
//////////////////
// With comms
//////////////////
std::vector<int> Dirichlet(5,0);
Coordinate Dirichlet(Nd+1,0);
std::cout << "\n\n\n\n\n\n" <<std::endl;
std::cout << GridLogMessage<< "++++++++++++++++++++++++++++++++++++++++++++++++" <<std::endl;
std::cout << GridLogMessage<< " Testing with full communication " <<std::endl;
std::cout << GridLogMessage<< "++++++++++++++++++++++++++++++++++++++++++++++++" <<std::endl;
@ -110,7 +76,7 @@ int main (int argc, char ** argv)
//////////////////
Coordinate latt4 = GridDefaultLatt();
Coordinate mpi = GridDefaultMpi();
std::vector<int> CommDim(Nd);
Coordinate CommDim(Nd);
Coordinate shm;
GlobalSharedMemory::GetShmDims(mpi,shm);
@ -118,36 +84,39 @@ int main (int argc, char ** argv)
//////////////////////
// Node level
//////////////////////
std::cout << "\n\n\n\n\n\n" <<std::endl;
std::cout << GridLogMessage<< "++++++++++++++++++++++++++++++++++++++++++++++++" <<std::endl;
std::cout << GridLogMessage<< " Testing without internode communication " <<std::endl;
std::cout << GridLogMessage<< "++++++++++++++++++++++++++++++++++++++++++++++++" <<std::endl;
for(int d=0;d<Nd;d++) CommDim[d]= (mpi[d]/shm[d])>1 ? 1 : 0;
Dirichlet = std::vector<int>({0,
CommDim[0]*latt4[0]/mpi[0] * shm[0],
CommDim[1]*latt4[1]/mpi[1] * shm[1],
CommDim[2]*latt4[2]/mpi[2] * shm[2],
CommDim[3]*latt4[3]/mpi[3] * shm[3]});
Dirichlet[0] = 0;
Dirichlet[1] = CommDim[0]*latt4[0]/mpi[0] * shm[0];
Dirichlet[2] = CommDim[1]*latt4[1]/mpi[1] * shm[1];
Dirichlet[3] = CommDim[2]*latt4[2]/mpi[2] * shm[2];
Dirichlet[4] = CommDim[3]*latt4[3]/mpi[3] * shm[3];
Benchmark(Ls,Dirichlet);
std::cout << "\n\n\n\n\n\n" <<std::endl;
std::cout << GridLogMessage<< "++++++++++++++++++++++++++++++++++++++++++++++++" <<std::endl;
std::cout << GridLogMessage<< " Testing without intranode communication " <<std::endl;
std::cout << GridLogMessage<< "++++++++++++++++++++++++++++++++++++++++++++++++" <<std::endl;
for(int d=0;d<Nd;d++) CommDim[d]= mpi[d]>1 ? 1 : 0;
Dirichlet = std::vector<int>({0,
CommDim[0]*latt4[0]/mpi[0],
CommDim[1]*latt4[1]/mpi[1],
CommDim[2]*latt4[2]/mpi[2],
CommDim[3]*latt4[3]/mpi[3]});
Dirichlet[0] = 0;
Dirichlet[1] = CommDim[0]*latt4[0]/mpi[0];
Dirichlet[2] = CommDim[1]*latt4[1]/mpi[1];
Dirichlet[3] = CommDim[2]*latt4[2]/mpi[2];
Dirichlet[4] = CommDim[3]*latt4[3]/mpi[3];
Benchmark(Ls,Dirichlet);
Grid_finalize();
exit(0);
}
void Benchmark(int Ls, std::vector<int> Dirichlet)
void Benchmark(int Ls, Coordinate Dirichlet)
{
Coordinate latt4 = GridDefaultLatt();
GridLogLayout();
@ -196,7 +165,9 @@ void Benchmark(int Ls, std::vector<int> Dirichlet)
std::cout << GridLogMessage << "Drawing gauge field" << std::endl;
LatticeGaugeFieldF Umu(UGrid);
LatticeGaugeFieldF UmuCopy(UGrid);
SU<Nc>::HotConfiguration(RNG4,Umu);
UmuCopy=Umu;
std::cout << GridLogMessage << "Random gauge initialised " << std::endl;
////////////////////////////////////
@ -205,7 +176,8 @@ void Benchmark(int Ls, std::vector<int> Dirichlet)
Coordinate Block(4);
for(int d=0;d<4;d++) Block[d]= Dirichlet[d+1];
std::cout << GridLogMessage << "Applying BCs for Dirichlet Block " << Block << std::endl;
std::cout << GridLogMessage << "Applying BCs for Dirichlet Block5 " << Dirichlet << std::endl;
std::cout << GridLogMessage << "Applying BCs for Dirichlet Block4 " << Block << std::endl;
DirichletFilter<LatticeGaugeFieldF> Filter(Block);
Filter.applyFilter(Umu);
@ -279,6 +251,7 @@ void Benchmark(int Ls, std::vector<int> Dirichlet)
DomainWallFermionF Dw(Umu,*FGrid,*FrbGrid,*UGrid,*UrbGrid,mass,M5);
Dw.DirichletBlock(Dirichlet);
Dw.ImportGauge(Umu);
int ncall =300;
@ -377,9 +350,6 @@ void Benchmark(int Ls, std::vector<int> Dirichlet)
std::cout<<GridLogMessage << "norm dag ref "<< norm2(ref)<<std::endl;
err = ref-result;
std::cout<<GridLogMessage << "norm dag diff "<< norm2(err)<<std::endl;
if ( norm2(err)>1.0e-4) {
std::cout << err << std::endl;
}
assert((norm2(err)<1.0e-4));
LatticeFermionF src_e (FrbGrid);