1
0
mirror of https://github.com/paboyle/Grid.git synced 2025-04-09 21:50:45 +01:00

Hadrons: memory management for fermion matrices, dynamic ownership in garbage collector

This commit is contained in:
Antonin Portelli 2016-05-04 19:11:03 -07:00
parent cbe52b0659
commit 75cd72a421
23 changed files with 253 additions and 111 deletions

View File

@ -60,14 +60,16 @@ std::vector<std::string> AWilson::getOutput(void)
} }
// execution /////////////////////////////////////////////////////////////////// // execution ///////////////////////////////////////////////////////////////////
void AWilson::execute(Environment &env) void AWilson::execute()
{ {
auto &U = *env.get<LatticeGaugeField>(par_.gauge); auto &U = *env().get<LatticeGaugeField>(par_.gauge);
auto &grid = *env.getGrid(); auto &grid = *env().getGrid();
auto &gridRb = *env.getRbGrid(); auto &gridRb = *env().getRbGrid();
auto fMatPt = new WilsonFermionR(U, grid, gridRb, par_.mass);
unsigned int size;
LOG(Message) << "Setting up Wilson fermion matrix with m= " << par_.mass LOG(Message) << "Setting up Wilson fermion matrix with m= " << par_.mass
<< " using gauge field '" << par_.gauge << "'" << std::endl; << " using gauge field '" << par_.gauge << "'" << std::endl;
env.addFermionMatrix(getName(), size = 3*env().lattice4dSize<WilsonFermionR::DoubledGaugeField>();
new WilsonFermionR(U, grid, gridRb, par_.mass)); env().addFermionMatrix(getName(), fMatPt, size);
} }

View File

@ -57,7 +57,7 @@ public:
virtual std::vector<std::string> getInput(void); virtual std::vector<std::string> getInput(void);
virtual std::vector<std::string> getOutput(void); virtual std::vector<std::string> getOutput(void);
// execution // execution
virtual void execute(Environment &env); virtual void execute(void);
private: private:
Par par_; Par par_;
}; };

View File

