mirror of
				https://github.com/paboyle/Grid.git
				synced 2025-11-03 21:44:33 +00:00 
			
		
		
		
	Fixing Laplace flopcount Minor cleanup
This commit is contained in:
		@@ -330,6 +330,72 @@ public:
 | 
			
		||||
	coalescedWrite(out[ss], res,lane);
 | 
			
		||||
    });
 | 
			
		||||
  };
 | 
			
		||||
 | 
			
		||||
  virtual void  MDerivLink(const Field &_in, Field &_out)
 | 
			
		||||
  {
 | 
			
		||||
    ///////////////////////////////////////////////
 | 
			
		||||
    // Halo exchange for this geometry of stencil
 | 
			
		||||
    ///////////////////////////////////////////////
 | 
			
		||||
    Stencil.HaloExchange(_in, Compressor);
 | 
			
		||||
 | 
			
		||||
    ///////////////////////////////////
 | 
			
		||||
    // Arithmetic expressions
 | 
			
		||||
    ///////////////////////////////////
 | 
			
		||||
//    auto st = Stencil.View(AcceleratorRead);
 | 
			
		||||
    autoView( st     , Stencil    , AcceleratorRead);
 | 
			
		||||
    auto buf = st.CommBuf();
 | 
			
		||||
 | 
			
		||||
    autoView( in     , _in    , AcceleratorRead);
 | 
			
		||||
    autoView( out    , _out   , AcceleratorWrite);
 | 
			
		||||
    autoView( U     , Uds    , AcceleratorRead);
 | 
			
		||||
 | 
			
		||||
    typedef typename Field::vector_object        vobj;
 | 
			
		||||
    typedef decltype(coalescedRead(in[0]))    calcObj;
 | 
			
		||||
    typedef decltype(coalescedRead(U[0](0))) calcLink;
 | 
			
		||||
 | 
			
		||||
    const int      Nsimd = vobj::Nsimd();
 | 
			
		||||
    const uint64_t NN = grid->oSites();
 | 
			
		||||
 | 
			
		||||
    accelerator_for( ss, NN, Nsimd, {
 | 
			
		||||
 | 
			
		||||
	StencilEntry *SE;
 | 
			
		||||
	
 | 
			
		||||
	const int lane=acceleratorSIMTlane(Nsimd);
 | 
			
		||||
 | 
			
		||||
	calcObj chi;
 | 
			
		||||
	calcObj res;
 | 
			
		||||
	calcObj Uchi;
 | 
			
		||||
	calcObj Utmp;
 | 
			
		||||
	calcObj Utmp2;
 | 
			
		||||
	calcLink UU;
 | 
			
		||||
	calcLink Udag;
 | 
			
		||||
	int ptype;
 | 
			
		||||
 | 
			
		||||
	res                 = coalescedRead(in[ss])*(-8.0);
 | 
			
		||||
 | 
			
		||||
#define LEG_LOAD_MULT(leg,polarisation)			\
 | 
			
		||||
	UU = coalescedRead(U[ss](polarisation));	\
 | 
			
		||||
	Udag = adj(UU);					\
 | 
			
		||||
	LEG_LOAD(leg);					\
 | 
			
		||||
	mult(&Utmp(), &UU, &chi());			\
 | 
			
		||||
	Utmp2 = adj(Utmp);				\
 | 
			
		||||
	mult(&Utmp(), &UU, &Utmp2());			\
 | 
			
		||||
	Uchi = adj(Utmp);				\
 | 
			
		||||
	res = res + Uchi;
 | 
			
		||||
	
 | 
			
		||||
	LEG_LOAD_MULT(0,Xp);
 | 
			
		||||
	LEG_LOAD_MULT(1,Yp);
 | 
			
		||||
	LEG_LOAD_MULT(2,Zp);
 | 
			
		||||
	LEG_LOAD_MULT(3,Tp);
 | 
			
		||||
	LEG_LOAD_MULT(4,Xm);
 | 
			
		||||
	LEG_LOAD_MULT(5,Ym);
 | 
			
		||||
	LEG_LOAD_MULT(6,Zm);
 | 
			
		||||
	LEG_LOAD_MULT(7,Tm);
 | 
			
		||||
 | 
			
		||||
	coalescedWrite(out[ss], res,lane);
 | 
			
		||||
    });
 | 
			
		||||
 | 
			
		||||
  };
 | 
			
		||||
  virtual void  M(const Field &in, Field &out) {Mnew(in,out);};
 | 
			
		||||
  virtual void  Mdag (const Field &in, Field &out) { M(in,out);}; // Laplacian is hermitian
 | 
			
		||||
  virtual  void Mdiag    (const Field &in, Field &out)                  {assert(0);}; // Unimplemented need only for multigrid
 | 
			
		||||
@@ -432,53 +498,6 @@ public:
 | 
			
		||||
