1
0
mirror of https://github.com/paboyle/Grid.git synced 2025-04-24 12:45:56 +01:00

Finished the four quark optimisation for Bag parameters.

To do:

   Abstract the cache blocking from the contraction with lambda functions.
   Share code between PionFieldXX with and without momentum. Share with the Meson field code somehow.
   Assemble the WWVV in a standalone routine.
   Play similar lambda function trick for the four quark operator.
   Hack it first by doing the MesonField Routine in here too.
This commit is contained in:
Peter Boyle 2018-08-28 14:11:03 +01:00
parent 81287133f3
commit 5b8b630919

View File

@ -57,21 +57,13 @@ class TA2APionField : public Module<A2APionFieldPar>
virtual void DeltaFeq2(int dt_min,int dt_max, virtual void DeltaFeq2(int dt_min,int dt_max,
Eigen::Tensor<ComplexD,2> &dF2_fig8, Eigen::Tensor<ComplexD,2> &dF2_fig8,
Eigen::Tensor<ComplexD,2> &dF2_trtr, Eigen::Tensor<ComplexD,2> &dF2_trtr,
Eigen::Tensor<ComplexD,1> &den0,
Eigen::Tensor<ComplexD,1> &den1,
Eigen::Tensor<ComplexD,3> &WW_sd, Eigen::Tensor<ComplexD,3> &WW_sd,
const LatticeFermion *vs, const LatticeFermion *vs,
const LatticeFermion *vd, const LatticeFermion *vd,
int orthogdim); int orthogdim);
virtual void DeltaFeq2_alt(int dt_min,int dt_max,
Eigen::Tensor<ComplexD,2> &dF2_fig8,
Eigen::Tensor<ComplexD,2> &dF2_trtr,
Eigen::Tensor<ComplexD,1> &den0,
Eigen::Tensor<ComplexD,1> &den1,
Eigen::Tensor<ComplexD,3> &WW_sd,
const LatticeFermion *vs,
const LatticeFermion *vd,
int orthogdim);
/////////////////////////////////////// ///////////////////////////////////////
// Arithmetic help. Move to Grid?? // Arithmetic help. Move to Grid??
/////////////////////////////////////// ///////////////////////////////////////
@ -671,299 +663,8 @@ inline iScalar<vtype> traceGammaXGamma5(const iMatrix<vtype, Ns> &rhs)
// //
// = sum_x Trace( WW[t0] VV[t,x] WW[t1] VV[t,x] ) // = sum_x Trace( WW[t0] VV[t,x] WW[t1] VV[t,x] )
// //
// Might as well form Ns x Nj x Ngamma matrix
//
/////////////////////////////////////////////////////////////////////////////// ///////////////////////////////////////////////////////////////////////////////
//
template <typename FImpl>
void TA2APionField<FImpl>::DeltaFeq2(int dt_min,int dt_max,
Eigen::Tensor<ComplexD,2> &dF2_fig8,
Eigen::Tensor<ComplexD,2> &dF2_trtr,
Eigen::Tensor<ComplexD,3> &WW_sd,
const LatticeFermion *vs,
const LatticeFermion *vd,
int orthogdim)
{
LOG(Message) << "Computing A2A DeltaF=2 graph" << std::endl;
int dt = dt_min; // HACK ; should loop over dt
typedef typename FImpl::SiteSpinor vobj;
typedef typename vobj::scalar_object sobj;
typedef typename vobj::scalar_type scalar_type;
typedef typename vobj::vector_type vector_type;
typedef iSpinMatrix<vector_type> SpinMatrix_v;
typedef iSpinMatrix<scalar_type> SpinMatrix_s;
typedef iSinglet<vector_type> Scalar_v;
typedef iSinglet<scalar_type> Scalar_s;
int N_s = WW_sd.dimension(1);
int N_d = WW_sd.dimension(2);
GridBase *grid = vs[0]._grid;
const int nd = grid->_ndimension;
const int Nsimd = grid->Nsimd();
int Nt = grid->GlobalDimensions()[orthogdim];
dF2_trtr.resize(Nt,16);
dF2_fig8.resize(Nt,16);
for(int t=0;t<Nt;t++){
for(int g=0;g<dF2_trtr.dimension(1);g++) dF2_trtr(t,g)= ComplexD(0.0);
for(int g=0;g<dF2_fig8.dimension(1);g++) dF2_fig8(t,g)= ComplexD(0.0);
}
// std::vector<LatticePropagator> WWVV (Nt,grid)
int fd=grid->_fdimensions[orthogdim];
int ld=grid->_ldimensions[orthogdim];
int rd=grid->_rdimensions[orthogdim];
//////////////////////////////////////
// will locally sum vectors first
// sum across these down to scalars
// splitting the SIMD
//////////////////////////////////////
int MFrvol = rd*N_s*N_d;
int MFlvol = ld*N_s*N_d;
Vector<vector_type > lvSum(MFrvol);
parallel_for (int r = 0; r < MFrvol; r++){
lvSum[r] = zero;
}
Vector<scalar_type > lsSum(MFlvol);
parallel_for (int r = 0; r < MFlvol; r++){
lsSum[r]=scalar_type(0.0);
}
int e1= grid->_slice_nblock[orthogdim];
int e2= grid->_slice_block [orthogdim];
int stride=grid->_slice_stride[orthogdim];
Eigen::Tensor<Scalar_v,3> VgV_sd(N_s,N_d,16); // trace with dirac structure
Eigen::Tensor<Scalar_s,4> VgV_sd_l(N_s,N_d,16,Nsimd);
int Ng;
LOG(Message) << "Computing A2A DeltaF=2 graph entering site loop" << std::endl;
double t_tot =0;
double t_vv =0;
double t_extr =0;
double t_transp=0;
double t_WW =0;
double t_trtr =0;
double t_fig8 =0;
t_tot -=usecond();
for(int r=0;r<rd;r++){
LOG(Message) << "Computing A2A DeltaF=2 timeslice "<< r << "/"<<rd << std::endl;
int so=r*grid->_ostride[orthogdim]; // base offset for start of plane
for(int n=0;n<e1;n++){
for(int b=0;b<e2;b++){
int ss= so+n*stride+b;
///////////////////////////////////
// _ _
// O_VV+AA = s V_mu d s V_mu d
// _ _
// + s A_mu d s A_mu d
///////////////////////////////////
// assemble the v_s v_d spin outer, colour inner product matrix
// the #vecs will be large
t_vv -=usecond();
parallel_for(int d=0;d<N_d;d++){
auto _vd = conjugate(vd[d]._odata[ss]);
for(int s=0;s<N_s;s++){
SpinMatrix_v vv;
auto _vs = vs[s]._odata[ss];
for(int s1=0;s1<Ns;s1++){
for(int s2=0;s2<Ns;s2++){
vv()(s1,s2)() = _vd()(s2)(0) * _vs()(s1)(0)
+ _vd()(s2)(1) * _vs()(s1)(1)
+ _vd()(s2)(2) * _vs()(s1)(2);
}}
int g=0;
// VgV_sd(s,d,g++) = trace(vv); // S
// VgV_sd(s,d,g++)() = traceGamma5(vv());// P
VgV_sd(s,d,g++)() = traceGammaX(vv());// Vmu
VgV_sd(s,d,g++)() = traceGammaY(vv());
VgV_sd(s,d,g++)() = traceGammaZ(vv());
VgV_sd(s,d,g++)() = traceGammaT(vv());
VgV_sd(s,d,g++)() = traceGammaXGamma5(vv());// Amu
VgV_sd(s,d,g++)() = traceGammaYGamma5(vv());
VgV_sd(s,d,g++)() = traceGammaZGamma5(vv());
VgV_sd(s,d,g++)() = traceGammaTGamma5(vv());
/*
VgV_sd(s,d,g++)() = traceSigmaXY(vv);// Sigma_munu unimplemented
VgV_sd(s,d,g++)() = traceSigmaXZ(vv);
VgV_sd(s,d,g++)() = traceSigmaXT(vv);
VgV_sd(s,d,g++)() = traceSigmaYZ(vv);
VgV_sd(s,d,g++)() = traceSigmaYT(vv);
VgV_sd(s,d,g++)() = traceSigmaZT(vv);
*/
Ng = g;
}
}
t_vv +=usecond();
/////////////////////////////////////////////////////////////////////
// PLAN: Make use of the fact that the trace of the product of two
// matrices is a_ij b_ji = a_ij (b^T)_ij and write as a "ZDOT"
// with one of the matrices transposed.
//
// Further since we want to take the trace-product with nt
// other matrices can do these in a single pass and gain another 2x.
// Can do this data parallel on all SIMD lanes of VgV_sd, with a
// broadcast of WW[t0] and WW[t1].
//
//
// Strategies:
//
// Wick1: TR TR(tx) = sum_x [ Trace(WW[t0] VgV(t,x) ) x Trace( WW_[t1] VgV(t,x) ) ]
//
// Take Tr (WW[t'] . VgV^T) for all "t'", store in array.
//
// Accumulate (WW[t0].VgV). (WW[t1].VgV) for dt and (t-t0+nt) % nt.
//
// Wick2:
//
// Fig8(tx) = sum_x Trace( VV[t,x] WW[t0] VV[t,x] WW[t1] )
//
// for(t0)
// form (VV^T WW VV^T)^T = VV WW^T VV = M0
// for(t1)
// Accumulate Tr(M0 . WW[t1] in dt and (t-t0+nt)%nt
//
/////////////////////////////////////////////////////////////////////
// Loop over t0 accumulate the dT matrix elements.
for(int g=0;g<Ng;g++){
/////////////////////////////////////////////////
// break out into Nsimd sites worth of scalar code.
/////////////////////////////////////////////////
t_extr -=usecond();
parallel_for(int d=0;d<N_d;d++){
std::vector<Scalar_s> extracted(Nsimd);
Scalar_v temp;
for(int s=0;s<N_s;s++){
for(int l=0;l<Nsimd;l++){
temp = VgV_sd(s,d,g);
extract(temp,extracted);
for(int l=0;l<Nsimd;l++){
VgV_sd_l(s,d,g,l) = extracted[l]()()();
}
}// lane
}
}// s,d
t_extr +=usecond();
////////////////////////////////////////////
// Work on a series of scalar problems
////////////////////////////////////////////
for(int l=0;l<Nsimd;l++){
std::vector<int> icoor(nd);
grid->iCoorFromIindex(icoor,l);
int ttt = r+icoor[orthogdim]*rd;
Eigen::MatrixXcd VgV(N_d,N_s);
Eigen::MatrixXcd VgV_T(N_s,N_d);
Eigen::MatrixXcd WW0(N_s,N_d);
Eigen::MatrixXcd WW1(N_s,N_d);
/////////////////////////////////////////
// Single site VgV , scalar order
/////////////////////////////////////////
t_transp -=usecond();
parallel_for(int d=0;d<N_d;d++){ for(int s=0;s<N_s;s++){
// Note pre-transpose VgV in the copy out
VgV(d,s) = VgV_sd_l(s,d,g,l);
VgV_T(s,d) = VgV(d,s);
}}
t_transp +=usecond();
/////////////////////////////////////////
// loop over time planes of meson
/////////////////////////////////////////
for(int t0=0;t0<Nt;t0++){
int t1 = (t0+dt)%Nt; // Future loop over dT
int tt = (ttt-t0+Nt)%Nt; // Time of this site relative to t0
/////////////////////////////////////////////////////////////////
// Extract this pair of WW matrices infuture loop over dT
/////////////////////////////////////////////////////////////////
t_WW -=usecond();
parallel_for(int d=0;d<N_d;d++){ for(int s=0;s<N_s;s++){
WW0(s,d) = WW_sd(t0,s,d);
WW1(s,d) = WW_sd(t1,s,d);
}}
t_WW +=usecond();
/////////////////////////////////////////
// Wick1 -- transpose
/////////////////////////////////////////
// VgV_ds WW_sd
t_trtr -=usecond();
ComplexD trWW0VV = (VgV_T.array()*WW0.array()).sum();
ComplexD trWW1VV = (VgV_T.array()*WW1.array()).sum();
dF2_trtr(tt,g) += trWW0VV * trWW1VV;
t_trtr +=usecond();
/////////////////////////////////////////
// Wick2 -- transpose
/////////////////////////////////////////
// VgV_ds WW_sd VgV_d's'
//
// This is the time consuming loop.
t_fig8 -=usecond();
Eigen::MatrixXcd VVWW0VV = VgV * WW0 * VgV ;
Eigen::MatrixXcd VVWW0VV_T =VVWW0VV.transpose();
auto trVVWW0VVWW1 =(VVWW0VV_T.array()*WW1.array()).sum();
dF2_fig8(tt,g) += trVVWW0VVWW1;
t_fig8 +=usecond();
}// t0 loop
}// l loop
}// gamma
}// close loop over time plane and block stride loop within timeplane
}
}
grid->GlobalSumVector(&dF2_fig8(0),Nt*Ng);
grid->GlobalSumVector(&dF2_trtr(0),Nt*Ng);
t_tot +=usecond();
LOG(Message) << "Computing A2A DeltaF=2 graph t_tot " << t_tot << " us "<< std::endl;
LOG(Message) << "Computing A2A DeltaF=2 graph t_vv " << t_vv << " us "<< std::endl;
LOG(Message) << "Computing A2A DeltaF=2 graph t_extr " << t_extr << " us "<< std::endl;
LOG(Message) << "Computing A2A DeltaF=2 graph t_transp " << t_transp << " us "<< std::endl;
LOG(Message) << "Computing A2A DeltaF=2 graph t_WW " << t_WW << " us "<< std::endl;
LOG(Message) << "Computing A2A DeltaF=2 graph t_trtr " << t_trtr << " us "<< std::endl;
LOG(Message) << "Computing A2A DeltaF=2 graph t_fig8 " << t_fig8 << " us "<< std::endl;
}
// WW is w_s^dag (x) w_d (G5 implicitly absorbed) // WW is w_s^dag (x) w_d (G5 implicitly absorbed)
// //
// WWVV will have spin-col (x) spin-col tensor. // WWVV will have spin-col (x) spin-col tensor.
@ -975,19 +676,19 @@ void TA2APionField<FImpl>::DeltaFeq2(int dt_min,int dt_max,
// //
template <typename FImpl> template <typename FImpl>
void TA2APionField<FImpl>::DeltaFeq2_alt(int dt_min,int dt_max, void TA2APionField<FImpl>::DeltaFeq2(int dt_min,int dt_max,
Eigen::Tensor<ComplexD,2> &dF2_fig8, Eigen::Tensor<ComplexD,2> &dF2_fig8,
Eigen::Tensor<ComplexD,2> &dF2_trtr, Eigen::Tensor<ComplexD,2> &dF2_trtr,
Eigen::Tensor<ComplexD,1> &den0, Eigen::Tensor<ComplexD,1> &den0,
Eigen::Tensor<ComplexD,1> &den1, Eigen::Tensor<ComplexD,1> &den1,
Eigen::Tensor<ComplexD,3> &WW_sd, Eigen::Tensor<ComplexD,3> &WW_sd,
const LatticeFermion *vs, const LatticeFermion *vs,
const LatticeFermion *vd, const LatticeFermion *vd,
int orthogdim) int orthogdim)
{ {
LOG(Message) << "Computing A2A DeltaF=2 graph" << std::endl; LOG(Message) << "Computing A2A DeltaF=2 graph" << std::endl;
int dt = dt_min; // HACK ; should loop over dt int dt = dt_min;
auto G5 = Gamma(Gamma::Algebra::Gamma5); auto G5 = Gamma(Gamma::Algebra::Gamma5);
@ -1027,7 +728,6 @@ void TA2APionField<FImpl>::DeltaFeq2_alt(int dt_min,int dt_max,
den1(t) =ComplexD(0.0); den1(t) =ComplexD(0.0);
} }
LatticeComplex D0(grid); // <P|A0> correlator from each wall LatticeComplex D0(grid); // <P|A0> correlator from each wall
LatticeComplex D1(grid); LatticeComplex D1(grid);
@ -1064,70 +764,28 @@ void TA2APionField<FImpl>::DeltaFeq2_alt(int dt_min,int dt_max,
WWVV[t] = zero; WWVV[t] = zero;
} }
////////////////////////////////////////////////////////// //////////////////////////////////////////////////////////////////
// Easily cache blocked and unrolled. // Method-5 - wrap this assembly in a distinct routine for reuse
////////////////////////////////////////////////////////// //////////////////////////////////////////////////////////////////
//
// Ways to speed up?? 20 min in 8^3 x 16 local volume with 400 modes.
//
// Flops are (cmul=6 + 1 (add))* 12*12 * N_s * N_d
// Bytes are read Nc x Ns x 2 + read/write Nc^2 x Ns^2 complex
//
double t_outer= -usecond(); double t_outer= -usecond();
double t_outer_0 =0.0;
double t_outer_1 =0.0;
#undef METHOD_1
#define METHOD_5
#ifdef METHOD_1
LOG(Message) << "METHOD_1" << std::endl;
// Method-1
// i) calculate flop rate and data rate. 17 GF/s and 80 GB/s
// Dominated by accumulating the N_t summands
parallel_for(int ss=0;ss<grid->oSites();ss++){
for(int s=0;s<N_s;s++){
for(int d=0;d<N_d;d++){
double t0=-usecond();
// RHS is conjugated
auto tmp = outerProduct(vs[s]._odata[ss],vd[d]._odata[ss]);
t0+=usecond();
double t1=-usecond();
for(int t=0;t<N_t;t++) {
WWVV[t]._odata[ss] += WW_sd(t,s,d) * tmp;
}
t1+=usecond();
#pragma omp critical
{ t_outer_0+=t0; t_outer_1+=t1;}
}
}
}
#endif
double b_outer = (12+12+2*12*12 ) * N_s * N_d * vol * sizeof(Complex);
double f_outer = 8 * 12*12 * N_s * N_d * vol;
#ifdef METHOD_5
LOG(Message) << "METHOD_5" << std::endl;
// Method-5
parallel_for(int ss=0;ss<grid->oSites();ss++){ parallel_for(int ss=0;ss<grid->oSites();ss++){
for(int d_o=0;d_o<N_d;d_o+=d_unroll){ for(int d_o=0;d_o<N_d;d_o+=d_unroll){
for(int t=0;t<N_t;t++){ for(int t=0;t<N_t;t++){
for(int s=0;s<N_s;s++){ for(int s=0;s<N_s;s++){
auto tmp1 = vs[s]._odata[ss]; auto tmp1 = vs[s]._odata[ss];
vobj tmp2 = zero; vobj tmp2 = zero;
// Surprisingly slow
////////////////////////////////////////
// Surprisingly slow with d_unroll = 32
////////////////////////////////////////
for(int d=d_o;d<MIN(d_o+d_unroll,N_d);d++){ for(int d=d_o;d<MIN(d_o+d_unroll,N_d);d++){
Scalar_v coeff = WW_sd(t,s,d); Scalar_v coeff = WW_sd(t,s,d);
mac(&tmp2 ,& coeff, & vd[d]._odata[ss]); mac(&tmp2 ,& coeff, & vd[d]._odata[ss]);
} }
// Outer product of tmp1 with a sum of terms suppressed by d_unroll //////////////////////////
// Fast outer product of tmp1 with a sum of terms suppressed by d_unroll
//////////////////////////
tmp2 = conjugate(tmp2); tmp2 = conjugate(tmp2);
for(int s1=0;s1<Ns;s1++){ for(int s1=0;s1<Ns;s1++){
for(int s2=0;s2<Ns;s2++){ for(int s2=0;s2<Ns;s2++){
@ -1145,8 +803,6 @@ void TA2APionField<FImpl>::DeltaFeq2_alt(int dt_min,int dt_max,
}} }}
} }
} }
#endif
t_outer+=usecond(); t_outer+=usecond();
////////////////////////////// //////////////////////////////
@ -1328,15 +984,8 @@ void TA2APionField<FImpl>::DeltaFeq2_alt(int dt_min,int dt_max,
double million=1.0e6; double million=1.0e6;
LOG(Message) << "Computing A2A DeltaF=2 graph t_tot " << t_tot /million << " s "<< std::endl; LOG(Message) << "Computing A2A DeltaF=2 graph t_tot " << t_tot /million << " s "<< std::endl;
LOG(Message) << "Computing A2A DeltaF=2 graph t_outer " << t_outer /million << " s "<< std::endl; LOG(Message) << "Computing A2A DeltaF=2 graph t_outer " << t_outer /million << " s "<< std::endl;
LOG(Message) << "Computing A2A DeltaF=2 graph t_outer_0 " << t_outer_0 /million << " s "<< std::endl;
LOG(Message) << "Computing A2A DeltaF=2 graph t_outer_1 " << t_outer_1 /million << " s "<< std::endl;
LOG(Message) << "Computing A2A DeltaF=2 graph t_contr " << t_contr /million << " s "<< std::endl; LOG(Message) << "Computing A2A DeltaF=2 graph t_contr " << t_contr /million << " s "<< std::endl;
LOG(Message) << "Computing A2A DeltaF=2 graph mflops/s outer " << f_outer/t_outer << " MF/s "<< std::endl;
LOG(Message) << "Computing A2A DeltaF=2 graph MB/s outer " << b_outer/t_outer << " MB/s "<< std::endl;
LOG(Message) << "Computing A2A DeltaF=2 graph mflops/s outer/node " << f_outer/t_outer/nodes << " MF/s "<< std::endl;
LOG(Message) << "Computing A2A DeltaF=2 graph MB/s outer/node " << b_outer/t_outer/nodes << " MB/s "<< std::endl;
} }
@ -1614,16 +1263,12 @@ void TA2APionField<FImpl>::execute(void)
const int dT=16; const int dT=16;
DeltaFeq2_alt (dT,dT,DeltaF2_fig8,DeltaF2_trtr, DeltaFeq2 (dT,dT,DeltaF2_fig8,DeltaF2_trtr,
denom0,denom1, denom0,denom1,
pionFieldWW_ij,&vi[0],&vj[0],Tp); pionFieldWW_ij,&vi[0],&vj[0],Tp);
for(int t=0;t<nt;t++) LOG(Message) << " denom0 [" << t << "] " << denom0(t)<<std::endl;
for(int t=0;t<nt;t++) LOG(Message) << " denom1 [" << t << "] " << denom1(t)<<std::endl; {
for(int g=0;g<4;g++){ int g=0; // O_{VV+AA}
for(int t=0;t<nt;t++) LOG(Message) << " DeltaF2_fig8 [" << t << ","<<g<<"] " << DeltaF2_fig8(t,g)<<std::endl;
for(int t=0;t<nt;t++) LOG(Message) << " DeltaF2_trtr [" << t << ","<<g<<"] " << DeltaF2_trtr(t,g)<<std::endl;
}
for(int g=0;g<4;g++){
for(int t=0;t<nt;t++) for(int t=0;t<nt;t++)
LOG(Message) << " Bag [" << t << ","<<g<<"] " LOG(Message) << " Bag [" << t << ","<<g<<"] "
<< (DeltaF2_fig8(t,g)+DeltaF2_trtr(t,g)) << (DeltaF2_fig8(t,g)+DeltaF2_trtr(t,g))
@ -1710,7 +1355,7 @@ void TA2APionField<FImpl>::execute(void)
LOG(Message) << " Wick2["<<g<<","<<t<< "] "<< C2[t]<<std::endl; LOG(Message) << " Wick2["<<g<<","<<t<< "] "<< C2[t]<<std::endl;
} }
} }
if( (g==9) || (g==7) ){ if( (g==9) || (g==7) ){ // P and At in above ordering
for(int t=0;t<C3.size();t++){ for(int t=0;t<C3.size();t++){
LOG(Message) << " <G|P>["<<g<<","<<t<< "] "<< C3[t]<<std::endl; LOG(Message) << " <G|P>["<<g<<","<<t<< "] "<< C3[t]<<std::endl;
} }