/************************************************************************************* Grid physics library, www.github.com/paboyle/Grid Source file: ./lib/qcd/hmc/GenericHmcRunner.h Copyright (C) 2015 Copyright (C) 2016 Author: Guido Cossu This program is free software; you can redistribute it and/or modify it under the terms of the GNU General Public License as published by the Free Software Foundation; either version 2 of the License, or (at your option) any later version. This program is distributed in the hope that it will be useful, but WITHOUT ANY WARRANTY; without even the implied warranty of MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the GNU General Public License for more details. You should have received a copy of the GNU General Public License along with this program; if not, write to the Free Software Foundation, Inc., 51 Franklin Street, Fifth Floor, Boston, MA 02110-1301 USA. See the full license in the file "LICENSE" in the top level distribution directory *************************************************************************************/ /* END LEGAL */ #ifndef HMC_RESOURCE_MANAGER_H #define HMC_RESOURCE_MANAGER_H #include // One function per Checkpointer, use a macro to simplify #define RegisterLoadCheckPointerFunction(NAME) \ void Load##NAME##Checkpointer(const CheckpointerParameters& Params_) { \ if (!have_CheckPointer) { \ std::cout << GridLogDebug << "Loading Checkpointer " << #NAME \ << std::endl; \ CP = std::unique_ptr( \ new NAME##CPModule(Params_)); \ have_CheckPointer = true; \ } else { \ std::cout << GridLogError << "Checkpointer already loaded " \ << std::endl; \ exit(1); \ } \ } namespace Grid { namespace QCD { // HMC Resource manager template class HMCResourceManager { typedef HMCModuleBase< QCD::BaseHmcCheckpointer > CheckpointerBaseModule; typedef HMCModuleBase< QCD::HmcObservable > ObservableBaseModule; typedef ActionModuleBase< QCD::Action, GridModule > ActionBaseModule; // Named storage for grid pairs (std + red-black) std::unordered_map Grids; RNGModule RNGs; // SmearingModule Smearing; std::unique_ptr CP; // A vector of HmcObservable modules std::vector > ObservablesList; // A vector of HmcObservable modules std::multimap > ActionsList; std::vector multipliers; bool have_RNG; bool have_CheckPointer; // NOTE: operator << is not overloaded for std::vector // so thsi function is necessary void output_vector_string(const std::vector &vs){ for (auto &i: vs) std::cout << i << " "; std::cout << std::endl; } public: HMCResourceManager() : have_RNG(false), have_CheckPointer(false) {} template void initialize(ReaderClass &Read){ // assumes we are starting from the main node // Geometry GridModuleParameters GridPar(Read); GridFourDimModule GridMod( GridPar) ; AddGrid("gauge", GridMod); // Checkpointer auto &CPfactory = HMC_CPModuleFactory::getInstance(); Read.push("Checkpointer"); std::string cp_type; read(Read,"name", cp_type); std::cout << "Registered types " << std::endl; output_vector_string(CPfactory.getBuilderList()); CP = CPfactory.create(cp_type, Read); CP->print_parameters(); Read.pop(); have_CheckPointer = true; RNGModuleParameters RNGpar(Read); SetRNGSeeds(RNGpar); // Observables auto &ObsFactory = HMC_ObservablesModuleFactory::getInstance(); Read.push(observable_string);// here must check if existing... do { std::string obs_type; read(Read,"name", obs_type); std::cout << "Registered types " << std::endl; output_vector_string(ObsFactory.getBuilderList() ); ObservablesList.emplace_back(ObsFactory.create(obs_type, Read)); ObservablesList[ObservablesList.size() - 1]->print_parameters(); } while (Read.nextElement(observable_string)); Read.pop(); // Loop on levels if(!Read.push("Actions")){ std::cout << "Actions not found" << std::endl; exit(1); } if(!Read.push("Level")){// push must check if the node exist std::cout << "Level not found" << std::endl; exit(1); } do { fill_ActionsLevel(Read); } while(Read.push("Level")); Read.pop(); } template void GetActionSet(ActionSet& Aset){ Aset.resize(multipliers.size()); for(auto it = ActionsList.begin(); it != ActionsList.end(); it++){ (*it).second->acquireResource(Grids["gauge"]); Aset[(*it).first-1].push_back((*it).second->getPtr()); } } ////////////////////////////////////////////////////////////// // Grids ////////////////////////////////////////////////////////////// void AddGrid(std::string s, GridModule& M) { // Check for name clashes auto search = Grids.find(s); if (search != Grids.end()) { std::cout << GridLogError << "Grid with name \"" << search->first << "\" already present. Terminating\n"; exit(1); } Grids[s] = std::move(M); } // Add a named grid set, 4d shortcut void AddFourDimGrid(std::string s) { GridFourDimModule Mod; AddGrid(s, Mod); } GridCartesian* GetCartesian(std::string s = "") { if (s.empty()) s = Grids.begin()->first; std::cout << GridLogDebug << "Getting cartesian grid from: " << s << std::endl; return Grids[s].get_full(); } GridRedBlackCartesian* GetRBCartesian(std::string s = "") { if (s.empty()) s = Grids.begin()->first; std::cout << GridLogDebug << "Getting rb-cartesian grid from: " << s << std::endl; return Grids[s].get_rb(); } ////////////////////////////////////////////////////// // Random number generators ////////////////////////////////////////////////////// void AddRNGs(std::string s = "") { // Couple the RNGs to the GridModule tagged by s // the default is the first grid registered assert(Grids.size() > 0 && !have_RNG); if (s.empty()) s = Grids.begin()->first; std::cout << GridLogDebug << "Adding RNG to grid: " << s << std::endl; RNGs.set_pRNG(new GridParallelRNG(GetCartesian(s))); have_RNG = true; } void SetRNGSeeds(RNGModuleParameters& Params) { RNGs.set_RNGSeeds(Params); } GridSerialRNG& GetSerialRNG() { return RNGs.get_sRNG(); } GridParallelRNG& GetParallelRNG() { assert(have_RNG); return RNGs.get_pRNG(); } void SeedFixedIntegers() { assert(have_RNG); RNGs.seed(); } ////////////////////////////////////////////////////// // Checkpointers ////////////////////////////////////////////////////// BaseHmcCheckpointer* GetCheckPointer() { if (have_CheckPointer) return CP->getPtr(); else { std::cout << GridLogError << "Error: no checkpointer defined" << std::endl; exit(1); } } RegisterLoadCheckPointerFunction(Binary); RegisterLoadCheckPointerFunction(Nersc); #ifdef HAVE_LIME RegisterLoadCheckPointerFunction(ILDG); #endif //////////////////////////////////////////////////////// // Observables //////////////////////////////////////////////////////// void AddObservable(ObservableBaseModule *O){ // acquire resource ObservablesList.push_back(std::unique_ptr(std::move(O))); } std::vector* > GetObservables(){ std::vector* > out; for (auto &i : ObservablesList){ out.push_back(i->getPtr()); } // Add the checkpointer to the observables out.push_back(GetCheckPointer()); return out; } private: // this private template void fill_ActionsLevel(ReaderClass &Read){ // Actions set int m; Read.readDefault("multiplier",m); multipliers.push_back(m); std::cout << "Level : " << multipliers.size() << " with multiplier : " << m << std::endl; // here gauge Read.push("Action"); do{ auto &ActionFactory = HMC_ActionModuleFactory::getInstance(); std::string action_type; Read.readDefault("name", action_type); output_vector_string(ActionFactory.getBuilderList() ); ActionsList.emplace(m, ActionFactory.create(action_type, Read)); } while (Read.nextElement("Action")); ActionsList.find(m)->second->print_parameters(); Read.pop(); } }; } } #endif // HMC_RESOURCE_MANAGER_H