From 1189ebc8b55fa64f71402b1bbb8c3520cb1172a0 Mon Sep 17 00:00:00 2001 From: Guido Cossu Date: Thu, 5 Jan 2017 15:52:52 +0000 Subject: [PATCH] Cleaning up the checkpointers interface --- lib/qcd/hmc/HMCModules.h | 26 ++-- lib/qcd/hmc/HMCResourceManager.h | 49 +++++-- lib/qcd/hmc/checkpointers/BaseCheckpointer.h | 46 ++++-- .../hmc/checkpointers/BinaryCheckpointer.h | 45 ++---- lib/qcd/hmc/checkpointers/ILDGCheckpointer.h | 57 ++------ lib/qcd/hmc/checkpointers/NerscCheckpointer.h | 134 ++++++++---------- tests/hmc/Test_hmc_WilsonGauge_Binary.cc | 22 ++- 7 files changed, 172 insertions(+), 207 deletions(-) diff --git a/lib/qcd/hmc/HMCModules.h b/lib/qcd/hmc/HMCModules.h index 708e6732..5e7e5922 100644 --- a/lib/qcd/hmc/HMCModules.h +++ b/lib/qcd/hmc/HMCModules.h @@ -96,29 +96,23 @@ class StoutSmearingModule: public SmearingModule{ SmearedConfiguration SmearingPolicy; }; - // Checkpoint module, owns the Checkpointer template -class CheckPointModule{ - std::unique_ptr< BaseHmcCheckpointer > cp_; +class CheckPointModule { + std::unique_ptr > cp_; -public: - void set_Checkpointer(BaseHmcCheckpointer *cp){ - cp_.reset(cp); - }; - BaseHmcCheckpointer* get_CheckPointer(){ - std::cout << "Checkpointer Pointer requested : " << cp_.get() << std::endl; - return cp_.get(); - } + public: + void set_Checkpointer(BaseHmcCheckpointer* cp) { + cp_.reset(cp); + }; - void initialize(CheckpointerParameters& P){ - cp_.initialize(P); - } + BaseHmcCheckpointer* get_CheckPointer() { + return cp_.get(); + } + void initialize(CheckpointerParameters& P) { cp_.initialize(P); } }; - - } // namespace QCD } // namespace Grid diff --git a/lib/qcd/hmc/HMCResourceManager.h b/lib/qcd/hmc/HMCResourceManager.h index f748a0d2..97a19b69 100644 --- a/lib/qcd/hmc/HMCResourceManager.h +++ b/lib/qcd/hmc/HMCResourceManager.h @@ -33,22 +33,37 @@ with this program; if not, write to the Free Software Foundation, Inc., #include // One function per Checkpointer, use a macro to simplify - #define RegisterLoadCheckPointerFunction(NAME) \ - void Load##NAME##Checkpointer(CheckpointerParameters& Params_) { \ - if (!have_CheckPointer) { \ - std::cout << GridLogDebug << "Loading Checkpointer " << #NAME \ - << std::endl; \ - CP.set_Checkpointer( \ - new NAME##HmcCheckpointer(Params_)); \ - have_CheckPointer = true; \ - } else { \ - std::cout << GridLogError << "Checkpointer already loaded " \ - << std::endl; \ - exit(1); \ - } \ +#define RegisterLoadCheckPointerFunction(NAME) \ + void Load##NAME##Checkpointer(const CheckpointerParameters& Params_) { \ + if (!have_CheckPointer) { \ + std::cout << GridLogDebug << "Loading Checkpointer " << #NAME \ + << std::endl; \ + CP.set_Checkpointer( \ + new NAME##HmcCheckpointer(Params_)); \ + have_CheckPointer = true; \ + } else { \ + std::cout << GridLogError << "Checkpointer already loaded " \ + << std::endl; \ + exit(1); \ + } \ } - +// One function per Checkpointer using the reader, use a macro to simplify +#define RegisterLoadCheckPointerReaderFunction(NAME) \ + template \ + void Load##NAME##Checkpointer(Reader& Reader_) { \ + if (!have_CheckPointer) { \ + std::cout << GridLogDebug << "Loading Checkpointer " << #NAME \ + << std::endl; \ + CP.set_Checkpointer(new NAME##HmcCheckpointer( \ + CheckpointerParameters(Reader_))); \ + have_CheckPointer = true; \ + } else { \ + std::cout << GridLogError << "Checkpointer already loaded " \ + << std::endl; \ + exit(1); \ + } \ + } namespace Grid { namespace QCD { @@ -141,7 +156,11 @@ class HMCResourceManager{ RegisterLoadCheckPointerFunction (Binary); RegisterLoadCheckPointerFunction (Nersc); - RegisterLoadCheckPointerFunction (ILDG) + RegisterLoadCheckPointerFunction (ILDG); + + RegisterLoadCheckPointerReaderFunction (Binary); + RegisterLoadCheckPointerReaderFunction (Nersc); + RegisterLoadCheckPointerReaderFunction (ILDG); }; } diff --git a/lib/qcd/hmc/checkpointers/BaseCheckpointer.h b/lib/qcd/hmc/checkpointers/BaseCheckpointer.h index ffb3dab3..3939b650 100644 --- a/lib/qcd/hmc/checkpointers/BaseCheckpointer.h +++ b/lib/qcd/hmc/checkpointers/BaseCheckpointer.h @@ -30,25 +30,49 @@ directory #define BASE_CHECKPOINTER namespace Grid { -namespace QCD { + namespace QCD { -class CheckpointerParameters : Serializable { - public: - GRID_SERIALIZABLE_CLASS_MEMBERS(CheckpointerParameters, std::string, - configStem, std::string, rngStem, int, - SaveInterval, std::string, format, ); + class CheckpointerParameters : Serializable { + public: + GRID_SERIALIZABLE_CLASS_MEMBERS(CheckpointerParameters, + std::string, config_prefix, + std::string, rng_prefix, + int, saveInterval, + std::string, format, ); - CheckpointerParameters(std::string cf = "cfg", std::string rn = "rng", - int savemodulo = 1, const std::string &f = "IEEE64BIG") - : configStem(cf), rngStem(rn), SaveInterval(savemodulo), format(f){}; -}; + CheckpointerParameters(std::string cf = "cfg", std::string rn = "rng", + int savemodulo = 1, const std::string &f = "IEEE64BIG") + : config_prefix(cf), rng_prefix(rn), saveInterval(savemodulo), format(f){}; + + + template + CheckpointerParameters(ReaderClass &Reader){ + read(Reader, "Checkpointer", *this); + } + + }; ////////////////////////////////////////////////////////////////////////////// // Base class for checkpointers template class BaseHmcCheckpointer : public HmcObservable { public: - virtual void initialize(CheckpointerParameters &Params) = 0; + void build_filenames(int traj, CheckpointerParameters &Params, + std::string &conf_file, std::string &rng_file) { + { + std::ostringstream os; + os << Params.rng_prefix << "." << traj; + rng_file = os.str(); + } + + { + std::ostringstream os; + os << Params.config_prefix << "." << traj; + conf_file = os.str(); + } + } + + virtual void initialize(const CheckpointerParameters &Params) = 0; virtual void CheckpointRestore(int traj, typename Impl::Field &U, GridSerialRNG &sRNG, diff --git a/lib/qcd/hmc/checkpointers/BinaryCheckpointer.h b/lib/qcd/hmc/checkpointers/BinaryCheckpointer.h index f05f30d4..251ed042 100644 --- a/lib/qcd/hmc/checkpointers/BinaryCheckpointer.h +++ b/lib/qcd/hmc/checkpointers/BinaryCheckpointer.h @@ -2,11 +2,11 @@ Grid physics library, www.github.com/paboyle/Grid -Source file: ./lib/qcd/hmc/NerscCheckpointer.h +Source file: ./lib/qcd/hmc/BinaryCheckpointer.h -Copyright (C) 2015 +Copyright (C) 2016 -Author: paboyle +Author: Guido Cossu This program is free software; you can redistribute it and/or modify it under the terms of the GNU General Public License as published by @@ -36,7 +36,6 @@ directory namespace Grid { namespace QCD { - // Simple checkpointer, only binary file template class BinaryHmcCheckpointer : public BaseHmcCheckpointer { @@ -46,17 +45,17 @@ class BinaryHmcCheckpointer : public BaseHmcCheckpointer { public: INHERIT_FIELD_TYPES(Impl); // Gets the Field type, a Lattice object - // Extract types from the Field + // Extract types from the Field typedef typename Field::vector_object vobj; typedef typename vobj::scalar_object sobj; typedef typename getPrecision::real_scalar_type sobj_stype; typedef typename sobj::DoublePrecision sobj_double; - BinaryHmcCheckpointer(CheckpointerParameters& Params_){ - initialize(Params_); + BinaryHmcCheckpointer(const CheckpointerParameters &Params_) { + initialize(Params_); } - void initialize(CheckpointerParameters& Params_){ Params = Params_; } + void initialize(const CheckpointerParameters &Params_) { Params = Params_; } void truncate(std::string file) { std::ofstream fout(file, std::ios::out); @@ -65,19 +64,9 @@ class BinaryHmcCheckpointer : public BaseHmcCheckpointer { void TrajectoryComplete(int traj, Field &U, GridSerialRNG &sRNG, GridParallelRNG &pRNG) { - if ((traj % Params.SaveInterval) == 0) { - std::string rng; - { - std::ostringstream os; - os << Params.rngStem << "." << traj; - rng = os.str(); - } - std::string config; - { - std::ostringstream os; - os << Params.configStem << "." << traj; - config = os.str(); - } + if ((traj % Params.saveInterval) == 0) { + std::string config, rng; + this->build_filenames(traj, Params, config, rng); BinaryIO::BinarySimpleUnmunger munge; truncate(rng); @@ -93,18 +82,8 @@ class BinaryHmcCheckpointer : public BaseHmcCheckpointer { void CheckpointRestore(int traj, Field &U, GridSerialRNG &sRNG, GridParallelRNG &pRNG) { - std::string rng; - { - std::ostringstream os; - os << Params.rngStem << "." << traj; - rng = os.str(); - } - std::string config; - { - std::ostringstream os; - os << Params.configStem << "." << traj; - config = os.str(); - } + std::string config, rng; + this->build_filenames(traj, Params, config, rng); BinaryIO::BinarySimpleMunger munge; BinaryIO::readRNGSerial(sRNG, pRNG, rng, 0); diff --git a/lib/qcd/hmc/checkpointers/ILDGCheckpointer.h b/lib/qcd/hmc/checkpointers/ILDGCheckpointer.h index 38a02635..8b8f9f23 100644 --- a/lib/qcd/hmc/checkpointers/ILDGCheckpointer.h +++ b/lib/qcd/hmc/checkpointers/ILDGCheckpointer.h @@ -40,30 +40,23 @@ namespace QCD { // Only for Gauge fields template -class ILDGHmcCheckpointer - : public BaseHmcCheckpointer { +class ILDGHmcCheckpointer : public BaseHmcCheckpointer { private: - CheckpointerParameters Params; -/* - std::string configStem; - std::string rngStem; - int SaveInterval; - std::string format; -*/ + CheckpointerParameters Params; public: INHERIT_GIMPL_TYPES(Implementation); - ILDGHmcCheckpointer(CheckpointerParameters &Params_) { initialize(Params_); } + ILDGHmcCheckpointer(const CheckpointerParameters &Params_) { initialize(Params_); } - void initialize(CheckpointerParameters &Params_) { + void initialize(const CheckpointerParameters &Params_) { Params = Params_; // check here that the format is valid int ieee32big = (Params.format == std::string("IEEE32BIG")); - int ieee32 = (Params.format == std::string("IEEE32")); + int ieee32 = (Params.format == std::string("IEEE32")); int ieee64big = (Params.format == std::string("IEEE64BIG")); - int ieee64 = (Params.format == std::string("IEEE64")); + int ieee64 = (Params.format == std::string("IEEE64")); if (!(ieee64big || ieee32 || ieee32big || ieee64)) { std::cout << GridLogError << "Unrecognized file format " << Params.format @@ -78,23 +71,13 @@ class ILDGHmcCheckpointer void TrajectoryComplete(int traj, GaugeField &U, GridSerialRNG &sRNG, GridParallelRNG &pRNG) { - if ((traj % Params.SaveInterval) == 0) { - std::string rng; - { - std::ostringstream os; - os << Params.rngStem << "." << traj; - rng = os.str(); - } - std::string config; - { - std::ostringstream os; - os << Params.configStem << "." << traj; - config = os.str(); - } + if ((traj % Params.saveInterval) == 0) { + std::string config, rng; + this->build_filenames(traj, Params, config, rng); ILDGIO IO(config, ILDGwrite); BinaryIO::writeRNGSerial(sRNG, pRNG, rng, 0); - uint32_t csum = IO.writeConfiguration(U, Params.format); + uint32_t csum = IO.writeConfiguration(U, Params.format); std::cout << GridLogMessage << "Written ILDG Configuration on " << config << " checksum " << std::hex << csum << std::dec << std::endl; @@ -103,22 +86,12 @@ class ILDGHmcCheckpointer void CheckpointRestore(int traj, GaugeField &U, GridSerialRNG &sRNG, GridParallelRNG &pRNG) { - std::string rng; - { - std::ostringstream os; - os << Params.rngStem << "." << traj; - rng = os.str(); - } - std::string config; - { - std::ostringstream os; - os << Params.configStem << "." << traj; - config = os.str(); - } + std::string config, rng; + this->build_filenames(traj, Params, config, rng); ILDGIO IO(config, ILDGread); BinaryIO::readRNGSerial(sRNG, pRNG, rng, 0); - uint32_t csum = IO.readConfiguration(U);// format from the header + uint32_t csum = IO.readConfiguration(U); // format from the header std::cout << GridLogMessage << "Read ILDG Configuration from " << config << " checksum " << std::hex << csum << std::dec << std::endl; @@ -127,5 +100,5 @@ class ILDGHmcCheckpointer } } -#endif // HAVE_LIME -#endif // ILDG_CHECKPOINTER +#endif // HAVE_LIME +#endif // ILDG_CHECKPOINTER diff --git a/lib/qcd/hmc/checkpointers/NerscCheckpointer.h b/lib/qcd/hmc/checkpointers/NerscCheckpointer.h index 52803854..395369a0 100644 --- a/lib/qcd/hmc/checkpointers/NerscCheckpointer.h +++ b/lib/qcd/hmc/checkpointers/NerscCheckpointer.h @@ -1,102 +1,80 @@ - /************************************************************************************* +/************************************************************************************* - Grid physics library, www.github.com/paboyle/Grid +Grid physics library, www.github.com/paboyle/Grid - Source file: ./lib/qcd/hmc/NerscCheckpointer.h +Source file: ./lib/qcd/hmc/NerscCheckpointer.h - Copyright (C) 2015 +Copyright (C) 2015 Author: paboyle - This program is free software; you can redistribute it and/or modify - it under the terms of the GNU General Public License as published by - the Free Software Foundation; either version 2 of the License, or - (at your option) any later version. +This program is free software; you can redistribute it and/or modify +it under the terms of the GNU General Public License as published by +the Free Software Foundation; either version 2 of the License, or +(at your option) any later version. - This program is distributed in the hope that it will be useful, - but WITHOUT ANY WARRANTY; without even the implied warranty of - MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the - GNU General Public License for more details. +This program is distributed in the hope that it will be useful, +but WITHOUT ANY WARRANTY; without even the implied warranty of +MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the +GNU General Public License for more details. - You should have received a copy of the GNU General Public License along - with this program; if not, write to the Free Software Foundation, Inc., - 51 Franklin Street, Fifth Floor, Boston, MA 02110-1301 USA. +You should have received a copy of the GNU General Public License along +with this program; if not, write to the Free Software Foundation, Inc., +51 Franklin Street, Fifth Floor, Boston, MA 02110-1301 USA. - See the full license in the file "LICENSE" in the top level distribution directory - *************************************************************************************/ - /* END LEGAL */ +See the full license in the file "LICENSE" in the top level distribution +directory +*************************************************************************************/ +/* END LEGAL */ #ifndef NERSC_CHECKPOINTER #define NERSC_CHECKPOINTER -#include #include #include +#include +namespace Grid { +namespace QCD { -namespace Grid{ - namespace QCD{ - - // Only for Gauge fields - template - class NerscHmcCheckpointer : public BaseHmcCheckpointer { - private: - CheckpointerParameters Params; +// Only for Gauge fields +template +class NerscHmcCheckpointer : public BaseHmcCheckpointer { + private: + CheckpointerParameters Params; - public: - INHERIT_GIMPL_TYPES(Gimpl);// + public: + INHERIT_GIMPL_TYPES(Gimpl); // only for gauge configurations - NerscHmcCheckpointer(CheckpointerParameters& Params_){ - initialize(Params_); - } + NerscHmcCheckpointer(const CheckpointerParameters &Params_) { initialize(Params_); } - void initialize(CheckpointerParameters &Params_) { - Params = Params_; - Params.format = "IEEE64BIG"; // fixed, overwrite any other choice - } + void initialize(const CheckpointerParameters &Params_) { + Params = Params_; + Params.format = "IEEE64BIG"; // fixed, overwrite any other choice + } - void TrajectoryComplete(int traj, GaugeField &U, GridSerialRNG &sRNG, - GridParallelRNG &pRNG) { - if ((traj % Params.SaveInterval) == 0) { - std::string rng; - { - std::ostringstream os; - os << Params.rngStem << "." << traj; - rng = os.str(); - } - std::string config; - { - std::ostringstream os; - os << Params.configStem << "." << traj; - config = os.str(); - } + void TrajectoryComplete(int traj, GaugeField &U, GridSerialRNG &sRNG, + GridParallelRNG &pRNG) { + if ((traj % Params.saveInterval) == 0) { + std::string config, rng; + this->build_filenames(traj, Params, config, rng); - int precision32 = 1; - int tworow = 0; - NerscIO::writeRNGState(sRNG, pRNG, rng); - NerscIO::writeConfiguration(U, config, tworow, precision32); - } - }; + int precision32 = 1; + int tworow = 0; + NerscIO::writeRNGState(sRNG, pRNG, rng); + NerscIO::writeConfiguration(U, config, tworow, precision32); + } + }; - void CheckpointRestore(int traj, GaugeField &U, GridSerialRNG &sRNG, - GridParallelRNG &pRNG) { - std::string rng; - { - std::ostringstream os; - os << Params.rngStem << "." << traj; - rng = os.str(); - } - std::string config; - { - std::ostringstream os; - os << Params.configStem << "." << traj; - config = os.str(); - } + void CheckpointRestore(int traj, GaugeField &U, GridSerialRNG &sRNG, + GridParallelRNG &pRNG) { + std::string config, rng; + this->build_filenames(traj, Params, config, rng); - NerscField header; - NerscIO::readRNGState(sRNG, pRNG, header, rng); - NerscIO::readConfiguration(U, header, config); - }; - - }; -}} + NerscField header; + NerscIO::readRNGState(sRNG, pRNG, header, rng); + NerscIO::readConfiguration(U, header, config); + }; +}; +} +} #endif diff --git a/tests/hmc/Test_hmc_WilsonGauge_Binary.cc b/tests/hmc/Test_hmc_WilsonGauge_Binary.cc index 9e0d3e12..49f0d741 100644 --- a/tests/hmc/Test_hmc_WilsonGauge_Binary.cc +++ b/tests/hmc/Test_hmc_WilsonGauge_Binary.cc @@ -34,7 +34,7 @@ namespace Grid { namespace QCD { //Change here the type of reader - typedef Grid::TextReader InputFileReader; + typedef Grid::XmlReader InputFileReader; class HMCRunnerParameters : Serializable { @@ -42,11 +42,11 @@ namespace Grid { GRID_SERIALIZABLE_CLASS_MEMBERS(HMCRunnerParameters, double, beta, int, MDsteps, - double, TrajectorLength, - int, SaveInterval, - std::string, format, - std::string, conf_prefix, - std::string, rng_prefix, + double, TrajectoryLength, + //int, SaveInterval, + //std::string, format, + //std::string, conf_prefix, + //std::string, rng_prefix, std::string, serial_seeds, std::string, parallel_seeds, ); @@ -66,6 +66,7 @@ int main(int argc, char **argv) { // Typedefs to simplify notation typedef GenericHMCRunner HMCWrapper; // Uses the default minimum norm + // here make a routine to print all the relevant information on the run std::cout << GridLogMessage << "Grid is setup to use " << threads << " threads" << std::endl; ////////////////////////////////////////////////////////////// @@ -74,21 +75,18 @@ int main(int argc, char **argv) { // now working with the text reader but I should drop this support // i need a structured format where every object is able // to locate the required data: XML, JSON, YAML. + InputFileReader Reader("input.wilson_gauge.params.xml"); HMCRunnerParameters HMCPar; - InputFileReader Reader("input.wilson_gauge.params"); read(Reader, "HMC", HMCPar); - std::cout << GridLogMessage << HMCPar << std::endl; // Seeds for the random number generators // generalise, ugly now std::vector SerSeed = strToVec(HMCPar.serial_seeds); std::vector ParSeed = strToVec(HMCPar.parallel_seeds); - CheckpointerParameters CP_params(HMCPar.conf_prefix, HMCPar.rng_prefix, - HMCPar.SaveInterval, HMCPar.format); HMCWrapper TheHMC; TheHMC.Resources.AddFourDimGrid("gauge"); - TheHMC.Resources.LoadBinaryCheckpointer(CP_params); + TheHMC.Resources.LoadBinaryCheckpointer(Reader); ///////////////////////////////////////////////////////////// // Collect actions, here use more encapsulation @@ -112,7 +110,7 @@ int main(int argc, char **argv) { // here we can simplify a lot if the input file is structured // just pass the input file reader TheHMC.Resources.AddRNGSeeds(SerSeed, ParSeed); - TheHMC.MDparameters.set(HMCPar.MDsteps, HMCPar.TrajectorLength); + TheHMC.MDparameters.set(HMCPar.MDsteps, HMCPar.TrajectoryLength); // eventually smearing here // ...