1
0
mirror of https://github.com/paboyle/Grid.git synced 2024-11-09 23:45:36 +00:00

Getting rid of one more non-auto View, comms overlap in Laplace operator

This commit is contained in:
Chulwoo Jung 2024-02-25 22:37:48 -05:00
parent fe98e9f555
commit cfa0576ffd
2 changed files with 109 additions and 67 deletions

View File

@ -109,6 +109,74 @@ public:
};
virtual GridBase *Grid(void) { return grid; };
//broken
#if 0
virtual void MDeriv(const Field &_left, Field &_right,Field &_der, int mu)
{
///////////////////////////////////////////////
// Halo exchange for this geometry of stencil
///////////////////////////////////////////////
Stencil.HaloExchange(_lef, Compressor);
///////////////////////////////////
// Arithmetic expressions
///////////////////////////////////
autoView( st , Stencil , AcceleratorRead);
auto buf = st.CommBuf();
autoView( in , _left , AcceleratorRead);
autoView( right , _right , AcceleratorRead);
autoView( der , _der , AcceleratorWrite);
autoView( U , Uds , AcceleratorRead);
typedef typename Field::vector_object vobj;
typedef decltype(coalescedRead(left[0])) calcObj;
typedef decltype(coalescedRead(U[0](0))) calcLink;
const int Nsimd = vobj::Nsimd();
const uint64_t NN = grid->oSites();
accelerator_for( ss, NN, Nsimd, {
StencilEntry *SE;
const int lane=acceleratorSIMTlane(Nsimd);
calcObj chi;
calcObj phi;
calcObj res;
calcObj Uchi;
calcObj Utmp;
calcObj Utmp2;
calcLink UU;
calcLink Udag;
int ptype;
res = coalescedRead(def[ss]);
phi = coalescedRead(right[ss]);
#define LEG_LOAD_MULT_LINK(leg,polarisation) \
UU = coalescedRead(U[ss](polarisation)); \
Udag = adj(UU); \
LEG_LOAD(leg); \
mult(&Utmp(), &UU, &chi()); \
Utmp2 = adj(Utmp); \
mult(&Utmp(), &UU, &Utmp2()); \
Utmp2 = adj(Utmp); \
mult(&Uchi(), &phi(), &Utmp2()); \
res = res + Uchi;
LEG_LOAD_MULT_LINK(0,Xp);
LEG_LOAD_MULT_LINK(1,Yp);
LEG_LOAD_MULT_LINK(2,Zp);
LEG_LOAD_MULT_LINK(3,Tp);
coalescedWrite(der[ss], res,lane);
});
};
#endif
virtual void Morig(const Field &_in, Field &_out)
{
///////////////////////////////////////////////
@ -331,71 +399,6 @@ public:
});
};
virtual void MDerivLink(const Field &_in, Field &_out)
{
///////////////////////////////////////////////
// Halo exchange for this geometry of stencil
///////////////////////////////////////////////
Stencil.HaloExchange(_in, Compressor);
///////////////////////////////////
// Arithmetic expressions
///////////////////////////////////
// auto st = Stencil.View(AcceleratorRead);
autoView( st , Stencil , AcceleratorRead);
auto buf = st.CommBuf();
autoView( in , _in , AcceleratorRead);
autoView( out , _out , AcceleratorWrite);
autoView( U , Uds , AcceleratorRead);
typedef typename Field::vector_object vobj;
typedef decltype(coalescedRead(in[0])) calcObj;
typedef decltype(coalescedRead(U[0](0))) calcLink;
const int Nsimd = vobj::Nsimd();
const uint64_t NN = grid->oSites();
accelerator_for( ss, NN, Nsimd, {
StencilEntry *SE;
const int lane=acceleratorSIMTlane(Nsimd);
calcObj chi;
calcObj res;
calcObj Uchi;
calcObj Utmp;
calcObj Utmp2;
calcLink UU;
calcLink Udag;
int ptype;
res = coalescedRead(in[ss])*(-8.0);
#define LEG_LOAD_MULT(leg,polarisation) \
UU = coalescedRead(U[ss](polarisation)); \
Udag = adj(UU); \
LEG_LOAD(leg); \
mult(&Utmp(), &UU, &chi()); \
Utmp2 = adj(Utmp); \
mult(&Utmp(), &UU, &Utmp2()); \
Uchi = adj(Utmp); \
res = res + Uchi;
LEG_LOAD_MULT(0,Xp);
LEG_LOAD_MULT(1,Yp);
LEG_LOAD_MULT(2,Zp);
LEG_LOAD_MULT(3,Tp);
LEG_LOAD_MULT(4,Xm);
LEG_LOAD_MULT(5,Ym);
LEG_LOAD_MULT(6,Zm);
LEG_LOAD_MULT(7,Tm);
coalescedWrite(out[ss], res,lane);
});
};
virtual void M(const Field &in, Field &out) {Mnew(in,out);};
virtual void Mdag (const Field &in, Field &out) { M(in,out);}; // Laplacian is hermitian
virtual void Mdiag (const Field &in, Field &out) {assert(0);}; // Unimplemented need only for multigrid
@ -404,6 +407,7 @@ public:
};
#undef LEG_LOAD_MULT
#undef LEG_LOAD_MULT_LINK
#undef LEG_LOAD
////////////////////////////////////////////////////////////

