1
0
mirror of https://github.com/paboyle/Grid.git synced 2024-09-20 09:15:38 +01:00

Changes nedded to have a current insertion on every second time slice - avoids unnecessary contractions

This commit is contained in:
fionnoh 2019-06-28 15:18:28 +08:00
parent ce29b18dc9
commit 67690df3bd
2 changed files with 62 additions and 21 deletions

View File

@ -69,14 +69,18 @@ public:
int orthogdim, double *t_kernel = nullptr, double *t_gsum = nullptr);
template <typename TensorType>
typename std::enable_if<std::is_same<Eigen::Tensor<ComplexD,3>, TensorType>::value, void>::type
typename std::enable_if<(std::is_same<Eigen::Tensor<ComplexD,3>, TensorType>::value ||
std::is_same<Eigen::TensorMap<Eigen::Tensor<Complex, 3, Eigen::RowMajor>>, TensorType>::value),
void>::type
static ContractWWVV(std::vector<PropagatorField> &WWVV,
const TensorType &WW_sd,
const FermionField *vs,
const FermionField *vd);
template <typename TensorType>
typename std::enable_if<!std::is_same<Eigen::Tensor<ComplexD,3>, TensorType>::value, void>::type
typename std::enable_if<!(std::is_same<Eigen::Tensor<ComplexD,3>, TensorType>::value ||
std::is_same<Eigen::TensorMap<Eigen::Tensor<Complex, 3, Eigen::RowMajor>>, TensorType>::value),
void>::type
static ContractWWVV(std::vector<PropagatorField> &WWVV,
const TensorType &WW_sd,
const FermionField *vs,
@ -977,7 +981,9 @@ void A2Autils<FImpl>::AslashField(TensorType &mat,
template <class FImpl>
template <typename TensorType>
typename std::enable_if<std::is_same<Eigen::Tensor<ComplexD,3>, TensorType>::value, void>::type
typename std::enable_if<(std::is_same<Eigen::Tensor<ComplexD,3>, TensorType>::value ||
std::is_same<Eigen::TensorMap<Eigen::Tensor<Complex, 3, Eigen::RowMajor>>, TensorType>::value),
void>::type
A2Autils<FImpl>::ContractWWVV(std::vector<PropagatorField> &WWVV,
const TensorType &WW_sd,
const FermionField *vs,
@ -1023,11 +1029,13 @@ A2Autils<FImpl>::ContractWWVV(std::vector<PropagatorField> &WWVV,
template <class FImpl>
template <typename TensorType>
typename std::enable_if<!std::is_same<Eigen::Tensor<ComplexD,3>, TensorType>::value, void>::type
typename std::enable_if<!(std::is_same<Eigen::Tensor<ComplexD, 3>, TensorType>::value ||
std::is_same<Eigen::TensorMap<Eigen::Tensor<Complex, 3, Eigen::RowMajor>>, TensorType>::value),
void>::type
A2Autils<FImpl>::ContractWWVV(std::vector<PropagatorField> &WWVV,
const TensorType &WW_sd,
const FermionField *vs,
const FermionField *vd)
const TensorType &WW_sd,
const FermionField *vs,
const FermionField *vd)
{
GridBase *grid = vs[0]._grid;

View File

@ -17,9 +17,11 @@ class A2AFourQuarkContractionPar: Serializable
{
public:
GRID_SERIALIZABLE_CLASS_MEMBERS(A2AFourQuarkContractionPar,
std::string, v1,
std::string, v2,
std::string, mf12);
std::string, v1,
std::string, v2,
std::string, mf12,
bool, allContr,
unsigned int, dt);
};
template <typename FImpl>
@ -38,6 +40,8 @@ class TA2AFourQuarkContraction: public Module<A2AFourQuarkContractionPar>
virtual void setup(void);
// execution
virtual void execute(void);
private:
unsigned int nt_;
};
MODULE_REGISTER_TMP(A2AFourQuarkContraction, TA2AFourQuarkContraction<FIMPL>, MContraction);
@ -72,30 +76,59 @@ std::vector<std::string> TA2AFourQuarkContraction<FImpl>::getOutput(void)
template <typename FImpl>
void TA2AFourQuarkContraction<FImpl>::setup(void)
{
int nt = env().getDim(Tp);
envCreate(std::vector<PropagatorField>, getName(), 1, nt, envGetGrid(PropagatorField));
if (par().allContr)
{
nt_ = env().getDim(Tp);
envTmp(std::vector<PropagatorField>, "tmpWWVV", 1, nt_, envGetGrid(PropagatorField));
envCreate(std::vector<PropagatorField>, getName(), 1, nt_, envGetGrid(PropagatorField));
}
else
{
envTmp(std::vector<PropagatorField>, "tmpWWVV", 1, 1, envGetGrid(PropagatorField));
envCreate(PropagatorField, getName(), 1, envGetGrid(PropagatorField));
}
}
// execution ///////////////////////////////////////////////////////////////////
template <typename FImpl>
void TA2AFourQuarkContraction<FImpl>::execute(void)
{
int nt = env().getDim(Tp);
auto &v1 = envGet(std::vector<FermionField>, par().v1);
auto &v2 = envGet(std::vector<FermionField>, par().v2);
auto &mf12 = envGet(EigenDiskVector<Complex>, par().mf12);
auto &wwvv = envGet(std::vector<PropagatorField>, getName());
envGetTmp(std::vector<PropagatorField>, tmpWWVV);
for (int t = 0; t < nt; t++)
unsigned int dt = par().dt;
unsigned int nt = env().getDim(Tp);
if (par().allContr)
{
wwvv[t] = zero;
}
LOG(Message) << "Computing 4 quark contraction for " << getName()
<< " for all t0 time translations "
<< "with nt = " << nt_ << " and dt = " << dt << std::endl;
LOG(Message) << "Computing 4 quark contraction for: " << getName() << std::endl;
A2Autils<FImpl>::ContractWWVV(wwvv, mf12, &v1[0], &v2[0]);
auto &WWVV = envGet(std::vector<PropagatorField>, getName());
A2Autils<FImpl>::ContractWWVV(tmpWWVV, mf12, &v1[0], &v2[0]);
for(unsigned int t = 0; t < nt_; t++){
unsigned int t0 = (t + dt) % nt_;
WWVV[t] = tmpWWVV[t0];
}
}
else
{
LOG(Message) << "Computing 4 quark contraction for: " << getName()
<< " for time dt = " << dt << std::endl;
auto &WWVV = envGet(PropagatorField, getName());
int ni = v1.size();
int nj = v2.size();
Eigen::Matrix<Complex, -1, -1, Eigen::RowMajor> mf;
mf = mf12[dt];
Eigen::TensorMap<Eigen::Tensor<Complex, 3, Eigen::RowMajor>> mfT(mf.data(), 1, ni, nj);
A2Autils<FImpl>::ContractWWVV(tmpWWVV, mfT, &v1[0], &v2[0]);
WWVV = tmpWWVV[0];
}
}
END_MODULE_NAMESPACE