1
0
mirror of https://github.com/paboyle/Grid.git synced 2025-04-25 21:25:56 +01:00

Adding factories

This commit is contained in:
Guido Cossu 2017-01-16 10:18:09 +00:00
parent 0dfda4bb90
commit c6f59c2933
15 changed files with 583 additions and 171 deletions

View File

@ -492,7 +492,6 @@ namespace QCD {
} //namespace QCD } //namespace QCD
} // Grid } // Grid
#include <Grid/qcd/utils/SpaceTimeGrid.h> #include <Grid/qcd/utils/SpaceTimeGrid.h>
#include <Grid/qcd/spin/Dirac.h> #include <Grid/qcd/spin/Dirac.h>
#include <Grid/qcd/spin/TwoSpinor.h> #include <Grid/qcd/spin/TwoSpinor.h>
@ -517,5 +516,6 @@ namespace QCD {
#include <Grid/qcd/hmc/HMC.h> #include <Grid/qcd/hmc/HMC.h>
#include <Grid/qcd/modules/mods.h>
#endif #endif

View File

@ -10,6 +10,7 @@ Author: Azusa Yamaguchi <ayamaguc@staffmail.ed.ac.uk>
Author: Peter Boyle <paboyle@ph.ed.ac.uk> Author: Peter Boyle <paboyle@ph.ed.ac.uk>
Author: neo <cossu@post.kek.jp> Author: neo <cossu@post.kek.jp>
Author: paboyle <paboyle@ph.ed.ac.uk> Author: paboyle <paboyle@ph.ed.ac.uk>
Author: Guido Cossu <guido.cossu@ed.ac.uk>
This program is free software; you can redistribute it and/or modify This program is free software; you can redistribute it and/or modify
it under the terms of the GNU General Public License as published by it under the terms of the GNU General Public License as published by
@ -43,11 +44,8 @@ class WilsonGaugeAction : public Action<typename Gimpl::GaugeField> {
public: public:
INHERIT_GIMPL_TYPES(Gimpl); INHERIT_GIMPL_TYPES(Gimpl);
private: /////////////////////////// constructors
RealD beta; explicit WilsonGaugeAction(RealD beta_):beta(beta_){};
public:
explicit WilsonGaugeAction(RealD b) : beta(b){}
virtual std::string action_name() {return "WilsonGaugeAction";} virtual std::string action_name() {return "WilsonGaugeAction";}
@ -85,9 +83,12 @@ class WilsonGaugeAction : public Action<typename Gimpl::GaugeField> {
PokeIndex<LorentzIndex>(dSdU, dSdU_mu, mu); PokeIndex<LorentzIndex>(dSdU, dSdU_mu, mu);
} }
} }
private:
RealD beta;
}; };
} }
} }

View File

