1
0
mirror of https://github.com/paboyle/Grid.git synced 2025-04-04 19:25:56 +01:00

The stencils for the staple and rect-staple padded cell implementations are now created and stored by workspace classes that allow for reuse providing the grids remain consistent

The workspaces are now used by the plaq+rectangle gauge action resulting in a further 2x performance improvement as measured on a 16^4 local volume for 2 nodes (16 ranks) of Crusher
This commit is contained in:
Christopher Kelly 2023-06-28 15:11:24 -04:00
parent f44dce390f
commit 1dfaa08afb
2 changed files with 220 additions and 103 deletions

View File

@ -43,7 +43,7 @@ public:
private:
RealD c_plaq;
RealD c_rect;
typename WilsonLoops<Gimpl>::StapleAndRectStapleAllWorkspace workspace;
public:
PlaqPlusRectangleAction(RealD b,RealD c): c_plaq(b),c_rect(c){};
@ -83,7 +83,7 @@ public:
U[mu] = PeekIndex<LorentzIndex>(Umu,mu);
}
std::vector<GaugeLinkField> RectStaple(Nd,grid), Staple(Nd,grid);
WilsonLoops<Gimpl>::StapleAndRectStapleAll(Staple, RectStaple, U);
WilsonLoops<Gimpl>::StapleAndRectStapleAll(Staple, RectStaple, U, workspace);
GaugeLinkField dSdU_mu(grid);
GaugeLinkField staple(grid);

View File

