1
0
mirror of https://github.com/paboyle/Grid.git synced 2024-11-09 23:45:36 +00:00

Passes multinode dirichlet test with boundaries at

node boundary or at the single rank boundary
This commit is contained in:
Peter Boyle 2022-02-23 01:42:14 -05:00
parent aab3bcb46f
commit 70988e43d2
2 changed files with 91 additions and 107 deletions

View File

@ -36,100 +36,41 @@ using namespace Grid;
/// Move to domains ////
////////////////////////
struct DomainDecomposition
{
Coordinate Block;
DomainDecomposition(const Coordinate &_Block): Block(_Block){ assert(Block.size()==Nd);};
template<class Field>
void ProjectDomain(Field &f,Integer domain)
{
GridBase *grid = f.Grid();
int dims = grid->Nd();
int isDWF= (dims==Nd+1);
assert((dims==Nd)||(dims==Nd+1));
Field zz(grid); zz = Zero();
LatticeInteger coor(grid);
LatticeInteger domaincoor(grid);
LatticeInteger mask(grid); mask = Integer(1);
LatticeInteger zi(grid); zi = Integer(0);
for(int d=0;d<Nd;d++){
Integer B= Block[d];
if ( B ) {
LatticeCoordinate(coor,d+isDWF);
domaincoor = mod(coor,B);
mask = where(domaincoor==Integer(0),zi,mask);
mask = where(domaincoor==Integer(B-1),zi,mask);
}
}
if ( !domain )
f = where(mask==Integer(1),f,zz);
else
f = where(mask==Integer(0),f,zz);
};
};
template<typename MomentaField>
struct DirichletFilter: public MomentumFilterBase<MomentaField>
{
typedef typename MomentaField::vector_type vector_type; //SIMD-vectorized complex type
typedef typename MomentaField::scalar_type scalar_type; //scalar complex type
typedef iScalar<iScalar<iScalar<vector_type> > > ScalarType; //complex phase for each site
Coordinate Block;
DirichletFilter(const Coordinate &_Block): Block(_Block){}
// Edge detect using domain projectors
void applyFilter (MomentaField &U) const override
void applyFilter(MomentaField &P) const override
{
DomainDecomposition Domains(Block);
GridBase *grid = U.Grid();
GridBase *grid = P.Grid();
typedef decltype(PeekIndex<LorentzIndex>(P, 0)) LatCM;
////////////////////////////////////////////////////
// Zero strictly links crossing between domains
////////////////////////////////////////////////////
LatticeInteger coor(grid);
LatticeInteger face(grid);
LatticeInteger one(grid); one = 1;
LatticeInteger zero(grid); zero = 0;
LatticeInteger omega(grid);
LatticeInteger omegabar(grid);
LatticeInteger tmp(grid);
omega=one; Domains.ProjectDomain(omega,0);
omegabar=one; Domains.ProjectDomain(omegabar,1);
LatticeInteger nface(grid); nface=Zero();
MomentaField projected(grid); projected=Zero();
typedef decltype(PeekIndex<LorentzIndex>(U,0)) MomentaLinkField;
MomentaLinkField Umu(grid);
MomentaLinkField zz(grid); zz=Zero();
int dims = grid->Nd();
Coordinate Global=grid->GlobalDimensions();
assert(dims==Nd);
LatCM zz(grid); zz = Zero();
for(int mu=0;mu<Nd;mu++) {
if ( Block[mu]!=0 ) {
Umu = PeekIndex<LorentzIndex>(U,mu);
// Upper face
tmp = Cshift(omegabar,mu,1);
tmp = tmp + omega;
face = where(tmp == Integer(2),one,zero );
tmp = Cshift(omega,mu,1);
tmp = tmp + omegabar;
face = where(tmp == Integer(2),one,face );
Umu = where(face,zz,Umu);
PokeIndex<LorentzIndex>(U, Umu, mu);
if ( Block[mu] ) {
// If costly could provide Grid earlier and precompute masks
LatticeCoordinate(coor,mu);
auto P_mu = PeekIndex<LorentzIndex>(P, mu);
P_mu = where(mod(coor,Block[mu])==Integer(Block[mu]-1),zz,P_mu);
PokeIndex<LorentzIndex>(P, P_mu, mu);
}
}
}
};
Gamma::Algebra Gmu [] = {
Gamma::Algebra::GammaX,
Gamma::Algebra::GammaY,
@ -152,27 +93,57 @@ int main (int argc, char ** argv)
std::stringstream ss(argv[i+1]); ss >> Ls;
}
}
//////////////////
// With comms
//////////////////
std::vector<int> Dirichlet(5,0);
std::cout << GridLogMessage<< "++++++++++++++++++++++++++++++++++++++++++++++++" <<std::endl;
std::cout << GridLogMessage<< " Testing with full communication " <<std::endl;
std::cout << GridLogMessage<< "++++++++++++++++++++++++++++++++++++++++++++++++" <<std::endl;
Benchmark(Ls,Dirichlet);
//////////////////
// Domain decomposed
//////////////////
Coordinate latt4 = GridDefaultLatt();
Coordinate mpi = GridDefaultMpi();
std::vector<int> CommDim(Nd);
Coordinate shm;
GlobalSharedMemory::GetShmDims(mpi,shm);
/*
Dirichlet = std::vector<int>({0,
latt4[0]/mpi[0] * shm[0],
latt4[1]/mpi[1] * shm[1],
latt4[2]/mpi[2] * shm[2],
latt4[3]/mpi[3] * shm[3]});
*/
Dirichlet = std::vector<int>({0,
latt4[0]/mpi[0] ,
latt4[1]/mpi[1] ,
latt4[2]/mpi[2] ,
latt4[3]/mpi[3] });
std::cout << " Dirichlet block "<< Dirichlet<< std::endl;
//////////////////////
// Node level
//////////////////////
std::cout << GridLogMessage<< "++++++++++++++++++++++++++++++++++++++++++++++++" <<std::endl;
std::cout << GridLogMessage<< " Testing without internode communication " <<std::endl;
std::cout << GridLogMessage<< "++++++++++++++++++++++++++++++++++++++++++++++++" <<std::endl;
for(int d=0;d<Nd;d++) CommDim[d]= (mpi[d]/shm[d])>1 ? 1 : 0;
Dirichlet = std::vector<int>({0,
CommDim[0]*latt4[0]/mpi[0] * shm[0],
CommDim[1]*latt4[1]/mpi[1] * shm[1],
CommDim[2]*latt4[2]/mpi[2] * shm[2],
CommDim[3]*latt4[3]/mpi[3] * shm[3]});
Benchmark(Ls,Dirichlet);
std::cout << GridLogMessage<< "++++++++++++++++++++++++++++++++++++++++++++++++" <<std::endl;
std::cout << GridLogMessage<< " Testing without intranode communication " <<std::endl;
std::cout << GridLogMessage<< "++++++++++++++++++++++++++++++++++++++++++++++++" <<std::endl;
for(int d=0;d<Nd;d++) CommDim[d]= mpi[d]>1 ? 1 : 0;
Dirichlet = std::vector<int>({0,
CommDim[0]*latt4[0]/mpi[0],
CommDim[1]*latt4[1]/mpi[1],
CommDim[2]*latt4[2]/mpi[2],
CommDim[3]*latt4[3]/mpi[3]});
Benchmark(Ls,Dirichlet);
Grid_finalize();
exit(0);
}
@ -203,8 +174,20 @@ void Benchmark(int Ls, std::vector<int> Dirichlet)
GridParallelRNG RNG5(FGrid); RNG5.SeedUniqueString(std::string("The 5D RNG"));
LatticeFermionF src (FGrid); random(RNG5,src);
#if 1
src = Zero();
{
Coordinate origin({0,0,0,latt4[2]-1,0});
SpinColourVectorF tmp;
tmp=Zero();
tmp()(0)(0)=Complex(-2.0,0.0);
std::cout << " source site 0 " << tmp<<std::endl;
pokeSite(tmp,src,origin);
}
#else
RealD N2 = 1.0/::sqrt(norm2(src));
src = src*N2;
#endif
LatticeFermionF result(FGrid); result=Zero();
LatticeFermionF ref(FGrid); ref=Zero();
@ -219,23 +202,22 @@ void Benchmark(int Ls, std::vector<int> Dirichlet)
////////////////////////////////////
// Apply BCs
////////////////////////////////////
std::cout << GridLogMessage << "Applying BCs " << std::endl;
Coordinate Block(4);
for(int d=0;d<4;d++) Block[d]= Dirichlet[d+1];
std::cout << GridLogMessage << "Dirichlet Block " << Block<< std::endl;
std::cout << GridLogMessage << "Applying BCs for Dirichlet Block " << Block << std::endl;
DirichletFilter<LatticeGaugeFieldF> Filter(Block);
Filter.applyFilter(Umu);
////////////////////////////////////
// Naive wilson implementation
////////////////////////////////////
// replicate across fifth dimension
// LatticeGaugeFieldF Umu5d(FGrid);
std::vector<LatticeColourMatrixF> U(4,UGrid);
for(int mu=0;mu<Nd;mu++){
U[mu] = PeekIndex<LorentzIndex>(Umu,mu);
}
std::cout << GridLogMessage << "Setting up Cshift based reference " << std::endl;
if (1)
@ -297,6 +279,7 @@ void Benchmark(int Ls, std::vector<int> Dirichlet)
DomainWallFermionF Dw(Umu,*FGrid,*FrbGrid,*UGrid,*UrbGrid,mass,M5);
Dw.DirichletBlock(Dirichlet);
int ncall =300;
if (1) {
@ -328,8 +311,8 @@ void Benchmark(int Ls, std::vector<int> Dirichlet)
std::cout<<GridLogMessage << "mflop/s = "<< flops/(t1-t0)<<std::endl;
std::cout<<GridLogMessage << "mflop/s per rank = "<< flops/(t1-t0)/NP<<std::endl;
std::cout<<GridLogMessage << "mflop/s per node = "<< flops/(t1-t0)/NN<<std::endl;
std::cout<<GridLogMessage << "RF GiB/s (base 2) = "<< 1000000. * data_rf/((t1-t0))<<std::endl;
std::cout<<GridLogMessage << "mem GiB/s (base 2) = "<< 1000000. * data_mem/((t1-t0))<<std::endl;
// std::cout<<GridLogMessage << "RF GiB/s (base 2) = "<< 1000000. * data_rf/((t1-t0))<<std::endl;
// std::cout<<GridLogMessage << "mem GiB/s (base 2) = "<< 1000000. * data_mem/((t1-t0))<<std::endl;
err = ref-result;
std::cout<<GridLogMessage << "norm diff "<< norm2(err)<<std::endl;
@ -382,19 +365,20 @@ void Benchmark(int Ls, std::vector<int> Dirichlet)
}
ref = -0.5*ref;
}
// dump=1;
Dw.Dhop(src,result,1);
Dw.Dhop(src,result,DaggerYes);
std::cout << GridLogMessage << "----------------------------------------------------------------" << std::endl;
std::cout << GridLogMessage << "Compare to naive wilson implementation Dag to verify correctness" << std::endl;
std::cout << GridLogMessage << "----------------------------------------------------------------" << std::endl;
std::cout<<GridLogMessage << "Called DwDag"<<std::endl;
std::cout<<GridLogMessage << "norm dag result "<< norm2(result)<<std::endl;
std::cout<<GridLogMessage << "norm dag ref "<< norm2(ref)<<std::endl;
err = ref-result;
std::cout<<GridLogMessage << "norm dag diff "<< norm2(err)<<std::endl;
if ( norm2(err)>1.0e-4) {
std::cout << "Error vector is\n" <<err << std::endl;
std::cout << "Ref vector is\n" <<ref << std::endl;
std::cout << "Result vector is\n" <<result << std::endl;
std::cout << err << std::endl;
}
assert((norm2(err)<1.0e-4));

View File

@ -1,5 +1,5 @@
module load PrgEnv-gnu
module load rocm/4.3.0
module load rocm/4.5.0
module load gmp
module load cray-fftw
module load craype-accel-amd-gfx908