mirror of
				https://github.com/paboyle/Grid.git
				synced 2025-11-04 05:54:32 +00:00 
			
		
		
		
	Improving efficiency of the force term
This commit is contained in:
		@@ -30,17 +30,34 @@ directory
 | 
			
		||||
#ifndef SCALAR_INT_ACTION_H
 | 
			
		||||
#define SCALAR_INT_ACTION_H
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
// Note: this action can completely absorb the ScalarAction for real float fields
 | 
			
		||||
// use the scalarObjs to generalise the structure
 | 
			
		||||
 | 
			
		||||
namespace Grid {
 | 
			
		||||
  // FIXME drop the QCD namespace everywhere here
 | 
			
		||||
 | 
			
		||||
template <class Impl>
 | 
			
		||||
class ScalarInteractionAction : public QCD::Action<typename Impl::Field> {
 | 
			
		||||
public:
 | 
			
		||||
    INHERIT_FIELD_TYPES(Impl);
 | 
			
		||||
private:
 | 
			
		||||
    RealD mass_square;
 | 
			
		||||
    RealD lambda;
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
    typedef typename Field::vector_object vobj;
 | 
			
		||||
    typedef CartesianStencil<vobj,vobj> Stencil;
 | 
			
		||||
 | 
			
		||||
    SimpleCompressor<vobj> compressor;
 | 
			
		||||
    int npoint = 8;
 | 
			
		||||
    std::vector<int> directions    = {0,1,2,3,0,1,2,3};  // forcing 4 dimensions
 | 
			
		||||
    std::vector<int> displacements = {1,1,1,1, -1,-1,-1,-1};
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
 public:
 | 
			
		||||
    INHERIT_FIELD_TYPES(Impl);
 | 
			
		||||
    ScalarInteractionAction(RealD ms, RealD l) : mass_square(ms), lambda(l) {}
 | 
			
		||||
 | 
			
		||||
    ScalarInteractionAction(RealD ms, RealD l) : mass_square(ms), lambda(l){}
 | 
			
		||||
 | 
			