@ -349,47 +349,129 @@ public:
for(int mu=0;mu<Nd;mu++) Staple(staple[mu], U, mu);
}
//A workspace class allowing reuse of the stencil
class WilsonLoopPaddedStencilWorkspace{
std::unique_ptr<GeneralLocalStencil> stencil;
size_t nshift;
void generateStencil(GridBase* padded_grid){
double t0 = usecond();
//Generate shift arrays
std::vector<Coordinate> shifts = this->getShifts();
nshift = shifts.size();
double t1 = usecond();
//Generate local stencil
stencil.reset(new GeneralLocalStencil(padded_grid,shifts));
double t2 = usecond();
std::cout << GridLogPerformance << " WilsonLoopPaddedWorkspace timings: coord:" << (t1-t0)/1000 << "ms, stencil:" << (t2-t1)/1000 << "ms" << std::endl;
}
public:
//Get the stencil. If not already generated, or if generated using a different Grid than in PaddedCell, it will be created on-the-fly
const GeneralLocalStencil & getStencil(const PaddedCell &pcell){
assert(pcell.depth >= this->paddingDepth());
if(!stencil || stencil->Grid() != (GridBase*)pcell.grids.back() ) generateStencil((GridBase*)pcell.grids.back());
return *stencil;
}
size_t Nshift() const{ return nshift; }
virtual std::vector<Coordinate> getShifts() const = 0;
virtual int paddingDepth() const = 0; //padding depth required
virtual ~WilsonLoopPaddedStencilWorkspace(){}
};
//This workspace allows the sharing of a common PaddedCell object between multiple stencil workspaces
class WilsonLoopPaddedWorkspace{
std::vector<WilsonLoopPaddedStencilWorkspace*> stencil_wk;
std::unique_ptr<PaddedCell> pcell;
void generatePcell(GridBase* unpadded_grid){
assert(stencil_wk.size());
int max_depth = 0;
for(auto const &s : stencil_wk) max_depth=std::max(max_depth, s->paddingDepth());
pcell.reset(new PaddedCell(max_depth, dynamic_cast<GridCartesian*>(unpadded_grid)));
}
public:
//Add a stencil definition. This should be done before the first call to retrieve a stencil object.
//Takes ownership of the pointer
void addStencil(WilsonLoopPaddedStencilWorkspace *stencil){
assert(!pcell);
stencil_wk.push_back(stencil);
}
const GeneralLocalStencil & getStencil(const size_t stencil_idx, GridBase* unpadded_grid){
if(!pcell || pcell->unpadded_grid != unpadded_grid) generatePcell(unpadded_grid);
return stencil_wk[stencil_idx]->getStencil(*pcell);
}
const PaddedCell & getPaddedCell(GridBase* unpadded_grid){
if(!pcell || pcell->unpadded_grid != unpadded_grid) generatePcell(unpadded_grid);
return *pcell;
}
~WilsonLoopPaddedWorkspace(){
for(auto &s : stencil_wk) delete s;
}
};
//A workspace class allowing reuse of the stencil
class StaplePaddedAllWorkspace: public WilsonLoopPaddedStencilWorkspace{
public:
std::vector<Coordinate> getShifts() const override{
std::vector<Coordinate> shifts;
for(int mu=0;mu<Nd;mu++){
for(int nu=0;nu<Nd;nu++){
if(nu != mu){
Coordinate shift_0(Nd,0);
Coordinate shift_mu(Nd,0); shift_mu[mu]=1;
Coordinate shift_nu(Nd,0); shift_nu[nu]=1;
Coordinate shift_mnu(Nd,0); shift_mnu[nu]=-1;
Coordinate shift_mnu_pmu(Nd,0); shift_mnu_pmu[nu]=-1; shift_mnu_pmu[mu]=1;
//U_nu(x+mu)U^dag_mu(x+nu) U^dag_nu(x)
shifts.push_back(shift_0);
shifts.push_back(shift_nu);
shifts.push_back(shift_mu);
//U_nu^dag(x-nu+mu) U_mu^dag(x-nu) U_nu(x-nu)
shifts.push_back(shift_mnu);
shifts.push_back(shift_mnu);
shifts.push_back(shift_mnu_pmu);
}
}
}
return shifts;
}
int paddingDepth() const override{ return 1; }
};
//Padded cell implementation of the staple method for all mu, summed over nu != mu
//staple: output staple for each mu, summed over nu != mu (Nd)
//U_padded: the gauge link fields padded out using the PaddedCell class
//Cell: the padded cell class
static void StaplePaddedAll(std::vector<GaugeMat> &staple, const std::vector<GaugeMat> &U_padded, const PaddedCell &Cell) {
StaplePaddedAllWorkspace wk;
StaplePaddedAll(staple,U_padded,Cell,wk.getStencil(Cell));
}
//Padded cell implementation of the staple method for all mu, summed over nu != mu
//staple: output staple for each mu, summed over nu != mu (Nd)
//U_padded: the gauge link fields padded out using the PaddedCell class
//Cell: the padded cell class
//gStencil: the precomputed generalized local stencil for the staple
static void StaplePaddedAll(std::vector<GaugeMat> &staple, const std::vector<GaugeMat> &U_padded, const PaddedCell &Cell, const GeneralLocalStencil &gStencil) {
double t0 = usecond();
assert(U_padded.size() == Nd); assert(staple.size() == Nd);
assert(U_padded[0].Grid() == (GridBase*)Cell.grids.back());
assert(Cell.depth >= 1);
GridBase *ggrid = U_padded[0].Grid(); //padded cell grid
double t0 = usecond();
//Generate shift arrays
std::vector<Coordinate> shifts;
for(int mu=0;mu<Nd;mu++){
for(int nu=0;nu<Nd;nu++){
if(nu != mu){
Coordinate shift_0(Nd,0);
Coordinate shift_mu(Nd,0); shift_mu[mu]=1;
Coordinate shift_nu(Nd,0); shift_nu[nu]=1;
Coordinate shift_mnu(Nd,0); shift_mnu[nu]=-1;
Coordinate shift_mnu_pmu(Nd,0); shift_mnu_pmu[nu]=-1; shift_mnu_pmu[mu]=1;
//U_nu(x+mu)U^dag_mu(x+nu) U^dag_nu(x)
shifts.push_back(shift_0);
shifts.push_back(shift_nu);
shifts.push_back(shift_mu);
//U_nu^dag(x-nu+mu) U_mu^dag(x-nu) U_nu(x-nu)
shifts.push_back(shift_mnu);
shifts.push_back(shift_mnu);
shifts.push_back(shift_mnu_pmu);
}
}
}
int shift_mu_off = shifts.size()/Nd;
double t1 = usecond();
//Generate local stencil
GeneralLocalStencil gStencil(ggrid,shifts);
double t2 = usecond();
int shift_mu_off = gStencil._npoints/Nd;
//Open views to padded gauge links and keep open over mu loop
typedef LatticeView<typename GaugeMat::vector_object> GaugeViewType;
@ -447,9 +529,9 @@ public:
free(Ug_dirs_v_host);
acceleratorFreeDevice(Ug_dirs_v);
double t3=usecond();
double t1=usecond();
std::cout << GridLogPerformance << "StaplePaddedAll timings: coord:" << (t1-t0)/1000 << "ms, stencil:" << (t2-t1)/1000 << "ms, kernel:" << (t3-t2)/1000 << "ms" << std::endl;
std::cout << GridLogPerformance << "StaplePaddedAll timing:" << (t1-t0)/1000 << "ms" << std::endl;
}
@ -1016,75 +1098,91 @@ public:
for(int mu=0;mu<Nd;mu++) RectStapleOptimised(Stap[mu], U2, U, mu);
}
//A workspace class allowing reuse of the stencil
class RectStaplePaddedAllWorkspace: public WilsonLoopPaddedStencilWorkspace{
public:
std::vector<Coordinate> getShifts() const override{
std::vector<Coordinate> shifts;
for (int mu = 0; mu < Nd; mu++){
for (int nu = 0; nu < Nd; nu++) {
if (nu != mu) {
auto genShift = [&](int mushift,int nushift){
Coordinate out(Nd,0); out[mu]=mushift; out[nu]=nushift; return out;
};
//tmp6 = tmp5(x+mu) = U_mu(x+mu)U_nu(x+2mu)U_mu^dag(x+nu+mu) U_mu^dag(x+nu) U_nu^dag(x)
shifts.push_back(genShift(0,0));
shifts.push_back(genShift(0,+1));
shifts.push_back(genShift(+1,+1));
shifts.push_back(genShift(+2,0));
shifts.push_back(genShift(+1,0));
//tmp5 = tmp4(x+mu) = U_mu(x+mu)U^dag_nu(x-nu+2mu)U^dag_mu(x-nu+mu)U^dag_mu(x-nu)U_nu(x-nu)
shifts.push_back(genShift(0,-1));
shifts.push_back(genShift(0,-1));
shifts.push_back(genShift(+1,-1));
shifts.push_back(genShift(+2,-1));
shifts.push_back(genShift(+1,0));
//tmp5 = tmp4(x+mu) = U^dag_nu(x-nu+mu)U^dag_mu(x-nu)U^dag_mu(x-mu-nu)U_nu(x-mu-nu)U_mu(x-mu)
shifts.push_back(genShift(-1,0));
shifts.push_back(genShift(-1,-1));
shifts.push_back(genShift(-1,-1));
shifts.push_back(genShift(0,-1));
shifts.push_back(genShift(+1,-1));
//tmp5 = tmp4(x+mu) = U_nu(x+mu)U_mu^dag(x+nu)U_mu^dag(x-mu+nu)U_nu^dag(x-mu)U_mu(x-mu)
shifts.push_back(genShift(-1,0));
shifts.push_back(genShift(-1,0));
shifts.push_back(genShift(-1,+1));
shifts.push_back(genShift(0,+1));
shifts.push_back(genShift(+1,0));
//tmp6 = tmp5(x+mu) = U_nu(x+mu)U_nu(x+mu+nu)U_mu^dag(x+2nu)U_nu^dag(x+nu)U_nu^dag(x)
shifts.push_back(genShift(0,0));
shifts.push_back(genShift(0,+1));
shifts.push_back(genShift(0,+2));
shifts.push_back(genShift(+1,+1));
shifts.push_back(genShift(+1,0));
//tmp5 = tmp4(x+mu) = U_nu^dag(x+mu-nu)U_nu^dag(x+mu-2nu)U_mu^dag(x-2nu)U_nu(x-2nu)U_nu(x-nu)
shifts.push_back(genShift(0,-1));
shifts.push_back(genShift(0,-2));
shifts.push_back(genShift(0,-2));
shifts.push_back(genShift(+1,-2));
shifts.push_back(genShift(+1,-1));
}
}
}
return shifts;
}
int paddingDepth() const override{ return 2; }
};
//Padded cell implementation of the rectangular staple method for all mu, summed over nu != mu
//staple: output staple for each mu, summed over nu != mu (Nd)
//U_padded: the gauge link fields padded out using the PaddedCell class
//Cell: the padded cell class
static void RectStaplePaddedAll(std::vector<GaugeMat> &staple, const std::vector<GaugeMat> &U_padded, const PaddedCell &Cell) {
RectStaplePaddedAllWorkspace wk;
RectStaplePaddedAll(staple,U_padded,Cell,wk.getStencil(Cell));
}
//Padded cell implementation of the rectangular staple method for all mu, summed over nu != mu
//staple: output staple for each mu, summed over nu != mu (Nd)
//U_padded: the gauge link fields padded out using the PaddedCell class
//Cell: the padded cell class
//gStencil: the stencil
static void RectStaplePaddedAll(std::vector<GaugeMat> &staple, const std::vector<GaugeMat> &U_padded, const PaddedCell &Cell, const GeneralLocalStencil &gStencil) {
double t0 = usecond();
assert(U_padded.size() == Nd); assert(staple.size() == Nd);
assert(U_padded[0].Grid() == (GridBase*)Cell.grids.back());
assert(Cell.depth >= 2);
GridBase *ggrid = U_padded[0].Grid(); //padded cell grid
double t0 = usecond();
std::vector<Coordinate> shifts;
for (int mu = 0; mu < Nd; mu++){
for (int nu = 0; nu < Nd; nu++) {
if (nu != mu) {
auto genShift = [&](int mushift,int nushift){
Coordinate out(Nd,0); out[mu]=mushift; out[nu]=nushift; return out;
};
//tmp6 = tmp5(x+mu) = U_mu(x+mu)U_nu(x+2mu)U_mu^dag(x+nu+mu) U_mu^dag(x+nu) U_nu^dag(x)
shifts.push_back(genShift(0,0));
shifts.push_back(genShift(0,+1));
shifts.push_back(genShift(+1,+1));
shifts.push_back(genShift(+2,0));
shifts.push_back(genShift(+1,0));
//tmp5 = tmp4(x+mu) = U_mu(x+mu)U^dag_nu(x-nu+2mu)U^dag_mu(x-nu+mu)U^dag_mu(x-nu)U_nu(x-nu)
shifts.push_back(genShift(0,-1));
shifts.push_back(genShift(0,-1));
shifts.push_back(genShift(+1,-1));
shifts.push_back(genShift(+2,-1));
shifts.push_back(genShift(+1,0));
//tmp5 = tmp4(x+mu) = U^dag_nu(x-nu+mu)U^dag_mu(x-nu)U^dag_mu(x-mu-nu)U_nu(x-mu-nu)U_mu(x-mu)
shifts.push_back(genShift(-1,0));
shifts.push_back(genShift(-1,-1));
shifts.push_back(genShift(-1,-1));
shifts.push_back(genShift(0,-1));
shifts.push_back(genShift(+1,-1));
//tmp5 = tmp4(x+mu) = U_nu(x+mu)U_mu^dag(x+nu)U_mu^dag(x-mu+nu)U_nu^dag(x-mu)U_mu(x-mu)
shifts.push_back(genShift(-1,0));
shifts.push_back(genShift(-1,0));
shifts.push_back(genShift(-1,+1));
shifts.push_back(genShift(0,+1));
shifts.push_back(genShift(+1,0));
//tmp6 = tmp5(x+mu) = U_nu(x+mu)U_nu(x+mu+nu)U_mu^dag(x+2nu)U_nu^dag(x+nu)U_nu^dag(x)
shifts.push_back(genShift(0,0));
shifts.push_back(genShift(0,+1));
shifts.push_back(genShift(0,+2));
shifts.push_back(genShift(+1,+1));
shifts.push_back(genShift(+1,0));
//tmp5 = tmp4(x+mu) = U_nu^dag(x+mu-nu)U_nu^dag(x+mu-2nu)U_mu^dag(x-2nu)U_nu(x-2nu)U_nu(x-nu)
shifts.push_back(genShift(0,-1));
shifts.push_back(genShift(0,-2));
shifts.push_back(genShift(0,-2));
shifts.push_back(genShift(+1,-2));
shifts.push_back(genShift(+1,-1));
}
}
}
size_t nshift = shifts.size();
size_t nshift = gStencil._npoints;
int mu_off_delta = nshift / Nd;
double t1 = usecond();
GeneralLocalStencil gStencil(ggrid,shifts);
double t2 = usecond();
//Open views to padded gauge links and keep open over mu loop
typedef LatticeView<typename GaugeMat::vector_object> GaugeViewType;
@ -1208,13 +1306,20 @@ public:
free(Ug_dirs_v_host);
acceleratorFreeDevice(Ug_dirs_v);
double t3 = usecond();
double t1 = usecond();
std::cout << GridLogPerformance << "RectStaplePaddedAll timings: coord:" << (t1-t0)/1000 << "ms, stencil:" << (t2-t1)/1000 << "ms, kernel:" << (t3-t2)/1000 << "ms" << std::endl;
std::cout << GridLogPerformance << "RectStaplePaddedAll timings:" << (t1-t0)/1000 << "ms" << std::endl;
}
//A workspace for reusing the PaddedCell and GeneralLocalStencil objects
class StapleAndRectStapleAllWorkspace: public WilsonLoopPaddedWorkspace{
public:
StapleAndRectStapleAllWorkspace(){
this->addStencil(new StaplePaddedAllWorkspace);
this->addStencil(new RectStaplePaddedAllWorkspace);
}
};
//////////////////////////////////////////////////////
//Compute the 1x1 and 1x2 staples for all orientations
//Stap : Array of staples (Nd)
@ -1222,27 +1327,39 @@ public:
//U: Gauge links in each direction (Nd)
/////////////////////////////////////////////////////
static void StapleAndRectStapleAll(std::vector<GaugeMat> &Stap, std::vector<GaugeMat> &RectStap, const std::vector<GaugeMat> &U){
StapleAndRectStapleAllWorkspace wk;
StapleAndRectStapleAll(Stap,RectStap,U,wk);
}
//////////////////////////////////////////////////////
//Compute the 1x1 and 1x2 staples for all orientations
//Stap : Array of staples (Nd)
//RectStap: Array of rectangular staples (Nd)
//U: Gauge links in each direction (Nd)
//wk: a workspace containing stored PaddedCell and GeneralLocalStencil objects to maximize reuse
/////////////////////////////////////////////////////
static void StapleAndRectStapleAll(std::vector<GaugeMat> &Stap, std::vector<GaugeMat> &RectStap, const std::vector<GaugeMat> &U, StapleAndRectStapleAllWorkspace &wk){
#if 0
StapleAll(Stap, U);
RectStapleAll(RectStap, U);
#else
double t0 = usecond();
//Use the padded cell with maximal reuse
PaddedCell Ghost(2, dynamic_cast<GridCartesian*>(U[0].Grid()));
GridCartesian* unpadded_grid = dynamic_cast<GridCartesian*>(U[0].Grid());
const PaddedCell &Ghost = wk.getPaddedCell(unpadded_grid);
CshiftImplGauge<Gimpl> cshift_impl;
std::vector<GaugeMat> U_pad(Nd, Ghost.grids.back());
for(int mu=0;mu<Nd;mu++) U_pad[mu] = Ghost.Exchange(U[mu], cshift_impl);
double t1 = usecond();
StaplePaddedAll(Stap, U_pad, Ghost);
StaplePaddedAll(Stap, U_pad, Ghost, wk.getStencil(0,unpadded_grid) );
double t2 = usecond();
RectStaplePaddedAll(RectStap, U_pad, Ghost);
RectStaplePaddedAll(RectStap, U_pad, Ghost, wk.getStencil(1,unpadded_grid));
double t3 = usecond();
std::cout << GridLogPerformance << "StapleAndRectStapleAll timings: pad:" << (t1-t0)/1000 << "ms, staple:" << (t2-t1)/1000 << "ms, rect-staple:" << (t3-t2)/1000 << "ms" << std::endl;
#endif
}
//////////////////////////////////////////////////
// Wilson loop of size (R1, R2), oriented in mu,nu plane
//////////////////////////////////////////////////