mirror of
				https://github.com/paboyle/Grid.git
				synced 2025-10-30 11:34:32 +00:00 
			
		
		
		
	Generalising the HMCRunner and moving parameters to the user level
This commit is contained in:
		| @@ -35,7 +35,6 @@ namespace QCD { | ||||
| // Virtual Class for HMC specific for gauge theories | ||||
| // implement a specific theory by defining the BuildTheAction | ||||
| template <class Implementation, | ||||
|           class IOCheckpointer = BinaryHmcCheckpointer<Implementation>, | ||||
|           class RepresentationsPolicy = NoHirep> | ||||
| class BinaryHmcRunnerTemplate { | ||||
|  public: | ||||
| @@ -45,7 +44,8 @@ class BinaryHmcRunnerTemplate { | ||||
|   enum StartType_t { ColdStart, HotStart, TepidStart, CheckpointStart }; | ||||
|  | ||||
|   ActionSet<Field, RepresentationsPolicy> TheAction; | ||||
|   // Add here a vector of HmcObservable  | ||||
|    | ||||
|   // A vector of HmcObservable  | ||||
|   // that can be injected from outside   | ||||
|   std::vector< HmcObservable<typename Implementation::Field>* > ObservablesList; | ||||
|  | ||||
| @@ -57,7 +57,9 @@ class BinaryHmcRunnerTemplate { | ||||
|  | ||||
|   virtual void BuildTheAction(int argc, char **argv) = 0;  // necessary? | ||||
|  | ||||
|   void Run(int argc, char **argv) { | ||||
| // add here the smearing implementation? | ||||
| template <class IOCheckpointer = BinaryHmcCheckpointer<Implementation> > | ||||
|   void Run(int argc, char **argv, IOCheckpointer &Checkpoint) { | ||||
|     StartType_t StartType = HotStart; | ||||
|  | ||||
|     std::string arg; | ||||
| @@ -120,12 +122,13 @@ class BinaryHmcRunnerTemplate { | ||||
|     IntegratorType MDynamics(UGrid, MDpar, TheAction, SmearingPolicy); | ||||
|  | ||||
|     // Checkpoint strategy    | ||||
|     /*  | ||||
|     int SaveInterval = 1; | ||||
|     std::string format = std::string("IEEE64BIG"); | ||||
|     std::string conf_prefix = std::string("ckpoint_lat"); | ||||
|     std::string rng_prefix = std::string("ckpoint_rng"); | ||||
|     IOCheckpointer Checkpoint(conf_prefix, rng_prefix, SaveInterval, format); | ||||
|      | ||||
|     */ | ||||
|  | ||||
|     HMCparameters HMCpar; | ||||
|     HMCpar.StartTrajectory = StartTraj; | ||||
| @@ -159,7 +162,7 @@ class BinaryHmcRunnerTemplate { | ||||
|     SmearingPolicy.set_Field(U); | ||||
|  | ||||
|     HybridMonteCarlo<IntegratorType> HMC(HMCpar, MDynamics, sRNG, pRNG, U); | ||||
|     HMC.AddObservable(&Checkpoint); | ||||
|     //HMC.AddObservable(&Checkpoint); | ||||
|  | ||||
|     for (int obs = 0; obs < ObservablesList.size(); obs++) | ||||
|       HMC.AddObservable(ObservablesList[obs]);  | ||||
| @@ -174,7 +177,7 @@ typedef BinaryHmcRunnerTemplate<PeriodicGimplR> BinaryHmcRunner; | ||||
| typedef BinaryHmcRunnerTemplate<PeriodicGimplF> BinaryHmcRunnerF; | ||||
| typedef BinaryHmcRunnerTemplate<PeriodicGimplD> BinaryHmcRunnerD; | ||||
|  | ||||
| typedef BinaryHmcRunnerTemplate<PeriodicGimplR, NerscHmcCheckpointer<PeriodicGimplR> > NerscTestHmcRunner; | ||||
| //typedef BinaryHmcRunnerTemplate<PeriodicGimplR, NerscHmcCheckpointer<PeriodicGimplR> > NerscTestHmcRunner; | ||||
|  | ||||
|  | ||||
| template <class RepresentationsPolicy> | ||||
| @@ -183,7 +186,8 @@ using BinaryHmcRunnerTemplateHirep = | ||||
|  | ||||
|  | ||||
|  | ||||
| typedef BinaryHmcRunnerTemplate<ScalarImplR, BinaryHmcCheckpointer<ScalarImplR>, ScalarFields> ScalarBinaryHmcRunner; | ||||
| //typedef BinaryHmcRunnerTemplate<ScalarImplR, BinaryHmcCheckpointer<ScalarImplR>, ScalarFields> ScalarBinaryHmcRunner; | ||||
|     typedef BinaryHmcRunnerTemplate<ScalarImplR, ScalarFields> ScalarBinaryHmcRunner; | ||||
| } | ||||
| } | ||||
| #endif | ||||
|   | ||||
| @@ -59,7 +59,15 @@ class HmcRunner : public ScalarBinaryHmcRunner { | ||||
|     TheAction.push_back(Level1); | ||||
|  | ||||
|  | ||||
|     Run(argc, argv); | ||||
|     // Add observables and checkpointers | ||||
|     int SaveInterval = 1; | ||||
|     std::string format = std::string("IEEE64BIG"); | ||||
|     std::string conf_prefix = std::string("ckpoint_scalar_lat"); | ||||
|     std::string rng_prefix = std::string("ckpoint_scalar_rng"); | ||||
|     BinaryHmcCheckpointer<ScalarBinaryHmcRunner::ImplPolicy> Checkpoint(conf_prefix, rng_prefix, SaveInterval, format); | ||||
|     ObservablesList.push_back(&Checkpoint); | ||||
|  | ||||
|     Run(argc, argv, Checkpoint); | ||||
|   };   | ||||
| }; | ||||
| } | ||||
|   | ||||
| @@ -7,9 +7,8 @@ Source file: ./tests/Test_hmc_WilsonFermionGauge.cc | ||||
| Copyright (C) 2015 | ||||
|  | ||||
| Author: Peter Boyle <pabobyle@ph.ed.ac.uk> | ||||
| Author: Peter Boyle <paboyle@ph.ed.ac.uk> | ||||
| Author: neo <cossu@post.kek.jp> | ||||
| Author: paboyle <paboyle@ph.ed.ac.uk> | ||||
| Author: Guido Cossu <guido.cossu@ed.ac.uk> | ||||
|  | ||||
| This program is free software; you can redistribute it and/or modify | ||||
| it under the terms of the GNU General Public License as published by | ||||
| @@ -25,7 +24,8 @@ You should have received a copy of the GNU General Public License along | ||||
| with this program; if not, write to the Free Software Foundation, Inc., | ||||
| 51 Franklin Street, Fifth Floor, Boston, MA 02110-1301 USA. | ||||
|  | ||||
| See the full license in the file "LICENSE" in the top level distribution directory | ||||
| See the full license in the file "LICENSE" in the top level distribution | ||||
| directory | ||||
| *************************************************************************************/ | ||||
| /*  END LEGAL */ | ||||
| #include <Grid/Grid.h> | ||||
| @@ -37,6 +37,23 @@ using namespace Grid::QCD; | ||||
| namespace Grid { | ||||
| namespace QCD { | ||||
|  | ||||
| class HMCRunnerParameters : Serializable { | ||||
|  public: | ||||
|   GRID_SERIALIZABLE_CLASS_MEMBERS(HMCRunnerParameters, | ||||
|                                   double, beta, | ||||
|                                   double, mass, | ||||
|                                   int, MaxCGIterations, | ||||
|                                   double, StoppingCondition, | ||||
|                                   bool, smearedAction, | ||||
|                                   int, SaveInterval, | ||||
|                                   std::string, format, | ||||
|                                   std::string, conf_prefix, | ||||
|                                   std::string, rng_prefix, | ||||
|                                   ); | ||||
|  | ||||
|   HMCRunnerParameters() {} | ||||
| }; | ||||
|  | ||||
| // Derive from the BinaryHmcRunner (templated for gauge fields) | ||||
| class HmcRunner : public BinaryHmcRunner { | ||||
|  public: | ||||
| @@ -62,6 +79,9 @@ class HmcRunner : public BinaryHmcRunner { | ||||
|     WilsonGaugeActionR Waction(5.6); | ||||
|  | ||||
|     Real mass = -0.77; | ||||
|  | ||||
|     // Can we define an overloaded operator that does not need U and initialises | ||||
|     // it with zeroes? | ||||
|     FermionAction FermOp(U, *FGrid, *FrbGrid, mass); | ||||
|  | ||||
|     ConjugateGradient<FermionField> CG(1.0e-8, 10000); | ||||
| @@ -82,10 +102,24 @@ class HmcRunner : public BinaryHmcRunner { | ||||
|     TheAction.push_back(Level2); | ||||
|  | ||||
|     // Add observables | ||||
|     PlaquetteLogger<BinaryHmcRunner::ImplPolicy> PlaqLog(std::string("plaq")); | ||||
|     ObservablesList.push_back(&PlaqLog); | ||||
|     int SaveInterval = 2; | ||||
|     std::string format = std::string("IEEE64BIG"); | ||||
|     std::string conf_prefix = std::string("ckpoint_lat"); | ||||
|     std::string rng_prefix = std::string("ckpoint_rng"); | ||||
|     BinaryHmcCheckpointer<BinaryHmcRunner::ImplPolicy> Checkpoint( | ||||
|         conf_prefix, rng_prefix, SaveInterval, format); | ||||
|     // Can implement also a specific function in the hmcrunner | ||||
|     // AddCheckpoint (...) that takes the same parameters + a string/tag | ||||
|     // defining the type of the checkpointer | ||||
|     // with tags can be implemented by overloading and no ifs | ||||
|     // Then force all checkpoint to have few common functions | ||||
|     // return an object that is then passed to the Run function | ||||
|  | ||||
|     Run(argc, argv); | ||||
|     PlaquetteLogger<BinaryHmcRunner::ImplPolicy> PlaqLog(std::string("Plaquette")); | ||||
|     ObservablesList.push_back(&PlaqLog); | ||||
|     ObservablesList.push_back(&Checkpoint); | ||||
|  | ||||
|     Run(argc, argv, Checkpoint); | ||||
|   }; | ||||
| }; | ||||
| } | ||||
|   | ||||
		Reference in New Issue
	
	Block a user