1
0
mirror of https://github.com/paboyle/Grid.git synced 2024-11-09 23:45:36 +00:00

Adding factories

This commit is contained in:
Guido Cossu 2017-01-16 10:18:09 +00:00
parent 0dfda4bb90
commit c6f59c2933
15 changed files with 583 additions and 171 deletions

View File

@ -492,7 +492,6 @@ namespace QCD {
} //namespace QCD
} // Grid
#include <Grid/qcd/utils/SpaceTimeGrid.h>
#include <Grid/qcd/spin/Dirac.h>
#include <Grid/qcd/spin/TwoSpinor.h>
@ -517,5 +516,6 @@ namespace QCD {
#include <Grid/qcd/hmc/HMC.h>
#include <Grid/qcd/modules/mods.h>
#endif

View File

@ -10,6 +10,7 @@ Author: Azusa Yamaguchi <ayamaguc@staffmail.ed.ac.uk>
Author: Peter Boyle <paboyle@ph.ed.ac.uk>
Author: neo <cossu@post.kek.jp>
Author: paboyle <paboyle@ph.ed.ac.uk>
Author: Guido Cossu <guido.cossu@ed.ac.uk>
This program is free software; you can redistribute it and/or modify
it under the terms of the GNU General Public License as published by
@ -40,21 +41,18 @@ namespace QCD {
////////////////////////////////////////////////////////////////////////
template <class Gimpl>
class WilsonGaugeAction : public Action<typename Gimpl::GaugeField> {
public:
public:
INHERIT_GIMPL_TYPES(Gimpl);
private:
RealD beta;
public:
explicit WilsonGaugeAction(RealD b) : beta(b){}
/////////////////////////// constructors
explicit WilsonGaugeAction(RealD beta_):beta(beta_){};
virtual std::string action_name() {return "WilsonGaugeAction";}
virtual std::string LogParameters(){
std::stringstream sstream;
sstream << GridLogMessage << "[WilsonGaugeAction] Beta: " << beta << std::endl;
return sstream.str();
std::stringstream sstream;
sstream << GridLogMessage << "[WilsonGaugeAction] Beta: " << beta << std::endl;
return sstream.str();
}
virtual void refresh(const GaugeField &U,
@ -85,9 +83,12 @@ class WilsonGaugeAction : public Action<typename Gimpl::GaugeField> {
PokeIndex<LorentzIndex>(dSdU, dSdU_mu, mu);
}
}
private:
RealD beta;
};
}
}

View File

