mirror of
				https://github.com/paboyle/Grid.git
				synced 2025-11-03 21:44:33 +00:00 
			
		
		
		
	Improving efficiency of the force term
This commit is contained in:
		@@ -30,17 +30,34 @@ directory
 | 
			
		||||
#ifndef SCALAR_INT_ACTION_H
 | 
			
		||||
#define SCALAR_INT_ACTION_H
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
// Note: this action can completely absorb the ScalarAction for real float fields
 | 
			
		||||
// use the scalarObjs to generalise the structure
 | 
			
		||||
 | 
			
		||||
namespace Grid {
 | 
			
		||||
  // FIXME drop the QCD namespace everywhere here
 | 
			
		||||
 | 
			
		||||
template <class Impl>
 | 
			
		||||
class ScalarInteractionAction : public QCD::Action<typename Impl::Field> {
 | 
			
		||||
public:
 | 
			
		||||
    INHERIT_FIELD_TYPES(Impl);
 | 
			
		||||
private:
 | 
			
		||||
    RealD mass_square;
 | 
			
		||||
    RealD lambda;
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
    typedef typename Field::vector_object vobj;
 | 
			
		||||
    typedef CartesianStencil<vobj,vobj> Stencil;
 | 
			
		||||
 | 
			
		||||
    SimpleCompressor<vobj> compressor;
 | 
			
		||||
    int npoint = 8;
 | 
			
		||||
    std::vector<int> directions    = {0,1,2,3,0,1,2,3};  // forcing 4 dimensions
 | 
			
		||||
    std::vector<int> displacements = {1,1,1,1, -1,-1,-1,-1};
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
 public:
 | 
			
		||||
    INHERIT_FIELD_TYPES(Impl);
 | 
			
		||||
    ScalarInteractionAction(RealD ms, RealD l) : mass_square(ms), lambda(l) {}
 | 
			
		||||
 | 
			
		||||
    ScalarInteractionAction(RealD ms, RealD l) : mass_square(ms), lambda(l){}
 | 
			
		||||
 | 
			