View File

@ -128,8 +128,8 @@ public:
void MDerivLink(const GaugeLinkField& left, const GaugeLinkField& right,
GaugeField& der) {
std::cout<<GridLogMessage << "MDerivLink start "<< std::endl;
RealD factor = -1. / (double(4 * Nd));
for (int mu = 0; mu < Nd; mu++) {
GaugeLinkField der_mu(der.Grid());
der_mu = Zero();
@ -141,7 +141,26 @@ public:
// }
PokeIndex<LorentzIndex>(der, -factor * der_mu, mu);
}
std::cout << GridLogDebug <<"MDerivLink: norm2(der) = "<<norm2(der)<<std::endl;
// std::cout << GridLogDebug <<"MDerivLink: norm2(der) = "<<norm2(der)<<std::endl;
std::cout<<GridLogMessage << "MDerivLink end "<< std::endl;
}
void MDerivLink(const GaugeLinkField& left, const GaugeLinkField& right,
std::vector<GaugeLinkField> & der) {
// std::cout<<GridLogMessage << "MDerivLink "<< std::endl;
RealD factor = -1. / (double(4 * Nd));
for (int mu = 0; mu < Nd; mu++) {
GaugeLinkField der_mu(left.Grid());
der_mu = Zero();
der_mu += U[mu] * Cshift(left, mu, 1) * adj(U[mu]) * right;
der_mu += U[mu] * Cshift(right, mu, 1) * adj(U[mu]) * left;
// PokeIndex<LorentzIndex>(der, -factor * der_mu, mu);
der[mu] = -factor*der_mu;
// std::cout << GridLogDebug <<"MDerivLink: norm2(der) = "<<norm2(der[mu])<<std::endl;
}
// std::cout<<GridLogMessage << "MDerivLink end "<< std::endl;
}
void MDerivInt(LaplacianRatParams &par, const GaugeField& left, const GaugeField& right,
@ -243,8 +262,12 @@ public:
GaugeField tempDer(left.Grid());
std::vector<GaugeLinkField> DerLink(Nd,left.Grid());
std::vector<GaugeLinkField> tempDerLink(Nd,left.Grid());
std::cout<<GridLogMessage << "force contraction "<< i <<std::endl;
// roctxRangePushA("RMHMC force contraction");
#if 0
MDerivLink(GMom,MinvMom[i],tempDer); der += coef*2*par.a1[i]*tempDer;
MDerivLink(left_nu,MinvGMom,tempDer); der += coef*2*par.a1[i]*tempDer;
MDerivLink(LMinvAGMom,MinvMom[i],tempDer); der += coef*-2.*par.b2*tempDer;
@ -253,6 +276,21 @@ public:
MDerivLink(AMinvMom,LMinvGMom,tempDer); der += coef*-2.*par.b2*tempDer;
MDerivLink(MinvAGMom,MinvMom[i],tempDer); der += coef*-2.*par.b1[i]*tempDer;
MDerivLink(AMinvMom,MinvGMom,tempDer); der += coef*-2.*par.b1[i]*tempDer;
#else
for (int mu=0;mu<Nd;mu++) DerLink[mu]=Zero();
MDerivLink(GMom,MinvMom[i],tempDerLink); for (int mu=0;mu<Nd;mu++) DerLink[mu] += coef*2*par.a1[i]*tempDerLink[mu];
MDerivLink(left_nu,MinvGMom,tempDerLink); for (int mu=0;mu<Nd;mu++) DerLink[mu] += coef*2*par.a1[i]*tempDerLink[mu];
MDerivLink(LMinvAGMom,MinvMom[i],tempDerLink); for (int mu=0;mu<Nd;mu++) DerLink[mu] += coef*-2.*par.b2*tempDerLink[mu];
MDerivLink(LMinvAMom,MinvGMom,tempDerLink); for (int mu=0;mu<Nd;mu++) DerLink[mu] += coef*-2.*par.b2*tempDerLink[mu];
MDerivLink(MinvAGMom,LMinvMom,tempDerLink); for (int mu=0;mu<Nd;mu++) DerLink[mu] += coef*-2.*par.b2*tempDerLink[mu];
MDerivLink(AMinvMom,LMinvGMom,tempDerLink); for (int mu=0;mu<Nd;mu++) DerLink[mu] += coef*-2.*par.b2*tempDerLink[mu];
MDerivLink(MinvAGMom,MinvMom[i],tempDerLink); for (int mu=0;mu<Nd;mu++) DerLink[mu] += coef*-2.*par.b1[i]*tempDerLink[mu];
MDerivLink(AMinvMom,MinvGMom,tempDerLink); for (int mu=0;mu<Nd;mu++) DerLink[mu] += coef*-2.*par.b1[i]*tempDerLink[mu];
// PokeIndex<LorentzIndex>(der, -factor * der_mu, mu);
for (int mu=0;mu<Nd;mu++) PokeIndex<LorentzIndex>(tempDer, tempDerLink[mu], mu);
der += tempDer;
#endif
std::cout<<GridLogMessage << "coef = force contraction "<< i << "done "<< coef <<std::endl;
// roctxRangePop();