mirror of
https://github.com/paboyle/Grid.git
synced 2024-11-10 07:55:35 +00:00
Passes multinode dirichlet test with boundaries at
node boundary or at the single rank boundary
This commit is contained in:
parent
aab3bcb46f
commit
70988e43d2
@ -36,100 +36,41 @@ using namespace Grid;
|
|||||||
/// Move to domains ////
|
/// Move to domains ////
|
||||||
////////////////////////
|
////////////////////////
|
||||||
|
|
||||||
struct DomainDecomposition
|
|
||||||
{
|
|
||||||
Coordinate Block;
|
|
||||||
|
|
||||||
DomainDecomposition(const Coordinate &_Block): Block(_Block){ assert(Block.size()==Nd);};
|
|
||||||
|
|
||||||
template<class Field>
|
|
||||||
void ProjectDomain(Field &f,Integer domain)
|
|
||||||
{
|
|
||||||
GridBase *grid = f.Grid();
|
|
||||||
int dims = grid->Nd();
|
|
||||||
int isDWF= (dims==Nd+1);
|
|
||||||
assert((dims==Nd)||(dims==Nd+1));
|
|
||||||
|
|
||||||
Field zz(grid); zz = Zero();
|
|
||||||
LatticeInteger coor(grid);
|
|
||||||
LatticeInteger domaincoor(grid);
|
|
||||||
LatticeInteger mask(grid); mask = Integer(1);
|
|
||||||
LatticeInteger zi(grid); zi = Integer(0);
|
|
||||||
for(int d=0;d<Nd;d++){
|
|
||||||
Integer B= Block[d];
|
|
||||||
if ( B ) {
|
|
||||||
LatticeCoordinate(coor,d+isDWF);
|
|
||||||
domaincoor = mod(coor,B);
|
|
||||||
mask = where(domaincoor==Integer(0),zi,mask);
|
|
||||||
mask = where(domaincoor==Integer(B-1),zi,mask);
|
|
||||||
}
|
|
||||||
}
|
|
||||||
if ( !domain )
|
|
||||||
f = where(mask==Integer(1),f,zz);
|
|
||||||
else
|
|
||||||
f = where(mask==Integer(0),f,zz);
|
|
||||||
};
|
|
||||||
};
|
|
||||||
|
|
||||||
template<typename MomentaField>
|
template<typename MomentaField>
|
||||||
struct DirichletFilter: public MomentumFilterBase<MomentaField>
|
struct DirichletFilter: public MomentumFilterBase<MomentaField>
|
||||||
{
|
{
|
||||||
|
typedef typename MomentaField::vector_type vector_type; //SIMD-vectorized complex type
|
||||||
|
typedef typename MomentaField::scalar_type scalar_type; //scalar complex type
|
||||||
|
|
||||||
|
typedef iScalar<iScalar<iScalar<vector_type> > > ScalarType; //complex phase for each site
|
||||||
|
|
||||||
Coordinate Block;
|
Coordinate Block;
|
||||||
|
|
||||||
DirichletFilter(const Coordinate &_Block): Block(_Block) {}
|
DirichletFilter(const Coordinate &_Block): Block(_Block){}
|
||||||
|
|
||||||
// Edge detect using domain projectors
|
void applyFilter(MomentaField &P) const override
|
||||||
void applyFilter (MomentaField &U) const override
|
|
||||||
{
|
{
|
||||||
DomainDecomposition Domains(Block);
|
GridBase *grid = P.Grid();
|
||||||
GridBase *grid = U.Grid();
|
typedef decltype(PeekIndex<LorentzIndex>(P, 0)) LatCM;
|
||||||
LatticeInteger coor(grid);
|
////////////////////////////////////////////////////
|
||||||
LatticeInteger face(grid);
|
// Zero strictly links crossing between domains
|
||||||
LatticeInteger one(grid); one = 1;
|
////////////////////////////////////////////////////
|
||||||
LatticeInteger zero(grid); zero = 0;
|
LatticeInteger coor(grid);
|
||||||
LatticeInteger omega(grid);
|
LatCM zz(grid); zz = Zero();
|
||||||
LatticeInteger omegabar(grid);
|
for(int mu=0;mu<Nd;mu++) {
|
||||||
LatticeInteger tmp(grid);
|
|
||||||
|
|
||||||
omega=one; Domains.ProjectDomain(omega,0);
|
if ( Block[mu] ) {
|
||||||
omegabar=one; Domains.ProjectDomain(omegabar,1);
|
// If costly could provide Grid earlier and precompute masks
|
||||||
|
LatticeCoordinate(coor,mu);
|
||||||
LatticeInteger nface(grid); nface=Zero();
|
auto P_mu = PeekIndex<LorentzIndex>(P, mu);
|
||||||
|
P_mu = where(mod(coor,Block[mu])==Integer(Block[mu]-1),zz,P_mu);
|
||||||
MomentaField projected(grid); projected=Zero();
|
PokeIndex<LorentzIndex>(P, P_mu, mu);
|
||||||
typedef decltype(PeekIndex<LorentzIndex>(U,0)) MomentaLinkField;
|
|
||||||
MomentaLinkField Umu(grid);
|
|
||||||
MomentaLinkField zz(grid); zz=Zero();
|
|
||||||
|
|
||||||
int dims = grid->Nd();
|
|
||||||
Coordinate Global=grid->GlobalDimensions();
|
|
||||||
assert(dims==Nd);
|
|
||||||
|
|
||||||
for(int mu=0;mu<Nd;mu++){
|
|
||||||
|
|
||||||
if ( Block[mu]!=0 ) {
|
|
||||||
|
|
||||||
Umu = PeekIndex<LorentzIndex>(U,mu);
|
|
||||||
|
|
||||||
// Upper face
|
|
||||||
tmp = Cshift(omegabar,mu,1);
|
|
||||||
tmp = tmp + omega;
|
|
||||||
face = where(tmp == Integer(2),one,zero );
|
|
||||||
|
|
||||||
tmp = Cshift(omega,mu,1);
|
|
||||||
tmp = tmp + omegabar;
|
|
||||||
face = where(tmp == Integer(2),one,face );
|
|
||||||
|
|
||||||
Umu = where(face,zz,Umu);
|
|
||||||
|
|
||||||
PokeIndex<LorentzIndex>(U, Umu, mu);
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
};
|
};
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
Gamma::Algebra Gmu [] = {
|
Gamma::Algebra Gmu [] = {
|
||||||
Gamma::Algebra::GammaX,
|
Gamma::Algebra::GammaX,
|
||||||
Gamma::Algebra::GammaY,
|
Gamma::Algebra::GammaY,
|
||||||
@ -152,27 +93,57 @@ int main (int argc, char ** argv)
|
|||||||
std::stringstream ss(argv[i+1]); ss >> Ls;
|
std::stringstream ss(argv[i+1]); ss >> Ls;
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
//////////////////
|
||||||
|
// With comms
|
||||||
|
//////////////////
|
||||||
std::vector<int> Dirichlet(5,0);
|
std::vector<int> Dirichlet(5,0);
|
||||||
|
|
||||||
|
std::cout << GridLogMessage<< "++++++++++++++++++++++++++++++++++++++++++++++++" <<std::endl;
|
||||||
|
std::cout << GridLogMessage<< " Testing with full communication " <<std::endl;
|
||||||
|
std::cout << GridLogMessage<< "++++++++++++++++++++++++++++++++++++++++++++++++" <<std::endl;
|
||||||
|
|
||||||
Benchmark(Ls,Dirichlet);
|
Benchmark(Ls,Dirichlet);
|
||||||
|
|
||||||
|
//////////////////
|
||||||
|
// Domain decomposed
|
||||||
|
//////////////////
|
||||||
Coordinate latt4 = GridDefaultLatt();
|
Coordinate latt4 = GridDefaultLatt();
|
||||||
Coordinate mpi = GridDefaultMpi();
|
Coordinate mpi = GridDefaultMpi();
|
||||||
|
std::vector<int> CommDim(Nd);
|
||||||
Coordinate shm;
|
Coordinate shm;
|
||||||
GlobalSharedMemory::GetShmDims(mpi,shm);
|
GlobalSharedMemory::GetShmDims(mpi,shm);
|
||||||
/*
|
|
||||||
Dirichlet = std::vector<int>({0,
|
|
||||||
latt4[0]/mpi[0] * shm[0],
|
|
||||||
latt4[1]/mpi[1] * shm[1],
|
|
||||||
latt4[2]/mpi[2] * shm[2],
|
|
||||||
latt4[3]/mpi[3] * shm[3]});
|
|
||||||
*/
|
|
||||||
Dirichlet = std::vector<int>({0,
|
|
||||||
latt4[0]/mpi[0] ,
|
|
||||||
latt4[1]/mpi[1] ,
|
|
||||||
latt4[2]/mpi[2] ,
|
|
||||||
latt4[3]/mpi[3] });
|
|
||||||
|
|
||||||
std::cout << " Dirichlet block "<< Dirichlet<< std::endl;
|
|
||||||
|
//////////////////////
|
||||||
|
// Node level
|
||||||
|
//////////////////////
|
||||||
|
std::cout << GridLogMessage<< "++++++++++++++++++++++++++++++++++++++++++++++++" <<std::endl;
|
||||||
|
std::cout << GridLogMessage<< " Testing without internode communication " <<std::endl;
|
||||||
|
std::cout << GridLogMessage<< "++++++++++++++++++++++++++++++++++++++++++++++++" <<std::endl;
|
||||||
|
|
||||||
|
for(int d=0;d<Nd;d++) CommDim[d]= (mpi[d]/shm[d])>1 ? 1 : 0;
|
||||||
|
Dirichlet = std::vector<int>({0,
|
||||||
|
CommDim[0]*latt4[0]/mpi[0] * shm[0],
|
||||||
|
CommDim[1]*latt4[1]/mpi[1] * shm[1],
|
||||||
|
CommDim[2]*latt4[2]/mpi[2] * shm[2],
|
||||||
|
CommDim[3]*latt4[3]/mpi[3] * shm[3]});
|
||||||
|
|
||||||
Benchmark(Ls,Dirichlet);
|
Benchmark(Ls,Dirichlet);
|
||||||
|
|
||||||
|
std::cout << GridLogMessage<< "++++++++++++++++++++++++++++++++++++++++++++++++" <<std::endl;
|
||||||
|
std::cout << GridLogMessage<< " Testing without intranode communication " <<std::endl;
|
||||||
|
std::cout << GridLogMessage<< "++++++++++++++++++++++++++++++++++++++++++++++++" <<std::endl;
|
||||||
|
|
||||||
|
for(int d=0;d<Nd;d++) CommDim[d]= mpi[d]>1 ? 1 : 0;
|
||||||
|
Dirichlet = std::vector<int>({0,
|
||||||
|
CommDim[0]*latt4[0]/mpi[0],
|
||||||
|
CommDim[1]*latt4[1]/mpi[1],
|
||||||
|
CommDim[2]*latt4[2]/mpi[2],
|
||||||
|
CommDim[3]*latt4[3]/mpi[3]});
|
||||||
|
|
||||||
|
Benchmark(Ls,Dirichlet);
|
||||||
|
|
||||||
Grid_finalize();
|
Grid_finalize();
|
||||||
exit(0);
|
exit(0);
|
||||||
}
|
}
|
||||||
@ -203,8 +174,20 @@ void Benchmark(int Ls, std::vector<int> Dirichlet)
|
|||||||
GridParallelRNG RNG5(FGrid); RNG5.SeedUniqueString(std::string("The 5D RNG"));
|
GridParallelRNG RNG5(FGrid); RNG5.SeedUniqueString(std::string("The 5D RNG"));
|
||||||
|
|
||||||
LatticeFermionF src (FGrid); random(RNG5,src);
|
LatticeFermionF src (FGrid); random(RNG5,src);
|
||||||
|
#if 1
|
||||||
|
src = Zero();
|
||||||
|
{
|
||||||
|
Coordinate origin({0,0,0,latt4[2]-1,0});
|
||||||
|
SpinColourVectorF tmp;
|
||||||
|
tmp=Zero();
|
||||||
|
tmp()(0)(0)=Complex(-2.0,0.0);
|
||||||
|
std::cout << " source site 0 " << tmp<<std::endl;
|
||||||
|
pokeSite(tmp,src,origin);
|
||||||
|
}
|
||||||
|
#else
|
||||||
RealD N2 = 1.0/::sqrt(norm2(src));
|
RealD N2 = 1.0/::sqrt(norm2(src));
|
||||||
src = src*N2;
|
src = src*N2;
|
||||||
|
#endif
|
||||||
|
|
||||||
LatticeFermionF result(FGrid); result=Zero();
|
LatticeFermionF result(FGrid); result=Zero();
|
||||||
LatticeFermionF ref(FGrid); ref=Zero();
|
LatticeFermionF ref(FGrid); ref=Zero();
|
||||||
@ -219,23 +202,22 @@ void Benchmark(int Ls, std::vector<int> Dirichlet)
|
|||||||
////////////////////////////////////
|
////////////////////////////////////
|
||||||
// Apply BCs
|
// Apply BCs
|
||||||
////////////////////////////////////
|
////////////////////////////////////
|
||||||
std::cout << GridLogMessage << "Applying BCs " << std::endl;
|
|
||||||
Coordinate Block(4);
|
Coordinate Block(4);
|
||||||
for(int d=0;d<4;d++) Block[d]= Dirichlet[d+1];
|
for(int d=0;d<4;d++) Block[d]= Dirichlet[d+1];
|
||||||
|
|
||||||
std::cout << GridLogMessage << "Dirichlet Block " << Block<< std::endl;
|
std::cout << GridLogMessage << "Applying BCs for Dirichlet Block " << Block << std::endl;
|
||||||
|
|
||||||
DirichletFilter<LatticeGaugeFieldF> Filter(Block);
|
DirichletFilter<LatticeGaugeFieldF> Filter(Block);
|
||||||
Filter.applyFilter(Umu);
|
Filter.applyFilter(Umu);
|
||||||
|
|
||||||
////////////////////////////////////
|
////////////////////////////////////
|
||||||
// Naive wilson implementation
|
// Naive wilson implementation
|
||||||
////////////////////////////////////
|
////////////////////////////////////
|
||||||
// replicate across fifth dimension
|
|
||||||
// LatticeGaugeFieldF Umu5d(FGrid);
|
|
||||||
std::vector<LatticeColourMatrixF> U(4,UGrid);
|
std::vector<LatticeColourMatrixF> U(4,UGrid);
|
||||||
for(int mu=0;mu<Nd;mu++){
|
for(int mu=0;mu<Nd;mu++){
|
||||||
U[mu] = PeekIndex<LorentzIndex>(Umu,mu);
|
U[mu] = PeekIndex<LorentzIndex>(Umu,mu);
|
||||||
}
|
}
|
||||||
|
|
||||||
std::cout << GridLogMessage << "Setting up Cshift based reference " << std::endl;
|
std::cout << GridLogMessage << "Setting up Cshift based reference " << std::endl;
|
||||||
|
|
||||||
if (1)
|
if (1)
|
||||||
@ -297,6 +279,7 @@ void Benchmark(int Ls, std::vector<int> Dirichlet)
|
|||||||
|
|
||||||
DomainWallFermionF Dw(Umu,*FGrid,*FrbGrid,*UGrid,*UrbGrid,mass,M5);
|
DomainWallFermionF Dw(Umu,*FGrid,*FrbGrid,*UGrid,*UrbGrid,mass,M5);
|
||||||
Dw.DirichletBlock(Dirichlet);
|
Dw.DirichletBlock(Dirichlet);
|
||||||
|
|
||||||
int ncall =300;
|
int ncall =300;
|
||||||
|
|
||||||
if (1) {
|
if (1) {
|
||||||
@ -328,8 +311,8 @@ void Benchmark(int Ls, std::vector<int> Dirichlet)
|
|||||||
std::cout<<GridLogMessage << "mflop/s = "<< flops/(t1-t0)<<std::endl;
|
std::cout<<GridLogMessage << "mflop/s = "<< flops/(t1-t0)<<std::endl;
|
||||||
std::cout<<GridLogMessage << "mflop/s per rank = "<< flops/(t1-t0)/NP<<std::endl;
|
std::cout<<GridLogMessage << "mflop/s per rank = "<< flops/(t1-t0)/NP<<std::endl;
|
||||||
std::cout<<GridLogMessage << "mflop/s per node = "<< flops/(t1-t0)/NN<<std::endl;
|
std::cout<<GridLogMessage << "mflop/s per node = "<< flops/(t1-t0)/NN<<std::endl;
|
||||||
std::cout<<GridLogMessage << "RF GiB/s (base 2) = "<< 1000000. * data_rf/((t1-t0))<<std::endl;
|
// std::cout<<GridLogMessage << "RF GiB/s (base 2) = "<< 1000000. * data_rf/((t1-t0))<<std::endl;
|
||||||
std::cout<<GridLogMessage << "mem GiB/s (base 2) = "<< 1000000. * data_mem/((t1-t0))<<std::endl;
|
// std::cout<<GridLogMessage << "mem GiB/s (base 2) = "<< 1000000. * data_mem/((t1-t0))<<std::endl;
|
||||||
err = ref-result;
|
err = ref-result;
|
||||||
std::cout<<GridLogMessage << "norm diff "<< norm2(err)<<std::endl;
|
std::cout<<GridLogMessage << "norm diff "<< norm2(err)<<std::endl;
|
||||||
|
|
||||||
@ -382,19 +365,20 @@ void Benchmark(int Ls, std::vector<int> Dirichlet)
|
|||||||
}
|
}
|
||||||
ref = -0.5*ref;
|
ref = -0.5*ref;
|
||||||
}
|
}
|
||||||
// dump=1;
|
|
||||||
Dw.Dhop(src,result,1);
|
Dw.Dhop(src,result,DaggerYes);
|
||||||
|
|
||||||
|
std::cout << GridLogMessage << "----------------------------------------------------------------" << std::endl;
|
||||||
std::cout << GridLogMessage << "Compare to naive wilson implementation Dag to verify correctness" << std::endl;
|
std::cout << GridLogMessage << "Compare to naive wilson implementation Dag to verify correctness" << std::endl;
|
||||||
|
std::cout << GridLogMessage << "----------------------------------------------------------------" << std::endl;
|
||||||
|
|
||||||
std::cout<<GridLogMessage << "Called DwDag"<<std::endl;
|
std::cout<<GridLogMessage << "Called DwDag"<<std::endl;
|
||||||
std::cout<<GridLogMessage << "norm dag result "<< norm2(result)<<std::endl;
|
std::cout<<GridLogMessage << "norm dag result "<< norm2(result)<<std::endl;
|
||||||
std::cout<<GridLogMessage << "norm dag ref "<< norm2(ref)<<std::endl;
|
std::cout<<GridLogMessage << "norm dag ref "<< norm2(ref)<<std::endl;
|
||||||
err = ref-result;
|
err = ref-result;
|
||||||
std::cout<<GridLogMessage << "norm dag diff "<< norm2(err)<<std::endl;
|
std::cout<<GridLogMessage << "norm dag diff "<< norm2(err)<<std::endl;
|
||||||
|
if ( norm2(err)>1.0e-4) {
|
||||||
if ( norm2(err) > 1.0e-4 ) {
|
std::cout << err << std::endl;
|
||||||
std::cout << "Error vector is\n" <<err << std::endl;
|
|
||||||
std::cout << "Ref vector is\n" <<ref << std::endl;
|
|
||||||
std::cout << "Result vector is\n" <<result << std::endl;
|
|
||||||
}
|
}
|
||||||
assert((norm2(err)<1.0e-4));
|
assert((norm2(err)<1.0e-4));
|
||||||
|
|
||||||
|
@ -1,5 +1,5 @@
|
|||||||
module load PrgEnv-gnu
|
module load PrgEnv-gnu
|
||||||
module load rocm/4.3.0
|
module load rocm/4.5.0
|
||||||
module load gmp
|
module load gmp
|
||||||
module load cray-fftw
|
module load cray-fftw
|
||||||
module load craype-accel-amd-gfx908
|
module load craype-accel-amd-gfx908
|
||||||
|
Loading…
Reference in New Issue
Block a user