1
0
mirror of https://github.com/paboyle/Grid.git synced 2025-06-17 07:17:06 +01:00

Hadrons: overhaul of A2A for production

This commit is contained in:
2018-08-07 18:27:59 +01:00
parent 231cc95be6
commit 0677adb4dd
6 changed files with 558 additions and 531 deletions

View File

@ -4,7 +4,7 @@
#include <Grid/Hadrons/Global.hpp>
#include <Grid/Hadrons/Module.hpp>
#include <Grid/Hadrons/ModuleFactory.hpp>
#include <Grid/Hadrons/AllToAllVectors.hpp>
#include <Grid/Hadrons/A2AVectors.hpp>
#include <Grid/Eigen/unsupported/CXX11/Tensor>
BEGIN_HADRONS_NAMESPACE
@ -24,7 +24,8 @@ class A2AMesonFieldPar : Serializable
int, cacheBlock,
int, schurBlock,
int, Nmom,
std::string, A2A,
std::string, v,
std::string, w,
std::string, output);
};
@ -34,9 +35,6 @@ class TA2AMesonField : public Module<A2AMesonFieldPar>
public:
FERM_TYPE_ALIASES(FImpl, );
SOLVER_TYPE_ALIASES(FImpl, );
typedef A2AModesSchurDiagTwo<typename FImpl::FermionField, FMat, Solver> A2ABase;
public:
// constructor
TA2AMesonField(const std::string name);
@ -80,7 +78,7 @@ TA2AMesonField<FImpl>::TA2AMesonField(const std::string name)
template <typename FImpl>
std::vector<std::string> TA2AMesonField<FImpl>::getInput(void)
{
std::vector<std::string> in = {par().A2A};
std::vector<std::string> in = {par().v, par().w};
return in;
}
@ -97,20 +95,7 @@ std::vector<std::string> TA2AMesonField<FImpl>::getOutput(void)
// setup ///////////////////////////////////////////////////////////////////////
template <typename FImpl>
void TA2AMesonField<FImpl>::setup(void)
{
auto &a2a = envGet(A2ABase, par().A2A);
int Ls = env().getObjectLs(par().A2A);
// Four D fields
envTmp(std::vector<FermionField>, "w", 1, par().schurBlock,
FermionField(env().getGrid()));
envTmp(std::vector<FermionField>, "v", 1, par().schurBlock,
FermionField(env().getGrid()));
// 5D tmp
envTmpLat(FermionField, "tmp_5d", Ls);
}
{}
//////////////////////////////////////////////////////////////////////////////////
// Cache blocked arithmetic routine
@ -304,7 +289,8 @@ void TA2AMesonField<FImpl>::execute(void)
{
LOG(Message) << "Computing A2A meson field" << std::endl;
auto &a2a = envGet(A2ABase, par().A2A);
auto &v = envGet(std::vector<FermionField>, par().v);
auto &w = envGet(std::vector<FermionField>, par().w);
// 2+6+4+4 = 16 gammas
// Ordering defined here
@ -330,15 +316,13 @@ void TA2AMesonField<FImpl>::execute(void)
///////////////////////////////////////////////
// Square assumption for now Nl = Nr = N
///////////////////////////////////////////////
int nt = env().getDim(Tp);
int nx = env().getDim(Xp);
int ny = env().getDim(Yp);
int nz = env().getDim(Zp);
int Nl = a2a.get_Nl();
int N = Nl + a2a.get_Nh();
int nt = env().getDim(Tp);
int nx = env().getDim(Xp);
int ny = env().getDim(Yp);
int nz = env().getDim(Zp);
int N_i = w.size();
int N_j = v.size();
int ngamma = gammas.size();
int schurBlock = par().schurBlock;
int cacheBlock = par().cacheBlock;
int nmom = par().Nmom;
@ -353,14 +337,8 @@ void TA2AMesonField<FImpl>::execute(void)
phases[m] = Complex(1.0); // All zero momentum for now
}
Eigen::Tensor<ComplexD,5> mesonField (nmom,ngamma,nt,N,N);
LOG(Message) << "N = Nh+Nl for A2A MesonField is " << N << std::endl;
envGetTmp(std::vector<FermionField>, w);
envGetTmp(std::vector<FermionField>, v);
envGetTmp(FermionField, tmp_5d);
LOG(Message) << "Finding v and w vectors for N = " << N << std::endl;
Eigen::Tensor<ComplexD,5> mesonField(nmom,ngamma,nt,N_i,N_j);
LOG(Message) << "MesonField size " << N_i << "x" << N_j << "x" << nt << std::endl;
//////////////////////////////////////////////////////////////////////////
// i,j is first loop over SchurBlock factors reusing 5D matrices
@ -379,10 +357,10 @@ void TA2AMesonField<FImpl>::execute(void)
double t_int_2=0;
double t_int_3=0;
double t0 = usecond();
int N_i = N;
int N_j = N;
double t0 = usecond();
int NBlock_i = N_i/schurBlock + (((N_i % schurBlock) != 0) ? 1 : 0);
int NBlock_j = N_j/schurBlock + (((N_j % schurBlock) != 0) ? 1 : 0);
for(int i=0;i<N_i;i+=schurBlock) //loop over SchurBlocking to suppress 5D matrix overhead
for(int j=0;j<N_j;j+=schurBlock)
{
@ -393,12 +371,13 @@ void TA2AMesonField<FImpl>::execute(void)
int N_jj = MIN(N_j-j,schurBlock);
t_schur-=usecond();
for(int ii =0;ii < N_ii;ii++) a2a.return_w(i+ii, tmp_5d, w[ii]);
for(int jj =0;jj < N_jj;jj++) a2a.return_v(j+jj, tmp_5d, v[jj]);
t_schur+=usecond();
LOG(Message) << "Found w vectors " << i <<" .. " << i+N_ii-1 << std::endl;
LOG(Message) << "Found v vectors " << j <<" .. " << j+N_jj-1 << std::endl;
LOG(Message) << "Meson field block "
<< j/schurBlock + NBlock_j*i/schurBlock + 1
<< "/" << NBlock_i*NBlock_j << " [" << i <<" .. "
<< i+N_ii-1 << ", " << j <<" .. " << j+N_jj-1 << "]"
<< std::endl;
///////////////////////////////////////////////////////////////
// Series of cache blocked chunks of the contractions within this SchurBlock
@ -411,11 +390,11 @@ void TA2AMesonField<FImpl>::execute(void)
Eigen::Tensor<ComplexD,5> mesonFieldBlocked(nmom,ngamma,nt,N_iii,N_jjj);
t_contr-=usecond();
MesonField(mesonFieldBlocked, &w[ii], &v[jj], gammas, phases,Tp,
t_int_0,t_int_1,t_int_2,t_int_3);
MesonField(mesonFieldBlocked, &w[i+ii], &v[j+jj], gammas, phases,Tp,
t_int_0,t_int_1,t_int_2,t_int_3);
t_contr+=usecond();
// flops for general N_c & N_s
flops += vol * ( 2 * 8.0 + 6.0 + 8.0*nmom) * N_iii*N_jjj*ngamma;
bytes += vol * (12.0 * sizeof(Complex) ) * N_iii*N_jjj
+ vol * ( 2.0 * sizeof(Complex) *nmom ) * N_iii*N_jjj* ngamma;
@ -435,17 +414,17 @@ void TA2AMesonField<FImpl>::execute(void)
double nodes=grid->NodeCount();
double t1 = usecond();
LOG(Message) << " Contraction of MesonFields took "<<(t1-t0)/1.0e6<< " seconds " << std::endl;
LOG(Message) << " Schur "<<(t_schur)/1.0e6<< " seconds " << std::endl;
LOG(Message) << " Contr "<<(t_contr)/1.0e6<< " seconds " << std::endl;
LOG(Message) << " Intern0 "<<(t_int_0)/1.0e6<< " seconds " << std::endl;
LOG(Message) << " Intern1 "<<(t_int_1)/1.0e6<< " seconds " << std::endl;
LOG(Message) << " Intern2 "<<(t_int_2)/1.0e6<< " seconds " << std::endl;
LOG(Message) << " Intern3 "<<(t_int_3)/1.0e6<< " seconds " << std::endl;
LOG(Message) << "Contraction of MesonFields took "<<(t1-t0)/1.0e6<< " s" << std::endl;
LOG(Message) << " Schur " << (t_schur)/1.0e6 << " s" << std::endl;
LOG(Message) << " Contr " << (t_contr)/1.0e6 << " s" << std::endl;
LOG(Message) << " Intern0 " << (t_int_0)/1.0e6 << " s" << std::endl;
LOG(Message) << " Intern1 " << (t_int_1)/1.0e6 << " s" << std::endl;
LOG(Message) << " Intern2 " << (t_int_2)/1.0e6 << " s" << std::endl;
LOG(Message) << " Intern3 " << (t_int_3)/1.0e6 << " s" << std::endl;
double t_kernel = t_int_0 + t_int_1;
LOG(Message) << " Arith "<<flops/(t_kernel)/1.0e3/nodes<< " Gflop/s / node " << std::endl;
LOG(Message) << " Arith "<<bytes/(t_kernel)/1.0e3/nodes<< " GB/s /node " << std::endl;
LOG(Message) << " Arith " << flops/(t_kernel)/1.0e3/nodes << " Gflop/s/ node " << std::endl;
LOG(Message) << " Arith " << bytes/(t_kernel)/1.0e3/nodes << " GB/s/node " << std::endl;
/////////////////////////////////////////////////////////////////////////
// Test: Build the pion correlator (two end)
@ -453,8 +432,8 @@ void TA2AMesonField<FImpl>::execute(void)
/////////////////////////////////////////////////////////////////////////
std::vector<ComplexD> corr(nt,ComplexD(0.0));
for(int i=0;i<N;i++)
for(int j=0;j<N;j++)
for(int i=0;i<N_i;i++)
for(int j=0;j<N_j;j++)
{
int m=0; // first momentum
int g=0; // first gamma in above ordering is gamma5 for pion