1
0
mirror of https://github.com/paboyle/Grid.git synced 2025-04-09 21:50:45 +01:00

Included eigenPacks and action as references, not inputs, of A2A module. They now now longer need to be parameters in the meson field modules.

This commit is contained in:
fionnoh 2018-06-28 16:14:49 +01:00
parent f7e86f81a0
commit 7fe3974c0a

View File

@ -4,6 +4,7 @@
#include <Grid/Hadrons/Global.hpp> #include <Grid/Hadrons/Global.hpp>
#include <Grid/Hadrons/Module.hpp> #include <Grid/Hadrons/Module.hpp>
#include <Grid/Hadrons/ModuleFactory.hpp> #include <Grid/Hadrons/ModuleFactory.hpp>
#include <Grid/Hadrons/Solver.hpp>
#include <Grid/Hadrons/EigenPack.hpp> #include <Grid/Hadrons/EigenPack.hpp>
#include <Grid/Hadrons/AllToAllVectors.hpp> #include <Grid/Hadrons/AllToAllVectors.hpp>
@ -23,7 +24,7 @@ public:
int, N, int, N,
std::vector<std::string>, sources, std::vector<std::string>, sources,
std::string, action, std::string, action,
std::string, eigenpack, std::string, eigenPack,
std::string, solver); std::string, solver);
}; };
@ -37,7 +38,7 @@ class TA2AVectors : public Module<A2AVectorsPar>
typedef FermionEigenPack<FImpl> EPack; typedef FermionEigenPack<FImpl> EPack;
typedef CoarseFermionEigenPack<FImpl, nBasis> CoarseEPack; typedef CoarseFermionEigenPack<FImpl, nBasis> CoarseEPack;
typedef A2AModesSchurDiagTwo<typename FImpl::FermionField, FMat> A2ABase; typedef A2AModesSchurDiagTwo<typename FImpl::FermionField, FMat, Solver> A2ABase;
public: public:
// constructor // constructor
@ -46,6 +47,7 @@ class TA2AVectors : public Module<A2AVectorsPar>
virtual ~TA2AVectors(void) {}; virtual ~TA2AVectors(void) {};
// dependency relation // dependency relation
virtual std::vector<std::string> getInput(void); virtual std::vector<std::string> getInput(void);
virtual std::vector<std::string> getReference(void);
virtual std::vector<std::string> getOutput(void); virtual std::vector<std::string> getOutput(void);
// setup // setup
virtual void setup(void); virtual void setup(void);
@ -54,7 +56,7 @@ class TA2AVectors : public Module<A2AVectorsPar>
private: private:
unsigned int Ls_; unsigned int Ls_;
std::string retName_; std::string className_;
}; };
MODULE_REGISTER_TMP(A2AVectors, ARG(TA2AVectors<FIMPL, HADRONS_DEFAULT_LANCZOS_NBASIS>), MSolver); MODULE_REGISTER_TMP(A2AVectors, ARG(TA2AVectors<FIMPL, HADRONS_DEFAULT_LANCZOS_NBASIS>), MSolver);
@ -67,14 +69,17 @@ MODULE_REGISTER_TMP(ZA2AVectors, ARG(TA2AVectors<ZFIMPL, HADRONS_DEFAULT_LANCZOS
template <typename FImpl, int nBasis> template <typename FImpl, int nBasis>
TA2AVectors<FImpl, nBasis>::TA2AVectors(const std::string name) TA2AVectors<FImpl, nBasis>::TA2AVectors(const std::string name)
: Module<A2AVectorsPar>(name) : Module<A2AVectorsPar>(name)
, retName_ (name + "_ret") , className_ (name + "_class")
{} {}
// dependencies/products /////////////////////////////////////////////////////// // dependencies/products ///////////////////////////////////////////////////////
template <typename FImpl, int nBasis> template <typename FImpl, int nBasis>
std::vector<std::string> TA2AVectors<FImpl, nBasis>::getInput(void) std::vector<std::string> TA2AVectors<FImpl, nBasis>::getInput(void)
{ {
std::vector<std::string> in = {par().action, par().solver, par().solver + "_subtract"}; int Nl = par().Nl;
std::string sub_string = "";
if (Nl > 0) sub_string = "_subtract";
std::vector<std::string> in = {par().solver + sub_string};
int n = par().sources.size(); int n = par().sources.size();
@ -86,10 +91,23 @@ std::vector<std::string> TA2AVectors<FImpl, nBasis>::getInput(void)
return in; return in;
} }
template <typename FImpl, int nBasis>
std::vector<std::string> TA2AVectors<FImpl, nBasis>::getReference(void)
{
std::vector<std::string> ref = {par().action};
if (!par().eigenPack.empty())
{
ref.push_back(par().eigenPack);
}
return ref;
}
template <typename FImpl, int nBasis> template <typename FImpl, int nBasis>
std::vector<std::string> TA2AVectors<FImpl, nBasis>::getOutput(void) std::vector<std::string> TA2AVectors<FImpl, nBasis>::getOutput(void)
{ {
std::vector<std::string> out = {getName(), retName_}; std::vector<std::string> out = {getName(), className_};
return out; return out;
} }
@ -106,7 +124,7 @@ void TA2AVectors<FImpl, nBasis>::setup(void)
std::string sub_string = ""; std::string sub_string = "";
if (Nl > 0) sub_string = "_subtract"; if (Nl > 0) sub_string = "_subtract";
auto &solver = envGet(SolverFn, par().solver + sub_string); auto &solver = envGet(Solver, par().solver + sub_string);
Ls_ = env().getObjectLs(par().solver + sub_string); Ls_ = env().getObjectLs(par().solver + sub_string);
auto &action = envGet(FMat, par().action); auto &action = envGet(FMat, par().action);
@ -121,10 +139,10 @@ void TA2AVectors<FImpl, nBasis>::setup(void)
if (Nl > 0) if (Nl > 0)
{ {
// Low modes // Low modes
auto &epack = envGet(EPack, par().eigenpack); auto &epack = envGet(EPack, par().eigenPack);
LOG(Message) << "Creating a2a vectors " << getName() << LOG(Message) << "Creating a2a vectors " << getName() <<
" using eigenpack '" << par().eigenpack << "' (" " using eigenpack '" << par().eigenPack << "' ("
<< epack.evec.size() << " modes)" << << epack.evec.size() << " modes)" <<
" and " << Nh << " high modes." << std::endl; " and " << Nh << " high modes." << std::endl;
evec = &epack.evec; evec = &epack.evec;
@ -135,7 +153,7 @@ void TA2AVectors<FImpl, nBasis>::setup(void)
LOG(Message) << "Creating a2a vectors " << getName() << LOG(Message) << "Creating a2a vectors " << getName() <<
" using " << Nh << " high modes only." << std::endl; " using " << Nh << " high modes only." << std::endl;
} }
envCreate(A2ABase, retName_, Ls_, envCreate(A2ABase, className_, Ls_,
evec, eval, evec, eval,
action, action,
solver, solver,
@ -158,7 +176,7 @@ void TA2AVectors<FImpl, nBasis>::execute(void)
if (Nl > 0) sub_string = "_subtract"; if (Nl > 0) sub_string = "_subtract";
Ls_ = env().getObjectLs(par().solver + sub_string); Ls_ = env().getObjectLs(par().solver + sub_string);
auto &a2areturn = envGet(A2ABase, retName_); auto &a2areturn = envGet(A2ABase, className_);
// High modes // High modes
auto sources = par().sources; auto sources = par().sources;
@ -168,8 +186,6 @@ void TA2AVectors<FImpl, nBasis>::execute(void)
envGetTmp(FermionField, tmp); envGetTmp(FermionField, tmp);
envGetTmp(FermionField, tmp2); envGetTmp(FermionField, tmp2);
// TODO: At the moment weighting only applies to the 4d->5d source path
// similar to how the 5d and 4d srcs are passed in, this needs more work to be less brittle
double weight = 1.0 / sqrt(Ns*Nc*Nsrc); double weight = 1.0 / sqrt(Ns*Nc*Nsrc);
int N_count = 0; int N_count = 0;
for (unsigned int s = 0; s < Ns; ++s) for (unsigned int s = 0; s < Ns; ++s)
@ -184,11 +200,13 @@ void TA2AVectors<FImpl, nBasis>::execute(void)
if (Ls_ == 1) if (Ls_ == 1)
{ {
PropToFerm<FImpl>(ferm_src, prop_src, s, c); PropToFerm<FImpl>(ferm_src, prop_src, s, c);
ferm_src = weight*ferm_src;
tmp = ferm_src;
} }
else else
{ {
PropToFerm<FImpl>(tmp2, prop_src, s, c); PropToFerm<FImpl>(tmp, prop_src, s, c);
tmp = weight*tmp2; tmp = weight*tmp;
action.ImportPhysicalFermionSource(tmp, ferm_src); action.ImportPhysicalFermionSource(tmp, ferm_src);
} }
} }
@ -202,6 +220,8 @@ void TA2AVectors<FImpl, nBasis>::execute(void)
else else
{ {
PropToFerm<FImpl>(ferm_src, prop_src, s, c); PropToFerm<FImpl>(ferm_src, prop_src, s, c);
ferm_src = weight*ferm_src;
action.ExportPhysicalFermionSolution(ferm_src, tmp);
} }
} }
LOG(Message) << "a2areturn.high_modes Ncount = " << N_count << std::endl; LOG(Message) << "a2areturn.high_modes Ncount = " << N_count << std::endl;