From 1dfaa08afb6314b3d74b6561687ed1ed7ce4f844 Mon Sep 17 00:00:00 2001 From: Christopher Kelly Date: Wed, 28 Jun 2023 15:11:24 -0400 Subject: [PATCH] The stencils for the staple and rect-staple padded cell implementations are now created and stored by workspace classes that allow for reuse providing the grids remain consistent The workspaces are now used by the plaq+rectangle gauge action resulting in a further 2x performance improvement as measured on a 16^4 local volume for 2 nodes (16 ranks) of Crusher --- .../action/gauge/PlaqPlusRectangleAction.h | 4 +- Grid/qcd/utils/WilsonLoops.h | 319 ++++++++++++------ 2 files changed, 220 insertions(+), 103 deletions(-) diff --git a/Grid/qcd/action/gauge/PlaqPlusRectangleAction.h b/Grid/qcd/action/gauge/PlaqPlusRectangleAction.h index 68eb0a67..b9d6ac16 100644 --- a/Grid/qcd/action/gauge/PlaqPlusRectangleAction.h +++ b/Grid/qcd/action/gauge/PlaqPlusRectangleAction.h @@ -43,7 +43,7 @@ public: private: RealD c_plaq; RealD c_rect; - + typename WilsonLoops::StapleAndRectStapleAllWorkspace workspace; public: PlaqPlusRectangleAction(RealD b,RealD c): c_plaq(b),c_rect(c){}; @@ -83,7 +83,7 @@ public: U[mu] = PeekIndex(Umu,mu); } std::vector RectStaple(Nd,grid), Staple(Nd,grid); - WilsonLoops::StapleAndRectStapleAll(Staple, RectStaple, U); + WilsonLoops::StapleAndRectStapleAll(Staple, RectStaple, U, workspace); GaugeLinkField dSdU_mu(grid); GaugeLinkField staple(grid); diff --git a/Grid/qcd/utils/WilsonLoops.h b/Grid/qcd/utils/WilsonLoops.h index 731b755c..78e25a8d 100644 --- a/Grid/qcd/utils/WilsonLoops.h +++ b/Grid/qcd/utils/WilsonLoops.h @@ -349,47 +349,129 @@ public: for(int mu=0;mu stencil; + size_t nshift; + + void generateStencil(GridBase* padded_grid){ + double t0 = usecond(); + + //Generate shift arrays + std::vector shifts = this->getShifts(); + nshift = shifts.size(); + + double t1 = usecond(); + //Generate local stencil + stencil.reset(new GeneralLocalStencil(padded_grid,shifts)); + double t2 = usecond(); + std::cout << GridLogPerformance << " WilsonLoopPaddedWorkspace timings: coord:" << (t1-t0)/1000 << "ms, stencil:" << (t2-t1)/1000 << "ms" << std::endl; + } + public: + //Get the stencil. If not already generated, or if generated using a different Grid than in PaddedCell, it will be created on-the-fly + const GeneralLocalStencil & getStencil(const PaddedCell &pcell){ + assert(pcell.depth >= this->paddingDepth()); + if(!stencil || stencil->Grid() != (GridBase*)pcell.grids.back() ) generateStencil((GridBase*)pcell.grids.back()); + return *stencil; + } + size_t Nshift() const{ return nshift; } + + virtual std::vector getShifts() const = 0; + virtual int paddingDepth() const = 0; //padding depth required + + virtual ~WilsonLoopPaddedStencilWorkspace(){} + }; + + //This workspace allows the sharing of a common PaddedCell object between multiple stencil workspaces + class WilsonLoopPaddedWorkspace{ + std::vector stencil_wk; + std::unique_ptr pcell; + + void generatePcell(GridBase* unpadded_grid){ + assert(stencil_wk.size()); + int max_depth = 0; + for(auto const &s : stencil_wk) max_depth=std::max(max_depth, s->paddingDepth()); + + pcell.reset(new PaddedCell(max_depth, dynamic_cast(unpadded_grid))); + } + + public: + //Add a stencil definition. This should be done before the first call to retrieve a stencil object. + //Takes ownership of the pointer + void addStencil(WilsonLoopPaddedStencilWorkspace *stencil){ + assert(!pcell); + stencil_wk.push_back(stencil); + } + + const GeneralLocalStencil & getStencil(const size_t stencil_idx, GridBase* unpadded_grid){ + if(!pcell || pcell->unpadded_grid != unpadded_grid) generatePcell(unpadded_grid); + return stencil_wk[stencil_idx]->getStencil(*pcell); + } + const PaddedCell & getPaddedCell(GridBase* unpadded_grid){ + if(!pcell || pcell->unpadded_grid != unpadded_grid) generatePcell(unpadded_grid); + return *pcell; + } + + ~WilsonLoopPaddedWorkspace(){ + for(auto &s : stencil_wk) delete s; + } + }; + + //A workspace class allowing reuse of the stencil + class StaplePaddedAllWorkspace: public WilsonLoopPaddedStencilWorkspace{ + public: + std::vector getShifts() const override{ + std::vector shifts; + for(int mu=0;mu &staple, const std::vector &U_padded, const PaddedCell &Cell) { + StaplePaddedAllWorkspace wk; + StaplePaddedAll(staple,U_padded,Cell,wk.getStencil(Cell)); + } + + //Padded cell implementation of the staple method for all mu, summed over nu != mu + //staple: output staple for each mu, summed over nu != mu (Nd) + //U_padded: the gauge link fields padded out using the PaddedCell class + //Cell: the padded cell class + //gStencil: the precomputed generalized local stencil for the staple + static void StaplePaddedAll(std::vector &staple, const std::vector &U_padded, const PaddedCell &Cell, const GeneralLocalStencil &gStencil) { + double t0 = usecond(); assert(U_padded.size() == Nd); assert(staple.size() == Nd); assert(U_padded[0].Grid() == (GridBase*)Cell.grids.back()); assert(Cell.depth >= 1); GridBase *ggrid = U_padded[0].Grid(); //padded cell grid - double t0 = usecond(); - //Generate shift arrays - std::vector shifts; - for(int mu=0;mu GaugeViewType; @@ -447,9 +529,9 @@ public: free(Ug_dirs_v_host); acceleratorFreeDevice(Ug_dirs_v); - double t3=usecond(); + double t1=usecond(); - std::cout << GridLogPerformance << "StaplePaddedAll timings: coord:" << (t1-t0)/1000 << "ms, stencil:" << (t2-t1)/1000 << "ms, kernel:" << (t3-t2)/1000 << "ms" << std::endl; + std::cout << GridLogPerformance << "StaplePaddedAll timing:" << (t1-t0)/1000 << "ms" << std::endl; } @@ -1016,75 +1098,91 @@ public: for(int mu=0;mu getShifts() const override{ + std::vector shifts; + for (int mu = 0; mu < Nd; mu++){ + for (int nu = 0; nu < Nd; nu++) { + if (nu != mu) { + auto genShift = [&](int mushift,int nushift){ + Coordinate out(Nd,0); out[mu]=mushift; out[nu]=nushift; return out; + }; + + //tmp6 = tmp5(x+mu) = U_mu(x+mu)U_nu(x+2mu)U_mu^dag(x+nu+mu) U_mu^dag(x+nu) U_nu^dag(x) + shifts.push_back(genShift(0,0)); + shifts.push_back(genShift(0,+1)); + shifts.push_back(genShift(+1,+1)); + shifts.push_back(genShift(+2,0)); + shifts.push_back(genShift(+1,0)); + + //tmp5 = tmp4(x+mu) = U_mu(x+mu)U^dag_nu(x-nu+2mu)U^dag_mu(x-nu+mu)U^dag_mu(x-nu)U_nu(x-nu) + shifts.push_back(genShift(0,-1)); + shifts.push_back(genShift(0,-1)); + shifts.push_back(genShift(+1,-1)); + shifts.push_back(genShift(+2,-1)); + shifts.push_back(genShift(+1,0)); + + //tmp5 = tmp4(x+mu) = U^dag_nu(x-nu+mu)U^dag_mu(x-nu)U^dag_mu(x-mu-nu)U_nu(x-mu-nu)U_mu(x-mu) + shifts.push_back(genShift(-1,0)); + shifts.push_back(genShift(-1,-1)); + shifts.push_back(genShift(-1,-1)); + shifts.push_back(genShift(0,-1)); + shifts.push_back(genShift(+1,-1)); + + //tmp5 = tmp4(x+mu) = U_nu(x+mu)U_mu^dag(x+nu)U_mu^dag(x-mu+nu)U_nu^dag(x-mu)U_mu(x-mu) + shifts.push_back(genShift(-1,0)); + shifts.push_back(genShift(-1,0)); + shifts.push_back(genShift(-1,+1)); + shifts.push_back(genShift(0,+1)); + shifts.push_back(genShift(+1,0)); + + //tmp6 = tmp5(x+mu) = U_nu(x+mu)U_nu(x+mu+nu)U_mu^dag(x+2nu)U_nu^dag(x+nu)U_nu^dag(x) + shifts.push_back(genShift(0,0)); + shifts.push_back(genShift(0,+1)); + shifts.push_back(genShift(0,+2)); + shifts.push_back(genShift(+1,+1)); + shifts.push_back(genShift(+1,0)); + + //tmp5 = tmp4(x+mu) = U_nu^dag(x+mu-nu)U_nu^dag(x+mu-2nu)U_mu^dag(x-2nu)U_nu(x-2nu)U_nu(x-nu) + shifts.push_back(genShift(0,-1)); + shifts.push_back(genShift(0,-2)); + shifts.push_back(genShift(0,-2)); + shifts.push_back(genShift(+1,-2)); + shifts.push_back(genShift(+1,-1)); + } + } + } + return shifts; + } + + int paddingDepth() const override{ return 2; } + }; + //Padded cell implementation of the rectangular staple method for all mu, summed over nu != mu //staple: output staple for each mu, summed over nu != mu (Nd) //U_padded: the gauge link fields padded out using the PaddedCell class //Cell: the padded cell class static void RectStaplePaddedAll(std::vector &staple, const std::vector &U_padded, const PaddedCell &Cell) { + RectStaplePaddedAllWorkspace wk; + RectStaplePaddedAll(staple,U_padded,Cell,wk.getStencil(Cell)); + } + + //Padded cell implementation of the rectangular staple method for all mu, summed over nu != mu + //staple: output staple for each mu, summed over nu != mu (Nd) + //U_padded: the gauge link fields padded out using the PaddedCell class + //Cell: the padded cell class + //gStencil: the stencil + static void RectStaplePaddedAll(std::vector &staple, const std::vector &U_padded, const PaddedCell &Cell, const GeneralLocalStencil &gStencil) { + double t0 = usecond(); assert(U_padded.size() == Nd); assert(staple.size() == Nd); assert(U_padded[0].Grid() == (GridBase*)Cell.grids.back()); assert(Cell.depth >= 2); GridBase *ggrid = U_padded[0].Grid(); //padded cell grid - double t0 = usecond(); - std::vector shifts; - for (int mu = 0; mu < Nd; mu++){ - for (int nu = 0; nu < Nd; nu++) { - if (nu != mu) { - auto genShift = [&](int mushift,int nushift){ - Coordinate out(Nd,0); out[mu]=mushift; out[nu]=nushift; return out; - }; - - //tmp6 = tmp5(x+mu) = U_mu(x+mu)U_nu(x+2mu)U_mu^dag(x+nu+mu) U_mu^dag(x+nu) U_nu^dag(x) - shifts.push_back(genShift(0,0)); - shifts.push_back(genShift(0,+1)); - shifts.push_back(genShift(+1,+1)); - shifts.push_back(genShift(+2,0)); - shifts.push_back(genShift(+1,0)); - - //tmp5 = tmp4(x+mu) = U_mu(x+mu)U^dag_nu(x-nu+2mu)U^dag_mu(x-nu+mu)U^dag_mu(x-nu)U_nu(x-nu) - shifts.push_back(genShift(0,-1)); - shifts.push_back(genShift(0,-1)); - shifts.push_back(genShift(+1,-1)); - shifts.push_back(genShift(+2,-1)); - shifts.push_back(genShift(+1,0)); - - //tmp5 = tmp4(x+mu) = U^dag_nu(x-nu+mu)U^dag_mu(x-nu)U^dag_mu(x-mu-nu)U_nu(x-mu-nu)U_mu(x-mu) - shifts.push_back(genShift(-1,0)); - shifts.push_back(genShift(-1,-1)); - shifts.push_back(genShift(-1,-1)); - shifts.push_back(genShift(0,-1)); - shifts.push_back(genShift(+1,-1)); - - //tmp5 = tmp4(x+mu) = U_nu(x+mu)U_mu^dag(x+nu)U_mu^dag(x-mu+nu)U_nu^dag(x-mu)U_mu(x-mu) - shifts.push_back(genShift(-1,0)); - shifts.push_back(genShift(-1,0)); - shifts.push_back(genShift(-1,+1)); - shifts.push_back(genShift(0,+1)); - shifts.push_back(genShift(+1,0)); - - //tmp6 = tmp5(x+mu) = U_nu(x+mu)U_nu(x+mu+nu)U_mu^dag(x+2nu)U_nu^dag(x+nu)U_nu^dag(x) - shifts.push_back(genShift(0,0)); - shifts.push_back(genShift(0,+1)); - shifts.push_back(genShift(0,+2)); - shifts.push_back(genShift(+1,+1)); - shifts.push_back(genShift(+1,0)); - - //tmp5 = tmp4(x+mu) = U_nu^dag(x+mu-nu)U_nu^dag(x+mu-2nu)U_mu^dag(x-2nu)U_nu(x-2nu)U_nu(x-nu) - shifts.push_back(genShift(0,-1)); - shifts.push_back(genShift(0,-2)); - shifts.push_back(genShift(0,-2)); - shifts.push_back(genShift(+1,-2)); - shifts.push_back(genShift(+1,-1)); - } - } - } - size_t nshift = shifts.size(); + size_t nshift = gStencil._npoints; int mu_off_delta = nshift / Nd; - double t1 = usecond(); - - GeneralLocalStencil gStencil(ggrid,shifts); - double t2 = usecond(); //Open views to padded gauge links and keep open over mu loop typedef LatticeView GaugeViewType; @@ -1208,13 +1306,20 @@ public: free(Ug_dirs_v_host); acceleratorFreeDevice(Ug_dirs_v); - double t3 = usecond(); + double t1 = usecond(); - std::cout << GridLogPerformance << "RectStaplePaddedAll timings: coord:" << (t1-t0)/1000 << "ms, stencil:" << (t2-t1)/1000 << "ms, kernel:" << (t3-t2)/1000 << "ms" << std::endl; + std::cout << GridLogPerformance << "RectStaplePaddedAll timings:" << (t1-t0)/1000 << "ms" << std::endl; } - - + //A workspace for reusing the PaddedCell and GeneralLocalStencil objects + class StapleAndRectStapleAllWorkspace: public WilsonLoopPaddedWorkspace{ + public: + StapleAndRectStapleAllWorkspace(){ + this->addStencil(new StaplePaddedAllWorkspace); + this->addStencil(new RectStaplePaddedAllWorkspace); + } + }; + ////////////////////////////////////////////////////// //Compute the 1x1 and 1x2 staples for all orientations //Stap : Array of staples (Nd) @@ -1222,27 +1327,39 @@ public: //U: Gauge links in each direction (Nd) ///////////////////////////////////////////////////// static void StapleAndRectStapleAll(std::vector &Stap, std::vector &RectStap, const std::vector &U){ + StapleAndRectStapleAllWorkspace wk; + StapleAndRectStapleAll(Stap,RectStap,U,wk); + } + + ////////////////////////////////////////////////////// + //Compute the 1x1 and 1x2 staples for all orientations + //Stap : Array of staples (Nd) + //RectStap: Array of rectangular staples (Nd) + //U: Gauge links in each direction (Nd) + //wk: a workspace containing stored PaddedCell and GeneralLocalStencil objects to maximize reuse + ///////////////////////////////////////////////////// + static void StapleAndRectStapleAll(std::vector &Stap, std::vector &RectStap, const std::vector &U, StapleAndRectStapleAllWorkspace &wk){ #if 0 StapleAll(Stap, U); RectStapleAll(RectStap, U); #else double t0 = usecond(); - //Use the padded cell with maximal reuse - PaddedCell Ghost(2, dynamic_cast(U[0].Grid())); + + GridCartesian* unpadded_grid = dynamic_cast(U[0].Grid()); + const PaddedCell &Ghost = wk.getPaddedCell(unpadded_grid); + CshiftImplGauge cshift_impl; std::vector U_pad(Nd, Ghost.grids.back()); for(int mu=0;mu