1
0
mirror of https://github.com/paboyle/Grid.git synced 2025-04-11 22:50:45 +01:00

Cleaning up the checkpointers interface

This commit is contained in:
Guido Cossu 2017-01-05 15:52:52 +00:00
parent 1bb8578173
commit 1189ebc8b5
7 changed files with 172 additions and 207 deletions

View File

@ -96,29 +96,23 @@ class StoutSmearingModule: public SmearingModule<ImplementationPolicy>{
SmearedConfiguration<ImplementationPolicy> SmearingPolicy; SmearedConfiguration<ImplementationPolicy> SmearingPolicy;
}; };
// Checkpoint module, owns the Checkpointer // Checkpoint module, owns the Checkpointer
template <class ImplementationPolicy> template <class ImplementationPolicy>
class CheckPointModule{ class CheckPointModule {
std::unique_ptr< BaseHmcCheckpointer<ImplementationPolicy> > cp_; std::unique_ptr<BaseHmcCheckpointer<ImplementationPolicy> > cp_;
public: public:
void set_Checkpointer(BaseHmcCheckpointer<ImplementationPolicy> *cp){ void set_Checkpointer(BaseHmcCheckpointer<ImplementationPolicy>* cp) {
cp_.reset(cp); cp_.reset(cp);
}; };
BaseHmcCheckpointer<ImplementationPolicy>* get_CheckPointer(){
std::cout << "Checkpointer Pointer requested : " << cp_.get() << std::endl;
return cp_.get();
}
void initialize(CheckpointerParameters& P){ BaseHmcCheckpointer<ImplementationPolicy>* get_CheckPointer() {
cp_.initialize(P); return cp_.get();
} }
void initialize(CheckpointerParameters& P) { cp_.initialize(P); }
}; };
} // namespace QCD } // namespace QCD
} // namespace Grid } // namespace Grid

View File

