1
0
mirror of https://github.com/paboyle/Grid.git synced 2025-04-09 21:50:45 +01:00

Dirichlet filters running on AMD and now integrated in Fermion op

This commit is contained in:
Peter Boyle 2022-02-23 19:29:28 -05:00
parent 70988e43d2
commit 0f1c5b08a1
9 changed files with 125 additions and 61 deletions

View File

@ -37,6 +37,9 @@ NAMESPACE_CHECK(ActionSet);
#include <Grid/qcd/action/ActionParams.h> #include <Grid/qcd/action/ActionParams.h>
NAMESPACE_CHECK(ActionParams); NAMESPACE_CHECK(ActionParams);
#include <Grid/qcd/action/filters/MomentumFilter.h>
#include <Grid/qcd/action/filters/DirichletFilter.h>
//////////////////////////////////////////// ////////////////////////////////////////////
// Gauge Actions // Gauge Actions
//////////////////////////////////////////// ////////////////////////////////////////////

View File

@ -49,6 +49,8 @@ public:
virtual FermionField &tmp(void) = 0; virtual FermionField &tmp(void) = 0;
virtual void DirichletBlock(Coordinate & _Block) { assert(0); };
GridBase * Grid(void) { return FermionGrid(); }; // this is all the linalg routines need to know GridBase * Grid(void) { return FermionGrid(); }; // this is all the linalg routines need to know
GridBase * RedBlackGrid(void) { return FermionRedBlackGrid(); }; GridBase * RedBlackGrid(void) { return FermionRedBlackGrid(); };

View File

@ -75,6 +75,10 @@ public:
FermionField _tmp; FermionField _tmp;
FermionField &tmp(void) { return _tmp; } FermionField &tmp(void) { return _tmp; }
int Dirichlet;
Coordinate Block;
/********** Deprecate timers **********/
void Report(void); void Report(void);
void ZeroCounters(void); void ZeroCounters(void);
double DhopCalls; double DhopCalls;
@ -174,10 +178,16 @@ public:
GridRedBlackCartesian &FourDimRedBlackGrid, GridRedBlackCartesian &FourDimRedBlackGrid,
double _M5,const ImplParams &p= ImplParams()); double _M5,const ImplParams &p= ImplParams());
void DirichletBlock(std::vector<int> & block){ virtual void DirichletBlock(Coordinate & block)
Stencil.DirichletBlock(block); {
StencilEven.DirichletBlock(block); assert(block.size()==Nd+1);
StencilOdd.DirichletBlock(block); if ( block[0] || block[1] || block[2] || block[3] || block[4] ){
Dirichlet = 1;
Block = block;
Stencil.DirichletBlock(block);
StencilEven.DirichletBlock(block);
StencilOdd.DirichletBlock(block);
}
} }
// Constructors // Constructors
/* /*

View File

@ -60,7 +60,8 @@ WilsonFermion5D<Impl>::WilsonFermion5D(GaugeField &_Umu,
UmuOdd (_FourDimRedBlackGrid), UmuOdd (_FourDimRedBlackGrid),
Lebesgue(_FourDimGrid), Lebesgue(_FourDimGrid),
LebesgueEvenOdd(_FourDimRedBlackGrid), LebesgueEvenOdd(_FourDimRedBlackGrid),
_tmp(&FiveDimRedBlackGrid) _tmp(&FiveDimRedBlackGrid),
Dirichlet(0)
{ {
// some assertions // some assertions
assert(FiveDimGrid._ndimension==5); assert(FiveDimGrid._ndimension==5);
@ -218,6 +219,14 @@ void WilsonFermion5D<Impl>::ImportGauge(const GaugeField &_Umu)
{ {
GaugeField HUmu(_Umu.Grid()); GaugeField HUmu(_Umu.Grid());
HUmu = _Umu*(-0.5); HUmu = _Umu*(-0.5);
if ( Dirichlet ) {
std::cout << GridLogMessage << " Dirichlet BCs 5d " <<Block<<std::endl;
Coordinate GaugeBlock(Nd);
for(int d=0;d<Nd;d++) GaugeBlock[d] = Block[d+1];
std::cout << GridLogMessage << " Dirichlet BCs 4d " <<GaugeBlock<<std::endl;
DirichletFilter<GaugeField> Filter(GaugeBlock);
Filter.applyFilter(HUmu);
}
Impl::DoubleStore(GaugeGrid(),Umu,HUmu); Impl::DoubleStore(GaugeGrid(),Umu,HUmu);
pickCheckerboard(Even,UmuEven,Umu); pickCheckerboard(Even,UmuEven,Umu);
pickCheckerboard(Odd ,UmuOdd,Umu); pickCheckerboard(Odd ,UmuOdd,Umu);

View File

@ -0,0 +1,71 @@
/*************************************************************************************
Grid physics library, www.github.com/paboyle/Grid
Source file: ./lib/qcd/hmc/integrators/DirichletFilter.h
Copyright (C) 2015
Author: Peter Boyle <paboyle@ph.ed.ac.uk>
This program is free software; you can redistribute it and/or modify
it under the terms of the GNU General Public License as published by
the Free Software Foundation; either version 2 of the License, or
(at your option) any later version.
This program is distributed in the hope that it will be useful,
but WITHOUT ANY WARRANTY; without even the implied warranty of
MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
GNU General Public License for more details.
You should have received a copy of the GNU General Public License along
with this program; if not, write to the Free Software Foundation, Inc.,
51 Franklin Street, Fifth Floor, Boston, MA 02110-1301 USA.
See the full license in the file "LICENSE" in the top level distribution
directory
*************************************************************************************/
/* END LEGAL */
//--------------------------------------------------------------------
#pragma once
NAMESPACE_BEGIN(Grid);
template<typename MomentaField>
struct DirichletFilter: public MomentumFilterBase<MomentaField>
{
typedef typename MomentaField::vector_type vector_type; //SIMD-vectorized complex type
typedef typename MomentaField::scalar_type scalar_type; //scalar complex type
typedef iScalar<iScalar<iScalar<vector_type> > > ScalarType; //complex phase for each site
Coordinate Block;
DirichletFilter(const Coordinate &_Block): Block(_Block){}
void applyFilter(MomentaField &P) const override
{
GridBase *grid = P.Grid();
typedef decltype(PeekIndex<LorentzIndex>(P, 0)) LatCM;
////////////////////////////////////////////////////
// Zero strictly links crossing between domains
////////////////////////////////////////////////////
LatticeInteger coor(grid);
LatCM zz(grid); zz = Zero();
for(int mu=0;mu<Nd;mu++) {
if ( (Block[mu]) && (Block[mu] < grid->GlobalDimensions()[mu] ) ) {
// If costly could provide Grid earlier and precompute masks
std::cout << " Dirichlet in mu="<<mu<<std::endl;
LatticeCoordinate(coor,mu);
auto P_mu = PeekIndex<LorentzIndex>(P, mu);
P_mu = where(mod(coor,Block[mu])==Integer(Block[mu]-1),zz,P_mu);
PokeIndex<LorentzIndex>(P, P_mu, mu);
}
}
}
};
NAMESPACE_END(Grid);

