1
0
mirror of https://github.com/paboyle/Grid.git synced 2024-11-10 07:55:35 +00:00

Hadrons: general guesser factory

This commit is contained in:
Antonin Portelli 2018-09-10 17:36:54 +01:00
parent 6d1d28955e
commit 6d912f6c67
2 changed files with 10 additions and 40 deletions

View File

@ -63,6 +63,9 @@ inline Environment & env(void) const\
return Environment::getInstance();\
}
#define DEFINE_ENV_LAMBDA \
auto env = [](void)->Environment &{return Environment::getInstance();}
class Environment
{
SINGLETON(Environment);

View File

@ -35,6 +35,7 @@ See the full license in the file "LICENSE" in the top level distribution directo
#include <Hadrons/ModuleFactory.hpp>
#include <Hadrons/Solver.hpp>
#include <Hadrons/EigenPack.hpp>
#include <Hadrons/Modules/MSolver/Guesser.hpp>
BEGIN_HADRONS_NAMESPACE
@ -59,13 +60,6 @@ class TRBPrecCG: public Module<RBPrecCGPar>
public:
FG_TYPE_ALIASES(FImpl,);
SOLVER_TYPE_ALIASES(FImpl,);
typedef FermionEigenPack<FImpl> EPack;
typedef CoarseFermionEigenPack<FImpl, nBasis> CoarseEPack;
typedef std::shared_ptr<Guesser<FermionField>> GuesserPt;
typedef DeflatedGuesser<typename FImpl::FermionField> FineGuesser;
typedef LocalCoherenceDeflatedGuesser<
typename FImpl::FermionField,
typename CoarseEPack::CoarseField> CoarseGuesser;
public:
// constructor
TRBPrecCG(const std::string name);
@ -138,45 +132,18 @@ void TRBPrecCG<FImpl, nBasis>::setup(void)
<< par().residual << ", maximum iteration "
<< par().maxIteration << std::endl;
auto Ls = env().getObjectLs(par().action);
auto &mat = envGet(FMat, par().action);
std::string guesserName = getName() + "_guesser";
GuesserPt guesser{nullptr};
auto Ls = env().getObjectLs(par().action);
auto &mat = envGet(FMat, par().action);
auto guesserPt = makeGuesser<FImpl, nBasis>(par().eigenPack);
if (par().eigenPack.empty())
{
guesser.reset(new ZeroGuesser<FermionField>());
}
else
{
try
{
auto &epack = envGetDerived(EPack, CoarseEPack, par().eigenPack);
LOG(Message) << "using low-mode deflation with coarse eigenpack '"
<< par().eigenPack << "' ("
<< epack.evecCoarse.size() << " modes)" << std::endl;
guesser.reset(new CoarseGuesser(epack.evec, epack.evecCoarse,
epack.evalCoarse));
}
catch (Exceptions::ObjectType &e)
{
auto &epack = envGet(EPack, par().eigenPack);
LOG(Message) << "using low-mode deflation with eigenpack '"
<< par().eigenPack << "' ("
<< epack.evec.size() << " modes)" << std::endl;
guesser.reset(new FineGuesser(epack.evec, epack.eval));
}
}
auto makeSolver = [&mat, guesser, this](bool subGuess) {
return [&mat, guesser, subGuess, this](FermionField &sol,
auto makeSolver = [&mat, guesserPt, this](bool subGuess) {
return [&mat, guesserPt, subGuess, this](FermionField &sol,
const FermionField &source) {
ConjugateGradient<FermionField> cg(par().residual,
par().maxIteration);
HADRONS_DEFAULT_SCHUR_SOLVE<FermionField> schurSolver(cg);
schurSolver.subtractGuess(subGuess);
schurSolver(mat, source, sol, *guesser);
schurSolver(mat, source, sol, *guesserPt);
};
};
auto solver = makeSolver(false);