@ -33,22 +33,37 @@ with this program; if not, write to the Free Software Foundation, Inc.,
#include <unordered_map> #include <unordered_map>
// One function per Checkpointer, use a macro to simplify // One function per Checkpointer, use a macro to simplify
#define RegisterLoadCheckPointerFunction(NAME) \ #define RegisterLoadCheckPointerFunction(NAME) \
void Load##NAME##Checkpointer(CheckpointerParameters& Params_) { \ void Load##NAME##Checkpointer(const CheckpointerParameters& Params_) { \
if (!have_CheckPointer) { \ if (!have_CheckPointer) { \
std::cout << GridLogDebug << "Loading Checkpointer " << #NAME \ std::cout << GridLogDebug << "Loading Checkpointer " << #NAME \
<< std::endl; \ << std::endl; \
CP.set_Checkpointer( \ CP.set_Checkpointer( \
new NAME##HmcCheckpointer<ImplementationPolicy>(Params_)); \ new NAME##HmcCheckpointer<ImplementationPolicy>(Params_)); \
have_CheckPointer = true; \ have_CheckPointer = true; \
} else { \ } else { \
std::cout << GridLogError << "Checkpointer already loaded " \ std::cout << GridLogError << "Checkpointer already loaded " \
<< std::endl; \ << std::endl; \
exit(1); \ exit(1); \
} \ } \
} }
// One function per Checkpointer using the reader, use a macro to simplify
#define RegisterLoadCheckPointerReaderFunction(NAME) \
template <class Reader> \
void Load##NAME##Checkpointer(Reader& Reader_) { \
if (!have_CheckPointer) { \
std::cout << GridLogDebug << "Loading Checkpointer " << #NAME \
<< std::endl; \
CP.set_Checkpointer(new NAME##HmcCheckpointer<ImplementationPolicy>( \
CheckpointerParameters(Reader_))); \
have_CheckPointer = true; \
} else { \
std::cout << GridLogError << "Checkpointer already loaded " \
<< std::endl; \
exit(1); \
} \
}
namespace Grid { namespace Grid {
namespace QCD { namespace QCD {
@ -141,7 +156,11 @@ class HMCResourceManager{
RegisterLoadCheckPointerFunction (Binary); RegisterLoadCheckPointerFunction (Binary);
RegisterLoadCheckPointerFunction (Nersc); RegisterLoadCheckPointerFunction (Nersc);
RegisterLoadCheckPointerFunction (ILDG) RegisterLoadCheckPointerFunction (ILDG);
RegisterLoadCheckPointerReaderFunction (Binary);
RegisterLoadCheckPointerReaderFunction (Nersc);
RegisterLoadCheckPointerReaderFunction (ILDG);
}; };
} }

View File

@ -30,25 +30,49 @@ directory
#define BASE_CHECKPOINTER #define BASE_CHECKPOINTER
namespace Grid { namespace Grid {
namespace QCD { namespace QCD {
class CheckpointerParameters : Serializable { class CheckpointerParameters : Serializable {
public: public:
GRID_SERIALIZABLE_CLASS_MEMBERS(CheckpointerParameters, std::string, GRID_SERIALIZABLE_CLASS_MEMBERS(CheckpointerParameters,
configStem, std::string, rngStem, int, std::string, config_prefix,
SaveInterval, std::string, format, ); std::string, rng_prefix,
int, saveInterval,
std::string, format, );
CheckpointerParameters(std::string cf = "cfg", std::string rn = "rng", CheckpointerParameters(std::string cf = "cfg", std::string rn = "rng",
int savemodulo = 1, const std::string &f = "IEEE64BIG") int savemodulo = 1, const std::string &f = "IEEE64BIG")
: configStem(cf), rngStem(rn), SaveInterval(savemodulo), format(f){}; : config_prefix(cf), rng_prefix(rn), saveInterval(savemodulo), format(f){};
};
template<class ReaderClass>
CheckpointerParameters(ReaderClass &Reader){
read(Reader, "Checkpointer", *this);
}
};
////////////////////////////////////////////////////////////////////////////// //////////////////////////////////////////////////////////////////////////////
// Base class for checkpointers // Base class for checkpointers
template <class Impl> template <class Impl>
class BaseHmcCheckpointer : public HmcObservable<typename Impl::Field> { class BaseHmcCheckpointer : public HmcObservable<typename Impl::Field> {
public: public:
virtual void initialize(CheckpointerParameters &Params) = 0; void build_filenames(int traj, CheckpointerParameters &Params,
std::string &conf_file, std::string &rng_file) {
{
std::ostringstream os;
os << Params.rng_prefix << "." << traj;
rng_file = os.str();
}
{
std::ostringstream os;
os << Params.config_prefix << "." << traj;
conf_file = os.str();
}
}
virtual void initialize(const CheckpointerParameters &Params) = 0;
virtual void CheckpointRestore(int traj, typename Impl::Field &U, virtual void CheckpointRestore(int traj, typename Impl::Field &U,
GridSerialRNG &sRNG, GridSerialRNG &sRNG,

View File

@ -2,11 +2,11 @@
Grid physics library, www.github.com/paboyle/Grid Grid physics library, www.github.com/paboyle/Grid
Source file: ./lib/qcd/hmc/NerscCheckpointer.h Source file: ./lib/qcd/hmc/BinaryCheckpointer.h
Copyright (C) 2015 Copyright (C) 2016
Author: paboyle <paboyle@ph.ed.ac.uk> Author: Guido Cossu <guido.cossu@ed.ac.uk>
This program is free software; you can redistribute it and/or modify This program is free software; you can redistribute it and/or modify
it under the terms of the GNU General Public License as published by it under the terms of the GNU General Public License as published by
@ -36,7 +36,6 @@ directory
namespace Grid { namespace Grid {
namespace QCD { namespace QCD {
// Simple checkpointer, only binary file // Simple checkpointer, only binary file
template <class Impl> template <class Impl>
class BinaryHmcCheckpointer : public BaseHmcCheckpointer<Impl> { class BinaryHmcCheckpointer : public BaseHmcCheckpointer<Impl> {
@ -46,17 +45,17 @@ class BinaryHmcCheckpointer : public BaseHmcCheckpointer<Impl> {
public: public:
INHERIT_FIELD_TYPES(Impl); // Gets the Field type, a Lattice object INHERIT_FIELD_TYPES(Impl); // Gets the Field type, a Lattice object
// Extract types from the Field // Extract types from the Field
typedef typename Field::vector_object vobj; typedef typename Field::vector_object vobj;
typedef typename vobj::scalar_object sobj; typedef typename vobj::scalar_object sobj;
typedef typename getPrecision<sobj>::real_scalar_type sobj_stype; typedef typename getPrecision<sobj>::real_scalar_type sobj_stype;
typedef typename sobj::DoublePrecision sobj_double; typedef typename sobj::DoublePrecision sobj_double;
BinaryHmcCheckpointer(CheckpointerParameters& Params_){ BinaryHmcCheckpointer(const CheckpointerParameters &Params_) {
initialize(Params_); initialize(Params_);
} }
void initialize(CheckpointerParameters& Params_){ Params = Params_; } void initialize(const CheckpointerParameters &Params_) { Params = Params_; }
void truncate(std::string file) { void truncate(std::string file) {
std::ofstream fout(file, std::ios::out); std::ofstream fout(file, std::ios::out);
@ -65,19 +64,9 @@ class BinaryHmcCheckpointer : public BaseHmcCheckpointer<Impl> {
void TrajectoryComplete(int traj, Field &U, GridSerialRNG &sRNG, void TrajectoryComplete(int traj, Field &U, GridSerialRNG &sRNG,
GridParallelRNG &pRNG) { GridParallelRNG &pRNG) {
if ((traj % Params.SaveInterval) == 0) { if ((traj % Params.saveInterval) == 0) {
std::string rng; std::string config, rng;
{ this->build_filenames(traj, Params, config, rng);
std::ostringstream os;
os << Params.rngStem << "." << traj;
rng = os.str();
}
std::string config;
{
std::ostringstream os;
os << Params.configStem << "." << traj;
config = os.str();
}
BinaryIO::BinarySimpleUnmunger<sobj_double, sobj> munge; BinaryIO::BinarySimpleUnmunger<sobj_double, sobj> munge;
truncate(rng); truncate(rng);
@ -93,18 +82,8 @@ class BinaryHmcCheckpointer : public BaseHmcCheckpointer<Impl> {
void CheckpointRestore(int traj, Field &U, GridSerialRNG &sRNG, void CheckpointRestore(int traj, Field &U, GridSerialRNG &sRNG,
GridParallelRNG &pRNG) { GridParallelRNG &pRNG) {
std::string rng; std::string config, rng;
{ this->build_filenames(traj, Params, config, rng);
std::ostringstream os;
os << Params.rngStem << "." << traj;
rng = os.str();
}
std::string config;
{
std::ostringstream os;
os << Params.configStem << "." << traj;
config = os.str();
}
BinaryIO::BinarySimpleMunger<sobj_double, sobj> munge; BinaryIO::BinarySimpleMunger<sobj_double, sobj> munge;
BinaryIO::readRNGSerial(sRNG, pRNG, rng, 0); BinaryIO::readRNGSerial(sRNG, pRNG, rng, 0);

View File

@ -40,30 +40,23 @@ namespace QCD {
// Only for Gauge fields // Only for Gauge fields
template <class Implementation> template <class Implementation>
class ILDGHmcCheckpointer class ILDGHmcCheckpointer : public BaseHmcCheckpointer<Implementation> {
: public BaseHmcCheckpointer<Implementation> {
private: private:
CheckpointerParameters Params; CheckpointerParameters Params;
/*
std::string configStem;
std::string rngStem;
int SaveInterval;
std::string format;
*/
public: public:
INHERIT_GIMPL_TYPES(Implementation); INHERIT_GIMPL_TYPES(Implementation);
ILDGHmcCheckpointer(CheckpointerParameters &Params_) { initialize(Params_); } ILDGHmcCheckpointer(const CheckpointerParameters &Params_) { initialize(Params_); }
void initialize(CheckpointerParameters &Params_) { void initialize(const CheckpointerParameters &Params_) {
Params = Params_; Params = Params_;
// check here that the format is valid // check here that the format is valid
int ieee32big = (Params.format == std::string("IEEE32BIG")); int ieee32big = (Params.format == std::string("IEEE32BIG"));
int ieee32 = (Params.format == std::string("IEEE32")); int ieee32 = (Params.format == std::string("IEEE32"));
int ieee64big = (Params.format == std::string("IEEE64BIG")); int ieee64big = (Params.format == std::string("IEEE64BIG"));
int ieee64 = (Params.format == std::string("IEEE64")); int ieee64 = (Params.format == std::string("IEEE64"));
if (!(ieee64big || ieee32 || ieee32big || ieee64)) { if (!(ieee64big || ieee32 || ieee32big || ieee64)) {
std::cout << GridLogError << "Unrecognized file format " << Params.format std::cout << GridLogError << "Unrecognized file format " << Params.format
@ -78,23 +71,13 @@ class ILDGHmcCheckpointer
void TrajectoryComplete(int traj, GaugeField &U, GridSerialRNG &sRNG, void TrajectoryComplete(int traj, GaugeField &U, GridSerialRNG &sRNG,
GridParallelRNG &pRNG) { GridParallelRNG &pRNG) {
if ((traj % Params.SaveInterval) == 0) { if ((traj % Params.saveInterval) == 0) {
std::string rng; std::string config, rng;
{ this->build_filenames(traj, Params, config, rng);
std::ostringstream os;
os << Params.rngStem << "." << traj;
rng = os.str();
}
std::string config;
{
std::ostringstream os;
os << Params.configStem << "." << traj;
config = os.str();
}
ILDGIO IO(config, ILDGwrite); ILDGIO IO(config, ILDGwrite);
BinaryIO::writeRNGSerial(sRNG, pRNG, rng, 0); BinaryIO::writeRNGSerial(sRNG, pRNG, rng, 0);
uint32_t csum = IO.writeConfiguration(U, Params.format); uint32_t csum = IO.writeConfiguration(U, Params.format);
std::cout << GridLogMessage << "Written ILDG Configuration on " << config std::cout << GridLogMessage << "Written ILDG Configuration on " << config
<< " checksum " << std::hex << csum << std::dec << std::endl; << " checksum " << std::hex << csum << std::dec << std::endl;
@ -103,22 +86,12 @@ class ILDGHmcCheckpointer
void CheckpointRestore(int traj, GaugeField &U, GridSerialRNG &sRNG, void CheckpointRestore(int traj, GaugeField &U, GridSerialRNG &sRNG,
GridParallelRNG &pRNG) { GridParallelRNG &pRNG) {
std::string rng; std::string config, rng;
{ this->build_filenames(traj, Params, config, rng);
std::ostringstream os;
os << Params.rngStem << "." << traj;
rng = os.str();
}
std::string config;
{
std::ostringstream os;
os << Params.configStem << "." << traj;
config = os.str();
}
ILDGIO IO(config, ILDGread); ILDGIO IO(config, ILDGread);
BinaryIO::readRNGSerial(sRNG, pRNG, rng, 0); BinaryIO::readRNGSerial(sRNG, pRNG, rng, 0);
uint32_t csum = IO.readConfiguration(U);// format from the header uint32_t csum = IO.readConfiguration(U); // format from the header
std::cout << GridLogMessage << "Read ILDG Configuration from " << config std::cout << GridLogMessage << "Read ILDG Configuration from " << config
<< " checksum " << std::hex << csum << std::dec << std::endl; << " checksum " << std::hex << csum << std::dec << std::endl;
@ -127,5 +100,5 @@ class ILDGHmcCheckpointer
} }
} }
#endif // HAVE_LIME #endif // HAVE_LIME
#endif // ILDG_CHECKPOINTER #endif // ILDG_CHECKPOINTER

View File

@ -1,102 +1,80 @@
/************************************************************************************* /*************************************************************************************
Grid physics library, www.github.com/paboyle/Grid Grid physics library, www.github.com/paboyle/Grid
Source file: ./lib/qcd/hmc/NerscCheckpointer.h Source file: ./lib/qcd/hmc/NerscCheckpointer.h
Copyright (C) 2015 Copyright (C) 2015
Author: paboyle <paboyle@ph.ed.ac.uk> Author: paboyle <paboyle@ph.ed.ac.uk>
This program is free software; you can redistribute it and/or modify This program is free software; you can redistribute it and/or modify
it under the terms of the GNU General Public License as published by it under the terms of the GNU General Public License as published by
the Free Software Foundation; either version 2 of the License, or the Free Software Foundation; either version 2 of the License, or
(at your option) any later version. (at your option) any later version.
This program is distributed in the hope that it will be useful, This program is distributed in the hope that it will be useful,
but WITHOUT ANY WARRANTY; without even the implied warranty of but WITHOUT ANY WARRANTY; without even the implied warranty of
MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
GNU General Public License for more details. GNU General Public License for more details.
You should have received a copy of the GNU General Public License along You should have received a copy of the GNU General Public License along
with this program; if not, write to the Free Software Foundation, Inc., with this program; if not, write to the Free Software Foundation, Inc.,
51 Franklin Street, Fifth Floor, Boston, MA 02110-1301 USA. 51 Franklin Street, Fifth Floor, Boston, MA 02110-1301 USA.
See the full license in the file "LICENSE" in the top level distribution directory See the full license in the file "LICENSE" in the top level distribution
*************************************************************************************/ directory
/* END LEGAL */ *************************************************************************************/
/* END LEGAL */
#ifndef NERSC_CHECKPOINTER #ifndef NERSC_CHECKPOINTER
#define NERSC_CHECKPOINTER #define NERSC_CHECKPOINTER
#include <string>
#include <iostream> #include <iostream>
#include <sstream> #include <sstream>
#include <string>
namespace Grid {
namespace QCD {
namespace Grid{ // Only for Gauge fields
namespace QCD{ template <class Gimpl>
class NerscHmcCheckpointer : public BaseHmcCheckpointer<Gimpl> {
// Only for Gauge fields private:
template<class Gimpl> CheckpointerParameters Params;
class NerscHmcCheckpointer : public BaseHmcCheckpointer<Gimpl> {
private:
CheckpointerParameters Params;
public: public:
INHERIT_GIMPL_TYPES(Gimpl);// INHERIT_GIMPL_TYPES(Gimpl); // only for gauge configurations
NerscHmcCheckpointer(CheckpointerParameters& Params_){ NerscHmcCheckpointer(const CheckpointerParameters &Params_) { initialize(Params_); }
initialize(Params_);
}
void initialize(CheckpointerParameters &Params_) { void initialize(const CheckpointerParameters &Params_) {
Params = Params_; Params = Params_;
Params.format = "IEEE64BIG"; // fixed, overwrite any other choice Params.format = "IEEE64BIG"; // fixed, overwrite any other choice
} }
void TrajectoryComplete(int traj, GaugeField &U, GridSerialRNG &sRNG, void TrajectoryComplete(int traj, GaugeField &U, GridSerialRNG &sRNG,
GridParallelRNG &pRNG) { GridParallelRNG &pRNG) {
if ((traj % Params.SaveInterval) == 0) { if ((traj % Params.saveInterval) == 0) {
std::string rng; std::string config, rng;
{ this->build_filenames(traj, Params, config, rng);
std::ostringstream os;
os << Params.rngStem << "." << traj;
rng = os.str();
}
std::string config;
{
std::ostringstream os;
os << Params.configStem << "." << traj;
config = os.str();
}
int precision32 = 1; int precision32 = 1;
int tworow = 0; int tworow = 0;
NerscIO::writeRNGState(sRNG, pRNG, rng); NerscIO::writeRNGState(sRNG, pRNG, rng);
NerscIO::writeConfiguration(U, config, tworow, precision32); NerscIO::writeConfiguration(U, config, tworow, precision32);
} }
}; };
void CheckpointRestore(int traj, GaugeField &U, GridSerialRNG &sRNG, void CheckpointRestore(int traj, GaugeField &U, GridSerialRNG &sRNG,
GridParallelRNG &pRNG) { GridParallelRNG &pRNG) {
std::string rng; std::string config, rng;
{ this->build_filenames(traj, Params, config, rng);
std::ostringstream os;
os << Params.rngStem << "." << traj;
rng = os.str();
}
std::string config;
{
std::ostringstream os;
os << Params.configStem << "." << traj;
config = os.str();
}
NerscField header; NerscField header;
NerscIO::readRNGState(sRNG, pRNG, header, rng); NerscIO::readRNGState(sRNG, pRNG, header, rng);
NerscIO::readConfiguration(U, header, config); NerscIO::readConfiguration(U, header, config);
}; };
};
}; }
}} }
#endif #endif

View File

@ -34,7 +34,7 @@ namespace Grid {
namespace QCD { namespace QCD {
//Change here the type of reader //Change here the type of reader
typedef Grid::TextReader InputFileReader; typedef Grid::XmlReader InputFileReader;
class HMCRunnerParameters : Serializable { class HMCRunnerParameters : Serializable {
@ -42,11 +42,11 @@ namespace Grid {
GRID_SERIALIZABLE_CLASS_MEMBERS(HMCRunnerParameters, GRID_SERIALIZABLE_CLASS_MEMBERS(HMCRunnerParameters,
double, beta, double, beta,
int, MDsteps, int, MDsteps,
double, TrajectorLength, double, TrajectoryLength,
int, SaveInterval, //int, SaveInterval,
std::string, format, //std::string, format,
std::string, conf_prefix, //std::string, conf_prefix,
std::string, rng_prefix, //std::string, rng_prefix,
std::string, serial_seeds, std::string, serial_seeds,
std::string, parallel_seeds, std::string, parallel_seeds,
); );
@ -66,6 +66,7 @@ int main(int argc, char **argv) {
// Typedefs to simplify notation // Typedefs to simplify notation
typedef GenericHMCRunner<MinimumNorm2> HMCWrapper; // Uses the default minimum norm typedef GenericHMCRunner<MinimumNorm2> HMCWrapper; // Uses the default minimum norm
// here make a routine to print all the relevant information on the run
std::cout << GridLogMessage << "Grid is setup to use " << threads << " threads" << std::endl; std::cout << GridLogMessage << "Grid is setup to use " << threads << " threads" << std::endl;
////////////////////////////////////////////////////////////// //////////////////////////////////////////////////////////////
@ -74,21 +75,18 @@ int main(int argc, char **argv) {
// now working with the text reader but I should drop this support // now working with the text reader but I should drop this support
// i need a structured format where every object is able // i need a structured format where every object is able
// to locate the required data: XML, JSON, YAML. // to locate the required data: XML, JSON, YAML.
InputFileReader Reader("input.wilson_gauge.params.xml");
HMCRunnerParameters HMCPar; HMCRunnerParameters HMCPar;
InputFileReader Reader("input.wilson_gauge.params");
read(Reader, "HMC", HMCPar); read(Reader, "HMC", HMCPar);
std::cout << GridLogMessage << HMCPar << std::endl;
// Seeds for the random number generators // Seeds for the random number generators
// generalise, ugly now // generalise, ugly now
std::vector<int> SerSeed = strToVec<int>(HMCPar.serial_seeds); std::vector<int> SerSeed = strToVec<int>(HMCPar.serial_seeds);
std::vector<int> ParSeed = strToVec<int>(HMCPar.parallel_seeds); std::vector<int> ParSeed = strToVec<int>(HMCPar.parallel_seeds);
CheckpointerParameters CP_params(HMCPar.conf_prefix, HMCPar.rng_prefix,
HMCPar.SaveInterval, HMCPar.format);
HMCWrapper TheHMC; HMCWrapper TheHMC;
TheHMC.Resources.AddFourDimGrid("gauge"); TheHMC.Resources.AddFourDimGrid("gauge");
TheHMC.Resources.LoadBinaryCheckpointer(CP_params); TheHMC.Resources.LoadBinaryCheckpointer(Reader);
///////////////////////////////////////////////////////////// /////////////////////////////////////////////////////////////
// Collect actions, here use more encapsulation // Collect actions, here use more encapsulation
@ -112,7 +110,7 @@ int main(int argc, char **argv) {
// here we can simplify a lot if the input file is structured // here we can simplify a lot if the input file is structured
// just pass the input file reader // just pass the input file reader
TheHMC.Resources.AddRNGSeeds(SerSeed, ParSeed); TheHMC.Resources.AddRNGSeeds(SerSeed, ParSeed);
TheHMC.MDparameters.set(HMCPar.MDsteps, HMCPar.TrajectorLength); TheHMC.MDparameters.set(HMCPar.MDsteps, HMCPar.TrajectoryLength);
// eventually smearing here // eventually smearing here
// ... // ...