@ -189,9 +189,11 @@ void Application::configLoop(void)
unsigned int Application::execute(const std::vector<std::string> &program) unsigned int Application::execute(const std::vector<std::string> &program)
{ {
unsigned int memPeak = 0, size; unsigned int memPeak = 0, size;
std::vector<std::vector<std::string>> freeProg; std::vector<std::set<std::string>> freeProg;
bool continueCollect;
// build garbage collection schedule
freeProg.resize(program.size()); freeProg.resize(program.size());
for (auto &n: associatedModule_) for (auto &n: associatedModule_)
{ {
@ -205,27 +207,54 @@ unsigned int Application::execute(const std::vector<std::string> &program)
auto it = std::find_if(program.rbegin(), program.rend(), pred); auto it = std::find_if(program.rbegin(), program.rend(), pred);
if (it != program.rend()) if (it != program.rend())
{ {
freeProg[program.rend() - it - 1].push_back(n.first); freeProg[program.rend() - it - 1].insert(n.first);
} }
} }
// program execution
for (unsigned int i = 0; i < program.size(); ++i) for (unsigned int i = 0; i < program.size(); ++i)
{ {
// execute module
LOG(Message) << "---------- Measurement step " << i+1 << "/" LOG(Message) << "---------- Measurement step " << i+1 << "/"
<< program.size() << " (module '" << program[i] << "')" << program.size() << " (module '" << program[i] << "')"
<< " ----------" << std::endl; << " ----------" << std::endl;
(*module_[program[i]])(env_); (*module_[program[i]])();
size = env_.getTotalSize(); size = env_.getTotalSize();
// print used memory after execution
LOG(Message) << "Allocated objects: " << sizeString(size*locVol_) LOG(Message) << "Allocated objects: " << sizeString(size*locVol_)
<< " (" << sizeString(size) << "/site)" << std::endl; << " (" << sizeString(size) << "/site)" << std::endl;
if (size > memPeak) if (size > memPeak)
{ {
memPeak = size; memPeak = size;
} }
// garbage collection for step i
LOG(Message) << "Garbage collection..." << std::endl; LOG(Message) << "Garbage collection..." << std::endl;
for (auto &n: freeProg[i]) do
{ {
env_.free(n); continueCollect = false;
auto toFree = freeProg[i];
for (auto &n: toFree)
{
// continue garbage collection while there are still
// objects without owners
continueCollect = continueCollect or !env_.hasOwners(n);
if(env_.free(n))
{
// if an object has been freed, remove it from
// the garbage collection schedule
freeProg[i].erase(n);
}
}
} while (continueCollect);
// any remaining objects in step i garbage collection schedule
// is scheduled for step i + 1
if (i + 1 < program.size())
{
for (auto &n: freeProg[i])
{
freeProg[i + 1].insert(n);
}
} }
// print used memory after garbage collection
size = env_.getTotalSize(); size = env_.getTotalSize();
LOG(Message) << "Allocated objects: " << sizeString(size*locVol_) LOG(Message) << "Allocated objects: " << sizeString(size*locVol_)
<< " (" << sizeString(size) << "/site)" << std::endl; << " (" << sizeString(size) << "/site)" << std::endl;

View File

@ -61,20 +61,20 @@ std::vector<std::string> CMeson::getOutput(void)
} }
// execution /////////////////////////////////////////////////////////////////// // execution ///////////////////////////////////////////////////////////////////
void CMeson::execute(Environment &env) void CMeson::execute(void)
{ {
LOG(Message) << "Computing meson contraction '" << getName() << "' using" LOG(Message) << "Computing meson contraction '" << getName() << "' using"
<< " quarks '" << par_.q1 << " and '" << par_.q2 << "'" << " quarks '" << par_.q1 << " and '" << par_.q2 << "'"
<< std::endl; << std::endl;
XmlWriter writer(par_.output); XmlWriter writer(par_.output);
LatticePropagator &q1 = *env.get<LatticePropagator>(par_.q1); LatticePropagator &q1 = *env().get<LatticePropagator>(par_.q1);
LatticePropagator &q2 = *env.get<LatticePropagator>(par_.q2); LatticePropagator &q2 = *env().get<LatticePropagator>(par_.q2);
LatticeComplex c(env.getGrid()); LatticeComplex c(env().getGrid());
SpinMatrix g[Ns*Ns], g5; SpinMatrix g[Ns*Ns], g5;
std::vector<TComplex> buf; std::vector<TComplex> buf;
Result result; Result result;
unsigned int nt = env.getGrid()->GlobalDimensions()[Tp]; unsigned int nt = env().getGrid()->GlobalDimensions()[Tp];
g5 = makeGammaProd(Ns*Ns - 1); g5 = makeGammaProd(Ns*Ns - 1);
result.corr.resize(Ns*Ns); result.corr.resize(Ns*Ns);

View File

@ -65,7 +65,7 @@ public:
virtual std::vector<std::string> getInput(void); virtual std::vector<std::string> getInput(void);
virtual std::vector<std::string> getOutput(void); virtual std::vector<std::string> getOutput(void);
// execution // execution
virtual void execute(Environment &env); virtual void execute(void);
private: private:
Par par_; Par par_;
}; };

View File

@ -106,18 +106,20 @@ GridRedBlackCartesian * Environment::getRbGrid(const unsigned int Ls) const
} }
// fermion actions ///////////////////////////////////////////////////////////// // fermion actions /////////////////////////////////////////////////////////////
void Environment::addFermionMatrix(const std::string name, FMat *fMat) void Environment::addFermionMatrix(const std::string name, FMat *fMat,
const unsigned int size)
{ {
fMat_[name].reset(fMat); fMat_[name].reset(fMat);
addSize(name, size);
} }
Environment::FMat * Environment::getFermionMatrix(const std::string name) const Environment::FMat * Environment::getFermionMatrix(const std::string name) const
{ {
try if (hasFermionMatrix(name))
{ {
return fMat_.at(name).get(); return fMat_.at(name).get();
} }
catch(std::out_of_range &) else
{ {
try try
{ {
@ -130,6 +132,25 @@ Environment::FMat * Environment::getFermionMatrix(const std::string name) const
} }
} }
void Environment::freeFermionMatrix(const std::string name)
{
if (hasFermionMatrix(name))
{
LOG(Message) << "freeing fermion matrix '" << name << "'" << std::endl;
fMat_.erase(name);
objectSize_.erase(name);
}
else
{
HADRON_ERROR("trying to free undefined fermion matrix '" + name + "'");
}
}
bool Environment::hasFermionMatrix(const std::string name) const
{
return (fMat_.find(name) != fMat_.end());
}
// solvers ///////////////////////////////////////////////////////////////////// // solvers /////////////////////////////////////////////////////////////////////
void Environment::addSolver(const std::string name, Solver s, void Environment::addSolver(const std::string name, Solver s,
const std::string actionName) const std::string actionName)
@ -138,14 +159,31 @@ void Environment::addSolver(const std::string name, Solver s,
solverAction_[name] = actionName; solverAction_[name] = actionName;
} }
bool Environment::hasSolver(const std::string name) const
{
return (solver_.find(name) != solver_.end());
}
std::string Environment::getSolverAction(const std::string name) const
{
if (hasSolver(name))
{
return solverAction_.at(name);
}
else
{
HADRON_ERROR("no solver with name '" << name << "'");
}
}
void Environment::callSolver(const std::string name, LatticeFermion &sol, void Environment::callSolver(const std::string name, LatticeFermion &sol,
const LatticeFermion &source) const const LatticeFermion &source) const
{ {
try if (hasSolver(name))
{ {
solver_.at(name)(sol, source); solver_.at(name)(sol, source);
} }
catch(std::out_of_range &) else
{ {
HADRON_ERROR("no solver with name '" << name << "'"); HADRON_ERROR("no solver with name '" << name << "'");
} }
@ -162,7 +200,7 @@ GridParallelRNG * Environment::get4dRng(void) const
return rng4d_.get(); return rng4d_.get();
} }
// data store ////////////////////////////////////////////////////////////////// // lattice store ///////////////////////////////////////////////////////////////
void Environment::freeLattice(const std::string name) void Environment::freeLattice(const std::string name)
{ {
if (hasLattice(name)) if (hasLattice(name))
@ -210,28 +248,62 @@ unsigned int Environment::getLatticeLs(const std::string name) const
} }
// general memory management /////////////////////////////////////////////////// // general memory management ///////////////////////////////////////////////////
void Environment::free(const std::string name) void Environment::addOwnership(const std::string owner,
const std::string property)
{ {
if (hasLattice(name)) owners_[property].insert(owner);
properties_[owner].insert(property);
}
bool Environment::hasOwners(const std::string name) const
{
try
{ {
freeLattice(name); return (!owners_.at(name).empty());
}
catch (std::out_of_range &)
{
return false;
}
}
bool Environment::free(const std::string name)
{
if (!hasOwners(name))
{
for (auto &p: properties_[name])
{
owners_[p].erase(name);
}
properties_[name].clear();
if (hasLattice(name))
{
freeLattice(name);
}
else if (hasFermionMatrix(name))
{
freeFermionMatrix(name);
}
return true;
}
else
{
return false;
} }
} }
void Environment::freeAll(void) void Environment::freeAll(void)
{ {
lattice_.clear(); lattice_.clear();
fMat_.clear();
solver_.clear();
objectSize_.clear(); objectSize_.clear();
} }
void Environment::addSize(const std::string name, const unsigned int size)
{
objectSize_[name] = size;
}
unsigned int Environment::getSize(const std::string name) const unsigned int Environment::getSize(const std::string name) const
{ {
if (hasLattice(name)) if (hasLattice(name) or hasFermionMatrix(name))
{ {
return objectSize_.at(name); return objectSize_.at(name);
} }
@ -254,3 +326,8 @@ long unsigned int Environment::getTotalSize(void) const
return size; return size;
} }
void Environment::addSize(const std::string name, const unsigned int size)
{
objectSize_[name] = size;
}

View File

@ -58,11 +58,16 @@ public:
GridCartesian * getGrid(const unsigned int Ls = 1) const; GridCartesian * getGrid(const unsigned int Ls = 1) const;
GridRedBlackCartesian * getRbGrid(const unsigned int Ls = 1) const; GridRedBlackCartesian * getRbGrid(const unsigned int Ls = 1) const;
// fermion actions // fermion actions
void addFermionMatrix(const std::string name, FMat *mat); void addFermionMatrix(const std::string name, FMat *mat,
const unsigned int size);
FMat * getFermionMatrix(const std::string name) const; FMat * getFermionMatrix(const std::string name) const;
void freeFermionMatrix(const std::string name);
bool hasFermionMatrix(const std::string name) const;
// solvers // solvers
void addSolver(const std::string name, Solver s, void addSolver(const std::string name, Solver s,
const std::string actionName); const std::string actionName);
bool hasSolver(const std::string name) const;
std::string getSolverAction(const std::string name) const;
void callSolver(const std::string name, void callSolver(const std::string name,
LatticeFermion &sol, LatticeFermion &sol,
const LatticeFermion &src) const; const LatticeFermion &src) const;
@ -71,22 +76,26 @@ public:
GridParallelRNG * get4dRng(void) const; GridParallelRNG * get4dRng(void) const;
// lattice store // lattice store
template <typename T> template <typename T>
unsigned int lattice4dSize(void) const;
template <typename T>
void create(const std::string name, void create(const std::string name,
const unsigned int Ls = 1); const unsigned int Ls = 1);
template <typename T> template <typename T>
T * get(const std::string name) const; T * get(const std::string name) const;
void freeLattice(const std::string name); void freeLattice(const std::string name);
bool hasLattice(const std::string name) const; bool hasLattice(const std::string name) const;
bool isLattice5d(const std::string name) const; bool isLattice5d(const std::string name) const;
unsigned int getLatticeLs(const std::string name) const; unsigned int getLatticeLs(const std::string name) const;
// general memory management // general memory management
void free(const std::string name); void addOwnership(const std::string owner,
const std::string property);
bool hasOwners(const std::string name) const;
bool free(const std::string name);
void freeAll(void); void freeAll(void);
void addSize(const std::string name,
const unsigned int size);
unsigned int getSize(const std::string name) const; unsigned int getSize(const std::string name) const;
long unsigned int getTotalSize(void) const; long unsigned int getTotalSize(void) const;
private:
void addSize(const std::string name, const unsigned int size);
private: private:
bool dryRun_{false}; bool dryRun_{false};
unsigned int traj_; unsigned int traj_;
@ -100,11 +109,19 @@ private:
std::map<std::string, std::string> solverAction_; std::map<std::string, std::string> solverAction_;
std::map<std::string, LatticePt> lattice_; std::map<std::string, LatticePt> lattice_;
std::map<std::string, unsigned int> objectSize_; std::map<std::string, unsigned int> objectSize_;
std::map<std::string, std::set<std::string>> owners_;
std::map<std::string, std::set<std::string>> properties_;
}; };
/****************************************************************************** /******************************************************************************
* template implementation * * template implementation *
******************************************************************************/ ******************************************************************************/
template <typename T>
unsigned int Environment::lattice4dSize(void) const
{
return sizeof(typename T::vector_object)/getGrid()->Nsimd();
}
template <typename T> template <typename T>
void Environment::create(const std::string name, const unsigned int Ls) void Environment::create(const std::string name, const unsigned int Ls)
{ {
@ -140,7 +157,7 @@ void Environment::create(const std::string name, const unsigned int Ls)
{ {
lattice_[name].reset(nullptr); lattice_[name].reset(nullptr);
} }
objectSize_[name] = sizeof(typename T::vector_object)/g->Nsimd()*Ls; addSize(name, lattice4dSize<T>()*Ls);
} }
template <typename T> template <typename T>