@ -151,7 +151,7 @@ class HMCWrapperTemplate {
Implementation::TepidConfiguration(Resources.GetParallelRNG(), U);
} else if (Payload.StartType == CheckpointStart) {
// CheckpointRestart
Resources.get_CheckPointer()->CheckpointRestore(Payload.Parameters.StartTrajectory, U,
Resources.GetCheckPointer()->CheckpointRestore(Payload.Parameters.StartTrajectory, U,
Resources.GetSerialRNG(),
Resources.GetParallelRNG());
}
@ -164,7 +164,7 @@ class HMCWrapperTemplate {
for (int obs = 0; obs < ObservablesList.size(); obs++)
HMC.AddObservable(ObservablesList[obs]);
HMC.AddObservable(Resources.get_CheckPointer());
HMC.AddObservable(Resources.GetCheckPointer());
// Run it

View File

@ -50,6 +50,8 @@ struct HMCparameters {
bool MetropolisTest;
Integer NoMetropolisUntil;
// nest here the MDparameters and make all serializable
HMCparameters() {
////////////////////////////// Default values
MetropolisTest = true;

View File

@ -33,8 +33,42 @@ directory
namespace Grid {
namespace QCD {
///////////////////////////////////////////////////
// Modules
class GridModuleParameters: Serializable{
GRID_SERIALIZABLE_CLASS_MEMBERS(GridModuleParameters,
std::string, lattice,
std::string, mpi);
public:
// these namings are ugly
// also ugly the distinction between the serializable members
// and this
std::vector<int> lattice_v;
std::vector<int> mpi_v;
GridModuleParameters(const std::vector<int> l_ = std::vector<int>(),
const std::vector<int> mpi_ = std::vector<int>()):lattice_v(l_), mpi_v(mpi_){}
template <class ReaderClass>
GridModuleParameters(Reader<ReaderClass>& Reader) {
read(Reader, "LatticeGrid", *this);
lattice_v = strToVec<int>(lattice);
mpi_v = strToVec<int>(mpi);
if (mpi_v.size() != lattice_v.size()) {
std::cout << "Error in GridModuleParameters: lattice and mpi dimensions "
"do not match"
<< std::endl;
exit(1);
}
}
};
class GridModule {
public:
GridCartesian* get_full() { return grid_.get(); }
@ -46,11 +80,13 @@ class GridModule {
protected:
std::unique_ptr<GridCartesian> grid_;
std::unique_ptr<GridRedBlackCartesian> rbgrid_;
};
// helpers
class GridFourDimModule : public GridModule {
public:
// add a function to create the module from a Reader
GridFourDimModule() {
set_full(SpaceTimeGrid::makeFourDimGrid(
GridDefaultLatt(), GridDefaultSimd(4, vComplex::Nsimd()),
@ -58,50 +94,53 @@ class GridFourDimModule : public GridModule {
set_rb(SpaceTimeGrid::makeFourDimRedBlackGrid(grid_.get()));
}
template <class vector_type = vComplex>
GridFourDimModule(GridModuleParameters Params) {
if (Params.lattice_v.size() == 4) {
set_full(SpaceTimeGrid::makeFourDimGrid(
Params.lattice_v, GridDefaultSimd(4, vector_type::Nsimd()),
Params.mpi_v));
set_rb(SpaceTimeGrid::makeFourDimRedBlackGrid(grid_.get()));
} else {
std::cout
<< "Error in GridFourDimModule: lattice dimension different from 4"
<< std::endl;
exit(1);
}
}
};
////////////////////////////////////////////////////////////////////
class RNGModuleParameters: Serializable {
public:
GRID_SERIALIZABLE_CLASS_MEMBERS(RNGModuleParameters,
std::vector<int>, SerialSeed_,
std::vector<int>, ParallelSeed_,);
std::string, serial_seeds,
std::string, parallel_seeds,);
public:
std::vector<int> SerialSeed;
std::vector<int> ParallelSeed;
RNGModuleParameters(const std::vector<int> S = std::vector<int>(),
const std::vector<int> P = std::vector<int>())
: SerialSeed(S), ParallelSeed(P) {}
// default constructor, needed for the non-Reader
// construction of the module
RNGModuleParameters(){
SerialSeed_.resize(0);
ParallelSeed_.resize(0);
template <class ReaderClass >
RNGModuleParameters(Reader<ReaderClass>& Reader){
read(Reader, "RandomNumberGenerator", *this);
SerialSeed = strToVec<int>(serial_seeds);
ParallelSeed = strToVec<int>(parallel_seeds);
}
RNGModuleParameters(const std::vector<int> S, const std::vector<int> P){
set_RNGSeeds(S,P);
}
template < class ReaderClass >
RNGModuleParameters(ReaderClass &Reader){
read(Reader, "RandomNumberGenerator", *this);
}
void set_RNGSeeds(const std::vector<int>& S, const std::vector<int>& P){
SerialSeed_ = S;
ParallelSeed_ = P;
}
};
// Random number generators module
class RNGModule{
// Random number generators
GridSerialRNG sRNG_;
std::unique_ptr<GridParallelRNG> pRNG_;
RNGModuleParameters Params_;
public:
template < class ReaderClass >
RNGModule(ReaderClass &Reader):Params_(Reader){};
RNGModule(){};
@ -109,36 +148,23 @@ public:
pRNG_.reset(pRNG);
}
void set_RNGSeeds(const std::vector<int> S, const std::vector<int> P) {
Params_.set_RNGSeeds(S,P);
void set_RNGSeeds(RNGModuleParameters& Params) {
Params_ = Params;
}
GridSerialRNG& get_sRNG() { return sRNG_; }
GridParallelRNG& get_pRNG() { return *pRNG_.get(); }
GridSerialRNG& get_sRNG(){
if (Params_.SerialSeed_.size()==0){
std::cout << "Serial seeds not initialized" << std::endl;
void seed() {
if (Params_.SerialSeed.size() == 0 && Params_.ParallelSeed.size() == 0) {
std::cout << "Seeds not initialized" << std::endl;
exit(1);
}
return sRNG_;
}
GridParallelRNG& get_pRNG(){
if (Params_.ParallelSeed_.size()==0){
std::cout << "Parallel seeds not initialized" << std::endl;
exit(1);
}
return *pRNG_.get();
}
void seed(){
sRNG_.SeedFixedIntegers(Params_.SerialSeed_);
pRNG_->SeedFixedIntegers(Params_.ParallelSeed_);
sRNG_.SeedFixedIntegers(Params_.SerialSeed);
pRNG_->SeedFixedIntegers(Params_.ParallelSeed);
}
};
///////////////////////////////////////////////////////////////////
/// Smearing module
template <class ImplementationPolicy>

View File

@ -48,6 +48,7 @@ with this program; if not, write to the Free Software Foundation, Inc.,
} \
}
/*
// One function per Checkpointer using the reader, use a macro to simplify
#define RegisterLoadCheckPointerReaderFunction(NAME) \
template <class Reader> \
@ -64,18 +65,19 @@ with this program; if not, write to the Free Software Foundation, Inc.,
exit(1); \
} \
}
*/
namespace Grid {
namespace QCD {
// HMC Resource manager
template <class ImplementationPolicy>
class HMCResourceManager{
// Storage for grid pairs (std + red-black)
template <class ImplementationPolicy>
class HMCResourceManager {
// Named storage for grid pairs (std + red-black)
std::unordered_map<std::string, GridModule> Grids;
RNGModule RNGs;
//SmearingModule<ImplementationPolicy> Smearing;
// SmearingModule<ImplementationPolicy> Smearing;
CheckPointModule<ImplementationPolicy> CP;
bool have_RNG;
@ -83,6 +85,13 @@ class HMCResourceManager{
public:
HMCResourceManager() : have_RNG(false), have_CheckPointer(false) {}
// Here need a constructor for using the Reader class
//////////////////////////////////////////////////////////////
// Grids
//////////////////////////////////////////////////////////////
void AddGrid(std::string s, GridModule& M) {
// Check for name clashes
auto search = Grids.find(s);
@ -94,12 +103,14 @@ class HMCResourceManager{
Grids[s] = std::move(M);
}
// Add a named grid set
// Add a named grid set, 4d shortcut
void AddFourDimGrid(std::string s) {
GridFourDimModule Mod;
AddGrid(s, Mod);
}
GridCartesian* GetCartesian(std::string s = "") {
if (s.empty()) s = Grids.begin()->first;
std::cout << GridLogDebug << "Getting cartesian grid from: " << s
@ -114,6 +125,10 @@ class HMCResourceManager{
return Grids[s].get_rb();
}
//////////////////////////////////////////////////////
// Random number generators
//////////////////////////////////////////////////////
void AddRNGs(std::string s = "") {
// Couple the RNGs to the GridModule tagged by s
// the default is the first grid registered
@ -124,43 +139,43 @@ class HMCResourceManager{
have_RNG = true;
}
void AddRNGSeeds(const std::vector<int> S, const std::vector<int> P) {
RNGs.set_RNGSeeds(S, P);
}
void SetRNGSeeds(RNGModuleParameters& Params) { RNGs.set_RNGSeeds(Params); }
GridSerialRNG& GetSerialRNG() { return RNGs.get_sRNG(); }
GridParallelRNG& GetParallelRNG() {
assert(have_RNG);
return RNGs.get_pRNG();
}
void SeedFixedIntegers() {
assert(have_RNG);
RNGs.seed();
}
//////////////////////////////////////////////////////
// Checkpointers
//////////////////////////////////////////////////////
BaseHmcCheckpointer<ImplementationPolicy>* get_CheckPointer(){
BaseHmcCheckpointer<ImplementationPolicy>* GetCheckPointer() {
if (have_CheckPointer)
return CP.get_CheckPointer();
else{
std::cout << GridLogError << "Error: no checkpointer defined" << std::endl;
return CP.get_CheckPointer();
else {
std::cout << GridLogError << "Error: no checkpointer defined"
<< std::endl;
exit(1);
}
}
RegisterLoadCheckPointerFunction (Binary);
RegisterLoadCheckPointerFunction (Nersc);
RegisterLoadCheckPointerFunction (ILDG);
RegisterLoadCheckPointerFunction(Binary);
RegisterLoadCheckPointerFunction(Nersc);
RegisterLoadCheckPointerFunction(ILDG);
RegisterLoadCheckPointerReaderFunction (Binary);
RegisterLoadCheckPointerReaderFunction (Nersc);
RegisterLoadCheckPointerReaderFunction (ILDG);
/*
RegisterLoadCheckPointerReaderFunction(Binary);
RegisterLoadCheckPointerReaderFunction(Nersc);
RegisterLoadCheckPointerReaderFunction(ILDG);
*/
};
}

View File

@ -30,27 +30,31 @@ directory
#define BASE_CHECKPOINTER
namespace Grid {
namespace QCD {
namespace QCD {
class CheckpointerParameters : Serializable {
public:
GRID_SERIALIZABLE_CLASS_MEMBERS(CheckpointerParameters,
std::string, config_prefix,
std::string, rng_prefix,
int, saveInterval,
std::string, format, );
class CheckpointerParameters : Serializable {
public:
GRID_SERIALIZABLE_CLASS_MEMBERS(CheckpointerParameters,
std::string, config_prefix,
std::string, rng_prefix,
int, saveInterval,
std::string, format, );
CheckpointerParameters(std::string cf = "cfg", std::string rn = "rng",
int savemodulo = 1, const std::string &f = "IEEE64BIG")
: config_prefix(cf), rng_prefix(rn), saveInterval(savemodulo), format(f){};
CheckpointerParameters(std::string cf = "cfg", std::string rn = "rng",
int savemodulo = 1, const std::string &f = "IEEE64BIG")
: config_prefix(cf),
rng_prefix(rn),
saveInterval(savemodulo),
format(f){};
template<class ReaderClass>
CheckpointerParameters(ReaderClass &Reader){
read(Reader, "Checkpointer", *this);
}
template <class ReaderClass, typename std::enable_if< isReader<ReaderClass>::value, int>::type = 0 >
CheckpointerParameters(ReaderClass &Reader) {
read(Reader, "Checkpointer", *this);
}
};
};
//////////////////////////////////////////////////////////////////////////////
// Base class for checkpointers

View File

@ -37,29 +37,30 @@ directory
namespace Grid {
namespace QCD {
struct IntegratorParameters {
unsigned int MDsteps; // number of outer steps
RealD trajL; // trajectory length
RealD stepsize; // trajectory stepsize
class IntegratorParameters: Serializable {
public:
GRID_SERIALIZABLE_CLASS_MEMBERS(IntegratorParameters,
std::string, name, // name of the integrator
unsigned int, MDsteps, // number of outer steps
RealD, trajL, // trajectory length
)
IntegratorParameters(int MDsteps_ = 10, RealD trajL_ = 1.0)
: MDsteps(MDsteps_),
trajL(trajL_),
stepsize(trajL / MDsteps){
// empty body constructor
trajL(trajL_){
// empty body constructor
};
void set(int MDsteps_, RealD trajL_){
MDsteps = MDsteps_;
trajL = trajL_;
stepsize = trajL/MDsteps;
}
template <class ReaderClass, typename std::enable_if<isReader<ReaderClass>::value, int >::type = 0 >
IntegratorParameters(ReaderClass & Reader){
read(Reader, "Integrator", *this);
}
void print_parameters() {
std::cout << GridLogMessage << "[Integrator] Trajectory length : " << trajL << std::endl;
std::cout << GridLogMessage << "[Integrator] Number of MD steps : " << MDsteps << std::endl;
std::cout << GridLogMessage << "[Integrator] Step size : " << stepsize << std::endl;
std::cout << GridLogMessage << "[Integrator] Step size : " << trajL/MDsteps << std::endl;
}
};

View File

@ -117,7 +117,7 @@ class LeapFrog : public Integrator<FieldImplementation, SmearingPolicy,
// eps : current step size
// Get current level step size
RealD eps = this->Params.stepsize;
RealD eps = this->Params.trajL/this->Params.MDsteps;
for (int l = 0; l <= level; ++l) eps /= this->as[l].multiplier;
int multiplier = this->as[level].multiplier;
@ -166,7 +166,7 @@ class MinimumNorm2 : public Integrator<FieldImplementation, SmearingPolicy,
int fl = this->as.size() - 1;
RealD eps = this->Params.stepsize * 2.0;
RealD eps = this->Params.trajL/this->Params.MDsteps * 2.0;
for (int l = 0; l <= level; ++l) eps /= 2.0 * this->as[l].multiplier;
// Nesting: 2xupdate_U of size eps/2
@ -247,7 +247,7 @@ class ForceGradient : public Integrator<FieldImplementation, SmearingPolicy,
}
void step(Field& U, int level, int _first, int _last) {
RealD eps = this->Params.stepsize * 2.0;
RealD eps = this->Params.trajL/this->Params.MDsteps * 2.0;
for (int l = 0; l <= level; ++l) eps /= 2.0 * this->as[l].multiplier;
RealD Chi = chi * eps * eps * eps;

100
lib/qcd/modules/Factory.h Normal file
View File

@ -0,0 +1,100 @@
/*************************************************************************************
Grid physics library, www.github.com/paboyle/Grid
Source file: extras/Hadrons/Factory.hpp
Copyright (C) 2015
Copyright (C) 2016
Author: Antonin Portelli <antonin.portelli@me.com>
This program is free software; you can redistribute it and/or modify
it under the terms of the GNU General Public License as published by
the Free Software Foundation; either version 2 of the License, or
(at your option) any later version.
This program is distributed in the hope that it will be useful,
but WITHOUT ANY WARRANTY; without even the implied warranty of
MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
GNU General Public License for more details.
You should have received a copy of the GNU General Public License along
with this program; if not, write to the Free Software Foundation, Inc.,
51 Franklin Street, Fifth Floor, Boston, MA 02110-1301 USA.
See the full license in the file "LICENSE" in the top level distribution directory
*************************************************************************************/
/* END LEGAL */
#ifndef Factory_hpp_
#define Factory_hpp_
namespace Grid{
/******************************************************************************
* abstract factory class *
******************************************************************************/
template <typename T, typename ProductCreator>
class Factory
{
public:
typedef std::function< std::unique_ptr<T> (const ProductCreator&)> Func;
public:
// constructor
Factory(void) = default;
// destructor
virtual ~Factory(void) = default;
// registration
void registerBuilder(const std::string type, const Func &f);
// get builder list
std::vector<std::string> getBuilderList(void) const;
// factory
std::unique_ptr<T> create(const std::string type,
const ProductCreator& name) const;
private:
std::map<std::string, Func> builder_;
};
/******************************************************************************
* template implementation *
******************************************************************************/
// registration ////////////////////////////////////////////////////////////////
template <typename T, typename ProductCreator>
void Factory<T, ProductCreator>::registerBuilder(const std::string type, const Func &f)
{
builder_[type] = f;
}
// get module list /////////////////////////////////////////////////////////////
template <typename T, typename ProductCreator>
std::vector<std::string> Factory<T, ProductCreator>::getBuilderList(void) const
{
std::vector<std::string> list;
for (auto &b: builder_)
{
list.push_back(b.first);
}
return list;
}
// factory /////////////////////////////////////////////////////////////////////
template <typename T, typename ProductCreator>
std::unique_ptr<T> Factory<T, ProductCreator>::create(const std::string type,
const ProductCreator& name) const
{
Func func;
try
{
func = builder_.at(type);
}
catch (std::out_of_range &)
{
//HADRON_ERROR("object of type '" + type + "' unknown");
}
return func(name);
}
}
#endif // Factory_hpp_

198
lib/qcd/modules/Modules.h Normal file
View File

@ -0,0 +1,198 @@
/*************************************************************************************
Grid physics library, www.github.com/paboyle/Grid
Source file: ./lib/qcd/action/gauge/WilsonGaugeAction.h
Copyright (C) 2016
Author: Guido Cossu <guido.cossu@ed.ac.uk>
This program is free software; you can redistribute it and/or modify
it under the terms of the GNU General Public License as published by
the Free Software Foundation; either version 2 of the License, or
(at your option) any later version.
This program is distributed in the hope that it will be useful,
but WITHOUT ANY WARRANTY; without even the implied warranty of
MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
GNU General Public License for more details.
You should have received a copy of the GNU General Public License along
with this program; if not, write to the Free Software Foundation, Inc.,
51 Franklin Street, Fifth Floor, Boston, MA 02110-1301 USA.
See the full license in the file "LICENSE" in the top level distribution
directory
*************************************************************************************/
/* END LEGAL */
#ifndef HMC_MODULES_H
#define HMC_MODULES_H
/*
Define loadable, serializable modules
for the HMC execution
*/
namespace Grid {
/*
Base class for modules with parameters
*/
template < class P >
class Parametrized{
public:
typedef P Parameters;
Parametrized(Parameters Par):Par_(Par){};
template <class ReaderClass>
Parametrized(Reader<ReaderClass> & Reader){
read(Reader, section_name(), Par_);
}
protected:
Parameters Par_;
private:
// identifies the section name
// override in derived classes if needed
virtual std::string section_name(){
return std::string("parameters"); //default
}
};
/*
Lowest level abstract module class
*/
template < class Prod >
class HMCModuleBase{
public:
typedef Prod Product;
virtual Prod* getPtr() = 0;
};
//////////////////////////////////////////////
// Actions
//////////////////////////////////////////////
template <class ActionType, class APar>
class ActionModule
: public Parametrized<APar>,
public HMCModuleBase<QCD::Action<typename ActionType::GaugeField> > {
public:
typedef HMCModuleBase< QCD::Action<typename ActionType::GaugeField> > Base;
typedef typename Base::Product Product;
std::unique_ptr<ActionType> ActionPtr;
ActionModule(APar Par) : Parametrized<APar>(Par) {}
template <class ReaderClass>
ActionModule(Reader<ReaderClass>& Reader) : Parametrized<APar>(Reader){};
Product* getPtr() {
if (!ActionPtr) initialize();
return ActionPtr.get();
}
private:
virtual void initialize() = 0;
};
namespace QCD{
class WilsonGaugeActionParameters : Serializable {
public:
GRID_SERIALIZABLE_CLASS_MEMBERS(WilsonGaugeActionParameters,
RealD, beta);
};
template<class Impl>
class WilsonGModule: public ActionModule<WilsonGaugeAction<Impl>, WilsonGaugeActionParameters> {
typedef ActionModule<WilsonGaugeAction<Impl>, WilsonGaugeActionParameters> ActionBase;
using ActionBase::ActionBase;
// acquire resource
virtual void initialize(){
ActionBase::ActionPtr.reset(new WilsonGaugeAction<Impl>(ActionBase::Par_.beta));
}
};
typedef WilsonGModule<PeriodicGimplR> WilsonGMod;
}// QCD temporarily here
// use the same classed defined by Antonin, does not make sense to rewrite
// Factory is perfectly fine
// Registar must be changed because I do not want to use the ModuleFactory
/*
define
*/
typedef HMCModuleBase< QCD::Action< QCD::LatticeGaugeField > > HMCModBase;
template <class ReaderClass >
class HMCActionModuleFactory
: public Factory < HMCModBase , Reader<ReaderClass> > {
public:
typedef Reader<ReaderClass> TheReader;
// use SINGLETON FUNCTOR MACRO HERE
HMCActionModuleFactory(const HMCActionModuleFactory& e) = delete;
void operator=(const HMCActionModuleFactory& e) = delete;
static HMCActionModuleFactory& getInstance(void) {
static HMCActionModuleFactory e;
return e;
}
private:
HMCActionModuleFactory(void) = default;
};
/*
then rewrite the registar
when this is done we have all the modules that contain the pointer to the objects
(actions, integrators, checkpointers, solvers)
factory will create only the modules and prepare the parameters
when needed a pointer is released
*/
template <class T, class TheFactory>
class Registrar {
public:
Registrar(std::string className) {
// register the class factory function
TheFactory::getInstance().registerBuilder(className, [&](typename TheFactory::TheReader Reader)
{ return std::unique_ptr<T>(new T(Reader));});
}
};
Registrar<QCD::WilsonGMod, HMCActionModuleFactory<XmlReader> > __WGmodInit("WilsonGaugeAction");
}
#endif //HMC_MODULES_H

38
lib/qcd/modules/mods.h Normal file
View File

@ -0,0 +1,38 @@
/*************************************************************************************
Grid physics library, www.github.com/paboyle/Grid
Source file: ./lib/qcd/action/gauge/WilsonGaugeAction.h
Copyright (C) 2016
Author: Guido Cossu <guido.cossu@ed.ac.uk>
This program is free software; you can redistribute it and/or modify
it under the terms of the GNU General Public License as published by
the Free Software Foundation; either version 2 of the License, or
(at your option) any later version.
This program is distributed in the hope that it will be useful,
but WITHOUT ANY WARRANTY; without even the implied warranty of
MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
GNU General Public License for more details.
You should have received a copy of the GNU General Public License along
with this program; if not, write to the Free Software Foundation, Inc.,
51 Franklin Street, Fifth Floor, Boston, MA 02110-1301 USA.
See the full license in the file "LICENSE" in the top level distribution
directory
*************************************************************************************/
/* END LEGAL */
#ifndef MODS_H
#define MODS_H
// Modules files
#include <Grid/qcd/modules/Factory.h>
#include <Grid/qcd/modules/Modules.h>
#endif //MODS_H

View File

@ -69,6 +69,9 @@ namespace Grid {
class Serializable {};
// static polymorphism implemented using CRTP idiom
// Static abstract writer
@ -121,11 +124,19 @@ namespace Grid {
private:
T *upcast;
};
// type traits
// What is the vtype
template<typename T> struct isReader {
static const bool value = false;
};
template<typename T> struct isWriter {
static const bool value = false;
};
// Generic writer interface
template <typename T>
inline void push(Writer<T> &w, const std::string &s)
{
inline void push(Writer<T> &w, const std::string &s) {
w.push(s);
}

View File

@ -79,6 +79,18 @@ namespace Grid
std::string fileName_;
};
template <>
struct isReader< XmlReader > {
static const bool value = true;
};
template <>
struct isWriter<XmlWriter > {
static const bool value = true;
};
// Writer template implementation ////////////////////////////////////////////
template <typename U>
void XmlWriter::writeDefault(const std::string &s, const U &x)

View File

@ -30,62 +30,59 @@ directory
/* END LEGAL */
#include <Grid/Grid.h>
namespace Grid {
namespace QCD {
//Change here the type of reader
typedef Grid::XmlReader InputFileReader;
class HMCRunnerParameters : Serializable {
public:
GRID_SERIALIZABLE_CLASS_MEMBERS(HMCRunnerParameters,
double, beta,
int, MDsteps,
double, TrajectoryLength,
std::string, serial_seeds,
std::string, parallel_seeds,
);
HMCRunnerParameters() {}
};
}
}
int main(int argc, char **argv) {
using namespace Grid;
using namespace Grid::QCD;
Grid_init(&argc, &argv);
int threads = GridThread::GetThreads();
// Typedefs to simplify notation
typedef GenericHMCRunner<MinimumNorm2> HMCWrapper; // Uses the default minimum norm
// here make a routine to print all the relevant information on the run
std::cout << GridLogMessage << "Grid is setup to use " << threads << " threads" << std::endl;
//////////////////////////////////////////////////////////////
// Input file section
// make input file name general
// now working with the text reader but I should drop this support
// i need a structured format where every object is able
// to locate the required data: XML, JSON, YAML.
InputFileReader Reader("input.wilson_gauge.params.xml");
HMCRunnerParameters HMCPar;
read(Reader, "HMC", HMCPar);
// Typedefs to simplify notation
typedef GenericHMCRunner<MinimumNorm2> HMCWrapper; // Uses the default minimum norm
typedef Grid::XmlReader InputFileReader;
// Seeds for the random number generators
// generalise, ugly now
std::vector<int> SerSeed = strToVec<int>(HMCPar.serial_seeds);
std::vector<int> ParSeed = strToVec<int>(HMCPar.parallel_seeds);
// Reader, now not necessary
InputFileReader Reader("input.wilson_gauge.params.xml");
HMCWrapper TheHMC;
// Grid from the command line
TheHMC.Resources.AddFourDimGrid("gauge");
// Grid from the Reader
/*
GridModuleParameters GridPar(Reader);
GridFourDimModule GridMod( GridPar) ;
TheHMC.Resources.AddGrid("gauge", GridMod);
*/
// Checkpointer definition
CheckpointerParameters CPparams;
CPparams.config_prefix = "ckpoint_lat";
CPparams.rng_prefix = "ckpoint_rng";
CPparams.saveInterval = 5;
CPparams.format = "IEEE64BIG";
// here using the Reader but an overloaded function to pass the
// parameters class is provided
TheHMC.Resources.LoadBinaryCheckpointer(Reader);
// can also use the reader constructor
// CheckpointerParameters CPparams(Reader);
TheHMC.Resources.LoadBinaryCheckpointer(CPparams);
// Fill resources
// Seeds for the random number generators
// Can also initialize using the Reader
RNGModuleParameters RNGpar;
RNGpar.SerialSeed = {1,2,3,4,5};
RNGpar.ParallelSeed = {6,7,8,9,10};
TheHMC.Resources.SetRNGSeeds(RNGpar);
// Construct observables
// here there is too much indirection
PlaquetteLogger<HMCWrapper::ImplPolicy> PlaqLog("Plaquette");
TheHMC.ObservablesList.push_back(&PlaqLog);
//////////////////////////////////////////////
/////////////////////////////////////////////////////////////
// Collect actions, here use more encapsulation
@ -93,23 +90,30 @@ int main(int argc, char **argv) {
// that have a complex construction
// Gauge action
WilsonGaugeActionR Waction(HMCPar.beta);
// as module
/*
WilsonGMod::Parameters WPar;
WPar.beta = 6.0;
WilsonGMod WGMod(WPar);
auto testAction = WGMod.getPtr();// test to pass to the action set
HMCModuleBase<Action<LatticeGaugeField>>* HMB = &WGMod;
*/
// standard
RealD beta = 5.6 ;
WilsonGaugeActionR Waction(beta);
ActionLevel<HMCWrapper::Field> Level1(1);
Level1.push_back(&Waction);
//Level1.push_back(WGMod.getPtr());
TheHMC.TheAction.push_back(Level1);
/////////////////////////////////////////////////////////////
// Construct observables
PlaquetteLogger<HMCWrapper::ImplPolicy> PlaqLog("Plaquette");
TheHMC.ObservablesList.push_back(&PlaqLog);
//////////////////////////////////////////////
// Fill resources
// here we can simplify a lot if the input file is structured
// just pass the input file reader
TheHMC.Resources.AddRNGSeeds(SerSeed, ParSeed);
TheHMC.MDparameters.set(HMCPar.MDsteps, HMCPar.TrajectoryLength);
// Nest MDparameters in the HMCparameters->HMCPayload
// make it serializable
TheHMC.MDparameters.MDsteps = 20;
TheHMC.MDparameters.trajL = 1.0;
// eventually smearing here
// ...