From 293df6cd20b36f3bbacb190593a9b2228e07b064 Mon Sep 17 00:00:00 2001 From: Guido Cossu Date: Mon, 10 Oct 2016 11:49:55 +0100 Subject: [PATCH] Generalising the HMCRunner and moving parameters to the user level --- lib/qcd/hmc/GenericHMCrunner.h | 20 ++++---- tests/hmc/Test_hmc_ScalarAction.cc | 10 +++- .../hmc/Test_hmc_WilsonFermionGauge_Binary.cc | 46 ++++++++++++++++--- 3 files changed, 61 insertions(+), 15 deletions(-) diff --git a/lib/qcd/hmc/GenericHMCrunner.h b/lib/qcd/hmc/GenericHMCrunner.h index 5df3f393..8500cfb4 100644 --- a/lib/qcd/hmc/GenericHMCrunner.h +++ b/lib/qcd/hmc/GenericHMCrunner.h @@ -35,7 +35,6 @@ namespace QCD { // Virtual Class for HMC specific for gauge theories // implement a specific theory by defining the BuildTheAction template , class RepresentationsPolicy = NoHirep> class BinaryHmcRunnerTemplate { public: @@ -45,7 +44,8 @@ class BinaryHmcRunnerTemplate { enum StartType_t { ColdStart, HotStart, TepidStart, CheckpointStart }; ActionSet TheAction; - // Add here a vector of HmcObservable + + // A vector of HmcObservable // that can be injected from outside std::vector< HmcObservable* > ObservablesList; @@ -57,7 +57,9 @@ class BinaryHmcRunnerTemplate { virtual void BuildTheAction(int argc, char **argv) = 0; // necessary? - void Run(int argc, char **argv) { +// add here the smearing implementation? +template > + void Run(int argc, char **argv, IOCheckpointer &Checkpoint) { StartType_t StartType = HotStart; std::string arg; @@ -119,13 +121,14 @@ class BinaryHmcRunnerTemplate { IntegratorParameters MDpar(20, 1.0); IntegratorType MDynamics(UGrid, MDpar, TheAction, SmearingPolicy); - // Checkpoint strategy + // Checkpoint strategy + /* int SaveInterval = 1; std::string format = std::string("IEEE64BIG"); std::string conf_prefix = std::string("ckpoint_lat"); std::string rng_prefix = std::string("ckpoint_rng"); IOCheckpointer Checkpoint(conf_prefix, rng_prefix, SaveInterval, format); - + */ HMCparameters HMCpar; HMCpar.StartTrajectory = StartTraj; @@ -159,7 +162,7 @@ class BinaryHmcRunnerTemplate { SmearingPolicy.set_Field(U); HybridMonteCarlo HMC(HMCpar, MDynamics, sRNG, pRNG, U); - HMC.AddObservable(&Checkpoint); + //HMC.AddObservable(&Checkpoint); for (int obs = 0; obs < ObservablesList.size(); obs++) HMC.AddObservable(ObservablesList[obs]); @@ -174,7 +177,7 @@ typedef BinaryHmcRunnerTemplate BinaryHmcRunner; typedef BinaryHmcRunnerTemplate BinaryHmcRunnerF; typedef BinaryHmcRunnerTemplate BinaryHmcRunnerD; -typedef BinaryHmcRunnerTemplate > NerscTestHmcRunner; +//typedef BinaryHmcRunnerTemplate > NerscTestHmcRunner; template @@ -183,7 +186,8 @@ using BinaryHmcRunnerTemplateHirep = -typedef BinaryHmcRunnerTemplate, ScalarFields> ScalarBinaryHmcRunner; +//typedef BinaryHmcRunnerTemplate, ScalarFields> ScalarBinaryHmcRunner; + typedef BinaryHmcRunnerTemplate ScalarBinaryHmcRunner; } } #endif diff --git a/tests/hmc/Test_hmc_ScalarAction.cc b/tests/hmc/Test_hmc_ScalarAction.cc index 8565d001..8a2ff036 100644 --- a/tests/hmc/Test_hmc_ScalarAction.cc +++ b/tests/hmc/Test_hmc_ScalarAction.cc @@ -59,7 +59,15 @@ class HmcRunner : public ScalarBinaryHmcRunner { TheAction.push_back(Level1); - Run(argc, argv); + // Add observables and checkpointers + int SaveInterval = 1; + std::string format = std::string("IEEE64BIG"); + std::string conf_prefix = std::string("ckpoint_scalar_lat"); + std::string rng_prefix = std::string("ckpoint_scalar_rng"); + BinaryHmcCheckpointer Checkpoint(conf_prefix, rng_prefix, SaveInterval, format); + ObservablesList.push_back(&Checkpoint); + + Run(argc, argv, Checkpoint); }; }; } diff --git a/tests/hmc/Test_hmc_WilsonFermionGauge_Binary.cc b/tests/hmc/Test_hmc_WilsonFermionGauge_Binary.cc index e0227b14..3e76c015 100644 --- a/tests/hmc/Test_hmc_WilsonFermionGauge_Binary.cc +++ b/tests/hmc/Test_hmc_WilsonFermionGauge_Binary.cc @@ -7,9 +7,8 @@ Source file: ./tests/Test_hmc_WilsonFermionGauge.cc Copyright (C) 2015 Author: Peter Boyle -Author: Peter Boyle Author: neo -Author: paboyle +Author: Guido Cossu This program is free software; you can redistribute it and/or modify it under the terms of the GNU General Public License as published by @@ -25,7 +24,8 @@ You should have received a copy of the GNU General Public License along with this program; if not, write to the Free Software Foundation, Inc., 51 Franklin Street, Fifth Floor, Boston, MA 02110-1301 USA. -See the full license in the file "LICENSE" in the top level distribution directory +See the full license in the file "LICENSE" in the top level distribution +directory *************************************************************************************/ /* END LEGAL */ #include @@ -37,6 +37,23 @@ using namespace Grid::QCD; namespace Grid { namespace QCD { +class HMCRunnerParameters : Serializable { + public: + GRID_SERIALIZABLE_CLASS_MEMBERS(HMCRunnerParameters, + double, beta, + double, mass, + int, MaxCGIterations, + double, StoppingCondition, + bool, smearedAction, + int, SaveInterval, + std::string, format, + std::string, conf_prefix, + std::string, rng_prefix, + ); + + HMCRunnerParameters() {} +}; + // Derive from the BinaryHmcRunner (templated for gauge fields) class HmcRunner : public BinaryHmcRunner { public: @@ -62,6 +79,9 @@ class HmcRunner : public BinaryHmcRunner { WilsonGaugeActionR Waction(5.6); Real mass = -0.77; + + // Can we define an overloaded operator that does not need U and initialises + // it with zeroes? FermionAction FermOp(U, *FGrid, *FrbGrid, mass); ConjugateGradient CG(1.0e-8, 10000); @@ -82,10 +102,24 @@ class HmcRunner : public BinaryHmcRunner { TheAction.push_back(Level2); // Add observables - PlaquetteLogger PlaqLog(std::string("plaq")); - ObservablesList.push_back(&PlaqLog); + int SaveInterval = 2; + std::string format = std::string("IEEE64BIG"); + std::string conf_prefix = std::string("ckpoint_lat"); + std::string rng_prefix = std::string("ckpoint_rng"); + BinaryHmcCheckpointer Checkpoint( + conf_prefix, rng_prefix, SaveInterval, format); + // Can implement also a specific function in the hmcrunner + // AddCheckpoint (...) that takes the same parameters + a string/tag + // defining the type of the checkpointer + // with tags can be implemented by overloading and no ifs + // Then force all checkpoint to have few common functions + // return an object that is then passed to the Run function - Run(argc, argv); + PlaquetteLogger PlaqLog(std::string("Plaquette")); + ObservablesList.push_back(&PlaqLog); + ObservablesList.push_back(&Checkpoint); + + Run(argc, argv, Checkpoint); }; }; }