From 67690df3bdaa6c32f082798c99272c102ffa77f1 Mon Sep 17 00:00:00 2001 From: fionnoh Date: Fri, 28 Jun 2019 15:18:28 +0800 Subject: [PATCH] Changes nedded to have a current insertion on every second time slice - avoids unnecessary contractions --- Grid/qcd/utils/A2Autils.h | 22 ++++--- .../MContraction/A2AFourQuarkContraction.hpp | 61 ++++++++++++++----- 2 files changed, 62 insertions(+), 21 deletions(-) diff --git a/Grid/qcd/utils/A2Autils.h b/Grid/qcd/utils/A2Autils.h index 966d84f7..e163b9c6 100644 --- a/Grid/qcd/utils/A2Autils.h +++ b/Grid/qcd/utils/A2Autils.h @@ -69,14 +69,18 @@ public: int orthogdim, double *t_kernel = nullptr, double *t_gsum = nullptr); template - typename std::enable_if, TensorType>::value, void>::type + typename std::enable_if<(std::is_same, TensorType>::value || + std::is_same>, TensorType>::value), + void>::type static ContractWWVV(std::vector &WWVV, const TensorType &WW_sd, const FermionField *vs, const FermionField *vd); template - typename std::enable_if, TensorType>::value, void>::type + typename std::enable_if, TensorType>::value || + std::is_same>, TensorType>::value), + void>::type static ContractWWVV(std::vector &WWVV, const TensorType &WW_sd, const FermionField *vs, @@ -977,7 +981,9 @@ void A2Autils::AslashField(TensorType &mat, template template -typename std::enable_if, TensorType>::value, void>::type +typename std::enable_if<(std::is_same, TensorType>::value || + std::is_same>, TensorType>::value), + void>::type A2Autils::ContractWWVV(std::vector &WWVV, const TensorType &WW_sd, const FermionField *vs, @@ -1023,11 +1029,13 @@ A2Autils::ContractWWVV(std::vector &WWVV, template template -typename std::enable_if, TensorType>::value, void>::type +typename std::enable_if, TensorType>::value || + std::is_same>, TensorType>::value), + void>::type A2Autils::ContractWWVV(std::vector &WWVV, - const TensorType &WW_sd, - const FermionField *vs, - const FermionField *vd) + const TensorType &WW_sd, + const FermionField *vs, + const FermionField *vd) { GridBase *grid = vs[0]._grid; diff --git a/Hadrons/Modules/MContraction/A2AFourQuarkContraction.hpp b/Hadrons/Modules/MContraction/A2AFourQuarkContraction.hpp index f977caab..7a8112e1 100644 --- a/Hadrons/Modules/MContraction/A2AFourQuarkContraction.hpp +++ b/Hadrons/Modules/MContraction/A2AFourQuarkContraction.hpp @@ -17,9 +17,11 @@ class A2AFourQuarkContractionPar: Serializable { public: GRID_SERIALIZABLE_CLASS_MEMBERS(A2AFourQuarkContractionPar, - std::string, v1, - std::string, v2, - std::string, mf12); + std::string, v1, + std::string, v2, + std::string, mf12, + bool, allContr, + unsigned int, dt); }; template @@ -38,6 +40,8 @@ class TA2AFourQuarkContraction: public Module virtual void setup(void); // execution virtual void execute(void); + private: + unsigned int nt_; }; MODULE_REGISTER_TMP(A2AFourQuarkContraction, TA2AFourQuarkContraction, MContraction); @@ -72,30 +76,59 @@ std::vector TA2AFourQuarkContraction::getOutput(void) template void TA2AFourQuarkContraction::setup(void) { - int nt = env().getDim(Tp); - - envCreate(std::vector, getName(), 1, nt, envGetGrid(PropagatorField)); + if (par().allContr) + { + nt_ = env().getDim(Tp); + envTmp(std::vector, "tmpWWVV", 1, nt_, envGetGrid(PropagatorField)); + envCreate(std::vector, getName(), 1, nt_, envGetGrid(PropagatorField)); + } + else + { + envTmp(std::vector, "tmpWWVV", 1, 1, envGetGrid(PropagatorField)); + envCreate(PropagatorField, getName(), 1, envGetGrid(PropagatorField)); + } } // execution /////////////////////////////////////////////////////////////////// template void TA2AFourQuarkContraction::execute(void) { - int nt = env().getDim(Tp); - auto &v1 = envGet(std::vector, par().v1); auto &v2 = envGet(std::vector, par().v2); auto &mf12 = envGet(EigenDiskVector, par().mf12); - auto &wwvv = envGet(std::vector, getName()); + envGetTmp(std::vector, tmpWWVV); - for (int t = 0; t < nt; t++) + unsigned int dt = par().dt; + unsigned int nt = env().getDim(Tp); + + if (par().allContr) { - wwvv[t] = zero; - } + LOG(Message) << "Computing 4 quark contraction for " << getName() + << " for all t0 time translations " + << "with nt = " << nt_ << " and dt = " << dt << std::endl; - LOG(Message) << "Computing 4 quark contraction for: " << getName() << std::endl; - A2Autils::ContractWWVV(wwvv, mf12, &v1[0], &v2[0]); + auto &WWVV = envGet(std::vector, getName()); + A2Autils::ContractWWVV(tmpWWVV, mf12, &v1[0], &v2[0]); + for(unsigned int t = 0; t < nt_; t++){ + unsigned int t0 = (t + dt) % nt_; + WWVV[t] = tmpWWVV[t0]; + } + } + else + { + LOG(Message) << "Computing 4 quark contraction for: " << getName() + << " for time dt = " << dt << std::endl; + + auto &WWVV = envGet(PropagatorField, getName()); + int ni = v1.size(); + int nj = v2.size(); + Eigen::Matrix mf; + mf = mf12[dt]; + Eigen::TensorMap> mfT(mf.data(), 1, ni, nj); + A2Autils::ContractWWVV(tmpWWVV, mfT, &v1[0], &v2[0]); + WWVV = tmpWWVV[0]; + } } END_MODULE_NAMESPACE