		||||
    virtual std::string LogParameters() {
 | 
			
		||||
      std::stringstream sstream;
 | 
			
		||||
@@ -51,27 +68,75 @@ class ScalarInteractionAction : public QCD::Action<typename Impl::Field> {
 | 
			
		||||
 | 
			
		||||
    virtual std::string action_name() {return "ScalarAction";}
 | 
			
		||||
 | 
			
		||||
    virtual void refresh(const Field &U,
 | 
			
		||||
                         GridParallelRNG &pRNG) {}  // noop as no pseudoferms
 | 
			
		||||
    virtual void refresh(const Field &U, GridParallelRNG &pRNG) {}
 | 
			
		||||
 | 
			
		||||
    virtual RealD S(const Field &p) {
 | 
			
		||||
        Field action(p._grid);
 | 
			
		||||
        Field pshift(p._grid);
 | 
			
		||||
        Field phisquared(p._grid);
 | 
			
		||||
        static Stencil phiStencil(p._grid, npoint, 0, directions, displacements);
 | 
			
		||||
        phiStencil.HaloExchange(p, compressor);
 | 
			
		||||
 | 
			
		||||
        Field action(p._grid), pshift(p._grid), phisquared(p._grid);
 | 
			
		||||
        phisquared = p*p;
 | 
			
		||||
        action = (2.0*QCD::Nd + mass_square)*phisquared + lambda*phisquared*phisquared;
 | 
			
		||||
        for (int mu = 0; mu < QCD::Nd; mu++) {
 | 
			
		||||
            pshift = Cshift(p, mu, +1);  // not efficient implement with stencils
 | 
			
		||||
            action -= pshift*p + p*pshift;
 | 
			
		||||
            //  pshift = Cshift(p, mu, +1);  // not efficient, implement with stencils
 | 
			
		||||
            PARALLEL_FOR_LOOP
 | 
			
		||||
            for (int i = 0; i < p._grid->oSites(); i++) {
 | 
			
		||||
                int permute_type;
 | 
			
		||||
                StencilEntry *SE;
 | 
			
		||||
                vobj temp2;
 | 
			
		||||
                vobj *temp;
 | 
			
		||||
                vobj *t_p;
 | 
			
		||||
 | 
			
		||||
                SE = phiStencil.GetEntry(permute_type, mu, i);
 | 
			
		||||
                t_p  = &p._odata[i];
 | 
			
		||||
                if ( SE->_is_local ) {
 | 
			
		||||
                    temp = &p._odata[SE->_offset];
 | 
			
		||||
                    if ( SE->_permute ) {
 | 
			
		||||
                        permute(temp2, *temp, permute_type);
 | 
			
		||||
                        action._odata[i] -= temp2*(*t_p) + (*t_p)*temp2;
 | 
			
		||||
                    } else {
 | 
			
		||||
                  action._odata[i] -= *temp*(*t_p) + (*t_p)*(*temp);
 | 
			
		||||
                    }
 | 
			
		||||
                } else {
 | 
			
		||||
                  action._odata[i] -= phiStencil.CommBuf()[SE->_offset]*(*t_p) + (*t_p)*phiStencil.CommBuf()[SE->_offset];
 | 
			
		||||
                }
 | 
			
		||||
            }
 | 
			
		||||
            //  action -= pshift*p + p*pshift;
 | 
			
		||||
        }
 | 
			
		||||
        // NB the trace in the algebra is normalised to 1/2
 | 
			
		||||
        // minus sign coming from the antihermitian fields
 | 
			
		||||
        return -(TensorRemove(sum(trace(action)))).real();
 | 
			
		||||
    };
 | 
			
		||||
 | 
			
		||||
    virtual void deriv(const Field &p,
 | 
			
		||||
                       Field &force) {
 | 
			
		||||
    virtual void deriv(const Field &p, Field &force) {
 | 
			
		||||
        force = (2.0*QCD::Nd + mass_square)*p + 2.0*lambda*p*p*p;
 | 
			
		||||
        // following is inefficient
 | 
			
		||||
        for (int mu = 0; mu < QCD::Nd; mu++) force -= Cshift(p, mu, -1) + Cshift(p, mu, 1);
 | 
			
		||||
        // move this outside
 | 
			
		||||
        static Stencil phiStencil(p._grid, npoint, 0, directions, displacements);
 | 
			
		||||
        phiStencil.HaloExchange(p, compressor);
 | 
			
		||||
 | 
			
		||||
        //for (int mu = 0; mu < QCD::Nd; mu++) force -= Cshift(p, mu, -1) + Cshift(p, mu, 1);
 | 
			
		||||
        for (int point = 0; point < npoint; point++) {
 | 
			
		||||
            PARALLEL_FOR_LOOP
 | 
			
		||||
            for (int i = 0; i < p._grid->oSites(); i++) {
 | 
			
		||||
                vobj *temp;
 | 
			
		||||
                vobj temp2;
 | 
			
		||||
                int permute_type;
 | 
			
		||||
                StencilEntry *SE;
 | 
			
		||||
                SE = phiStencil.GetEntry(permute_type, point, i);
 | 
			
		||||
 | 
			
		||||
                if ( SE->_is_local ) {
 | 
			
		||||
                    temp = &p._odata[SE->_offset];
 | 
			
		||||
                    if ( SE->_permute ) {
 | 
			
		||||
                        permute(temp2, *temp, permute_type);
 | 
			
		||||
                        force._odata[i] -= temp2;
 | 
			
		||||
                    } else {
 | 
			
		||||
                        force._odata[i] -= *temp;
 | 
			
		||||
                    }
 | 
			
		||||
                } else {
 | 
			
		||||
                    force._odata[i] -= phiStencil.CommBuf()[SE->_offset];
 | 
			
		||||
                }
 | 
			
		||||
            }
 | 
			
		||||
        }
 | 
			
		||||
    }
 | 
			
		||||
};
 | 
			
		||||
 | 
			
		||||
 
 | 
			
		||||
		Reference in New Issue
	
	Block a user