View File

@ -60,18 +60,18 @@ std::vector<std::string> GLoad::getOutput(void)
} }
// allocation ////////////////////////////////////////////////////////////////// // allocation //////////////////////////////////////////////////////////////////
void GLoad::allocate(Environment &env) void GLoad::allocate(void)
{ {
env.create<LatticeGaugeField>(getName()); env().create<LatticeGaugeField>(getName());
gauge_ = env.get<LatticeGaugeField>(getName()); gauge_ = env().get<LatticeGaugeField>(getName());
} }
// execution /////////////////////////////////////////////////////////////////// // execution ///////////////////////////////////////////////////////////////////
void GLoad::execute(Environment &env) void GLoad::execute(void)
{ {
NerscField header; NerscField header;
std::string fileName = par_.file + "." std::string fileName = par_.file + "."
+ std::to_string(env.getTrajectory()); + std::to_string(env().getTrajectory());
LOG(Message) << "Loading NERSC configuration from file '" << fileName LOG(Message) << "Loading NERSC configuration from file '" << fileName
<< "'" << std::endl; << "'" << std::endl;

View File

@ -56,9 +56,9 @@ public:
virtual std::vector<std::string> getInput(void); virtual std::vector<std::string> getInput(void);
virtual std::vector<std::string> getOutput(void); virtual std::vector<std::string> getOutput(void);
// allocation // allocation
virtual void allocate(Environment &env); virtual void allocate(void);
// execution // execution
virtual void execute(Environment &env); virtual void execute(void);
private: private:
Par par_; Par par_;
LatticeGaugeField *gauge_ = nullptr; LatticeGaugeField *gauge_ = nullptr;

View File

@ -52,15 +52,15 @@ std::vector<std::string> GRandom::getOutput(void)
} }
// allocation ////////////////////////////////////////////////////////////////// // allocation //////////////////////////////////////////////////////////////////
void GRandom::allocate(Environment &env) void GRandom::allocate(void)
{ {
env.create<LatticeGaugeField>(getName()); env().create<LatticeGaugeField>(getName());
gauge_ = env.get<LatticeGaugeField>(getName()); gauge_ = env().get<LatticeGaugeField>(getName());
} }
// execution /////////////////////////////////////////////////////////////////// // execution ///////////////////////////////////////////////////////////////////
void GRandom::execute(Environment &env) void GRandom::execute(void)
{ {
LOG(Message) << "Generating random gauge configuration" << std::endl; LOG(Message) << "Generating random gauge configuration" << std::endl;
SU3::HotConfiguration(*env.get4dRng(), *gauge_); SU3::HotConfiguration(*env().get4dRng(), *gauge_);
} }

