From 38806343a873ea10264c79103db31182d6770947 Mon Sep 17 00:00:00 2001 From: Guido Cossu Date: Wed, 15 Mar 2017 15:16:16 +0900 Subject: [PATCH] Improving efficiency of the force term --- .../action/scalar/ScalarInteractionAction.h | 91 ++++++++++++++++--- tests/Test_stencil.cc | 43 +++++---- tests/hmc/Test_hmc_ScalarActionNxN.cc | 11 +-- 3 files changed, 104 insertions(+), 41 deletions(-) diff --git a/lib/qcd/action/scalar/ScalarInteractionAction.h b/lib/qcd/action/scalar/ScalarInteractionAction.h index 2607b041..5a322a5e 100644 --- a/lib/qcd/action/scalar/ScalarInteractionAction.h +++ b/lib/qcd/action/scalar/ScalarInteractionAction.h @@ -30,17 +30,34 @@ directory #ifndef SCALAR_INT_ACTION_H #define SCALAR_INT_ACTION_H + +// Note: this action can completely absorb the ScalarAction for real float fields +// use the scalarObjs to generalise the structure + namespace Grid { // FIXME drop the QCD namespace everywhere here template class ScalarInteractionAction : public QCD::Action { +public: + INHERIT_FIELD_TYPES(Impl); +private: RealD mass_square; RealD lambda; + + typedef typename Field::vector_object vobj; + typedef CartesianStencil Stencil; + + SimpleCompressor compressor; + int npoint = 8; + std::vector directions = {0,1,2,3,0,1,2,3}; // forcing 4 dimensions + std::vector displacements = {1,1,1,1, -1,-1,-1,-1}; + + public: - INHERIT_FIELD_TYPES(Impl); - ScalarInteractionAction(RealD ms, RealD l) : mass_square(ms), lambda(l) {} + + ScalarInteractionAction(RealD ms, RealD l) : mass_square(ms), lambda(l){} virtual std::string LogParameters() { std::stringstream sstream; @@ -51,27 +68,75 @@ class ScalarInteractionAction : public QCD::Action { virtual std::string action_name() {return "ScalarAction";} - virtual void refresh(const Field &U, - GridParallelRNG &pRNG) {} // noop as no pseudoferms + virtual void refresh(const Field &U, GridParallelRNG &pRNG) {} virtual RealD S(const Field &p) { - Field action(p._grid); - Field pshift(p._grid); - Field phisquared(p._grid); + static Stencil phiStencil(p._grid, npoint, 0, directions, displacements); + phiStencil.HaloExchange(p, compressor); + + Field action(p._grid), pshift(p._grid), phisquared(p._grid); phisquared = p*p; action = (2.0*QCD::Nd + mass_square)*phisquared + lambda*phisquared*phisquared; for (int mu = 0; mu < QCD::Nd; mu++) { - pshift = Cshift(p, mu, +1); // not efficient implement with stencils - action -= pshift*p + p*pshift; + // pshift = Cshift(p, mu, +1); // not efficient, implement with stencils + PARALLEL_FOR_LOOP + for (int i = 0; i < p._grid->oSites(); i++) { + int permute_type; + StencilEntry *SE; + vobj temp2; + vobj *temp; + vobj *t_p; + + SE = phiStencil.GetEntry(permute_type, mu, i); + t_p = &p._odata[i]; + if ( SE->_is_local ) { + temp = &p._odata[SE->_offset]; + if ( SE->_permute ) { + permute(temp2, *temp, permute_type); + action._odata[i] -= temp2*(*t_p) + (*t_p)*temp2; + } else { + action._odata[i] -= *temp*(*t_p) + (*t_p)*(*temp); + } + } else { + action._odata[i] -= phiStencil.CommBuf()[SE->_offset]*(*t_p) + (*t_p)*phiStencil.CommBuf()[SE->_offset]; + } + } + // action -= pshift*p + p*pshift; } + // NB the trace in the algebra is normalised to 1/2 + // minus sign coming from the antihermitian fields return -(TensorRemove(sum(trace(action)))).real(); }; - virtual void deriv(const Field &p, - Field &force) { + virtual void deriv(const Field &p, Field &force) { force = (2.0*QCD::Nd + mass_square)*p + 2.0*lambda*p*p*p; - // following is inefficient - for (int mu = 0; mu < QCD::Nd; mu++) force -= Cshift(p, mu, -1) + Cshift(p, mu, 1); + // move this outside + static Stencil phiStencil(p._grid, npoint, 0, directions, displacements); + phiStencil.HaloExchange(p, compressor); + + //for (int mu = 0; mu < QCD::Nd; mu++) force -= Cshift(p, mu, -1) + Cshift(p, mu, 1); + for (int point = 0; point < npoint; point++) { + PARALLEL_FOR_LOOP + for (int i = 0; i < p._grid->oSites(); i++) { + vobj *temp; + vobj temp2; + int permute_type; + StencilEntry *SE; + SE = phiStencil.GetEntry(permute_type, point, i); + + if ( SE->_is_local ) { + temp = &p._odata[SE->_offset]; + if ( SE->_permute ) { + permute(temp2, *temp, permute_type); + force._odata[i] -= temp2; + } else { + force._odata[i] -= *temp; + } + } else { + force._odata[i] -= phiStencil.CommBuf()[SE->_offset]; + } + } + } } }; diff --git a/tests/Test_stencil.cc b/tests/Test_stencil.cc index 1b71b8a5..1d35e1bb 100644 --- a/tests/Test_stencil.cc +++ b/tests/Test_stencil.cc @@ -1,6 +1,6 @@ /************************************************************************************* - Grid physics library, www.github.com/paboyle/Grid + Grid physics library, www.github.com/paboyle/Grid Source file: ./tests/Test_stencil.cc @@ -33,9 +33,8 @@ using namespace std; using namespace Grid; using namespace Grid::QCD; -int main (int argc, char ** argv) -{ - Grid_init(&argc,&argv); +int main(int argc, char ** argv) { + Grid_init(&argc, &argv); // typedef LatticeColourMatrix Field; typedef LatticeComplex Field; @@ -47,7 +46,7 @@ int main (int argc, char ** argv) std::vector mpi_layout = GridDefaultMpi(); double volume = latt_size[0]*latt_size[1]*latt_size[2]*latt_size[3]; - + GridCartesian Fine(latt_size,simd_layout,mpi_layout); GridRedBlackCartesian rbFine(latt_size,simd_layout,mpi_layout); GridParallelRNG fRNG(&Fine); @@ -55,14 +54,14 @@ int main (int argc, char ** argv) // fRNG.SeedRandomDevice(); std::vector seeds({1,2,3,4}); fRNG.SeedFixedIntegers(seeds); - + Field Foo(&Fine); Field Bar(&Fine); Field Check(&Fine); Field Diff(&Fine); LatticeComplex lex(&Fine); - lex = zero; + lex = zero; random(fRNG,Foo); gaussian(fRNG,Bar); @@ -98,7 +97,7 @@ int main (int argc, char ** argv) Fine.oCoorFromOindex(ocoor,o); ocoor[dir]=(ocoor[dir]+disp)%Fine._rdimensions[dir]; } - + SimpleCompressor compress; myStencil.HaloExchange(Foo,compress); @@ -106,16 +105,16 @@ int main (int argc, char ** argv) // Implement a stencil code that should agree with cshift! for(int i=0;ioSites();i++){ - + int permute_type; StencilEntry *SE; SE = myStencil.GetEntry(permute_type,0,i); - + if ( SE->_is_local && SE->_permute ) permute(Check._odata[i],Foo._odata[SE->_offset],permute_type); else if (SE->_is_local) Check._odata[i] = Foo._odata[SE->_offset]; - else + else Check._odata[i] = myStencil.CommBuf()[SE->_offset]; } @@ -144,7 +143,7 @@ int main (int argc, char ** argv) <<") " < compress; EStencil.HaloExchange(EFoo,compress); OStencil.HaloExchange(OFoo,compress); - + Bar = Cshift(Foo,dir,disp); if ( disp & 0x1 ) { ECheck.checkerboard = Even; OCheck.checkerboard = Odd; - } else { + } else { ECheck.checkerboard = Odd; OCheck.checkerboard = Even; } @@ -206,7 +205,7 @@ int main (int argc, char ** argv) permute(OCheck._odata[i],EFoo._odata[SE->_offset],permute_type); else if (SE->_is_local) OCheck._odata[i] = EFoo._odata[SE->_offset]; - else + else OCheck._odata[i] = EStencil.CommBuf()[SE->_offset]; } for(int i=0;ioSites();i++){ @@ -214,18 +213,18 @@ int main (int argc, char ** argv) StencilEntry *SE; SE = OStencil.GetEntry(permute_type,0,i); // std::cout << "ODD source "<< i<<" -> " <_offset << " "<< SE->_is_local<_is_local && SE->_permute ) permute(ECheck._odata[i],OFoo._odata[SE->_offset],permute_type); else if (SE->_is_local) ECheck._odata[i] = OFoo._odata[SE->_offset]; - else + else ECheck._odata[i] = OStencil.CommBuf()[SE->_offset]; } - + setCheckerboard(Check,ECheck); setCheckerboard(Check,OCheck); - + Real nrmC = norm2(Check); Real nrmB = norm2(Bar); Diff = Check-Bar; @@ -248,10 +247,10 @@ int main (int argc, char ** argv) diff =norm2(ddiff); if ( diff > 0){ std::cout <<"Coor (" << coor[0]<<","< -namespace Grid{ +namespace Grid { class ScalarActionParameters : Serializable { public: GRID_SERIALIZABLE_CLASS_MEMBERS(ScalarActionParameters, @@ -44,7 +44,7 @@ int main(int argc, char **argv) { // here make a routine to print all the relevant information on the run std::cout << GridLogMessage << "Grid is setup to use " << threads << " threads" << std::endl; - // Typedefs to simplify notation + // Typedefs to simplify notation typedef ScalarAdjGenericHMCRunner HMCWrapper; // Uses the default minimum norm, real scalar fields //:::::::::::::::::::::::::::::::::::::::::::::::::::::::::::::::::::: @@ -52,7 +52,7 @@ int main(int argc, char **argv) { // Grid from the command line GridModule ScalarGrid; - ScalarGrid.set_full( SpaceTimeGrid::makeFourDimGrid( + ScalarGrid.set_full(SpaceTimeGrid::makeFourDimGrid( GridDefaultLatt(), GridDefaultSimd(Nd, vComplex::Nsimd()), GridDefaultMpi())); ScalarGrid.set_rb(SpaceTimeGrid::makeFourDimRedBlackGrid(ScalarGrid.get_full())); @@ -89,12 +89,11 @@ int main(int argc, char **argv) { ///////////////////////////////////////////////////////////// // HMC parameters are serialisable - TheHMC.Parameters.MD.MDsteps = 10; + TheHMC.Parameters.MD.MDsteps = 20; TheHMC.Parameters.MD.trajL = 1.0; TheHMC.ReadCommandLine(argc, argv); TheHMC.Run(); Grid_finalize(); - -} // main +} // main