diff --git a/lib/qcd/hmc/GenericHMCrunner.h b/lib/qcd/hmc/GenericHMCrunner.h index 5bfe7a90..de9b36d3 100644 --- a/lib/qcd/hmc/GenericHMCrunner.h +++ b/lib/qcd/hmc/GenericHMCrunner.h @@ -52,26 +52,10 @@ class HMCWrapperTemplate: public HMCBase { template > using IntegratorType = Integrator; - enum StartType_t { - ColdStart, - HotStart, - TepidStart, - CheckpointStart, - FilenameStart - }; - - struct HMCPayload { - StartType_t StartType; - HMCparameters Parameters; - - HMCPayload() { StartType = HotStart; } - }; - - // These can be rationalised, some private - HMCPayload Payload; // Parameters + HMCparameters Parameters; HMCResourceManager Resources; - IntegratorParameters MDparameters; + // The set of actions ActionSet TheAction; // A vector of HmcObservable that can be injected from outside @@ -80,44 +64,39 @@ class HMCWrapperTemplate: public HMCBase { void ReadCommandLine(int argc, char **argv) { std::string arg; - if (GridCmdOptionExists(argv, argv + argc, "--StartType")) { - arg = GridCmdOptionPayload(argv, argv + argc, "--StartType"); - if (arg == "HotStart") { - Payload.StartType = HotStart; - } else if (arg == "ColdStart") { - Payload.StartType = ColdStart; - } else if (arg == "TepidStart") { - Payload.StartType = TepidStart; - } else if (arg == "CheckpointStart") { - Payload.StartType = CheckpointStart; - } else { - std::cout << GridLogError << "Unrecognized option in --StartType\n"; + if (GridCmdOptionExists(argv, argv + argc, "--StartingType")) { + arg = GridCmdOptionPayload(argv, argv + argc, "--StartingType"); + + if (arg != "HotStart" && arg != "ColdStart" && arg != "TepidStart" && + arg != "CheckpointStart") { + std::cout << GridLogError << "Unrecognized option in --StartingType\n"; std::cout << GridLogError << "Valid [HotStart, ColdStart, TepidStart, CheckpointStart]\n"; - assert(0); + exit(1); } + Parameters.StartingType = arg; } if (GridCmdOptionExists(argv, argv + argc, "--StartTrajectory")) { arg = GridCmdOptionPayload(argv, argv + argc, "--StartTrajectory"); std::vector ivec(0); GridCmdOptionIntVector(arg, ivec); - Payload.Parameters.StartTrajectory = ivec[0]; + Parameters.StartTrajectory = ivec[0]; } if (GridCmdOptionExists(argv, argv + argc, "--Trajectories")) { arg = GridCmdOptionPayload(argv, argv + argc, "--Trajectories"); std::vector ivec(0); GridCmdOptionIntVector(arg, ivec); - Payload.Parameters.Trajectories = ivec[0]; + Parameters.Trajectories = ivec[0]; } if (GridCmdOptionExists(argv, argv + argc, "--Thermalizations")) { arg = GridCmdOptionPayload(argv, argv + argc, "--Thermalizations"); std::vector ivec(0); GridCmdOptionIntVector(arg, ivec); - Payload.Parameters.NoMetropolisUntil = ivec[0]; + Parameters.NoMetropolisUntil = ivec[0]; } } @@ -143,30 +122,30 @@ class HMCWrapperTemplate: public HMCBase { // Can move this outside? typedef IntegratorType TheIntegrator; - TheIntegrator MDynamics(UGrid, MDparameters, TheAction, Smearing); + TheIntegrator MDynamics(UGrid, Parameters.MD, TheAction, Smearing); - if (Payload.StartType == HotStart) { + if (Parameters.StartingType == "HotStart") { // Hot start Resources.SeedFixedIntegers(); Implementation::HotConfiguration(Resources.GetParallelRNG(), U); - } else if (Payload.StartType == ColdStart) { + } else if (Parameters.StartingType == "ColdStart") { // Cold start Resources.SeedFixedIntegers(); Implementation::ColdConfiguration(Resources.GetParallelRNG(), U); - } else if (Payload.StartType == TepidStart) { + } else if (Parameters.StartingType == "TepidStart") { // Tepid start Resources.SeedFixedIntegers(); Implementation::TepidConfiguration(Resources.GetParallelRNG(), U); - } else if (Payload.StartType == CheckpointStart) { + } else if (Parameters.StartingType == "CheckpointStart") { // CheckpointRestart - Resources.GetCheckPointer()->CheckpointRestore(Payload.Parameters.StartTrajectory, U, + Resources.GetCheckPointer()->CheckpointRestore(Parameters.StartTrajectory, U, Resources.GetSerialRNG(), Resources.GetParallelRNG()); } Smearing.set_Field(U); - HybridMonteCarlo HMC(Payload.Parameters, MDynamics, + HybridMonteCarlo HMC(Parameters, MDynamics, Resources.GetSerialRNG(), Resources.GetParallelRNG(), U); diff --git a/lib/qcd/hmc/HMC.h b/lib/qcd/hmc/HMC.h index c5e9e65b..3304ad4f 100644 --- a/lib/qcd/hmc/HMC.h +++ b/lib/qcd/hmc/HMC.h @@ -50,24 +50,39 @@ struct HMCparameters: Serializable { Integer, Trajectories, /* @brief Number of sweeps in this run */ bool, MetropolisTest, Integer, NoMetropolisUntil, + std::string, StartingType, + IntegratorParameters, MD, ) - // nest here the MDparameters and make all serializable - HMCparameters() { ////////////////////////////// Default values MetropolisTest = true; NoMetropolisUntil = 10; StartTrajectory = 0; Trajectories = 10; + StartingType = "HotStart"; ///////////////////////////////// } + template + HMCparameters(Reader & TheReader){ + initialize(TheReader); + } + + template < class ReaderClass > + void initialize(Reader &TheReader){ + std::cout << "Reading HMC\n"; + read(TheReader, "HMC", *this); + } + + void print_parameters() const { std::cout << GridLogMessage << "[HMC parameters] Trajectories : " << Trajectories << "\n"; std::cout << GridLogMessage << "[HMC parameters] Start trajectory : " << StartTrajectory << "\n"; std::cout << GridLogMessage << "[HMC parameters] Metropolis test (on/off): " << std::boolalpha << MetropolisTest << "\n"; std::cout << GridLogMessage << "[HMC parameters] Thermalization trajs : " << NoMetropolisUntil << "\n"; + std::cout << GridLogMessage << "[HMC parameters] Starting type : " << StartingType << "\n"; + MD.print_parameters(); } }; @@ -209,7 +224,6 @@ class HybridMonteCarlo { Field Ucopy(Ucur._grid); Params.print_parameters(); - TheIntegrator.print_parameters(); TheIntegrator.print_actions(); // Actual updates (evolve a copy Ucopy then copy back eventually) diff --git a/lib/qcd/hmc/integrators/Integrator.h b/lib/qcd/hmc/integrators/Integrator.h index 6101e225..6a72919e 100644 --- a/lib/qcd/hmc/integrators/Integrator.h +++ b/lib/qcd/hmc/integrators/Integrator.h @@ -54,10 +54,12 @@ public: template ::value, int >::type = 0 > IntegratorParameters(ReaderClass & Reader){ + std::cout << "Reading integrator\n"; read(Reader, "Integrator", *this); } - void print_parameters() { + void print_parameters() const { + std::cout << GridLogMessage << "[Integrator] Type : " << name << std::endl; std::cout << GridLogMessage << "[Integrator] Trajectory length : " << trajL << std::endl; std::cout << GridLogMessage << "[Integrator] Number of MD steps : " << MDsteps << std::endl; std::cout << GridLogMessage << "[Integrator] Step size : " << trajL/MDsteps << std::endl; diff --git a/tests/hmc/Test_hmc_Factories.cc b/tests/hmc/Test_hmc_Factories.cc index f33beeb1..841d009d 100644 --- a/tests/hmc/Test_hmc_Factories.cc +++ b/tests/hmc/Test_hmc_Factories.cc @@ -43,11 +43,12 @@ int main(int argc, char **argv) { typedef GenericHMCRunner HMCWrapper; // Uses the default minimum norm typedef Grid::XmlReader InputFileReader; - // Reader + // Reader, file should come from command line InputFileReader Reader("input.wilson_gauge.params.xml"); HMCWrapper TheHMC; + TheHMC.Parameters.initialize(Reader); TheHMC.Resources.initialize(Reader); // Construct observables @@ -67,20 +68,13 @@ int main(int argc, char **argv) { ActionLevel Level1(1); Level1.push_back(&Waction); - //Level1.push_back(WGMod.getPtr()); TheHMC.TheAction.push_back(Level1); ///////////////////////////////////////////////////////////// - // Nest MDparameters in the HMCparameters->HMCPayload - // make it serializable - TheHMC.MDparameters.MDsteps = 20; - TheHMC.MDparameters.trajL = 1.0; - // eventually smearing here // ... //////////////////////////////////////////////////////////////// - TheHMC.ReadCommandLine(argc, argv); // these must be parameters from file TheHMC.Run(); // no smearing Grid_finalize(); diff --git a/tests/hmc/Test_hmc_WilsonGauge_Binary.cc b/tests/hmc/Test_hmc_WilsonGauge_Binary.cc index 9c809e8b..c1e67c0b 100644 --- a/tests/hmc/Test_hmc_WilsonGauge_Binary.cc +++ b/tests/hmc/Test_hmc_WilsonGauge_Binary.cc @@ -83,10 +83,9 @@ int main(int argc, char **argv) { TheHMC.TheAction.push_back(Level1); ///////////////////////////////////////////////////////////// - // Nest MDparameters in the HMCparameters->HMCPayload - // make it serializable - TheHMC.MDparameters.MDsteps = 20; - TheHMC.MDparameters.trajL = 1.0; + // HMC parameters are serialisable + TheHMC.Parameters.MD.MDsteps = 20; + TheHMC.Parameters.MD.trajL = 1.0; TheHMC.ReadCommandLine(argc, argv); // these must be parameters from file TheHMC.Run(); // no smearing