diff --git a/lib/qcd/hmc/HMCResourceManager.h b/lib/qcd/hmc/HMCResourceManager.h index 3e20a8c1..fcfaeaed 100644 --- a/lib/qcd/hmc/HMCResourceManager.h +++ b/lib/qcd/hmc/HMCResourceManager.h @@ -48,6 +48,22 @@ with this program; if not, write to the Free Software Foundation, Inc., } \ } +#define RegisterLoadCheckPointerMetadataFunction(NAME) \ + template < class Metadata > \ + void Load##NAME##Checkpointer(const CheckpointerParameters& Params_, const Metadata& M_) { \ + if (!have_CheckPointer) { \ + std::cout << GridLogDebug << "Loading Metadata Checkpointer " << #NAME \ + << std::endl; \ + CP = std::unique_ptr( \ + new NAME##CPModule(Params_, M_)); \ + have_CheckPointer = true; \ + } else { \ + std::cout << GridLogError << "Checkpointer already loaded " \ + << std::endl; \ + exit(1); \ + } \ + } + namespace Grid { namespace QCD { @@ -77,7 +93,7 @@ class HMCResourceManager { bool have_CheckPointer; // NOTE: operator << is not overloaded for std::vector - // so thsi function is necessary + // so this function is necessary void output_vector_string(const std::vector &vs){ for (auto &i: vs) std::cout << i << " "; @@ -254,6 +270,7 @@ class HMCResourceManager { RegisterLoadCheckPointerFunction(Nersc); #ifdef HAVE_LIME RegisterLoadCheckPointerFunction(ILDG); + RegisterLoadCheckPointerMetadataFunction(Scidac); #endif //////////////////////////////////////////////////////// diff --git a/lib/qcd/hmc/checkpointers/BaseCheckpointer.h b/lib/qcd/hmc/checkpointers/BaseCheckpointer.h index 9be9efca..f4ef252b 100644 --- a/lib/qcd/hmc/checkpointers/BaseCheckpointer.h +++ b/lib/qcd/hmc/checkpointers/BaseCheckpointer.h @@ -76,6 +76,14 @@ class BaseHmcCheckpointer : public HmcObservable { } } + void check_filename(const std::string &filename){ + std::ifstream f(filename.c_str()); + if(!f.good()){ + std::cout << GridLogError << "Filename " << filename << " not found. Aborting. " << std::endl; + abort(); + }; + } + virtual void initialize(const CheckpointerParameters &Params) = 0; virtual void CheckpointRestore(int traj, typename Impl::Field &U, diff --git a/lib/qcd/hmc/checkpointers/BinaryCheckpointer.h b/lib/qcd/hmc/checkpointers/BinaryCheckpointer.h index 59d655ad..025398eb 100644 --- a/lib/qcd/hmc/checkpointers/BinaryCheckpointer.h +++ b/lib/qcd/hmc/checkpointers/BinaryCheckpointer.h @@ -93,6 +93,9 @@ class BinaryHmcCheckpointer : public BaseHmcCheckpointer { void CheckpointRestore(int traj, Field &U, GridSerialRNG &sRNG, GridParallelRNG &pRNG) { std::string config, rng; this->build_filenames(traj, Params, config, rng); + this->check_filename(rng); + this->check_filename(config); + BinarySimpleMunger munge; diff --git a/lib/qcd/hmc/checkpointers/CheckPointerModules.h b/lib/qcd/hmc/checkpointers/CheckPointerModules.h index 5debedef..d49d6e72 100644 --- a/lib/qcd/hmc/checkpointers/CheckPointerModules.h +++ b/lib/qcd/hmc/checkpointers/CheckPointerModules.h @@ -136,6 +136,22 @@ class ILDGCPModule: public CheckPointerModule< ImplementationPolicy> { }; +template +class ScidacCPModule: public CheckPointerModule< ImplementationPolicy> { + typedef CheckPointerModule< ImplementationPolicy> CPBase; + Metadata M; + + //using CPBase::CPBase; // for constructors + + // acquire resource + virtual void initialize(){ + this->CheckPointPtr.reset(new ScidacHmcCheckpointer(this->Par_, M)); + } +public: + ScidacCPModule(typename CPBase::APar Par, Metadata M_):M(M_), CPBase(Par) {} + template + ScidacCPModule(Reader& Reader) : Parametrized(Reader){}; +}; #endif diff --git a/lib/qcd/hmc/checkpointers/CheckPointers.h b/lib/qcd/hmc/checkpointers/CheckPointers.h index 423ce45c..e7a5fa82 100644 --- a/lib/qcd/hmc/checkpointers/CheckPointers.h +++ b/lib/qcd/hmc/checkpointers/CheckPointers.h @@ -34,6 +34,7 @@ directory #include #include #include +#include //#include diff --git a/lib/qcd/hmc/checkpointers/ILDGCheckpointer.h b/lib/qcd/hmc/checkpointers/ILDGCheckpointer.h index 9bcc33df..f7e6b17e 100644 --- a/lib/qcd/hmc/checkpointers/ILDGCheckpointer.h +++ b/lib/qcd/hmc/checkpointers/ILDGCheckpointer.h @@ -95,6 +95,10 @@ class ILDGHmcCheckpointer : public BaseHmcCheckpointer { GridParallelRNG &pRNG) { std::string config, rng; this->build_filenames(traj, Params, config, rng); + this->check_filename(rng); + this->check_filename(config); + + uint32_t nersc_csum,scidac_csuma,scidac_csumb; BinaryIO::readRNG(sRNG, pRNG, rng, 0,nersc_csum,scidac_csuma,scidac_csumb); diff --git a/lib/qcd/hmc/checkpointers/NerscCheckpointer.h b/lib/qcd/hmc/checkpointers/NerscCheckpointer.h index a4b1b480..d452b994 100644 --- a/lib/qcd/hmc/checkpointers/NerscCheckpointer.h +++ b/lib/qcd/hmc/checkpointers/NerscCheckpointer.h @@ -69,6 +69,9 @@ class NerscHmcCheckpointer : public BaseHmcCheckpointer { GridParallelRNG &pRNG) { std::string config, rng; this->build_filenames(traj, Params, config, rng); + this->check_filename(rng); + this->check_filename(config); + FieldMetaData header; NerscIO::readRNGState(sRNG, pRNG, header, rng); diff --git a/lib/qcd/hmc/checkpointers/ScidacCheckpointer.h b/lib/qcd/hmc/checkpointers/ScidacCheckpointer.h new file mode 100644 index 00000000..0867b882 --- /dev/null +++ b/lib/qcd/hmc/checkpointers/ScidacCheckpointer.h @@ -0,0 +1,125 @@ +/************************************************************************************* + +Grid physics library, www.github.com/paboyle/Grid + +Source file: ./lib/qcd/hmc/ScidacCheckpointer.h + +Copyright (C) 2018 + +Author: Guido Cossu + +This program is free software; you can redistribute it and/or modify +it under the terms of the GNU General Public License as published by +the Free Software Foundation; either version 2 of the License, or +(at your option) any later version. + +This program is distributed in the hope that it will be useful, +but WITHOUT ANY WARRANTY; without even the implied warranty of +MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the +GNU General Public License for more details. + +You should have received a copy of the GNU General Public License along +with this program; if not, write to the Free Software Foundation, Inc., +51 Franklin Street, Fifth Floor, Boston, MA 02110-1301 USA. + +See the full license in the file "LICENSE" in the top level distribution +directory +*************************************************************************************/ +/* END LEGAL */ +#ifndef SCIDAC_CHECKPOINTER +#define SCIDAC_CHECKPOINTER + +#ifdef HAVE_LIME + +#include +#include +#include + +namespace Grid { +namespace QCD { + +// For generic fields +template +class ScidacHmcCheckpointer : public BaseHmcCheckpointer { + private: + CheckpointerParameters Params; + Metadata MData; + + typedef typename Implementation::Field Field; + + public: + //INHERIT_GIMPL_TYPES(Implementation); + + ScidacHmcCheckpointer(const CheckpointerParameters &Params_) { initialize(Params_); } + ScidacHmcCheckpointer(const CheckpointerParameters &Params_, const Metadata& M_):MData(M_) { initialize(Params_); } + + void initialize(const CheckpointerParameters &Params_) { + Params = Params_; + + // check here that the format is valid + int ieee32big = (Params.format == std::string("IEEE32BIG")); + int ieee32 = (Params.format == std::string("IEEE32")); + int ieee64big = (Params.format == std::string("IEEE64BIG")); + int ieee64 = (Params.format == std::string("IEEE64")); + + if (!(ieee64big || ieee32 || ieee32big || ieee64)) { + std::cout << GridLogError << "Unrecognized file format " << Params.format + << std::endl; + std::cout << GridLogError + << "Allowed: IEEE32BIG | IEEE32 | IEEE64BIG | IEEE64" + << std::endl; + + exit(1); + } + } + + void TrajectoryComplete(int traj, Field &U, GridSerialRNG &sRNG, + GridParallelRNG &pRNG) { + if ((traj % Params.saveInterval) == 0) { + std::string config, rng; + this->build_filenames(traj, Params, config, rng); + GridBase *grid = U._grid; + uint32_t nersc_csum,scidac_csuma,scidac_csumb; + BinaryIO::writeRNG(sRNG, pRNG, rng, 0,nersc_csum,scidac_csuma,scidac_csumb); + ScidacWriter _ScidacWriter(grid->IsBoss()); + _ScidacWriter.open(config); + _ScidacWriter.writeScidacFieldRecord(U, MData); + _ScidacWriter.close(); + + std::cout << GridLogMessage << "Written Scidac Configuration on " << config + << " checksum " << std::hex << nersc_csum<<"/" + << scidac_csuma<<"/" << scidac_csumb + << std::dec << std::endl; + } + }; + + void CheckpointRestore(int traj, Field &U, GridSerialRNG &sRNG, + GridParallelRNG &pRNG) { + std::string config, rng; + this->build_filenames(traj, Params, config, rng); + this->check_filename(rng); + this->check_filename(config); + + + uint32_t nersc_csum,scidac_csuma,scidac_csumb; + BinaryIO::readRNG(sRNG, pRNG, rng, 0,nersc_csum,scidac_csuma,scidac_csumb); + + Metadata md_content; + ScidacReader _ScidacReader; + _ScidacReader.open(config); + _ScidacReader.readScidacFieldRecord(U,md_content); // format from the header + _ScidacReader.close(); + + std::cout << GridLogMessage << "Read Scidac Configuration from " << config + << " checksum " << std::hex + << nersc_csum<<"/" + << scidac_csuma<<"/" + << scidac_csumb + << std::dec << std::endl; + }; +}; +} +} + +#endif // HAVE_LIME +#endif // ILDG_CHECKPOINTER diff --git a/tests/hmc/Test_hmc_WG_Production.cc b/tests/hmc/Test_hmc_WG_Production.cc index b99446d5..7f8d8124 100644 --- a/tests/hmc/Test_hmc_WG_Production.cc +++ b/tests/hmc/Test_hmc_WG_Production.cc @@ -33,6 +33,7 @@ namespace Grid{ GRID_SERIALIZABLE_CLASS_MEMBERS(ActionParameters, double, beta) + ActionParameters() = default; template ActionParameters(Reader& Reader){ @@ -68,11 +69,15 @@ int main(int argc, char **argv) { } Serialiser Reader(TheHMC.ParameterFile); - + // Read parameters from input file + ActionParameters WilsonPar(Reader); // Checkpointer definition CheckpointerParameters CPparams(Reader); - TheHMC.Resources.LoadNerscCheckpointer(CPparams); + //TheHMC.Resources.LoadNerscCheckpointer(CPparams); + + // Store metadata in the Scidac checkpointer + TheHMC.Resources.LoadScidacCheckpointer(CPparams, WilsonPar); RNGModuleParameters RNGpar(Reader); TheHMC.Resources.SetRNGSeeds(RNGpar); @@ -91,8 +96,6 @@ int main(int argc, char **argv) { // need wrappers of the fermionic classes // that have a complex construction // standard - ActionParameters WilsonPar(Reader); - //RealD beta = 6.4 ; WilsonGaugeActionR Waction(WilsonPar.beta); ActionLevel Level1(1);