mirror of
https://github.com/paboyle/Grid.git
synced 2025-06-17 07:17:06 +01:00
Hadrons: overhaul of A2A for production
This commit is contained in:
@ -4,7 +4,7 @@
|
||||
#include <Grid/Hadrons/Global.hpp>
|
||||
#include <Grid/Hadrons/Module.hpp>
|
||||
#include <Grid/Hadrons/ModuleFactory.hpp>
|
||||
#include <Grid/Hadrons/AllToAllVectors.hpp>
|
||||
#include <Grid/Hadrons/A2AVectors.hpp>
|
||||
#include <Grid/Eigen/unsupported/CXX11/Tensor>
|
||||
|
||||
BEGIN_HADRONS_NAMESPACE
|
||||
@ -24,7 +24,8 @@ class A2AMesonFieldPar : Serializable
|
||||
int, cacheBlock,
|
||||
int, schurBlock,
|
||||
int, Nmom,
|
||||
std::string, A2A,
|
||||
std::string, v,
|
||||
std::string, w,
|
||||
std::string, output);
|
||||
};
|
||||
|
||||
@ -34,9 +35,6 @@ class TA2AMesonField : public Module<A2AMesonFieldPar>
|
||||
public:
|
||||
FERM_TYPE_ALIASES(FImpl, );
|
||||
SOLVER_TYPE_ALIASES(FImpl, );
|
||||
|
||||
typedef A2AModesSchurDiagTwo<typename FImpl::FermionField, FMat, Solver> A2ABase;
|
||||
|
||||
public:
|
||||
// constructor
|
||||
TA2AMesonField(const std::string name);
|
||||
@ -80,7 +78,7 @@ TA2AMesonField<FImpl>::TA2AMesonField(const std::string name)
|
||||
template <typename FImpl>
|
||||
std::vector<std::string> TA2AMesonField<FImpl>::getInput(void)
|
||||
{
|
||||
std::vector<std::string> in = {par().A2A};
|
||||
std::vector<std::string> in = {par().v, par().w};
|
||||
|
||||
return in;
|
||||
}
|
||||
@ -97,20 +95,7 @@ std::vector<std::string> TA2AMesonField<FImpl>::getOutput(void)
|
||||
// setup ///////////////////////////////////////////////////////////////////////
|
||||
template <typename FImpl>
|
||||
void TA2AMesonField<FImpl>::setup(void)
|
||||
{
|
||||
auto &a2a = envGet(A2ABase, par().A2A);
|
||||
int Ls = env().getObjectLs(par().A2A);
|
||||
|
||||
// Four D fields
|
||||
envTmp(std::vector<FermionField>, "w", 1, par().schurBlock,
|
||||
FermionField(env().getGrid()));
|
||||
envTmp(std::vector<FermionField>, "v", 1, par().schurBlock,
|
||||
FermionField(env().getGrid()));
|
||||
|
||||
// 5D tmp
|
||||
envTmpLat(FermionField, "tmp_5d", Ls);
|
||||
}
|
||||
|
||||
{}
|
||||
|
||||
//////////////////////////////////////////////////////////////////////////////////
|
||||
// Cache blocked arithmetic routine
|
||||
@ -304,7 +289,8 @@ void TA2AMesonField<FImpl>::execute(void)
|
||||
{
|
||||
LOG(Message) << "Computing A2A meson field" << std::endl;
|
||||
|
||||
auto &a2a = envGet(A2ABase, par().A2A);
|
||||
auto &v = envGet(std::vector<FermionField>, par().v);
|
||||
auto &w = envGet(std::vector<FermionField>, par().w);
|
||||
|
||||
// 2+6+4+4 = 16 gammas
|
||||
// Ordering defined here
|
||||
@ -330,15 +316,13 @@ void TA2AMesonField<FImpl>::execute(void)
|
||||
///////////////////////////////////////////////
|
||||
// Square assumption for now Nl = Nr = N
|
||||
///////////////////////////////////////////////
|
||||
int nt = env().getDim(Tp);
|
||||
int nx = env().getDim(Xp);
|
||||
int ny = env().getDim(Yp);
|
||||
int nz = env().getDim(Zp);
|
||||
int Nl = a2a.get_Nl();
|
||||
int N = Nl + a2a.get_Nh();
|
||||
|
||||
int nt = env().getDim(Tp);
|
||||
int nx = env().getDim(Xp);
|
||||
int ny = env().getDim(Yp);
|
||||
int nz = env().getDim(Zp);
|
||||
int N_i = w.size();
|
||||
int N_j = v.size();
|
||||
int ngamma = gammas.size();
|
||||
|
||||
int schurBlock = par().schurBlock;
|
||||
int cacheBlock = par().cacheBlock;
|
||||
int nmom = par().Nmom;
|
||||
@ -353,14 +337,8 @@ void TA2AMesonField<FImpl>::execute(void)
|
||||
phases[m] = Complex(1.0); // All zero momentum for now
|
||||
}
|
||||
|
||||
Eigen::Tensor<ComplexD,5> mesonField (nmom,ngamma,nt,N,N);
|
||||
LOG(Message) << "N = Nh+Nl for A2A MesonField is " << N << std::endl;
|
||||
|
||||
envGetTmp(std::vector<FermionField>, w);
|
||||
envGetTmp(std::vector<FermionField>, v);
|
||||
envGetTmp(FermionField, tmp_5d);
|
||||
|
||||
LOG(Message) << "Finding v and w vectors for N = " << N << std::endl;
|
||||
Eigen::Tensor<ComplexD,5> mesonField(nmom,ngamma,nt,N_i,N_j);
|
||||
LOG(Message) << "MesonField size " << N_i << "x" << N_j << "x" << nt << std::endl;
|
||||
|
||||
//////////////////////////////////////////////////////////////////////////
|
||||
// i,j is first loop over SchurBlock factors reusing 5D matrices
|
||||
@ -379,10 +357,10 @@ void TA2AMesonField<FImpl>::execute(void)
|
||||
double t_int_2=0;
|
||||
double t_int_3=0;
|
||||
|
||||
double t0 = usecond();
|
||||
int N_i = N;
|
||||
int N_j = N;
|
||||
|
||||
double t0 = usecond();
|
||||
int NBlock_i = N_i/schurBlock + (((N_i % schurBlock) != 0) ? 1 : 0);
|
||||
int NBlock_j = N_j/schurBlock + (((N_j % schurBlock) != 0) ? 1 : 0);
|
||||
|
||||
for(int i=0;i<N_i;i+=schurBlock) //loop over SchurBlocking to suppress 5D matrix overhead
|
||||
for(int j=0;j<N_j;j+=schurBlock)
|
||||
{
|
||||
@ -393,12 +371,13 @@ void TA2AMesonField<FImpl>::execute(void)
|
||||
int N_jj = MIN(N_j-j,schurBlock);
|
||||
|
||||
t_schur-=usecond();
|
||||
for(int ii =0;ii < N_ii;ii++) a2a.return_w(i+ii, tmp_5d, w[ii]);
|
||||
for(int jj =0;jj < N_jj;jj++) a2a.return_v(j+jj, tmp_5d, v[jj]);
|
||||
t_schur+=usecond();
|
||||
|
||||
LOG(Message) << "Found w vectors " << i <<" .. " << i+N_ii-1 << std::endl;
|
||||
LOG(Message) << "Found v vectors " << j <<" .. " << j+N_jj-1 << std::endl;
|
||||
LOG(Message) << "Meson field block "
|
||||
<< j/schurBlock + NBlock_j*i/schurBlock + 1
|
||||
<< "/" << NBlock_i*NBlock_j << " [" << i <<" .. "
|
||||
<< i+N_ii-1 << ", " << j <<" .. " << j+N_jj-1 << "]"
|
||||
<< std::endl;
|
||||
|
||||
///////////////////////////////////////////////////////////////
|
||||
// Series of cache blocked chunks of the contractions within this SchurBlock
|
||||
@ -411,11 +390,11 @@ void TA2AMesonField<FImpl>::execute(void)
|
||||
Eigen::Tensor<ComplexD,5> mesonFieldBlocked(nmom,ngamma,nt,N_iii,N_jjj);
|
||||
|
||||
t_contr-=usecond();
|
||||
MesonField(mesonFieldBlocked, &w[ii], &v[jj], gammas, phases,Tp,
|
||||
t_int_0,t_int_1,t_int_2,t_int_3);
|
||||
MesonField(mesonFieldBlocked, &w[i+ii], &v[j+jj], gammas, phases,Tp,
|
||||
t_int_0,t_int_1,t_int_2,t_int_3);
|
||||
t_contr+=usecond();
|
||||
// flops for general N_c & N_s
|
||||
flops += vol * ( 2 * 8.0 + 6.0 + 8.0*nmom) * N_iii*N_jjj*ngamma;
|
||||
|
||||
bytes += vol * (12.0 * sizeof(Complex) ) * N_iii*N_jjj
|
||||
+ vol * ( 2.0 * sizeof(Complex) *nmom ) * N_iii*N_jjj* ngamma;
|
||||
|
||||
@ -435,17 +414,17 @@ void TA2AMesonField<FImpl>::execute(void)
|
||||
|
||||
double nodes=grid->NodeCount();
|
||||
double t1 = usecond();
|
||||
LOG(Message) << " Contraction of MesonFields took "<<(t1-t0)/1.0e6<< " seconds " << std::endl;
|
||||
LOG(Message) << " Schur "<<(t_schur)/1.0e6<< " seconds " << std::endl;
|
||||
LOG(Message) << " Contr "<<(t_contr)/1.0e6<< " seconds " << std::endl;
|
||||
LOG(Message) << " Intern0 "<<(t_int_0)/1.0e6<< " seconds " << std::endl;
|
||||
LOG(Message) << " Intern1 "<<(t_int_1)/1.0e6<< " seconds " << std::endl;
|
||||
LOG(Message) << " Intern2 "<<(t_int_2)/1.0e6<< " seconds " << std::endl;
|
||||
LOG(Message) << " Intern3 "<<(t_int_3)/1.0e6<< " seconds " << std::endl;
|
||||
LOG(Message) << "Contraction of MesonFields took "<<(t1-t0)/1.0e6<< " s" << std::endl;
|
||||
LOG(Message) << " Schur " << (t_schur)/1.0e6 << " s" << std::endl;
|
||||
LOG(Message) << " Contr " << (t_contr)/1.0e6 << " s" << std::endl;
|
||||
LOG(Message) << " Intern0 " << (t_int_0)/1.0e6 << " s" << std::endl;
|
||||
LOG(Message) << " Intern1 " << (t_int_1)/1.0e6 << " s" << std::endl;
|
||||
LOG(Message) << " Intern2 " << (t_int_2)/1.0e6 << " s" << std::endl;
|
||||
LOG(Message) << " Intern3 " << (t_int_3)/1.0e6 << " s" << std::endl;
|
||||
|
||||
double t_kernel = t_int_0 + t_int_1;
|
||||
LOG(Message) << " Arith "<<flops/(t_kernel)/1.0e3/nodes<< " Gflop/s / node " << std::endl;
|
||||
LOG(Message) << " Arith "<<bytes/(t_kernel)/1.0e3/nodes<< " GB/s /node " << std::endl;
|
||||
LOG(Message) << " Arith " << flops/(t_kernel)/1.0e3/nodes << " Gflop/s/ node " << std::endl;
|
||||
LOG(Message) << " Arith " << bytes/(t_kernel)/1.0e3/nodes << " GB/s/node " << std::endl;
|
||||
|
||||
/////////////////////////////////////////////////////////////////////////
|
||||
// Test: Build the pion correlator (two end)
|
||||
@ -453,8 +432,8 @@ void TA2AMesonField<FImpl>::execute(void)
|
||||
/////////////////////////////////////////////////////////////////////////
|
||||
std::vector<ComplexD> corr(nt,ComplexD(0.0));
|
||||
|
||||
for(int i=0;i<N;i++)
|
||||
for(int j=0;j<N;j++)
|
||||
for(int i=0;i<N_i;i++)
|
||||
for(int j=0;j<N_j;j++)
|
||||
{
|
||||
int m=0; // first momentum
|
||||
int g=0; // first gamma in above ordering is gamma5 for pion
|
||||
|
Reference in New Issue
Block a user