//    std::cout << GridLogDebug <<"M:norm2(out) = "<<norm2(out)<<std::endl;
 | 
			
		||||
  }
 | 
			
		||||
 | 
			
		||||
#if 0
 | 
			
		||||
  void Quad(const GaugeField& in, GaugeField& out,RealD a0,RealD a1,RealD a2) {
 | 
			
		||||
 | 
			
		||||
    GaugeLinkField tmp(in.Grid());
 | 
			
		||||
    GaugeLinkField tmp2(in.Grid());
 | 
			
		||||
#if 0
 | 
			
		||||
    std::vector<GaugeLinkField> sum(in.Grid(),Nd);
 | 
			
		||||
    std::vector<GaugeLinkField> sum2(in.Grid(),Nd);
 | 
			
		||||
    std::vector<GaugeLinkField> in_nu(in.Grid(),Nd);
 | 
			
		||||
    std::vector<GaugeLinkField> out_nu(in.Grid(),Nd);
 | 
			
		||||
 | 
			
		||||
    for (int nu = 0; nu < Nd; nu++) {
 | 
			
		||||
      sum[nu] = Zero();
 | 
			
		||||
      in_nu[nu] = PeekIndex<LorentzIndex>(in, nu);
 | 
			
		||||
      out_nu[nu] = a0*in_nu[nu];
 | 
			
		||||
      for (int mu = 0; mu < Nd; mu++) {
 | 
			
		||||
        tmp = U[mu] * Cshift(in_nu[nu], mu, +1) * adj(U[mu]);
 | 
			
		||||
        tmp2 = adj(U[mu]) * in_nu[nu] * U[mu];
 | 
			
		||||
        sum[nu] += tmp + Cshift(tmp2, mu, -1) - 2.0 * in_nu;
 | 
			
		||||
      }
 | 
			
		||||
      out_nu[nu] +=  a1*  1. / (double(4 * Nd)) * sum[nu];
 | 
			
		||||
      sum2[nu] = Zero();
 | 
			
		||||
      for (int mu = 0; mu < Nd; mu++) {
 | 
			
		||||
        tmp = U[mu] * Cshift(sum[nu], mu, +1) * adj(U[mu]);
 | 
			
		||||
        tmp2 = adj(U[mu]) * in_nu * U[mu];
 | 
			
		||||
        sum2[nu] += tmp + Cshift(tmp2, mu, -1) - 2.0 * in_nu;
 | 
			
		||||
      }
 | 
			
		||||
      out_nu[nu] +=  a2* ( 1. / (double(4 * Nd)))^2 * sum[nu];
 | 
			
		||||
      PokeIndex<LorentzIndex>(out, out_nu[nu], nu);
 | 
			
		||||
    }
 | 
			
		||||
#else
 | 
			
		||||
    for (int nu = 0; nu < Nd; nu++) {
 | 
			
		||||
      GaugeLinkField in_nu = PeekIndex<LorentzIndex>(in, nu);
 | 
			
		||||
      GaugeLinkField out_nu(out.Grid());
 | 
			
		||||
      GaugeLinkField sum(out.Grid());
 | 
			
		||||
      GaugeLinkField sum2(out.Grid());
 | 
			
		||||
      out_nu=a0*in_nu;
 | 
			
		||||
      LapStencil.M(in_nu,sum);
 | 
			
		||||
      out_nu +=  a1*  1. / (double(4 * Nd)) * sum;
 | 
			
		||||
      LapStencil.M(sum,sum2);
 | 
			
		||||
      out_nu +=  a2* ( 1. / (double(4 * Nd)))^2 * sum2;
 | 
			
		||||
//      out_nu += (1.0 - kappa) * in_nu - kappa / (double(4 * Nd)) * sum;
 | 
			
		||||
      PokeIndex<LorentzIndex>(out, out_nu, nu);
 | 
			
		||||
    }
 | 
			
		||||
#endif
 | 
			
		||||
  }
 | 
			
		||||
#endif
 | 
			
		||||
 | 
			
		||||
  void MDeriv(const GaugeField& in, GaugeField& der) {
 | 
			
		||||
    // in is anti-hermitian
 | 
			
		||||
 
 | 
			
		||||
@@ -768,7 +768,7 @@ public:
 | 
			
		||||
	
 | 
			
		||||
	double volume=1;  for(int mu=0;mu<Nd;mu++) volume=volume*latt4[mu];
 | 
			
		||||
//	double flops=(1146.0*volume)/2;
 | 
			
		||||
	double flops=(2*8*216.0*volume);
 | 
			
		||||
	double flops=(2*2*8*216.0*volume);
 | 
			
		||||
	double mf_hi, mf_lo, mf_err;
 | 
			
		||||
	
 | 
			
		||||
	timestat.statistics(t_time);
 | 
			
		||||
 
 | 
			
		||||
		Reference in New Issue
	
	Block a user