From fe98e9f5554317f7bf5757bd959116ee7220e40d Mon Sep 17 00:00:00 2001 From: Chulwoo Jung Date: Tue, 13 Feb 2024 12:06:08 -0500 Subject: [PATCH] Fixing Laplace flopcount Minor cleanup --- Grid/qcd/utils/CovariantLaplacian.h | 113 ++++++++++++++++------------ benchmarks/Benchmark_ITT.cc | 2 +- 2 files changed, 67 insertions(+), 48 deletions(-) diff --git a/Grid/qcd/utils/CovariantLaplacian.h b/Grid/qcd/utils/CovariantLaplacian.h index 9e1204d4..7410a991 100644 --- a/Grid/qcd/utils/CovariantLaplacian.h +++ b/Grid/qcd/utils/CovariantLaplacian.h @@ -330,6 +330,72 @@ public: coalescedWrite(out[ss], res,lane); }); }; + + virtual void MDerivLink(const Field &_in, Field &_out) + { + /////////////////////////////////////////////// + // Halo exchange for this geometry of stencil + /////////////////////////////////////////////// + Stencil.HaloExchange(_in, Compressor); + + /////////////////////////////////// + // Arithmetic expressions + /////////////////////////////////// +// auto st = Stencil.View(AcceleratorRead); + autoView( st , Stencil , AcceleratorRead); + auto buf = st.CommBuf(); + + autoView( in , _in , AcceleratorRead); + autoView( out , _out , AcceleratorWrite); + autoView( U , Uds , AcceleratorRead); + + typedef typename Field::vector_object vobj; + typedef decltype(coalescedRead(in[0])) calcObj; + typedef decltype(coalescedRead(U[0](0))) calcLink; + + const int Nsimd = vobj::Nsimd(); + const uint64_t NN = grid->oSites(); + + accelerator_for( ss, NN, Nsimd, { + + StencilEntry *SE; + + const int lane=acceleratorSIMTlane(Nsimd); + + calcObj chi; + calcObj res; + calcObj Uchi; + calcObj Utmp; + calcObj Utmp2; + calcLink UU; + calcLink Udag; + int ptype; + + res = coalescedRead(in[ss])*(-8.0); + +#define LEG_LOAD_MULT(leg,polarisation) \ + UU = coalescedRead(U[ss](polarisation)); \ + Udag = adj(UU); \ + LEG_LOAD(leg); \ + mult(&Utmp(), &UU, &chi()); \ + Utmp2 = adj(Utmp); \ + mult(&Utmp(), &UU, &Utmp2()); \ + Uchi = adj(Utmp); \ + res = res + Uchi; + + LEG_LOAD_MULT(0,Xp); + LEG_LOAD_MULT(1,Yp); + LEG_LOAD_MULT(2,Zp); + LEG_LOAD_MULT(3,Tp); + LEG_LOAD_MULT(4,Xm); + LEG_LOAD_MULT(5,Ym); + LEG_LOAD_MULT(6,Zm); + LEG_LOAD_MULT(7,Tm); + + coalescedWrite(out[ss], res,lane); + }); + + }; virtual void M(const Field &in, Field &out) {Mnew(in,out);}; virtual void Mdag (const Field &in, Field &out) { M(in,out);}; // Laplacian is hermitian virtual void Mdiag (const Field &in, Field &out) {assert(0);}; // Unimplemented need only for multigrid @@ -432,53 +498,6 @@ public: // std::cout << GridLogDebug <<"M:norm2(out) = "< sum(in.Grid(),Nd); - std::vector sum2(in.Grid(),Nd); - std::vector in_nu(in.Grid(),Nd); - std::vector out_nu(in.Grid(),Nd); - - for (int nu = 0; nu < Nd; nu++) { - sum[nu] = Zero(); - in_nu[nu] = PeekIndex(in, nu); - out_nu[nu] = a0*in_nu[nu]; - for (int mu = 0; mu < Nd; mu++) { - tmp = U[mu] * Cshift(in_nu[nu], mu, +1) * adj(U[mu]); - tmp2 = adj(U[mu]) * in_nu[nu] * U[mu]; - sum[nu] += tmp + Cshift(tmp2, mu, -1) - 2.0 * in_nu; - } - out_nu[nu] += a1* 1. / (double(4 * Nd)) * sum[nu]; - sum2[nu] = Zero(); - for (int mu = 0; mu < Nd; mu++) { - tmp = U[mu] * Cshift(sum[nu], mu, +1) * adj(U[mu]); - tmp2 = adj(U[mu]) * in_nu * U[mu]; - sum2[nu] += tmp + Cshift(tmp2, mu, -1) - 2.0 * in_nu; - } - out_nu[nu] += a2* ( 1. / (double(4 * Nd)))^2 * sum[nu]; - PokeIndex(out, out_nu[nu], nu); - } -#else - for (int nu = 0; nu < Nd; nu++) { - GaugeLinkField in_nu = PeekIndex(in, nu); - GaugeLinkField out_nu(out.Grid()); - GaugeLinkField sum(out.Grid()); - GaugeLinkField sum2(out.Grid()); - out_nu=a0*in_nu; - LapStencil.M(in_nu,sum); - out_nu += a1* 1. / (double(4 * Nd)) * sum; - LapStencil.M(sum,sum2); - out_nu += a2* ( 1. / (double(4 * Nd)))^2 * sum2; -// out_nu += (1.0 - kappa) * in_nu - kappa / (double(4 * Nd)) * sum; - PokeIndex(out, out_nu, nu); - } -#endif - } -#endif void MDeriv(const GaugeField& in, GaugeField& der) { // in is anti-hermitian diff --git a/benchmarks/Benchmark_ITT.cc b/benchmarks/Benchmark_ITT.cc index 91ecc9c9..40c230c0 100644 --- a/benchmarks/Benchmark_ITT.cc +++ b/benchmarks/Benchmark_ITT.cc @@ -768,7 +768,7 @@ public: double volume=1; for(int mu=0;mu