View File

@ -48,9 +48,9 @@ public:
virtual std::vector<std::string> getInput(void); virtual std::vector<std::string> getInput(void);
virtual std::vector<std::string> getOutput(void); virtual std::vector<std::string> getOutput(void);
// allocation // allocation
virtual void allocate(Environment &env); virtual void allocate(void);
// execution // execution
virtual void execute(Environment &env); virtual void execute(void);
private: private:
LatticeGaugeField *gauge_ = nullptr; LatticeGaugeField *gauge_ = nullptr;
}; };

View File

@ -52,15 +52,15 @@ std::vector<std::string> GUnit::getOutput(void)
} }
// allocation ////////////////////////////////////////////////////////////////// // allocation //////////////////////////////////////////////////////////////////
void GUnit::allocate(Environment &env) void GUnit::allocate(void)
{ {
env.create<LatticeGaugeField>(getName()); env().create<LatticeGaugeField>(getName());
gauge_ = env.get<LatticeGaugeField>(getName()); gauge_ = env().get<LatticeGaugeField>(getName());
} }
// execution /////////////////////////////////////////////////////////////////// // execution ///////////////////////////////////////////////////////////////////
void GUnit::execute(Environment &env) void GUnit::execute(void)
{ {
LOG(Message) << "Creating unit gauge configuration" << std::endl; LOG(Message) << "Creating unit gauge configuration" << std::endl;
SU3::ColdConfiguration(*env.get4dRng(), *gauge_); SU3::ColdConfiguration(*env().get4dRng(), *gauge_);
} }

