From 70988e43d23c6fc4c86b482d56a56c9efd6e6912 Mon Sep 17 00:00:00 2001 From: Peter Boyle Date: Wed, 23 Feb 2022 01:42:14 -0500 Subject: [PATCH] Passes multinode dirichlet test with boundaries at node boundary or at the single rank boundary --- benchmarks/Benchmark_dwf_fp32.cc | 196 ++++++++++++++----------------- systems/Spock/sourceme.sh | 2 +- 2 files changed, 91 insertions(+), 107 deletions(-) diff --git a/benchmarks/Benchmark_dwf_fp32.cc b/benchmarks/Benchmark_dwf_fp32.cc index 6896bddf..5a64aaa8 100644 --- a/benchmarks/Benchmark_dwf_fp32.cc +++ b/benchmarks/Benchmark_dwf_fp32.cc @@ -36,100 +36,41 @@ using namespace Grid; /// Move to domains //// //////////////////////// -struct DomainDecomposition -{ - Coordinate Block; - - DomainDecomposition(const Coordinate &_Block): Block(_Block){ assert(Block.size()==Nd);}; - - template - void ProjectDomain(Field &f,Integer domain) - { - GridBase *grid = f.Grid(); - int dims = grid->Nd(); - int isDWF= (dims==Nd+1); - assert((dims==Nd)||(dims==Nd+1)); - - Field zz(grid); zz = Zero(); - LatticeInteger coor(grid); - LatticeInteger domaincoor(grid); - LatticeInteger mask(grid); mask = Integer(1); - LatticeInteger zi(grid); zi = Integer(0); - for(int d=0;d struct DirichletFilter: public MomentumFilterBase { + typedef typename MomentaField::vector_type vector_type; //SIMD-vectorized complex type + typedef typename MomentaField::scalar_type scalar_type; //scalar complex type + + typedef iScalar > > ScalarType; //complex phase for each site + Coordinate Block; - DirichletFilter(const Coordinate &_Block): Block(_Block) {} + DirichletFilter(const Coordinate &_Block): Block(_Block){} - // Edge detect using domain projectors - void applyFilter (MomentaField &U) const override + void applyFilter(MomentaField &P) const override { - DomainDecomposition Domains(Block); - GridBase *grid = U.Grid(); - LatticeInteger coor(grid); - LatticeInteger face(grid); - LatticeInteger one(grid); one = 1; - LatticeInteger zero(grid); zero = 0; - LatticeInteger omega(grid); - LatticeInteger omegabar(grid); - LatticeInteger tmp(grid); + GridBase *grid = P.Grid(); + typedef decltype(PeekIndex(P, 0)) LatCM; + //////////////////////////////////////////////////// + // Zero strictly links crossing between domains + //////////////////////////////////////////////////// + LatticeInteger coor(grid); + LatCM zz(grid); zz = Zero(); + for(int mu=0;mu(U,0)) MomentaLinkField; - MomentaLinkField Umu(grid); - MomentaLinkField zz(grid); zz=Zero(); - - int dims = grid->Nd(); - Coordinate Global=grid->GlobalDimensions(); - assert(dims==Nd); - - for(int mu=0;mu(U,mu); - - // Upper face - tmp = Cshift(omegabar,mu,1); - tmp = tmp + omega; - face = where(tmp == Integer(2),one,zero ); - - tmp = Cshift(omega,mu,1); - tmp = tmp + omegabar; - face = where(tmp == Integer(2),one,face ); - - Umu = where(face,zz,Umu); - - PokeIndex(U, Umu, mu); + if ( Block[mu] ) { + // If costly could provide Grid earlier and precompute masks + LatticeCoordinate(coor,mu); + auto P_mu = PeekIndex(P, mu); + P_mu = where(mod(coor,Block[mu])==Integer(Block[mu]-1),zz,P_mu); + PokeIndex(P, P_mu, mu); } } } }; - Gamma::Algebra Gmu [] = { Gamma::Algebra::GammaX, Gamma::Algebra::GammaY, @@ -152,27 +93,57 @@ int main (int argc, char ** argv) std::stringstream ss(argv[i+1]); ss >> Ls; } } + + ////////////////// + // With comms + ////////////////// std::vector Dirichlet(5,0); + + std::cout << GridLogMessage<< "++++++++++++++++++++++++++++++++++++++++++++++++" < CommDim(Nd); Coordinate shm; GlobalSharedMemory::GetShmDims(mpi,shm); - /* + + + ////////////////////// + // Node level + ////////////////////// + std::cout << GridLogMessage<< "++++++++++++++++++++++++++++++++++++++++++++++++" <1 ? 1 : 0; Dirichlet = std::vector({0, - latt4[0]/mpi[0] * shm[0], - latt4[1]/mpi[1] * shm[1], - latt4[2]/mpi[2] * shm[2], - latt4[3]/mpi[3] * shm[3]}); - */ - Dirichlet = std::vector({0, - latt4[0]/mpi[0] , - latt4[1]/mpi[1] , - latt4[2]/mpi[2] , - latt4[3]/mpi[3] }); - - std::cout << " Dirichlet block "<< Dirichlet<< std::endl; + CommDim[0]*latt4[0]/mpi[0] * shm[0], + CommDim[1]*latt4[1]/mpi[1] * shm[1], + CommDim[2]*latt4[2]/mpi[2] * shm[2], + CommDim[3]*latt4[3]/mpi[3] * shm[3]}); + Benchmark(Ls,Dirichlet); + + std::cout << GridLogMessage<< "++++++++++++++++++++++++++++++++++++++++++++++++" <1 ? 1 : 0; + Dirichlet = std::vector({0, + CommDim[0]*latt4[0]/mpi[0], + CommDim[1]*latt4[1]/mpi[1], + CommDim[2]*latt4[2]/mpi[2], + CommDim[3]*latt4[3]/mpi[3]}); + + Benchmark(Ls,Dirichlet); + Grid_finalize(); exit(0); } @@ -203,8 +174,20 @@ void Benchmark(int Ls, std::vector Dirichlet) GridParallelRNG RNG5(FGrid); RNG5.SeedUniqueString(std::string("The 5D RNG")); LatticeFermionF src (FGrid); random(RNG5,src); +#if 1 + src = Zero(); + { + Coordinate origin({0,0,0,latt4[2]-1,0}); + SpinColourVectorF tmp; + tmp=Zero(); + tmp()(0)(0)=Complex(-2.0,0.0); + std::cout << " source site 0 " << tmp< Dirichlet) //////////////////////////////////// // Apply BCs //////////////////////////////////// - std::cout << GridLogMessage << "Applying BCs " << std::endl; Coordinate Block(4); for(int d=0;d<4;d++) Block[d]= Dirichlet[d+1]; - std::cout << GridLogMessage << "Dirichlet Block " << Block<< std::endl; + std::cout << GridLogMessage << "Applying BCs for Dirichlet Block " << Block << std::endl; + DirichletFilter Filter(Block); Filter.applyFilter(Umu); //////////////////////////////////// // Naive wilson implementation //////////////////////////////////// - // replicate across fifth dimension - // LatticeGaugeFieldF Umu5d(FGrid); std::vector U(4,UGrid); for(int mu=0;mu(Umu,mu); } + std::cout << GridLogMessage << "Setting up Cshift based reference " << std::endl; if (1) @@ -297,6 +279,7 @@ void Benchmark(int Ls, std::vector Dirichlet) DomainWallFermionF Dw(Umu,*FGrid,*FrbGrid,*UGrid,*UrbGrid,mass,M5); Dw.DirichletBlock(Dirichlet); + int ncall =300; if (1) { @@ -328,8 +311,8 @@ void Benchmark(int Ls, std::vector Dirichlet) std::cout< Dirichlet) } ref = -0.5*ref; } - // dump=1; - Dw.Dhop(src,result,1); + + Dw.Dhop(src,result,DaggerYes); + + std::cout << GridLogMessage << "----------------------------------------------------------------" << std::endl; std::cout << GridLogMessage << "Compare to naive wilson implementation Dag to verify correctness" << std::endl; + std::cout << GridLogMessage << "----------------------------------------------------------------" << std::endl; + std::cout< 1.0e-4 ) { - std::cout << "Error vector is\n" <1.0e-4) { + std::cout << err << std::endl; } assert((norm2(err)<1.0e-4)); diff --git a/systems/Spock/sourceme.sh b/systems/Spock/sourceme.sh index 40d864b5..72a2ff4e 100644 --- a/systems/Spock/sourceme.sh +++ b/systems/Spock/sourceme.sh @@ -1,5 +1,5 @@ module load PrgEnv-gnu -module load rocm/4.3.0 +module load rocm/4.5.0 module load gmp module load cray-fftw module load craype-accel-amd-gfx908