mirror of
https://github.com/paboyle/Grid.git
synced 2025-06-19 00:07:05 +01:00
Merge pull request #243 from fionnoh/feature/A2A_current_insertion
Feature/a2 a current insertion
This commit is contained in:
@ -67,8 +67,21 @@ public:
|
||||
const std::vector<ComplexField> &emB1,
|
||||
int orthogdim, double *t_kernel = nullptr, double *t_gsum = nullptr);
|
||||
|
||||
static void ContractWWVV(std::vector<PropagatorField> &WWVV,
|
||||
const Eigen::Tensor<ComplexD,3> &WW_sd,
|
||||
template <typename TensorType>
|
||||
typename std::enable_if<(std::is_same<Eigen::Tensor<ComplexD,3>, TensorType>::value ||
|
||||
std::is_same<Eigen::TensorMap<Eigen::Tensor<Complex, 3, Eigen::RowMajor>>, TensorType>::value),
|
||||
void>::type
|
||||
static ContractWWVV(std::vector<PropagatorField> &WWVV,
|
||||
const TensorType &WW_sd,
|
||||
const FermionField *vs,
|
||||
const FermionField *vd);
|
||||
|
||||
template <typename TensorType>
|
||||
typename std::enable_if<!(std::is_same<Eigen::Tensor<ComplexD,3>, TensorType>::value ||
|
||||
std::is_same<Eigen::TensorMap<Eigen::Tensor<Complex, 3, Eigen::RowMajor>>, TensorType>::value),
|
||||
void>::type
|
||||
static ContractWWVV(std::vector<PropagatorField> &WWVV,
|
||||
const TensorType &WW_sd,
|
||||
const FermionField *vs,
|
||||
const FermionField *vd);
|
||||
|
||||
@ -98,6 +111,11 @@ public:
|
||||
const FermionField *vd,
|
||||
int orthogdim);
|
||||
#endif
|
||||
private:
|
||||
inline static void OuterProductWWVV(PropagatorField &WWVV,
|
||||
const vobj &lhs,
|
||||
const vobj &rhs,
|
||||
const int Ns, const int ss);
|
||||
};
|
||||
|
||||
template <class FImpl>
|
||||
@ -968,9 +986,13 @@ void A2Autils<FImpl>::AslashField(TensorType &mat,
|
||||
// Take WW_sd v^dag_d (x) v_s
|
||||
//
|
||||
|
||||
template<class FImpl>
|
||||
void A2Autils<FImpl>::ContractWWVV(std::vector<PropagatorField> &WWVV,
|
||||
const Eigen::Tensor<ComplexD,3> &WW_sd,
|
||||
template <class FImpl>
|
||||
template <typename TensorType>
|
||||
typename std::enable_if<(std::is_same<Eigen::Tensor<ComplexD,3>, TensorType>::value ||
|
||||
std::is_same<Eigen::TensorMap<Eigen::Tensor<Complex, 3, Eigen::RowMajor>>, TensorType>::value),
|
||||
void>::type
|
||||
A2Autils<FImpl>::ContractWWVV(std::vector<PropagatorField> &WWVV,
|
||||
const TensorType &WW_sd,
|
||||
const FermionField *vs,
|
||||
const FermionField *vd)
|
||||
{
|
||||
@ -992,39 +1014,100 @@ void A2Autils<FImpl>::ContractWWVV(std::vector<PropagatorField> &WWVV,
|
||||
for(int d_o=0;d_o<N_d;d_o+=d_unroll){
|
||||
for(int t=0;t<N_t;t++){
|
||||
for(int s=0;s<N_s;s++){
|
||||
auto vs_v = vs[s].View();
|
||||
auto tmp1 = vs_v[ss];
|
||||
vobj tmp2 = Zero();
|
||||
vobj tmp3 = Zero();
|
||||
for(int d=d_o;d<MIN(d_o+d_unroll,N_d);d++){
|
||||
auto vd_v = vd[d].View();
|
||||
Scalar_v coeff = WW_sd(t,s,d);
|
||||
tmp3 = conjugate(vd_v[ss]);
|
||||
mac(&tmp2, &coeff, &tmp3);
|
||||
}
|
||||
auto vs_v = vs[s].View();
|
||||
auto tmp1 = vs_v[ss];
|
||||
vobj tmp2 = Zero();
|
||||
vobj tmp3 = Zero();
|
||||
for(int d=d_o;d<MIN(d_o+d_unroll,N_d);d++){
|
||||
auto vd_v = vd[d].View();
|
||||
Scalar_v coeff = WW_sd(t,s,d);
|
||||
tmp3 = conjugate(vd_v[ss]);
|
||||
mac(&tmp2, &coeff, &tmp3);
|
||||
}
|
||||
|
||||
//////////////////////////
|
||||
// Fast outer product of tmp1 with a sum of terms suppressed by d_unroll
|
||||
//////////////////////////
|
||||
auto WWVV_v = WWVV[t].View();
|
||||
for(int s1=0;s1<Ns;s1++){
|
||||
for(int s2=0;s2<Ns;s2++){
|
||||
WWVV_v[ss]()(s1,s2)(0,0) += tmp1()(s1)(0)*tmp2()(s2)(0);
|
||||
WWVV_v[ss]()(s1,s2)(0,1) += tmp1()(s1)(0)*tmp2()(s2)(1);
|
||||
WWVV_v[ss]()(s1,s2)(0,2) += tmp1()(s1)(0)*tmp2()(s2)(2);
|
||||
WWVV_v[ss]()(s1,s2)(1,0) += tmp1()(s1)(1)*tmp2()(s2)(0);
|
||||
WWVV_v[ss]()(s1,s2)(1,1) += tmp1()(s1)(1)*tmp2()(s2)(1);
|
||||
WWVV_v[ss]()(s1,s2)(1,2) += tmp1()(s1)(1)*tmp2()(s2)(2);
|
||||
WWVV_v[ss]()(s1,s2)(2,0) += tmp1()(s1)(2)*tmp2()(s2)(0);
|
||||
WWVV_v[ss]()(s1,s2)(2,1) += tmp1()(s1)(2)*tmp2()(s2)(1);
|
||||
WWVV_v[ss]()(s1,s2)(2,2) += tmp1()(s1)(2)*tmp2()(s2)(2);
|
||||
}}
|
||||
//////////////////////////
|
||||
// Fast outer product of tmp1 with a sum of terms suppressed by d_unroll
|
||||
//////////////////////////
|
||||
OuterProductWWVV(WWVV[t], tmp1, tmp2, Ns, ss);
|
||||
|
||||
}}
|
||||
}
|
||||
});
|
||||
}
|
||||
|
||||
template <class FImpl>
|
||||
template <typename TensorType>
|
||||
typename std::enable_if<!(std::is_same<Eigen::Tensor<ComplexD, 3>, TensorType>::value ||
|
||||
std::is_same<Eigen::TensorMap<Eigen::Tensor<Complex, 3, Eigen::RowMajor>>, TensorType>::value),
|
||||
void>::type
|
||||
A2Autils<FImpl>::ContractWWVV(std::vector<PropagatorField> &WWVV,
|
||||
const TensorType &WW_sd,
|
||||
const FermionField *vs,
|
||||
const FermionField *vd)
|
||||
{
|
||||
GridBase *grid = vs[0].Grid();
|
||||
|
||||
int nd = grid->_ndimension;
|
||||
int Nsimd = grid->Nsimd();
|
||||
int N_t = WW_sd.dimensions()[0];
|
||||
int N_s = WW_sd.dimensions()[1];
|
||||
int N_d = WW_sd.dimensions()[2];
|
||||
|
||||
int d_unroll = 32;// Empirical optimisation
|
||||
|
||||
Eigen::Matrix<Complex, -1, -1, Eigen::RowMajor> buf;
|
||||
|
||||
for(int t=0;t<N_t;t++){
|
||||
WWVV[t] = Zero();
|
||||
}
|
||||
|
||||
for (int t = 0; t < N_t; t++){
|
||||
std::cout << GridLogMessage << "Contraction t = " << t << std::endl;
|
||||
buf = WW_sd[t];
|
||||
thread_for(ss,grid->oSites(),{
|
||||
for(int d_o=0;d_o<N_d;d_o+=d_unroll){
|
||||
for(int s=0;s<N_s;s++){
|
||||
auto vs_v = vs[s].View();
|
||||
auto tmp1 = vs_v[ss];
|
||||
vobj tmp2 = Zero();
|
||||
vobj tmp3 = Zero();
|
||||
for(int d=d_o;d<MIN(d_o+d_unroll,N_d);d++){
|
||||
auto vd_v = vd[d].View();
|
||||
Scalar_v coeff = buf(s,d);
|
||||
tmp3 = conjugate(vd_v[ss]);
|
||||
mac(&tmp2, &coeff, &tmp3);
|
||||
}
|
||||
|
||||
//////////////////////////
|
||||
// Fast outer product of tmp1 with a sum of terms suppressed by d_unroll
|
||||
//////////////////////////
|
||||
OuterProductWWVV(WWVV[t], tmp1, tmp2, Ns, ss);
|
||||
}}
|
||||
});
|
||||
}
|
||||
}
|
||||
|
||||
template <class FImpl>
|
||||
inline void A2Autils<FImpl>::OuterProductWWVV(PropagatorField &WWVV,
|
||||
const vobj &lhs,
|
||||
const vobj &rhs,
|
||||
const int Ns, const int ss)
|
||||
{
|
||||
auto WWVV_v = WWVV.View();
|
||||
for (int s1 = 0; s1 < Ns; s1++){
|
||||
for (int s2 = 0; s2 < Ns; s2++){
|
||||
WWVV_v[ss]()(s1,s2)(0, 0) += lhs()(s1)(0) * rhs()(s2)(0);
|
||||
WWVV_v[ss]()(s1,s2)(0, 1) += lhs()(s1)(0) * rhs()(s2)(1);
|
||||
WWVV_v[ss]()(s1,s2)(0, 2) += lhs()(s1)(0) * rhs()(s2)(2);
|
||||
WWVV_v[ss]()(s1,s2)(1, 0) += lhs()(s1)(1) * rhs()(s2)(0);
|
||||
WWVV_v[ss]()(s1,s2)(1, 1) += lhs()(s1)(1) * rhs()(s2)(1);
|
||||
WWVV_v[ss]()(s1,s2)(1, 2) += lhs()(s1)(1) * rhs()(s2)(2);
|
||||
WWVV_v[ss]()(s1,s2)(2, 0) += lhs()(s1)(2) * rhs()(s2)(0);
|
||||
WWVV_v[ss]()(s1,s2)(2, 1) += lhs()(s1)(2) * rhs()(s2)(1);
|
||||
WWVV_v[ss]()(s1,s2)(2, 2) += lhs()(s1)(2) * rhs()(s2)(2);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
template<class FImpl>
|
||||
void A2Autils<FImpl>::ContractFourQuarkColourDiagonal(const PropagatorField &WWVV0,
|
||||
|
Reference in New Issue
Block a user