1
0
mirror of https://github.com/paboyle/Grid.git synced 2025-06-19 08:17:05 +01:00

Changes to A2Autils, A2AMatirx and DiskVector code that is needed for Hadrons 4 quark contraction module

This commit is contained in:
fionnoh
2019-06-27 13:45:20 +08:00
parent ac530636ca
commit 421a0a8a36
3 changed files with 290 additions and 142 deletions

View File

@ -68,8 +68,17 @@ public:
const std::vector<ComplexField> &emB1,
int orthogdim, double *t_kernel = nullptr, double *t_gsum = nullptr);
static void ContractWWVV(std::vector<PropagatorField> &WWVV,
const Eigen::Tensor<ComplexD,3> &WW_sd,
template <typename TensorType>
typename std::enable_if<std::is_same<Eigen::Tensor<ComplexD,3>, TensorType>::value, void>::type
static ContractWWVV(std::vector<PropagatorField> &WWVV,
const TensorType &WW_sd,
const FermionField *vs,
const FermionField *vd);
template <typename TensorType>
typename std::enable_if<!std::is_same<Eigen::Tensor<ComplexD,3>, TensorType>::value, void>::type
static ContractWWVV(std::vector<PropagatorField> &WWVV,
const TensorType &WW_sd,
const FermionField *vs,
const FermionField *vd);
@ -99,6 +108,11 @@ public:
const FermionField *vd,
int orthogdim);
#endif
private:
inline static void OuterProductWWVV(std::vector<PropagatorField> &WWVV,
const vobj &lhs,
const vobj &rhs,
const int Ns, const int ss, const int t);
};
template <class FImpl>
@ -961,12 +975,15 @@ void A2Autils<FImpl>::AslashField(TensorType &mat,
// Take WW_sd v^dag_d (x) v_s
//
template<class FImpl>
void A2Autils<FImpl>::ContractWWVV(std::vector<PropagatorField> &WWVV,
const Eigen::Tensor<ComplexD,3> &WW_sd,
template <class FImpl>
template <typename TensorType>
typename std::enable_if<std::is_same<Eigen::Tensor<ComplexD,3>, TensorType>::value, void>::type
A2Autils<FImpl>::ContractWWVV(std::vector<PropagatorField> &WWVV,
const TensorType &WW_sd,
const FermionField *vs,
const FermionField *vd)
{
std::cout << "Start contraction" << std::endl;
GridBase *grid = vs[0]._grid;
int nd = grid->_ndimension;
@ -989,33 +1006,90 @@ void A2Autils<FImpl>::ContractWWVV(std::vector<PropagatorField> &WWVV,
vobj tmp2 = zero;
vobj tmp3 = zero;
for(int d=d_o;d<MIN(d_o+d_unroll,N_d);d++){
for(int d=d_o;d<MIN(d_o+d_unroll,N_d);d++){
Scalar_v coeff = WW_sd(t,s,d);
tmp3 = conjugate(vd[d]._odata[ss]);
mac(&tmp2, &coeff, &tmp3);
}
tmp3 = conjugate(vd[d]._odata[ss]);
mac(&tmp2 ,& coeff, &tmp3 );
}
//////////////////////////
// Fast outer product of tmp1 with a sum of terms suppressed by d_unroll
//////////////////////////
for(int s1=0;s1<Ns;s1++){
for(int s2=0;s2<Ns;s2++){
WWVV[t]._odata[ss]()(s1,s2)(0,0) += tmp1()(s1)(0)*tmp2()(s2)(0);
WWVV[t]._odata[ss]()(s1,s2)(0,1) += tmp1()(s1)(0)*tmp2()(s2)(1);
WWVV[t]._odata[ss]()(s1,s2)(0,2) += tmp1()(s1)(0)*tmp2()(s2)(2);
WWVV[t]._odata[ss]()(s1,s2)(1,0) += tmp1()(s1)(1)*tmp2()(s2)(0);
WWVV[t]._odata[ss]()(s1,s2)(1,1) += tmp1()(s1)(1)*tmp2()(s2)(1);
WWVV[t]._odata[ss]()(s1,s2)(1,2) += tmp1()(s1)(1)*tmp2()(s2)(2);
WWVV[t]._odata[ss]()(s1,s2)(2,0) += tmp1()(s1)(2)*tmp2()(s2)(0);
WWVV[t]._odata[ss]()(s1,s2)(2,1) += tmp1()(s1)(2)*tmp2()(s2)(1);
WWVV[t]._odata[ss]()(s1,s2)(2,2) += tmp1()(s1)(2)*tmp2()(s2)(2);
}}
OuterProductWWVV(WWVV, tmp1, tmp2, Ns, ss, t);
}}
}
}
}
template <class FImpl>
template <typename TensorType>
typename std::enable_if<!std::is_same<Eigen::Tensor<ComplexD,3>, TensorType>::value, void>::type
A2Autils<FImpl>::ContractWWVV(std::vector<PropagatorField> &WWVV,
const TensorType &WW_sd,
const FermionField *vs,
const FermionField *vd)
{
GridBase *grid = vs[0]._grid;
int nd = grid->_ndimension;
int Nsimd = grid->Nsimd();
int N_t = WW_sd.dimensions()[0];
int N_s = WW_sd.dimensions()[1];
int N_d = WW_sd.dimensions()[2];
int d_unroll = 32;// Empirical optimisation
Eigen::Matrix<Complex, -1, -1, Eigen::RowMajor> buf;
for(int t=0;t<N_t;t++){
WWVV[t] = zero;
}
for (int t = 0; t < N_t; t++){
std::cout << GridLogMessage << "Contraction t = " << t << std::endl;
buf = WW_sd[t];
parallel_for(int ss=0;ss<grid->oSites();ss++){
for(int d_o=0;d_o<N_d;d_o+=d_unroll){
for(int s=0;s<N_s;s++){
auto tmp1 = vs[s]._odata[ss];
vobj tmp2 = zero;
vobj tmp3 = zero;
for(int d=d_o;d<MIN(d_o+d_unroll,N_d);d++){
Scalar_v coeff = buf(s,d);
tmp3 = conjugate(vd[d]._odata[ss]);
mac(&tmp2 ,& coeff, &tmp3 );
}
//////////////////////////
// Fast outer product of tmp1 with a sum of terms suppressed by d_unroll
//////////////////////////
OuterProductWWVV(WWVV, tmp1, tmp2, Ns, ss, t);
}}
}
}
}
template <class FImpl>
inline void A2Autils<FImpl>::OuterProductWWVV(std::vector<PropagatorField> &WWVV,
const vobj &lhs,
const vobj &rhs,
const int Ns, const int ss, const int t)
{
for (int s1 = 0; s1 < Ns; s1++){
for (int s2 = 0; s2 < Ns; s2++){
WWVV[t]._odata[ss]()(s1, s2)(0, 0) += lhs()(s1)(0) * rhs()(s2)(0);
WWVV[t]._odata[ss]()(s1, s2)(0, 1) += lhs()(s1)(0) * rhs()(s2)(1);
WWVV[t]._odata[ss]()(s1, s2)(0, 2) += lhs()(s1)(0) * rhs()(s2)(2);
WWVV[t]._odata[ss]()(s1, s2)(1, 0) += lhs()(s1)(1) * rhs()(s2)(0);
WWVV[t]._odata[ss]()(s1, s2)(1, 1) += lhs()(s1)(1) * rhs()(s2)(1);
WWVV[t]._odata[ss]()(s1, s2)(1, 2) += lhs()(s1)(1) * rhs()(s2)(2);
WWVV[t]._odata[ss]()(s1, s2)(2, 0) += lhs()(s1)(2) * rhs()(s2)(0);
WWVV[t]._odata[ss]()(s1, s2)(2, 1) += lhs()(s1)(2) * rhs()(s2)(1);
WWVV[t]._odata[ss]()(s1, s2)(2, 2) += lhs()(s1)(2) * rhs()(s2)(2);
}
}
}
template<class FImpl>
void A2Autils<FImpl>::ContractFourQuarkColourDiagonal(const PropagatorField &WWVV0,