View File

@ -33,7 +33,6 @@ directory
#define INTEGRATOR_INCLUDED #define INTEGRATOR_INCLUDED
#include <memory> #include <memory>
#include "MomentumFilter.h"
NAMESPACE_BEGIN(Grid); NAMESPACE_BEGIN(Grid);

View File

@ -648,7 +648,7 @@ public:
} }
} }
/// Introduce a block structure and switch off comms on boundaries /// Introduce a block structure and switch off comms on boundaries
void DirichletBlock(const std::vector<int> &dirichlet_block) void DirichletBlock(const Coordinate &dirichlet_block)
{ {
this->_dirichlet = 1; this->_dirichlet = 1;
for(int ii=0;ii<this->_npoints;ii++){ for(int ii=0;ii<this->_npoints;ii++){

View File

@ -36,41 +36,6 @@ using namespace Grid;
/// Move to domains //// /// Move to domains ////
//////////////////////// ////////////////////////
template<typename MomentaField>
struct DirichletFilter: public MomentumFilterBase<MomentaField>
{
typedef typename MomentaField::vector_type vector_type; //SIMD-vectorized complex type
typedef typename MomentaField::scalar_type scalar_type; //scalar complex type
typedef iScalar<iScalar<iScalar<vector_type> > > ScalarType; //complex phase for each site
Coordinate Block;
DirichletFilter(const Coordinate &_Block): Block(_Block){}
void applyFilter(MomentaField &P) const override
{
GridBase *grid = P.Grid();
typedef decltype(PeekIndex<LorentzIndex>(P, 0)) LatCM;
////////////////////////////////////////////////////
// Zero strictly links crossing between domains
////////////////////////////////////////////////////
LatticeInteger coor(grid);
LatCM zz(grid); zz = Zero();
for(int mu=0;mu<Nd;mu++) {
if ( Block[mu] ) {
// If costly could provide Grid earlier and precompute masks
LatticeCoordinate(coor,mu);
auto P_mu = PeekIndex<LorentzIndex>(P, mu);
P_mu = where(mod(coor,Block[mu])==Integer(Block[mu]-1),zz,P_mu);
PokeIndex<LorentzIndex>(P, P_mu, mu);
}
}
}
};
Gamma::Algebra Gmu [] = { Gamma::Algebra Gmu [] = {
Gamma::Algebra::GammaX, Gamma::Algebra::GammaX,
Gamma::Algebra::GammaY, Gamma::Algebra::GammaY,
@ -78,7 +43,7 @@ Gamma::Algebra Gmu [] = {
Gamma::Algebra::GammaT Gamma::Algebra::GammaT
}; };
void Benchmark(int Ls, std::vector<int> Dirichlet); void Benchmark(int Ls, Coordinate Dirichlet);
int main (int argc, char ** argv) int main (int argc, char ** argv)
{ {
@ -97,8 +62,9 @@ int main (int argc, char ** argv)
////////////////// //////////////////
// With comms // With comms
////////////////// //////////////////
std::vector<int> Dirichlet(5,0); Coordinate Dirichlet(Nd+1,0);
std::cout << "\n\n\n\n\n\n" <<std::endl;
std::cout << GridLogMessage<< "++++++++++++++++++++++++++++++++++++++++++++++++" <<std::endl; std::cout << GridLogMessage<< "++++++++++++++++++++++++++++++++++++++++++++++++" <<std::endl;
std::cout << GridLogMessage<< " Testing with full communication " <<std::endl; std::cout << GridLogMessage<< " Testing with full communication " <<std::endl;
std::cout << GridLogMessage<< "++++++++++++++++++++++++++++++++++++++++++++++++" <<std::endl; std::cout << GridLogMessage<< "++++++++++++++++++++++++++++++++++++++++++++++++" <<std::endl;
@ -110,7 +76,7 @@ int main (int argc, char ** argv)
////////////////// //////////////////
Coordinate latt4 = GridDefaultLatt(); Coordinate latt4 = GridDefaultLatt();
Coordinate mpi = GridDefaultMpi(); Coordinate mpi = GridDefaultMpi();
std::vector<int> CommDim(Nd); Coordinate CommDim(Nd);
Coordinate shm; Coordinate shm;
GlobalSharedMemory::GetShmDims(mpi,shm); GlobalSharedMemory::GetShmDims(mpi,shm);
@ -118,36 +84,39 @@ int main (int argc, char ** argv)
////////////////////// //////////////////////
// Node level // Node level
////////////////////// //////////////////////
std::cout << "\n\n\n\n\n\n" <<std::endl;
std::cout << GridLogMessage<< "++++++++++++++++++++++++++++++++++++++++++++++++" <<std::endl; std::cout << GridLogMessage<< "++++++++++++++++++++++++++++++++++++++++++++++++" <<std::endl;
std::cout << GridLogMessage<< " Testing without internode communication " <<std::endl; std::cout << GridLogMessage<< " Testing without internode communication " <<std::endl;
std::cout << GridLogMessage<< "++++++++++++++++++++++++++++++++++++++++++++++++" <<std::endl; std::cout << GridLogMessage<< "++++++++++++++++++++++++++++++++++++++++++++++++" <<std::endl;
for(int d=0;d<Nd;d++) CommDim[d]= (mpi[d]/shm[d])>1 ? 1 : 0; for(int d=0;d<Nd;d++) CommDim[d]= (mpi[d]/shm[d])>1 ? 1 : 0;
Dirichlet = std::vector<int>({0, Dirichlet[0] = 0;
CommDim[0]*latt4[0]/mpi[0] * shm[0], Dirichlet[1] = CommDim[0]*latt4[0]/mpi[0] * shm[0];
CommDim[1]*latt4[1]/mpi[1] * shm[1], Dirichlet[2] = CommDim[1]*latt4[1]/mpi[1] * shm[1];
CommDim[2]*latt4[2]/mpi[2] * shm[2], Dirichlet[3] = CommDim[2]*latt4[2]/mpi[2] * shm[2];
CommDim[3]*latt4[3]/mpi[3] * shm[3]}); Dirichlet[4] = CommDim[3]*latt4[3]/mpi[3] * shm[3];
Benchmark(Ls,Dirichlet); Benchmark(Ls,Dirichlet);
std::cout << "\n\n\n\n\n\n" <<std::endl;
std::cout << GridLogMessage<< "++++++++++++++++++++++++++++++++++++++++++++++++" <<std::endl; std::cout << GridLogMessage<< "++++++++++++++++++++++++++++++++++++++++++++++++" <<std::endl;
std::cout << GridLogMessage<< " Testing without intranode communication " <<std::endl; std::cout << GridLogMessage<< " Testing without intranode communication " <<std::endl;
std::cout << GridLogMessage<< "++++++++++++++++++++++++++++++++++++++++++++++++" <<std::endl; std::cout << GridLogMessage<< "++++++++++++++++++++++++++++++++++++++++++++++++" <<std::endl;
for(int d=0;d<Nd;d++) CommDim[d]= mpi[d]>1 ? 1 : 0; for(int d=0;d<Nd;d++) CommDim[d]= mpi[d]>1 ? 1 : 0;
Dirichlet = std::vector<int>({0, Dirichlet[0] = 0;
CommDim[0]*latt4[0]/mpi[0], Dirichlet[1] = CommDim[0]*latt4[0]/mpi[0];
CommDim[1]*latt4[1]/mpi[1], Dirichlet[2] = CommDim[1]*latt4[1]/mpi[1];
CommDim[2]*latt4[2]/mpi[2], Dirichlet[3] = CommDim[2]*latt4[2]/mpi[2];
CommDim[3]*latt4[3]/mpi[3]}); Dirichlet[4] = CommDim[3]*latt4[3]/mpi[3];
Benchmark(Ls,Dirichlet); Benchmark(Ls,Dirichlet);
Grid_finalize(); Grid_finalize();
exit(0); exit(0);
} }
void Benchmark(int Ls, std::vector<int> Dirichlet) void Benchmark(int Ls, Coordinate Dirichlet)
{ {
Coordinate latt4 = GridDefaultLatt(); Coordinate latt4 = GridDefaultLatt();
GridLogLayout(); GridLogLayout();
@ -196,7 +165,9 @@ void Benchmark(int Ls, std::vector<int> Dirichlet)
std::cout << GridLogMessage << "Drawing gauge field" << std::endl; std::cout << GridLogMessage << "Drawing gauge field" << std::endl;
LatticeGaugeFieldF Umu(UGrid); LatticeGaugeFieldF Umu(UGrid);
LatticeGaugeFieldF UmuCopy(UGrid);
SU<Nc>::HotConfiguration(RNG4,Umu); SU<Nc>::HotConfiguration(RNG4,Umu);
UmuCopy=Umu;
std::cout << GridLogMessage << "Random gauge initialised " << std::endl; std::cout << GridLogMessage << "Random gauge initialised " << std::endl;
//////////////////////////////////// ////////////////////////////////////
@ -205,7 +176,8 @@ void Benchmark(int Ls, std::vector<int> Dirichlet)
Coordinate Block(4); Coordinate Block(4);
for(int d=0;d<4;d++) Block[d]= Dirichlet[d+1]; for(int d=0;d<4;d++) Block[d]= Dirichlet[d+1];
std::cout << GridLogMessage << "Applying BCs for Dirichlet Block " << Block << std::endl; std::cout << GridLogMessage << "Applying BCs for Dirichlet Block5 " << Dirichlet << std::endl;
std::cout << GridLogMessage << "Applying BCs for Dirichlet Block4 " << Block << std::endl;
DirichletFilter<LatticeGaugeFieldF> Filter(Block); DirichletFilter<LatticeGaugeFieldF> Filter(Block);
Filter.applyFilter(Umu); Filter.applyFilter(Umu);
@ -279,6 +251,7 @@ void Benchmark(int Ls, std::vector<int> Dirichlet)
DomainWallFermionF Dw(Umu,*FGrid,*FrbGrid,*UGrid,*UrbGrid,mass,M5); DomainWallFermionF Dw(Umu,*FGrid,*FrbGrid,*UGrid,*UrbGrid,mass,M5);
Dw.DirichletBlock(Dirichlet); Dw.DirichletBlock(Dirichlet);
Dw.ImportGauge(Umu);
int ncall =300; int ncall =300;
@ -377,9 +350,6 @@ void Benchmark(int Ls, std::vector<int> Dirichlet)
std::cout<<GridLogMessage << "norm dag ref "<< norm2(ref)<<std::endl; std::cout<<GridLogMessage << "norm dag ref "<< norm2(ref)<<std::endl;
err = ref-result; err = ref-result;
std::cout<<GridLogMessage << "norm dag diff "<< norm2(err)<<std::endl; std::cout<<GridLogMessage << "norm dag diff "<< norm2(err)<<std::endl;
if ( norm2(err)>1.0e-4) {
std::cout << err << std::endl;
}
assert((norm2(err)<1.0e-4)); assert((norm2(err)<1.0e-4));
LatticeFermionF src_e (FrbGrid); LatticeFermionF src_e (FrbGrid);