View File

@ -48,9 +48,9 @@ public:
virtual std::vector<std::string> getInput(void); virtual std::vector<std::string> getInput(void);
virtual std::vector<std::string> getOutput(void); virtual std::vector<std::string> getOutput(void);
// allocation // allocation
virtual void allocate(Environment &env); virtual void allocate(void);
// execution // execution
virtual void execute(Environment &env); virtual void execute(void);
private: private:
LatticeGaugeField *gauge_ = nullptr; LatticeGaugeField *gauge_ = nullptr;
}; };

View File

@ -61,9 +61,9 @@ std::vector<std::string> MQuark::getOutput(void)
} }
// setup /////////////////////////////////////////////////////////////////////// // setup ///////////////////////////////////////////////////////////////////////
void MQuark::setup(Environment &env) void MQuark::setup(void)
{ {
auto dim = env.getFermionMatrix(par_.solver)->Grid()->GlobalDimensions(); auto dim = env().getFermionMatrix(par_.solver)->Grid()->GlobalDimensions();
if (dim.size() == Nd) if (dim.size() == Nd)
{ {
@ -76,30 +76,30 @@ void MQuark::setup(Environment &env)
} }
// allocation ////////////////////////////////////////////////////////////////// // allocation //////////////////////////////////////////////////////////////////
void MQuark::allocate(Environment &env) void MQuark::allocate(void)
{ {
env.create<LatticePropagator>(getName()); env().create<LatticePropagator>(getName());
quark_ = env.get<LatticePropagator>(getName()); quark_ = env().get<LatticePropagator>(getName());
if (Ls_ > 1) if (Ls_ > 1)
{ {
env.create<LatticePropagator>(getName() + "_5d", Ls_); env().create<LatticePropagator>(getName() + "_5d", Ls_);
quark5d_ = env.get<LatticePropagator>(getName() + "_5d"); quark5d_ = env().get<LatticePropagator>(getName() + "_5d");
} }
} }
// execution /////////////////////////////////////////////////////////////////// // execution ///////////////////////////////////////////////////////////////////
void MQuark::execute(Environment &env) void MQuark::execute(void)
{ {
LatticePropagator *fullSource; LatticePropagator *fullSource;
LatticeFermion source(env.getGrid(Ls_)), sol(env.getGrid(Ls_)); LatticeFermion source(env().getGrid(Ls_)), sol(env().getGrid(Ls_));
LOG(Message) << "Computing quark propagator '" << getName() << "'" LOG(Message) << "Computing quark propagator '" << getName() << "'"
<< std::endl; << std::endl;
if (!env.isLattice5d(par_.source)) if (!env().isLattice5d(par_.source))
{ {
if (Ls_ == 1) if (Ls_ == 1)
{ {
fullSource = env.get<LatticePropagator>(par_.source); fullSource = env().get<LatticePropagator>(par_.source);
} }
else else
{ {
@ -112,16 +112,15 @@ void MQuark::execute(Environment &env)
{ {
HADRON_ERROR("MQuark not implemented with 5D actions"); HADRON_ERROR("MQuark not implemented with 5D actions");
} }
else if (Ls_ != env.getLatticeLs(par_.source)) else if (Ls_ != env().getLatticeLs(par_.source))
{ {
HADRON_ERROR("MQuark not implemented with 5D actions"); HADRON_ERROR("MQuark not implemented with 5D actions");
} }
else else
{ {
fullSource = env.get<LatticePropagator>(par_.source); fullSource = env().get<LatticePropagator>(par_.source);
} }
} }
LOG(Message) << "Inverting using solver '" << par_.solver LOG(Message) << "Inverting using solver '" << par_.solver
<< "' on source '" << par_.source << "'" << std::endl; << "' on source '" << par_.source << "'" << std::endl;
for (unsigned int s = 0; s < Ns; ++s) for (unsigned int s = 0; s < Ns; ++s)
@ -129,7 +128,7 @@ void MQuark::execute(Environment &env)
{ {
PropToFerm(source, *fullSource, s, c); PropToFerm(source, *fullSource, s, c);
sol = zero; sol = zero;
env.callSolver(par_.solver, sol, source); env().callSolver(par_.solver, sol, source);
if (Ls_ == 1) if (Ls_ == 1)
{ {
FermToProp(*quark_, sol, s, c); FermToProp(*quark_, sol, s, c);

View File

@ -43,8 +43,8 @@ public:
class Par: Serializable class Par: Serializable
{ {
public: public:
GRID_SERIALIZABLE_CLASS_MEMBERS(Par, std::string , source, GRID_SERIALIZABLE_CLASS_MEMBERS(Par, std::string, source,
std::string , solver); std::string, solver);
}; };
public: public:
// constructor // constructor
@ -57,11 +57,11 @@ public:
virtual std::vector<std::string> getInput(void); virtual std::vector<std::string> getInput(void);
virtual std::vector<std::string> getOutput(void); virtual std::vector<std::string> getOutput(void);
// setup // setup
virtual void setup(Environment &env); virtual void setup(void);
// allocation // allocation
virtual void allocate(Environment &env); virtual void allocate(void);
// execution // execution
virtual void execute(Environment &env); virtual void execute(void);
private: private:
Par par_; Par par_;
unsigned int Ls_; unsigned int Ls_;

View File

@ -37,6 +37,7 @@ using namespace Hadrons;
// constructor ///////////////////////////////////////////////////////////////// // constructor /////////////////////////////////////////////////////////////////
Module::Module(const std::string name) Module::Module(const std::string name)
: name_(name) : name_(name)
, env_(Environment::getInstance())
{} {}
// access ////////////////////////////////////////////////////////////////////// // access //////////////////////////////////////////////////////////////////////
@ -45,12 +46,18 @@ std::string Module::getName(void) const
return name_; return name_;
} }
void Module::operator()(Environment &env) Environment & Module::env(void) const
{ {
setup(env); return env_;
allocate(env); }
if (!env.isDryRun())
// execution ///////////////////////////////////////////////////////////////////
void Module::operator()(void)
{
setup();
allocate();
if (!env().isDryRun())
{ {
execute(env); execute();
} }
} }

View File

@ -61,20 +61,22 @@ public:
virtual ~Module(void) = default; virtual ~Module(void) = default;
// access // access
std::string getName(void) const; std::string getName(void) const;
Environment &env(void) const;
// parse parameters // parse parameters
virtual void parseParameters(XmlReader &reader, const std::string name) {}; virtual void parseParameters(XmlReader &reader, const std::string name) {};
// dependencies/products // dependencies/products
virtual std::vector<std::string> getInput(void) = 0; virtual std::vector<std::string> getInput(void) = 0;
virtual std::vector<std::string> getOutput(void) = 0; virtual std::vector<std::string> getOutput(void) = 0;
// setup // setup
virtual void setup(Environment &env) {}; virtual void setup(void) {};
// allocation // allocation
virtual void allocate(Environment &env) {}; virtual void allocate(void) {};
// execution // execution
void operator()(Environment &env); void operator()(void);
virtual void execute(Environment &env) = 0; virtual void execute(void) = 0;
private: private:
std::string name_; std::string name_;
Environment &env_;
}; };
END_HADRONS_NAMESPACE END_HADRONS_NAMESPACE

View File

@ -60,10 +60,17 @@ std::vector<std::string> SolRBPrecCG::getOutput(void)
return out; return out;
} }
// execution /////////////////////////////////////////////////////////////////// // setup ///////////////////////////////////////////////////////////////////////
void SolRBPrecCG::execute(Environment &env) void SolRBPrecCG::setup(void)
{ {
auto &mat = *(env.getFermionMatrix(par_.action)); env().addOwnership(getName(), par_.action);
}
// execution ///////////////////////////////////////////////////////////////////
void SolRBPrecCG::execute(void)
{
auto &mat = *(env().getFermionMatrix(par_.action));
auto solver = [&mat, this](LatticeFermion &sol, auto solver = [&mat, this](LatticeFermion &sol,
const LatticeFermion &source) const LatticeFermion &source)
{ {
@ -76,5 +83,5 @@ void SolRBPrecCG::execute(Environment &env)
LOG(Message) << "setting up Schur red-black preconditioned CG for" LOG(Message) << "setting up Schur red-black preconditioned CG for"
<< " action '" << par_.action << "' with residual " << " action '" << par_.action << "' with residual "
<< par_.residual << std::endl; << par_.residual << std::endl;
env.addSolver(getName(), solver, par_.action); env().addSolver(getName(), solver, par_.action);
} }

View File

@ -56,8 +56,10 @@ public:
// dependencies/products // dependencies/products
virtual std::vector<std::string> getInput(void); virtual std::vector<std::string> getInput(void);
virtual std::vector<std::string> getOutput(void); virtual std::vector<std::string> getOutput(void);
// setup
virtual void setup(void);
// execution // execution
virtual void execute(Environment &env); virtual void execute(void);
private: private:
Par par_; Par par_;
}; };

View File

@ -60,14 +60,14 @@ std::vector<std::string> SrcPoint::getOutput(void)
} }
// allocation ////////////////////////////////////////////////////////////////// // allocation //////////////////////////////////////////////////////////////////
void SrcPoint::allocate(Environment &env) void SrcPoint::allocate(void)
{ {
env.create<LatticePropagator>(getName()); env().create<LatticePropagator>(getName());
src_ = env.get<LatticePropagator>(getName()); src_ = env().get<LatticePropagator>(getName());
} }
// execution /////////////////////////////////////////////////////////////////// // execution ///////////////////////////////////////////////////////////////////
void SrcPoint::execute(Environment &env) void SrcPoint::execute(void)
{ {
std::vector<int> position = strToVec<int>(par_.position); std::vector<int> position = strToVec<int>(par_.position);
SpinColourMatrix id; SpinColourMatrix id;

View File

@ -67,9 +67,9 @@ public:
virtual std::vector<std::string> getInput(void); virtual std::vector<std::string> getInput(void);
virtual std::vector<std::string> getOutput(void); virtual std::vector<std::string> getOutput(void);
// allocation // allocation
virtual void allocate(Environment &env); virtual void allocate(void);
// execution // execution
virtual void execute(Environment &env); virtual void execute(void);
private: private:
Par par_; Par par_;
LatticePropagator *src_{nullptr}; LatticePropagator *src_{nullptr};

View File

@ -60,18 +60,18 @@ std::vector<std::string> SrcZ2::getOutput(void)
} }
// allocation ////////////////////////////////////////////////////////////////// // allocation //////////////////////////////////////////////////////////////////
void SrcZ2::allocate(Environment &env) void SrcZ2::allocate(void)
{ {
env.create<LatticePropagator>(getName()); env().create<LatticePropagator>(getName());
src_ = env.get<LatticePropagator>(getName()); src_ = env().get<LatticePropagator>(getName());
} }
// execution /////////////////////////////////////////////////////////////////// // execution ///////////////////////////////////////////////////////////////////
void SrcZ2::execute(Environment &env) void SrcZ2::execute(void)
{ {
Lattice<iScalar<vInteger>> t(env.getGrid()); Lattice<iScalar<vInteger>> t(env().getGrid());
LatticeComplex eta(env.getGrid()); LatticeComplex eta(env().getGrid());
LatticeFermion phi(env.getGrid()); LatticeFermion phi(env().getGrid());
Complex shift(1., 1.); Complex shift(1., 1.);
if (par_.tA == par_.tB) if (par_.tA == par_.tB)
@ -85,7 +85,7 @@ void SrcZ2::execute(Environment &env)
<< par_.tB << std::endl; << par_.tB << std::endl;
} }
LatticeCoordinate(t, Tp); LatticeCoordinate(t, Tp);
bernoulli(*env.get4dRng(), eta); bernoulli(*env().get4dRng(), eta);
eta = (2.*eta - shift)*(1./::sqrt(2.)); eta = (2.*eta - shift)*(1./::sqrt(2.));
eta = where((t >= par_.tA) and (t <= par_.tB), eta, 0.*eta); eta = where((t >= par_.tA) and (t <= par_.tB), eta, 0.*eta);
*src_ = 1.; *src_ = 1.;

View File

@ -69,9 +69,9 @@ public:
virtual std::vector<std::string> getInput(void); virtual std::vector<std::string> getInput(void);
virtual std::vector<std::string> getOutput(void); virtual std::vector<std::string> getOutput(void);
// allocation // allocation
virtual void allocate(Environment &env); virtual void allocate(void);
// execution // execution
virtual void execute(Environment &env); virtual void execute(void);
private: private:
Par par_; Par par_;
LatticePropagator *src_; LatticePropagator *src_;