@ -151,7 +151,7 @@ class HMCWrapperTemplate {
Implementation::TepidConfiguration(Resources.GetParallelRNG(), U); Implementation::TepidConfiguration(Resources.GetParallelRNG(), U);
} else if (Payload.StartType == CheckpointStart) { } else if (Payload.StartType == CheckpointStart) {
// CheckpointRestart // CheckpointRestart
Resources.get_CheckPointer()->CheckpointRestore(Payload.Parameters.StartTrajectory, U, Resources.GetCheckPointer()->CheckpointRestore(Payload.Parameters.StartTrajectory, U,
Resources.GetSerialRNG(), Resources.GetSerialRNG(),
Resources.GetParallelRNG()); Resources.GetParallelRNG());
} }
@ -164,7 +164,7 @@ class HMCWrapperTemplate {
for (int obs = 0; obs < ObservablesList.size(); obs++) for (int obs = 0; obs < ObservablesList.size(); obs++)
HMC.AddObservable(ObservablesList[obs]); HMC.AddObservable(ObservablesList[obs]);
HMC.AddObservable(Resources.get_CheckPointer()); HMC.AddObservable(Resources.GetCheckPointer());
// Run it // Run it

View File

@ -50,6 +50,8 @@ struct HMCparameters {
bool MetropolisTest; bool MetropolisTest;
Integer NoMetropolisUntil; Integer NoMetropolisUntil;
// nest here the MDparameters and make all serializable
HMCparameters() { HMCparameters() {
////////////////////////////// Default values ////////////////////////////// Default values
MetropolisTest = true; MetropolisTest = true;

View File

@ -33,8 +33,42 @@ directory
namespace Grid { namespace Grid {
namespace QCD { namespace QCD {
/////////////////////////////////////////////////// ///////////////////////////////////////////////////
// Modules // Modules
class GridModuleParameters: Serializable{
GRID_SERIALIZABLE_CLASS_MEMBERS(GridModuleParameters,
std::string, lattice,
std::string, mpi);
public:
// these namings are ugly
// also ugly the distinction between the serializable members
// and this
std::vector<int> lattice_v;
std::vector<int> mpi_v;
GridModuleParameters(const std::vector<int> l_ = std::vector<int>(),
const std::vector<int> mpi_ = std::vector<int>()):lattice_v(l_), mpi_v(mpi_){}
template <class ReaderClass>
GridModuleParameters(Reader<ReaderClass>& Reader) {
read(Reader, "LatticeGrid", *this);
lattice_v = strToVec<int>(lattice);
mpi_v = strToVec<int>(mpi);
if (mpi_v.size() != lattice_v.size()) {
std::cout << "Error in GridModuleParameters: lattice and mpi dimensions "
"do not match"
<< std::endl;
exit(1);
}
}
};
class GridModule { class GridModule {
public: public:
GridCartesian* get_full() { return grid_.get(); } GridCartesian* get_full() { return grid_.get(); }
@ -46,11 +80,13 @@ class GridModule {
protected: protected:
std::unique_ptr<GridCartesian> grid_; std::unique_ptr<GridCartesian> grid_;
std::unique_ptr<GridRedBlackCartesian> rbgrid_; std::unique_ptr<GridRedBlackCartesian> rbgrid_;
}; };
// helpers // helpers
class GridFourDimModule : public GridModule { class GridFourDimModule : public GridModule {
public: public:
// add a function to create the module from a Reader
GridFourDimModule() { GridFourDimModule() {
set_full(SpaceTimeGrid::makeFourDimGrid( set_full(SpaceTimeGrid::makeFourDimGrid(
GridDefaultLatt(), GridDefaultSimd(4, vComplex::Nsimd()), GridDefaultLatt(), GridDefaultSimd(4, vComplex::Nsimd()),
@ -58,50 +94,53 @@ class GridFourDimModule : public GridModule {
set_rb(SpaceTimeGrid::makeFourDimRedBlackGrid(grid_.get())); set_rb(SpaceTimeGrid::makeFourDimRedBlackGrid(grid_.get()));
} }
template <class vector_type = vComplex>
GridFourDimModule(GridModuleParameters Params) {
if (Params.lattice_v.size() == 4) {
set_full(SpaceTimeGrid::makeFourDimGrid(
Params.lattice_v, GridDefaultSimd(4, vector_type::Nsimd()),
Params.mpi_v));
set_rb(SpaceTimeGrid::makeFourDimRedBlackGrid(grid_.get()));
} else {
std::cout
<< "Error in GridFourDimModule: lattice dimension different from 4"
<< std::endl;
exit(1);
}
}
}; };
//////////////////////////////////////////////////////////////////// ////////////////////////////////////////////////////////////////////
class RNGModuleParameters: Serializable { class RNGModuleParameters: Serializable {
public:
GRID_SERIALIZABLE_CLASS_MEMBERS(RNGModuleParameters, GRID_SERIALIZABLE_CLASS_MEMBERS(RNGModuleParameters,
std::vector<int>, SerialSeed_, std::string, serial_seeds,
std::vector<int>, ParallelSeed_,); std::string, parallel_seeds,);
public:
std::vector<int> SerialSeed;
std::vector<int> ParallelSeed;
RNGModuleParameters(const std::vector<int> S = std::vector<int>(),
const std::vector<int> P = std::vector<int>())
: SerialSeed(S), ParallelSeed(P) {}
// default constructor, needed for the non-Reader
// construction of the module
RNGModuleParameters(){
SerialSeed_.resize(0);
ParallelSeed_.resize(0);
}
RNGModuleParameters(const std::vector<int> S, const std::vector<int> P){
set_RNGSeeds(S,P);
}
template <class ReaderClass > template <class ReaderClass >
RNGModuleParameters(ReaderClass &Reader){ RNGModuleParameters(Reader<ReaderClass>& Reader){
read(Reader, "RandomNumberGenerator", *this); read(Reader, "RandomNumberGenerator", *this);
} SerialSeed = strToVec<int>(serial_seeds);
ParallelSeed = strToVec<int>(parallel_seeds);
void set_RNGSeeds(const std::vector<int>& S, const std::vector<int>& P){
SerialSeed_ = S;
ParallelSeed_ = P;
} }
}; };
// Random number generators module
class RNGModule{ class RNGModule{
// Random number generators
GridSerialRNG sRNG_; GridSerialRNG sRNG_;
std::unique_ptr<GridParallelRNG> pRNG_; std::unique_ptr<GridParallelRNG> pRNG_;
RNGModuleParameters Params_; RNGModuleParameters Params_;
public: public:
template < class ReaderClass >
RNGModule(ReaderClass &Reader):Params_(Reader){};
RNGModule(){}; RNGModule(){};
@ -109,36 +148,23 @@ public:
pRNG_.reset(pRNG); pRNG_.reset(pRNG);
} }
void set_RNGSeeds(const std::vector<int> S, const std::vector<int> P) { void set_RNGSeeds(RNGModuleParameters& Params) {
Params_.set_RNGSeeds(S,P); Params_ = Params;
} }
GridSerialRNG& get_sRNG() { return sRNG_; }
GridParallelRNG& get_pRNG() { return *pRNG_.get(); }
GridSerialRNG& get_sRNG(){
if (Params_.SerialSeed_.size()==0){
std::cout << "Serial seeds not initialized" << std::endl;
exit(1);
}
return sRNG_;
}
GridParallelRNG& get_pRNG(){
if (Params_.ParallelSeed_.size()==0){
std::cout << "Parallel seeds not initialized" << std::endl;
exit(1);
}
return *pRNG_.get();
}
void seed() { void seed() {
sRNG_.SeedFixedIntegers(Params_.SerialSeed_); if (Params_.SerialSeed.size() == 0 && Params_.ParallelSeed.size() == 0) {
pRNG_->SeedFixedIntegers(Params_.ParallelSeed_); std::cout << "Seeds not initialized" << std::endl;
exit(1);
}
sRNG_.SeedFixedIntegers(Params_.SerialSeed);
pRNG_->SeedFixedIntegers(Params_.ParallelSeed);
} }
}; };
/////////////////////////////////////////////////////////////////// ///////////////////////////////////////////////////////////////////
/// Smearing module /// Smearing module
template <class ImplementationPolicy> template <class ImplementationPolicy>

View File

@ -48,6 +48,7 @@ with this program; if not, write to the Free Software Foundation, Inc.,
} \ } \
} }
/*
// One function per Checkpointer using the reader, use a macro to simplify // One function per Checkpointer using the reader, use a macro to simplify
#define RegisterLoadCheckPointerReaderFunction(NAME) \ #define RegisterLoadCheckPointerReaderFunction(NAME) \
template <class Reader> \ template <class Reader> \
@ -64,6 +65,7 @@ with this program; if not, write to the Free Software Foundation, Inc.,
exit(1); \ exit(1); \
} \ } \
} }
*/
namespace Grid { namespace Grid {
namespace QCD { namespace QCD {
@ -71,7 +73,7 @@ namespace QCD {
// HMC Resource manager // HMC Resource manager
template <class ImplementationPolicy> template <class ImplementationPolicy>
class HMCResourceManager { class HMCResourceManager {
// Storage for grid pairs (std + red-black) // Named storage for grid pairs (std + red-black)
std::unordered_map<std::string, GridModule> Grids; std::unordered_map<std::string, GridModule> Grids;
RNGModule RNGs; RNGModule RNGs;
@ -83,6 +85,13 @@ class HMCResourceManager{
public: public:
HMCResourceManager() : have_RNG(false), have_CheckPointer(false) {} HMCResourceManager() : have_RNG(false), have_CheckPointer(false) {}
// Here need a constructor for using the Reader class
//////////////////////////////////////////////////////////////
// Grids
//////////////////////////////////////////////////////////////
void AddGrid(std::string s, GridModule& M) { void AddGrid(std::string s, GridModule& M) {
// Check for name clashes // Check for name clashes
auto search = Grids.find(s); auto search = Grids.find(s);
@ -94,12 +103,14 @@ class HMCResourceManager{
Grids[s] = std::move(M); Grids[s] = std::move(M);
} }
// Add a named grid set // Add a named grid set, 4d shortcut
void AddFourDimGrid(std::string s) { void AddFourDimGrid(std::string s) {
GridFourDimModule Mod; GridFourDimModule Mod;
AddGrid(s, Mod); AddGrid(s, Mod);
} }
GridCartesian* GetCartesian(std::string s = "") { GridCartesian* GetCartesian(std::string s = "") {
if (s.empty()) s = Grids.begin()->first; if (s.empty()) s = Grids.begin()->first;
std::cout << GridLogDebug << "Getting cartesian grid from: " << s std::cout << GridLogDebug << "Getting cartesian grid from: " << s
@ -114,6 +125,10 @@ class HMCResourceManager{
return Grids[s].get_rb(); return Grids[s].get_rb();
} }
//////////////////////////////////////////////////////
// Random number generators
//////////////////////////////////////////////////////
void AddRNGs(std::string s = "") { void AddRNGs(std::string s = "") {
// Couple the RNGs to the GridModule tagged by s // Couple the RNGs to the GridModule tagged by s
// the default is the first grid registered // the default is the first grid registered
@ -124,11 +139,10 @@ class HMCResourceManager{
have_RNG = true; have_RNG = true;
} }
void AddRNGSeeds(const std::vector<int> S, const std::vector<int> P) { void SetRNGSeeds(RNGModuleParameters& Params) { RNGs.set_RNGSeeds(Params); }
RNGs.set_RNGSeeds(S, P);
}
GridSerialRNG& GetSerialRNG() { return RNGs.get_sRNG(); } GridSerialRNG& GetSerialRNG() { return RNGs.get_sRNG(); }
GridParallelRNG& GetParallelRNG() { GridParallelRNG& GetParallelRNG() {
assert(have_RNG); assert(have_RNG);
return RNGs.get_pRNG(); return RNGs.get_pRNG();
@ -139,17 +153,16 @@ class HMCResourceManager{
RNGs.seed(); RNGs.seed();
} }
////////////////////////////////////////////////////// //////////////////////////////////////////////////////
// Checkpointers // Checkpointers
////////////////////////////////////////////////////// //////////////////////////////////////////////////////
BaseHmcCheckpointer<ImplementationPolicy>* get_CheckPointer(){ BaseHmcCheckpointer<ImplementationPolicy>* GetCheckPointer() {
if (have_CheckPointer) if (have_CheckPointer)
return CP.get_CheckPointer(); return CP.get_CheckPointer();
else { else {
std::cout << GridLogError << "Error: no checkpointer defined" << std::endl; std::cout << GridLogError << "Error: no checkpointer defined"
<< std::endl;
exit(1); exit(1);
} }
} }
@ -158,9 +171,11 @@ class HMCResourceManager{
RegisterLoadCheckPointerFunction(Nersc); RegisterLoadCheckPointerFunction(Nersc);
RegisterLoadCheckPointerFunction(ILDG); RegisterLoadCheckPointerFunction(ILDG);
/*
RegisterLoadCheckPointerReaderFunction(Binary); RegisterLoadCheckPointerReaderFunction(Binary);
RegisterLoadCheckPointerReaderFunction(Nersc); RegisterLoadCheckPointerReaderFunction(Nersc);
RegisterLoadCheckPointerReaderFunction(ILDG); RegisterLoadCheckPointerReaderFunction(ILDG);
*/
}; };
} }

View File

@ -42,14 +42,18 @@ namespace Grid {
CheckpointerParameters(std::string cf = "cfg", std::string rn = "rng", CheckpointerParameters(std::string cf = "cfg", std::string rn = "rng",
int savemodulo = 1, const std::string &f = "IEEE64BIG") int savemodulo = 1, const std::string &f = "IEEE64BIG")
: config_prefix(cf), rng_prefix(rn), saveInterval(savemodulo), format(f){}; : config_prefix(cf),
rng_prefix(rn),
saveInterval(savemodulo),
format(f){};
template<class ReaderClass> template <class ReaderClass, typename std::enable_if< isReader<ReaderClass>::value, int>::type = 0 >
CheckpointerParameters(ReaderClass &Reader) { CheckpointerParameters(ReaderClass &Reader) {
read(Reader, "Checkpointer", *this); read(Reader, "Checkpointer", *this);
} }
}; };
////////////////////////////////////////////////////////////////////////////// //////////////////////////////////////////////////////////////////////////////

View File

@ -37,29 +37,30 @@ directory
namespace Grid { namespace Grid {
namespace QCD { namespace QCD {
struct IntegratorParameters { class IntegratorParameters: Serializable {
unsigned int MDsteps; // number of outer steps public:
RealD trajL; // trajectory length GRID_SERIALIZABLE_CLASS_MEMBERS(IntegratorParameters,
RealD stepsize; // trajectory stepsize std::string, name, // name of the integrator
unsigned int, MDsteps, // number of outer steps
RealD, trajL, // trajectory length
)
IntegratorParameters(int MDsteps_ = 10, RealD trajL_ = 1.0) IntegratorParameters(int MDsteps_ = 10, RealD trajL_ = 1.0)
: MDsteps(MDsteps_), : MDsteps(MDsteps_),
trajL(trajL_), trajL(trajL_){
stepsize(trajL / MDsteps){
// empty body constructor // empty body constructor
}; };
void set(int MDsteps_, RealD trajL_){
MDsteps = MDsteps_;
trajL = trajL_;
stepsize = trajL/MDsteps;
}
template <class ReaderClass, typename std::enable_if<isReader<ReaderClass>::value, int >::type = 0 >
IntegratorParameters(ReaderClass & Reader){
read(Reader, "Integrator", *this);
}
void print_parameters() { void print_parameters() {
std::cout << GridLogMessage << "[Integrator] Trajectory length : " << trajL << std::endl; std::cout << GridLogMessage << "[Integrator] Trajectory length : " << trajL << std::endl;
std::cout << GridLogMessage << "[Integrator] Number of MD steps : " << MDsteps << std::endl; std::cout << GridLogMessage << "[Integrator] Number of MD steps : " << MDsteps << std::endl;
std::cout << GridLogMessage << "[Integrator] Step size : " << stepsize << std::endl; std::cout << GridLogMessage << "[Integrator] Step size : " << trajL/MDsteps << std::endl;
} }
}; };

View File

@ -117,7 +117,7 @@ class LeapFrog : public Integrator<FieldImplementation, SmearingPolicy,
// eps : current step size // eps : current step size
// Get current level step size // Get current level step size
RealD eps = this->Params.stepsize; RealD eps = this->Params.trajL/this->Params.MDsteps;
for (int l = 0; l <= level; ++l) eps /= this->as[l].multiplier; for (int l = 0; l <= level; ++l) eps /= this->as[l].multiplier;
int multiplier = this->as[level].multiplier; int multiplier = this->as[level].multiplier;
@ -166,7 +166,7 @@ class MinimumNorm2 : public Integrator<FieldImplementation, SmearingPolicy,
int fl = this->as.size() - 1; int fl = this->as.size() - 1;
RealD eps = this->Params.stepsize * 2.0; RealD eps = this->Params.trajL/this->Params.MDsteps * 2.0;
for (int l = 0; l <= level; ++l) eps /= 2.0 * this->as[l].multiplier; for (int l = 0; l <= level; ++l) eps /= 2.0 * this->as[l].multiplier;
// Nesting: 2xupdate_U of size eps/2 // Nesting: 2xupdate_U of size eps/2
@ -247,7 +247,7 @@ class ForceGradient : public Integrator<FieldImplementation, SmearingPolicy,
} }
void step(Field& U, int level, int _first, int _last) { void step(Field& U, int level, int _first, int _last) {
RealD eps = this->Params.stepsize * 2.0; RealD eps = this->Params.trajL/this->Params.MDsteps * 2.0;
for (int l = 0; l <= level; ++l) eps /= 2.0 * this->as[l].multiplier; for (int l = 0; l <= level; ++l) eps /= 2.0 * this->as[l].multiplier;
RealD Chi = chi * eps * eps * eps; RealD Chi = chi * eps * eps * eps;

100
lib/qcd/modules/Factory.h Normal file
View File

@ -0,0 +1,100 @@
/*************************************************************************************
Grid physics library, www.github.com/paboyle/Grid
Source file: extras/Hadrons/Factory.hpp
Copyright (C) 2015
Copyright (C) 2016
Author: Antonin Portelli <antonin.portelli@me.com>
This program is free software; you can redistribute it and/or modify
it under the terms of the GNU General Public License as published by
the Free Software Foundation; either version 2 of the License, or
(at your option) any later version.
This program is distributed in the hope that it will be useful,
but WITHOUT ANY WARRANTY; without even the implied warranty of
MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
GNU General Public License for more details.
You should have received a copy of the GNU General Public License along
with this program; if not, write to the Free Software Foundation, Inc.,
51 Franklin Street, Fifth Floor, Boston, MA 02110-1301 USA.
See the full license in the file "LICENSE" in the top level distribution directory
*************************************************************************************/
/* END LEGAL */
#ifndef Factory_hpp_
#define Factory_hpp_
namespace Grid{
/******************************************************************************
* abstract factory class *
******************************************************************************/
template <typename T, typename ProductCreator>
class Factory
{
public:
typedef std::function< std::unique_ptr<T> (const ProductCreator&)> Func;
public:
// constructor
Factory(void) = default;
// destructor
virtual ~Factory(void) = default;
// registration
void registerBuilder(const std::string type, const Func &f);
// get builder list
std::vector<std::string> getBuilderList(void) const;
// factory
std::unique_ptr<T> create(const std::string type,
const ProductCreator& name) const;
private:
std::map<std::string, Func> builder_;
};
/******************************************************************************
* template implementation *
******************************************************************************/
// registration ////////////////////////////////////////////////////////////////
template <typename T, typename ProductCreator>
void Factory<T, ProductCreator>::registerBuilder(const std::string type, const Func &f)
{
builder_[type] = f;
}
// get module list /////////////////////////////////////////////////////////////
template <typename T, typename ProductCreator>
std::vector<std::string> Factory<T, ProductCreator>::getBuilderList(void) const
{
std::vector<std::string> list;
for (auto &b: builder_)
{
list.push_back(b.first);
}
return list;
}
// factory /////////////////////////////////////////////////////////////////////
template <typename T, typename ProductCreator>
std::unique_ptr<T> Factory<T, ProductCreator>::create(const std::string type,
const ProductCreator& name) const
{
Func func;
try
{
func = builder_.at(type);
}
catch (std::out_of_range &)
{
//HADRON_ERROR("object of type '" + type + "' unknown");
}
return func(name);
}
}
#endif // Factory_hpp_

198
lib/qcd/modules/Modules.h Normal file
View File

@ -0,0 +1,198 @@
/*************************************************************************************
Grid physics library, www.github.com/paboyle/Grid
Source file: ./lib/qcd/action/gauge/WilsonGaugeAction.h
Copyright (C) 2016
Author: Guido Cossu <guido.cossu@ed.ac.uk>
This program is free software; you can redistribute it and/or modify
it under the terms of the GNU General Public License as published by
the Free Software Foundation; either version 2 of the License, or
(at your option) any later version.
This program is distributed in the hope that it will be useful,
but WITHOUT ANY WARRANTY; without even the implied warranty of
MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
GNU General Public License for more details.
You should have received a copy of the GNU General Public License along
with this program; if not, write to the Free Software Foundation, Inc.,
51 Franklin Street, Fifth Floor, Boston, MA 02110-1301 USA.
See the full license in the file "LICENSE" in the top level distribution
directory
*************************************************************************************/
/* END LEGAL */
#ifndef HMC_MODULES_H
#define HMC_MODULES_H
/*
Define loadable, serializable modules
for the HMC execution
*/
namespace Grid {
/*
Base class for modules with parameters
*/
template < class P >
class Parametrized{
public:
typedef P Parameters;
Parametrized(Parameters Par):Par_(Par){};
template <class ReaderClass>
Parametrized(Reader<ReaderClass> & Reader){
read(Reader, section_name(), Par_);
}
protected:
Parameters Par_;
private:
// identifies the section name
// override in derived classes if needed
virtual std::string section_name(){
return std::string("parameters"); //default
}
};
/*
Lowest level abstract module class
*/
template < class Prod >
class HMCModuleBase{
public:
typedef Prod Product;
virtual Prod* getPtr() = 0;
};
//////////////////////////////////////////////
// Actions
//////////////////////////////////////////////
template <class ActionType, class APar>
class ActionModule
: public Parametrized<APar>,
public HMCModuleBase<QCD::Action<typename ActionType::GaugeField> > {
public:
typedef HMCModuleBase< QCD::Action<typename ActionType::GaugeField> > Base;
typedef typename Base::Product Product;
std::unique_ptr<ActionType> ActionPtr;
ActionModule(APar Par) : Parametrized<APar>(Par) {}
template <class ReaderClass>
ActionModule(Reader<ReaderClass>& Reader) : Parametrized<APar>(Reader){};
Product* getPtr() {
if (!ActionPtr) initialize();
return ActionPtr.get();
}
private:
virtual void initialize() = 0;
};
namespace QCD{
class WilsonGaugeActionParameters : Serializable {
public:
GRID_SERIALIZABLE_CLASS_MEMBERS(WilsonGaugeActionParameters,
RealD, beta);
};
template<class Impl>
class WilsonGModule: public ActionModule<WilsonGaugeAction<Impl>, WilsonGaugeActionParameters> {
typedef ActionModule<WilsonGaugeAction<Impl>, WilsonGaugeActionParameters> ActionBase;
using ActionBase::ActionBase;
// acquire resource
virtual void initialize(){
ActionBase::ActionPtr.reset(new WilsonGaugeAction<Impl>(ActionBase::Par_.beta));
}
};
typedef WilsonGModule<PeriodicGimplR> WilsonGMod;
}// QCD temporarily here
// use the same classed defined by Antonin, does not make sense to rewrite
// Factory is perfectly fine
// Registar must be changed because I do not want to use the ModuleFactory
/*
define
*/
typedef HMCModuleBase< QCD::Action< QCD::LatticeGaugeField > > HMCModBase;
template <class ReaderClass >
class HMCActionModuleFactory
: public Factory < HMCModBase , Reader<ReaderClass> > {
public:
typedef Reader<ReaderClass> TheReader;
// use SINGLETON FUNCTOR MACRO HERE
HMCActionModuleFactory(const HMCActionModuleFactory& e) = delete;
void operator=(const HMCActionModuleFactory& e) = delete;
static HMCActionModuleFactory& getInstance(void) {
static HMCActionModuleFactory e;
return e;
}
private:
HMCActionModuleFactory(void) = default;
};
/*
then rewrite the registar
when this is done we have all the modules that contain the pointer to the objects
(actions, integrators, checkpointers, solvers)
factory will create only the modules and prepare the parameters
when needed a pointer is released
*/
template <class T, class TheFactory>
class Registrar {
public:
Registrar(std::string className) {
// register the class factory function
TheFactory::getInstance().registerBuilder(className, [&](typename TheFactory::TheReader Reader)
{ return std::unique_ptr<T>(new T(Reader));});
}
};
Registrar<QCD::WilsonGMod, HMCActionModuleFactory<XmlReader> > __WGmodInit("WilsonGaugeAction");
}
#endif //HMC_MODULES_H

38
lib/qcd/modules/mods.h Normal file
View File

@ -0,0 +1,38 @@
/*************************************************************************************
Grid physics library, www.github.com/paboyle/Grid
Source file: ./lib/qcd/action/gauge/WilsonGaugeAction.h
Copyright (C) 2016
Author: Guido Cossu <guido.cossu@ed.ac.uk>
This program is free software; you can redistribute it and/or modify
it under the terms of the GNU General Public License as published by
the Free Software Foundation; either version 2 of the License, or
(at your option) any later version.
This program is distributed in the hope that it will be useful,
but WITHOUT ANY WARRANTY; without even the implied warranty of
MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
GNU General Public License for more details.
You should have received a copy of the GNU General Public License along
with this program; if not, write to the Free Software Foundation, Inc.,
51 Franklin Street, Fifth Floor, Boston, MA 02110-1301 USA.
See the full license in the file "LICENSE" in the top level distribution
directory
*************************************************************************************/
/* END LEGAL */
#ifndef MODS_H
#define MODS_H
// Modules files
#include <Grid/qcd/modules/Factory.h>
#include <Grid/qcd/modules/Modules.h>
#endif //MODS_H

View File

@ -69,6 +69,9 @@ namespace Grid {
class Serializable {}; class Serializable {};
// static polymorphism implemented using CRTP idiom // static polymorphism implemented using CRTP idiom
// Static abstract writer // Static abstract writer
@ -122,10 +125,18 @@ namespace Grid {
T *upcast; T *upcast;
}; };
// type traits
// What is the vtype
template<typename T> struct isReader {
static const bool value = false;
};
template<typename T> struct isWriter {
static const bool value = false;
};
// Generic writer interface // Generic writer interface
template <typename T> template <typename T>
inline void push(Writer<T> &w, const std::string &s) inline void push(Writer<T> &w, const std::string &s) {
{
w.push(s); w.push(s);
} }

View File

@ -79,6 +79,18 @@ namespace Grid
std::string fileName_; std::string fileName_;
}; };
template <>
struct isReader< XmlReader > {
static const bool value = true;
};
template <>
struct isWriter<XmlWriter > {
static const bool value = true;
};
// Writer template implementation //////////////////////////////////////////// // Writer template implementation ////////////////////////////////////////////
template <typename U> template <typename U>
void XmlWriter::writeDefault(const std::string &s, const U &x) void XmlWriter::writeDefault(const std::string &s, const U &x)

View File

@ -30,62 +30,59 @@ directory
/* END LEGAL */ /* END LEGAL */
#include <Grid/Grid.h> #include <Grid/Grid.h>
namespace Grid {
namespace QCD {
//Change here the type of reader
typedef Grid::XmlReader InputFileReader;
class HMCRunnerParameters : Serializable {
public:
GRID_SERIALIZABLE_CLASS_MEMBERS(HMCRunnerParameters,
double, beta,
int, MDsteps,
double, TrajectoryLength,
std::string, serial_seeds,
std::string, parallel_seeds,
);
HMCRunnerParameters() {}
};
}
}
int main(int argc, char **argv) { int main(int argc, char **argv) {
using namespace Grid; using namespace Grid;
using namespace Grid::QCD; using namespace Grid::QCD;
Grid_init(&argc, &argv); Grid_init(&argc, &argv);
int threads = GridThread::GetThreads(); int threads = GridThread::GetThreads();
// Typedefs to simplify notation
typedef GenericHMCRunner<MinimumNorm2> HMCWrapper; // Uses the default minimum norm
// here make a routine to print all the relevant information on the run // here make a routine to print all the relevant information on the run
std::cout << GridLogMessage << "Grid is setup to use " << threads << " threads" << std::endl; std::cout << GridLogMessage << "Grid is setup to use " << threads << " threads" << std::endl;
////////////////////////////////////////////////////////////// // Typedefs to simplify notation
// Input file section typedef GenericHMCRunner<MinimumNorm2> HMCWrapper; // Uses the default minimum norm
// make input file name general typedef Grid::XmlReader InputFileReader;
// now working with the text reader but I should drop this support
// i need a structured format where every object is able
// to locate the required data: XML, JSON, YAML.
InputFileReader Reader("input.wilson_gauge.params.xml");
HMCRunnerParameters HMCPar;
read(Reader, "HMC", HMCPar);
// Seeds for the random number generators // Reader, now not necessary
// generalise, ugly now InputFileReader Reader("input.wilson_gauge.params.xml");
std::vector<int> SerSeed = strToVec<int>(HMCPar.serial_seeds);
std::vector<int> ParSeed = strToVec<int>(HMCPar.parallel_seeds);
HMCWrapper TheHMC; HMCWrapper TheHMC;
// Grid from the command line
TheHMC.Resources.AddFourDimGrid("gauge"); TheHMC.Resources.AddFourDimGrid("gauge");
// here using the Reader but an overloaded function to pass the // Grid from the Reader
// parameters class is provided /*
TheHMC.Resources.LoadBinaryCheckpointer(Reader); GridModuleParameters GridPar(Reader);
GridFourDimModule GridMod( GridPar) ;
TheHMC.Resources.AddGrid("gauge", GridMod);
*/
// Checkpointer definition
CheckpointerParameters CPparams;
CPparams.config_prefix = "ckpoint_lat";
CPparams.rng_prefix = "ckpoint_rng";
CPparams.saveInterval = 5;
CPparams.format = "IEEE64BIG";
// can also use the reader constructor
// CheckpointerParameters CPparams(Reader);
TheHMC.Resources.LoadBinaryCheckpointer(CPparams);
// Fill resources
// Seeds for the random number generators
// Can also initialize using the Reader
RNGModuleParameters RNGpar;
RNGpar.SerialSeed = {1,2,3,4,5};
RNGpar.ParallelSeed = {6,7,8,9,10};
TheHMC.Resources.SetRNGSeeds(RNGpar);
// Construct observables
// here there is too much indirection
PlaquetteLogger<HMCWrapper::ImplPolicy> PlaqLog("Plaquette");
TheHMC.ObservablesList.push_back(&PlaqLog);
//////////////////////////////////////////////
///////////////////////////////////////////////////////////// /////////////////////////////////////////////////////////////
// Collect actions, here use more encapsulation // Collect actions, here use more encapsulation
@ -93,23 +90,30 @@ int main(int argc, char **argv) {
// that have a complex construction // that have a complex construction
// Gauge action // Gauge action
WilsonGaugeActionR Waction(HMCPar.beta); // as module
/*
WilsonGMod::Parameters WPar;
WPar.beta = 6.0;
WilsonGMod WGMod(WPar);
auto testAction = WGMod.getPtr();// test to pass to the action set
HMCModuleBase<Action<LatticeGaugeField>>* HMB = &WGMod;
*/
// standard
RealD beta = 5.6 ;
WilsonGaugeActionR Waction(beta);
ActionLevel<HMCWrapper::Field> Level1(1); ActionLevel<HMCWrapper::Field> Level1(1);
Level1.push_back(&Waction); Level1.push_back(&Waction);
//Level1.push_back(WGMod.getPtr());
TheHMC.TheAction.push_back(Level1); TheHMC.TheAction.push_back(Level1);
///////////////////////////////////////////////////////////// /////////////////////////////////////////////////////////////
// Construct observables // Nest MDparameters in the HMCparameters->HMCPayload
PlaquetteLogger<HMCWrapper::ImplPolicy> PlaqLog("Plaquette"); // make it serializable
TheHMC.ObservablesList.push_back(&PlaqLog); TheHMC.MDparameters.MDsteps = 20;
////////////////////////////////////////////// TheHMC.MDparameters.trajL = 1.0;
// Fill resources
// here we can simplify a lot if the input file is structured
// just pass the input file reader
TheHMC.Resources.AddRNGSeeds(SerSeed, ParSeed);
TheHMC.MDparameters.set(HMCPar.MDsteps, HMCPar.TrajectoryLength);
// eventually smearing here // eventually smearing here
// ... // ...