		||||
    virtual std::string LogParameters() {
 | 
			
		||||
      std::stringstream sstream;
 | 
			
		||||
@@ -51,27 +68,75 @@ class ScalarInteractionAction : public QCD::Action<typename Impl::Field> {
 | 
			
		||||
 | 
			
		||||
    virtual std::string action_name() {return "ScalarAction";}
 | 
			
		||||
 | 
			
		||||
    virtual void refresh(const Field &U,
 | 
			
		||||
                         GridParallelRNG &pRNG) {}  // noop as no pseudoferms
 | 
			
		||||
    virtual void refresh(const Field &U, GridParallelRNG &pRNG) {}
 | 
			
		||||
 | 
			
		||||
    virtual RealD S(const Field &p) {
 | 
			
		||||
        Field action(p._grid);
 | 
			
		||||
        Field pshift(p._grid);
 | 
			
		||||
        Field phisquared(p._grid);
 | 
			
		||||
        static Stencil phiStencil(p._grid, npoint, 0, directions, displacements);
 | 
			
		||||
        phiStencil.HaloExchange(p, compressor);
 | 
			
		||||
 | 
			
		||||
        Field action(p._grid), pshift(p._grid), phisquared(p._grid);
 | 
			
		||||
        phisquared = p*p;
 | 
			
		||||
        action = (2.0*QCD::Nd + mass_square)*phisquared + lambda*phisquared*phisquared;
 | 
			
		||||
        for (int mu = 0; mu < QCD::Nd; mu++) {
 | 
			
		||||
            pshift = Cshift(p, mu, +1);  // not efficient implement with stencils
 | 
			
		||||
            action -= pshift*p + p*pshift;
 | 
			
		||||
            //  pshift = Cshift(p, mu, +1);  // not efficient, implement with stencils
 | 
			
		||||
            PARALLEL_FOR_LOOP
 | 
			
		||||
            for (int i = 0; i < p._grid->oSites(); i++) {
 | 
			
		||||
                int permute_type;
 | 
			
		||||
                StencilEntry *SE;
 | 
			
		||||
                vobj temp2;
 | 
			
		||||
                vobj *temp;
 | 
			
		||||
                vobj *t_p;
 | 
			
		||||
 | 
			
		||||
                SE = phiStencil.GetEntry(permute_type, mu, i);
 | 
			
		||||
                t_p  = &p._odata[i];
 | 
			
		||||
                if ( SE->_is_local ) {
 | 
			
		||||
                    temp = &p._odata[SE->_offset];
 | 
			
		||||
                    if ( SE->_permute ) {
 | 
			
		||||
                        permute(temp2, *temp, permute_type);
 | 
			
		||||
                        action._odata[i] -= temp2*(*t_p) + (*t_p)*temp2;
 | 
			
		||||
                    } else {
 | 
			
		||||
                  action._odata[i] -= *temp*(*t_p) + (*t_p)*(*temp);
 | 
			
		||||
                    }
 | 
			
		||||
                } else {
 | 
			
		||||
                  action._odata[i] -= phiStencil.CommBuf()[SE->_offset]*(*t_p) + (*t_p)*phiStencil.CommBuf()[SE->_offset];
 | 
			
		||||
                }
 | 
			
		||||
            }
 | 
			
		||||
            //  action -= pshift*p + p*pshift;
 | 
			
		||||
        }
 | 
			
		||||
        // NB the trace in the algebra is normalised to 1/2
 | 
			
		||||
        // minus sign coming from the antihermitian fields
 | 
			
		||||
        return -(TensorRemove(sum(trace(action)))).real();
 | 
			
		||||
    };
 | 
			
		||||
 | 
			
		||||
    virtual void deriv(const Field &p,
 | 
			
		||||
                       Field &force) {
 | 
			
		||||
    virtual void deriv(const Field &p, Field &force) {
 | 
			
		||||
        force = (2.0*QCD::Nd + mass_square)*p + 2.0*lambda*p*p*p;
 | 
			
		||||
        // following is inefficient
 | 
			
		||||
        for (int mu = 0; mu < QCD::Nd; mu++) force -= Cshift(p, mu, -1) + Cshift(p, mu, 1);
 | 
			
		||||
        // move this outside
 | 
			
		||||
        static Stencil phiStencil(p._grid, npoint, 0, directions, displacements);
 | 
			
		||||
        phiStencil.HaloExchange(p, compressor);
 | 
			
		||||
 | 
			
		||||
        //for (int mu = 0; mu < QCD::Nd; mu++) force -= Cshift(p, mu, -1) + Cshift(p, mu, 1);
 | 
			
		||||
        for (int point = 0; point < npoint; point++) {
 | 
			
		||||
            PARALLEL_FOR_LOOP
 | 
			
		||||
            for (int i = 0; i < p._grid->oSites(); i++) {
 | 
			
		||||
                vobj *temp;
 | 
			
		||||
                vobj temp2;
 | 
			
		||||
                int permute_type;
 | 
			
		||||
                StencilEntry *SE;
 | 
			
		||||
                SE = phiStencil.GetEntry(permute_type, point, i);
 | 
			
		||||
 | 
			
		||||
                if ( SE->_is_local ) {
 | 
			
		||||
                    temp = &p._odata[SE->_offset];
 | 
			
		||||
                    if ( SE->_permute ) {
 | 
			
		||||
                        permute(temp2, *temp, permute_type);
 | 
			
		||||
                        force._odata[i] -= temp2;
 | 
			
		||||
                    } else {
 | 
			
		||||
                        force._odata[i] -= *temp;
 | 
			
		||||
                    }
 | 
			
		||||
                } else {
 | 
			
		||||
                    force._odata[i] -= phiStencil.CommBuf()[SE->_offset];
 | 
			
		||||
                }
 | 
			
		||||
            }
 | 
			
		||||
        }
 | 
			
		||||
    }
 | 
			
		||||
};
 | 
			
		||||
 | 
			
		||||
 
 | 
			
		||||
@@ -1,6 +1,6 @@
 | 
			
		||||
    /*************************************************************************************
 | 
			
		||||
 | 
			
		||||
    Grid physics library, www.github.com/paboyle/Grid 
 | 
			
		||||
    Grid physics library, www.github.com/paboyle/Grid
 | 
			
		||||
 | 
			
		||||
    Source file: ./tests/Test_stencil.cc
 | 
			
		||||
 | 
			
		||||
@@ -33,9 +33,8 @@ using namespace std;
 | 
			
		||||
using namespace Grid;
 | 
			
		||||
using namespace Grid::QCD;
 | 
			
		||||
 | 
			
		||||
int main (int argc, char ** argv)
 | 
			
		||||
{
 | 
			
		||||
  Grid_init(&argc,&argv);
 | 
			
		||||
int main(int argc, char ** argv) {
 | 
			
		||||
  Grid_init(&argc, &argv);
 | 
			
		||||
 | 
			
		||||
  //  typedef LatticeColourMatrix Field;
 | 
			
		||||
  typedef LatticeComplex Field;
 | 
			
		||||
@@ -47,7 +46,7 @@ int main (int argc, char ** argv)
 | 
			
		||||
  std::vector<int> mpi_layout  = GridDefaultMpi();
 | 
			
		||||
 | 
			
		||||
  double volume = latt_size[0]*latt_size[1]*latt_size[2]*latt_size[3];
 | 
			
		||||
    
 | 
			
		||||
 | 
			
		||||
  GridCartesian Fine(latt_size,simd_layout,mpi_layout);
 | 
			
		||||
  GridRedBlackCartesian rbFine(latt_size,simd_layout,mpi_layout);
 | 
			
		||||
  GridParallelRNG       fRNG(&Fine);
 | 
			
		||||
@@ -55,14 +54,14 @@ int main (int argc, char ** argv)
 | 
			
		||||
  //  fRNG.SeedRandomDevice();
 | 
			
		||||
  std::vector<int> seeds({1,2,3,4});
 | 
			
		||||
  fRNG.SeedFixedIntegers(seeds);
 | 
			
		||||
  
 | 
			
		||||
 | 
			
		||||
  Field Foo(&Fine);
 | 
			
		||||
  Field Bar(&Fine);
 | 
			
		||||
  Field Check(&Fine);
 | 
			
		||||
  Field Diff(&Fine);
 | 
			
		||||
  LatticeComplex lex(&Fine);
 | 
			
		||||
 | 
			
		||||
  lex = zero;  
 | 
			
		||||
  lex = zero;
 | 
			
		||||
  random(fRNG,Foo);
 | 
			
		||||
  gaussian(fRNG,Bar);
 | 
			
		||||
 | 
			
		||||
@@ -98,7 +97,7 @@ int main (int argc, char ** argv)
 | 
			
		||||
	  Fine.oCoorFromOindex(ocoor,o);
 | 
			
		||||
	  ocoor[dir]=(ocoor[dir]+disp)%Fine._rdimensions[dir];
 | 
			
		||||
	}
 | 
			
		||||
	
 | 
			
		||||
 | 
			
		||||
	SimpleCompressor<vobj> compress;
 | 
			
		||||
	myStencil.HaloExchange(Foo,compress);
 | 
			
		||||
 | 
			
		||||
@@ -106,16 +105,16 @@ int main (int argc, char ** argv)
 | 
			
		||||
 | 
			
		||||
	// Implement a stencil code that should agree with cshift!
 | 
			
		||||
	for(int i=0;i<Check._grid->oSites();i++){
 | 
			
		||||
	  
 | 
			
		||||
 | 
			
		||||
	  int permute_type;
 | 
			
		||||
	  StencilEntry *SE;
 | 
			
		||||
	  SE = myStencil.GetEntry(permute_type,0,i);
 | 
			
		||||
	  
 | 
			
		||||
 | 
			
		||||
	  if ( SE->_is_local && SE->_permute )
 | 
			
		||||
	    permute(Check._odata[i],Foo._odata[SE->_offset],permute_type);
 | 
			
		||||
	  else if (SE->_is_local)
 | 
			
		||||
	    Check._odata[i] = Foo._odata[SE->_offset];
 | 
			
		||||
	  else 
 | 
			
		||||
	  else
 | 
			
		||||
	    Check._odata[i] = myStencil.CommBuf()[SE->_offset];
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
@@ -144,7 +143,7 @@ int main (int argc, char ** argv)
 | 
			
		||||
		      <<") " <<check<<" vs "<<bar<<std::endl;
 | 
			
		||||
	  }
 | 
			
		||||
 | 
			
		||||
	 
 | 
			
		||||
 | 
			
		||||
	}}}}
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
@@ -179,18 +178,18 @@ int main (int argc, char ** argv)
 | 
			
		||||
	  Fine.oCoorFromOindex(ocoor,o);
 | 
			
		||||
	  ocoor[dir]=(ocoor[dir]+disp)%Fine._rdimensions[dir];
 | 
			
		||||
	}
 | 
			
		||||
	
 | 
			
		||||
 | 
			
		||||
	SimpleCompressor<vobj> compress;
 | 
			
		||||
 | 
			
		||||
	EStencil.HaloExchange(EFoo,compress);
 | 
			
		||||
	OStencil.HaloExchange(OFoo,compress);
 | 
			
		||||
	
 | 
			
		||||
 | 
			
		||||
	Bar = Cshift(Foo,dir,disp);
 | 
			
		||||
 | 
			
		||||
	if ( disp & 0x1 ) {
 | 
			
		||||
	  ECheck.checkerboard = Even;
 | 
			
		||||
	  OCheck.checkerboard = Odd;
 | 
			
		||||
	} else { 
 | 
			
		||||
	} else {
 | 
			
		||||
	  ECheck.checkerboard = Odd;
 | 
			
		||||
	  OCheck.checkerboard = Even;
 | 
			
		||||
	}
 | 
			
		||||
@@ -206,7 +205,7 @@ int main (int argc, char ** argv)
 | 
			
		||||
	    permute(OCheck._odata[i],EFoo._odata[SE->_offset],permute_type);
 | 
			
		||||
	  else if (SE->_is_local)
 | 
			
		||||
	    OCheck._odata[i] = EFoo._odata[SE->_offset];
 | 
			
		||||
	  else 
 | 
			
		||||
	  else
 | 
			
		||||
	    OCheck._odata[i] = EStencil.CommBuf()[SE->_offset];
 | 
			
		||||
	}
 | 
			
		||||
	for(int i=0;i<ECheck._grid->oSites();i++){
 | 
			
		||||
@@ -214,18 +213,18 @@ int main (int argc, char ** argv)
 | 
			
		||||
	  StencilEntry *SE;
 | 
			
		||||
	  SE = OStencil.GetEntry(permute_type,0,i);
 | 
			
		||||
	  //	  std::cout << "ODD source "<< i<<" -> " <<SE->_offset << " "<< SE->_is_local<<std::endl;
 | 
			
		||||
	  
 | 
			
		||||
 | 
			
		||||
	  if ( SE->_is_local && SE->_permute )
 | 
			
		||||
	    permute(ECheck._odata[i],OFoo._odata[SE->_offset],permute_type);
 | 
			
		||||
	  else if (SE->_is_local)
 | 
			
		||||
	    ECheck._odata[i] = OFoo._odata[SE->_offset];
 | 
			
		||||
	  else 
 | 
			
		||||
	  else
 | 
			
		||||
	    ECheck._odata[i] = OStencil.CommBuf()[SE->_offset];
 | 
			
		||||
	}
 | 
			
		||||
	
 | 
			
		||||
 | 
			
		||||
	setCheckerboard(Check,ECheck);
 | 
			
		||||
	setCheckerboard(Check,OCheck);
 | 
			
		||||
	
 | 
			
		||||
 | 
			
		||||
	Real nrmC = norm2(Check);
 | 
			
		||||
	Real nrmB = norm2(Bar);
 | 
			
		||||
	Diff = Check-Bar;
 | 
			
		||||
@@ -248,10 +247,10 @@ int main (int argc, char ** argv)
 | 
			
		||||
	  diff =norm2(ddiff);
 | 
			
		||||
	  if ( diff > 0){
 | 
			
		||||
	    std::cout <<"Coor (" << coor[0]<<","<<coor[1]<<","<<coor[2]<<","<<coor[3] <<") "
 | 
			
		||||
		      <<"shift "<<disp<<" dir "<< dir 
 | 
			
		||||
		      <<"shift "<<disp<<" dir "<< dir
 | 
			
		||||
		      << "  stencil impl " <<check<<" vs cshift impl "<<bar<<std::endl;
 | 
			
		||||
	  }
 | 
			
		||||
	 
 | 
			
		||||
 | 
			
		||||
	}}}}
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
 
 | 
			
		||||
@@ -26,7 +26,7 @@ See the full license in the file "LICENSE" in the top level distribution directo
 | 
			
		||||
*************************************************************************************/
 | 
			
		||||
/*  END LEGAL */
 | 
			
		||||
#include <Grid/Grid.h>
 | 
			
		||||
namespace Grid{
 | 
			
		||||
namespace Grid {
 | 
			
		||||
class ScalarActionParameters : Serializable {
 | 
			
		||||
 public:
 | 
			
		||||
  GRID_SERIALIZABLE_CLASS_MEMBERS(ScalarActionParameters,
 | 
			
		||||
@@ -44,7 +44,7 @@ int main(int argc, char **argv) {
 | 
			
		||||
  // here make a routine to print all the relevant information on the run
 | 
			
		||||
  std::cout << GridLogMessage << "Grid is setup to use " << threads << " threads" << std::endl;
 | 
			
		||||
 | 
			
		||||
   // Typedefs to simplify notation
 | 
			
		||||
  // Typedefs to simplify notation
 | 
			
		||||
  typedef ScalarAdjGenericHMCRunner HMCWrapper;  // Uses the default minimum norm, real scalar fields
 | 
			
		||||
 | 
			
		||||
  //::::::::::::::::::::::::::::::::::::::::::::::::::::::::::::::::::::
 | 
			
		||||
@@ -52,7 +52,7 @@ int main(int argc, char **argv) {
 | 
			
		||||
 | 
			
		||||
  // Grid from the command line
 | 
			
		||||
  GridModule ScalarGrid;
 | 
			
		||||
  ScalarGrid.set_full( SpaceTimeGrid::makeFourDimGrid(
 | 
			
		||||
  ScalarGrid.set_full(SpaceTimeGrid::makeFourDimGrid(
 | 
			
		||||
        GridDefaultLatt(), GridDefaultSimd(Nd, vComplex::Nsimd()),
 | 
			
		||||
        GridDefaultMpi()));
 | 
			
		||||
  ScalarGrid.set_rb(SpaceTimeGrid::makeFourDimRedBlackGrid(ScalarGrid.get_full()));
 | 
			
		||||
@@ -89,12 +89,11 @@ int main(int argc, char **argv) {
 | 
			
		||||
  /////////////////////////////////////////////////////////////
 | 
			
		||||
 | 
			
		||||
  // HMC parameters are serialisable
 | 
			
		||||
  TheHMC.Parameters.MD.MDsteps = 10;
 | 
			
		||||
  TheHMC.Parameters.MD.MDsteps = 20;
 | 
			
		||||
  TheHMC.Parameters.MD.trajL   = 1.0;
 | 
			
		||||
 | 
			
		||||
  TheHMC.ReadCommandLine(argc, argv);
 | 
			
		||||
  TheHMC.Run();
 | 
			
		||||
 | 
			
		||||
  Grid_finalize();
 | 
			
		||||
 | 
			
		||||
} // main
 | 
			
		||||
}  // main
 | 
			
		||||
 
 | 
			
		||||
		Reference in New Issue
	
	Block a user