1
0
mirror of https://github.com/paboyle/Grid.git synced 2024-11-09 23:45:36 +00:00

Fixing Laplace flopcount Minor cleanup

This commit is contained in:
Chulwoo Jung 2024-02-13 12:06:08 -05:00
parent 948d16fb06
commit fe98e9f555
2 changed files with 67 additions and 48 deletions

View File

@ -330,6 +330,72 @@ public:
coalescedWrite(out[ss], res,lane); coalescedWrite(out[ss], res,lane);
}); });
}; };
virtual void MDerivLink(const Field &_in, Field &_out)
{
///////////////////////////////////////////////
// Halo exchange for this geometry of stencil
///////////////////////////////////////////////
Stencil.HaloExchange(_in, Compressor);
///////////////////////////////////
// Arithmetic expressions
///////////////////////////////////
// auto st = Stencil.View(AcceleratorRead);
autoView( st , Stencil , AcceleratorRead);
auto buf = st.CommBuf();
autoView( in , _in , AcceleratorRead);
autoView( out , _out , AcceleratorWrite);
autoView( U , Uds , AcceleratorRead);
typedef typename Field::vector_object vobj;
typedef decltype(coalescedRead(in[0])) calcObj;
typedef decltype(coalescedRead(U[0](0))) calcLink;
const int Nsimd = vobj::Nsimd();
const uint64_t NN = grid->oSites();
accelerator_for( ss, NN, Nsimd, {
StencilEntry *SE;
const int lane=acceleratorSIMTlane(Nsimd);
calcObj chi;
calcObj res;
calcObj Uchi;
calcObj Utmp;
calcObj Utmp2;
calcLink UU;
calcLink Udag;
int ptype;
res = coalescedRead(in[ss])*(-8.0);
#define LEG_LOAD_MULT(leg,polarisation) \
UU = coalescedRead(U[ss](polarisation)); \
Udag = adj(UU); \
LEG_LOAD(leg); \
mult(&Utmp(), &UU, &chi()); \
Utmp2 = adj(Utmp); \
mult(&Utmp(), &UU, &Utmp2()); \
Uchi = adj(Utmp); \
res = res + Uchi;
LEG_LOAD_MULT(0,Xp);
LEG_LOAD_MULT(1,Yp);
LEG_LOAD_MULT(2,Zp);
LEG_LOAD_MULT(3,Tp);
LEG_LOAD_MULT(4,Xm);
LEG_LOAD_MULT(5,Ym);
LEG_LOAD_MULT(6,Zm);
LEG_LOAD_MULT(7,Tm);
coalescedWrite(out[ss], res,lane);
});
};
virtual void M(const Field &in, Field &out) {Mnew(in,out);}; virtual void M(const Field &in, Field &out) {Mnew(in,out);};
virtual void Mdag (const Field &in, Field &out) { M(in,out);}; // Laplacian is hermitian virtual void Mdag (const Field &in, Field &out) { M(in,out);}; // Laplacian is hermitian
virtual void Mdiag (const Field &in, Field &out) {assert(0);}; // Unimplemented need only for multigrid virtual void Mdiag (const Field &in, Field &out) {assert(0);}; // Unimplemented need only for multigrid
@ -432,53 +498,6 @@ public:
// std::cout << GridLogDebug <<"M:norm2(out) = "<<norm2(out)<<std::endl; // std::cout << GridLogDebug <<"M:norm2(out) = "<<norm2(out)<<std::endl;
} }
#if 0
void Quad(const GaugeField& in, GaugeField& out,RealD a0,RealD a1,RealD a2) {
GaugeLinkField tmp(in.Grid());
GaugeLinkField tmp2(in.Grid());
#if 0
std::vector<GaugeLinkField> sum(in.Grid(),Nd);
std::vector<GaugeLinkField> sum2(in.Grid(),Nd);
std::vector<GaugeLinkField> in_nu(in.Grid(),Nd);
std::vector<GaugeLinkField> out_nu(in.Grid(),Nd);
for (int nu = 0; nu < Nd; nu++) {
sum[nu] = Zero();
in_nu[nu] = PeekIndex<LorentzIndex>(in, nu);
out_nu[nu] = a0*in_nu[nu];
for (int mu = 0; mu < Nd; mu++) {
tmp = U[mu] * Cshift(in_nu[nu], mu, +1) * adj(U[mu]);
tmp2 = adj(U[mu]) * in_nu[nu] * U[mu];
sum[nu] += tmp + Cshift(tmp2, mu, -1) - 2.0 * in_nu;
}
out_nu[nu] += a1* 1. / (double(4 * Nd)) * sum[nu];
sum2[nu] = Zero();
for (int mu = 0; mu < Nd; mu++) {
tmp = U[mu] * Cshift(sum[nu], mu, +1) * adj(U[mu]);
tmp2 = adj(U[mu]) * in_nu * U[mu];
sum2[nu] += tmp + Cshift(tmp2, mu, -1) - 2.0 * in_nu;
}
out_nu[nu] += a2* ( 1. / (double(4 * Nd)))^2 * sum[nu];
PokeIndex<LorentzIndex>(out, out_nu[nu], nu);
}
#else
for (int nu = 0; nu < Nd; nu++) {
GaugeLinkField in_nu = PeekIndex<LorentzIndex>(in, nu);
GaugeLinkField out_nu(out.Grid());
GaugeLinkField sum(out.Grid());
GaugeLinkField sum2(out.Grid());
out_nu=a0*in_nu;
LapStencil.M(in_nu,sum);
out_nu += a1* 1. / (double(4 * Nd)) * sum;
LapStencil.M(sum,sum2);
out_nu += a2* ( 1. / (double(4 * Nd)))^2 * sum2;
// out_nu += (1.0 - kappa) * in_nu - kappa / (double(4 * Nd)) * sum;
PokeIndex<LorentzIndex>(out, out_nu, nu);
}
#endif
}
#endif
void MDeriv(const GaugeField& in, GaugeField& der) { void MDeriv(const GaugeField& in, GaugeField& der) {
// in is anti-hermitian // in is anti-hermitian

View File

@ -768,7 +768,7 @@ public:
double volume=1; for(int mu=0;mu<Nd;mu++) volume=volume*latt4[mu]; double volume=1; for(int mu=0;mu<Nd;mu++) volume=volume*latt4[mu];
// double flops=(1146.0*volume)/2; // double flops=(1146.0*volume)/2;
double flops=(2*8*216.0*volume); double flops=(2*2*8*216.0*volume);
double mf_hi, mf_lo, mf_err; double mf_hi, mf_lo, mf_err;
timestat.statistics(t_time); timestat.statistics(t_time);