1
0
mirror of https://github.com/paboyle/Grid.git synced 2025-04-27 14:15:55 +01:00

Moved more parameters to serialization

This commit is contained in:
Guido Cossu 2017-01-17 13:22:18 +00:00
parent 0157274762
commit 924130833e
5 changed files with 45 additions and 57 deletions

View File

@ -52,26 +52,10 @@ class HMCWrapperTemplate: public HMCBase {
template <typename S = NoSmearing<Implementation> > template <typename S = NoSmearing<Implementation> >
using IntegratorType = Integrator<Implementation, S, RepresentationsPolicy>; using IntegratorType = Integrator<Implementation, S, RepresentationsPolicy>;
enum StartType_t {
ColdStart,
HotStart,
TepidStart,
CheckpointStart,
FilenameStart
};
struct HMCPayload {
StartType_t StartType;
HMCparameters Parameters; HMCparameters Parameters;
HMCPayload() { StartType = HotStart; }
};
// These can be rationalised, some private
HMCPayload Payload; // Parameters
HMCResourceManager<Implementation> Resources; HMCResourceManager<Implementation> Resources;
IntegratorParameters MDparameters;
// The set of actions
ActionSet<Field, RepresentationsPolicy> TheAction; ActionSet<Field, RepresentationsPolicy> TheAction;
// A vector of HmcObservable that can be injected from outside // A vector of HmcObservable that can be injected from outside
@ -80,44 +64,39 @@ class HMCWrapperTemplate: public HMCBase {
void ReadCommandLine(int argc, char **argv) { void ReadCommandLine(int argc, char **argv) {
std::string arg; std::string arg;
if (GridCmdOptionExists(argv, argv + argc, "--StartType")) { if (GridCmdOptionExists(argv, argv + argc, "--StartingType")) {
arg = GridCmdOptionPayload(argv, argv + argc, "--StartType"); arg = GridCmdOptionPayload(argv, argv + argc, "--StartingType");
if (arg == "HotStart") {
Payload.StartType = HotStart; if (arg != "HotStart" && arg != "ColdStart" && arg != "TepidStart" &&
} else if (arg == "ColdStart") { arg != "CheckpointStart") {
Payload.StartType = ColdStart; std::cout << GridLogError << "Unrecognized option in --StartingType\n";
} else if (arg == "TepidStart") {
Payload.StartType = TepidStart;
} else if (arg == "CheckpointStart") {
Payload.StartType = CheckpointStart;
} else {
std::cout << GridLogError << "Unrecognized option in --StartType\n";
std::cout std::cout
<< GridLogError << GridLogError
<< "Valid [HotStart, ColdStart, TepidStart, CheckpointStart]\n"; << "Valid [HotStart, ColdStart, TepidStart, CheckpointStart]\n";
assert(0); exit(1);
} }
Parameters.StartingType = arg;
} }
if (GridCmdOptionExists(argv, argv + argc, "--StartTrajectory")) { if (GridCmdOptionExists(argv, argv + argc, "--StartTrajectory")) {
arg = GridCmdOptionPayload(argv, argv + argc, "--StartTrajectory"); arg = GridCmdOptionPayload(argv, argv + argc, "--StartTrajectory");
std::vector<int> ivec(0); std::vector<int> ivec(0);
GridCmdOptionIntVector(arg, ivec); GridCmdOptionIntVector(arg, ivec);
Payload.Parameters.StartTrajectory = ivec[0]; Parameters.StartTrajectory = ivec[0];
} }
if (GridCmdOptionExists(argv, argv + argc, "--Trajectories")) { if (GridCmdOptionExists(argv, argv + argc, "--Trajectories")) {
arg = GridCmdOptionPayload(argv, argv + argc, "--Trajectories"); arg = GridCmdOptionPayload(argv, argv + argc, "--Trajectories");
std::vector<int> ivec(0); std::vector<int> ivec(0);
GridCmdOptionIntVector(arg, ivec); GridCmdOptionIntVector(arg, ivec);
Payload.Parameters.Trajectories = ivec[0]; Parameters.Trajectories = ivec[0];
} }
if (GridCmdOptionExists(argv, argv + argc, "--Thermalizations")) { if (GridCmdOptionExists(argv, argv + argc, "--Thermalizations")) {
arg = GridCmdOptionPayload(argv, argv + argc, "--Thermalizations"); arg = GridCmdOptionPayload(argv, argv + argc, "--Thermalizations");
std::vector<int> ivec(0); std::vector<int> ivec(0);
GridCmdOptionIntVector(arg, ivec); GridCmdOptionIntVector(arg, ivec);
Payload.Parameters.NoMetropolisUntil = ivec[0]; Parameters.NoMetropolisUntil = ivec[0];
} }
} }
@ -143,30 +122,30 @@ class HMCWrapperTemplate: public HMCBase {
// Can move this outside? // Can move this outside?
typedef IntegratorType<SmearingPolicy> TheIntegrator; typedef IntegratorType<SmearingPolicy> TheIntegrator;
TheIntegrator MDynamics(UGrid, MDparameters, TheAction, Smearing); TheIntegrator MDynamics(UGrid, Parameters.MD, TheAction, Smearing);
if (Payload.StartType == HotStart) { if (Parameters.StartingType == "HotStart") {
// Hot start // Hot start
Resources.SeedFixedIntegers(); Resources.SeedFixedIntegers();
Implementation::HotConfiguration(Resources.GetParallelRNG(), U); Implementation::HotConfiguration(Resources.GetParallelRNG(), U);
} else if (Payload.StartType == ColdStart) { } else if (Parameters.StartingType == "ColdStart") {
// Cold start // Cold start
Resources.SeedFixedIntegers(); Resources.SeedFixedIntegers();
Implementation::ColdConfiguration(Resources.GetParallelRNG(), U); Implementation::ColdConfiguration(Resources.GetParallelRNG(), U);
} else if (Payload.StartType == TepidStart) { } else if (Parameters.StartingType == "TepidStart") {
// Tepid start // Tepid start
Resources.SeedFixedIntegers(); Resources.SeedFixedIntegers();
Implementation::TepidConfiguration(Resources.GetParallelRNG(), U); Implementation::TepidConfiguration(Resources.GetParallelRNG(), U);
} else if (Payload.StartType == CheckpointStart) { } else if (Parameters.StartingType == "CheckpointStart") {
// CheckpointRestart // CheckpointRestart
Resources.GetCheckPointer()->CheckpointRestore(Payload.Parameters.StartTrajectory, U, Resources.GetCheckPointer()->CheckpointRestore(Parameters.StartTrajectory, U,
Resources.GetSerialRNG(), Resources.GetSerialRNG(),
Resources.GetParallelRNG()); Resources.GetParallelRNG());
} }
Smearing.set_Field(U); Smearing.set_Field(U);
HybridMonteCarlo<TheIntegrator> HMC(Payload.Parameters, MDynamics, HybridMonteCarlo<TheIntegrator> HMC(Parameters, MDynamics,
Resources.GetSerialRNG(), Resources.GetSerialRNG(),
Resources.GetParallelRNG(), U); Resources.GetParallelRNG(), U);

View File

@ -50,24 +50,39 @@ struct HMCparameters: Serializable {
Integer, Trajectories, /* @brief Number of sweeps in this run */ Integer, Trajectories, /* @brief Number of sweeps in this run */
bool, MetropolisTest, bool, MetropolisTest,
Integer, NoMetropolisUntil, Integer, NoMetropolisUntil,
std::string, StartingType,
IntegratorParameters, MD,
) )
// nest here the MDparameters and make all serializable
HMCparameters() { HMCparameters() {
////////////////////////////// Default values ////////////////////////////// Default values
MetropolisTest = true; MetropolisTest = true;
NoMetropolisUntil = 10; NoMetropolisUntil = 10;
StartTrajectory = 0; StartTrajectory = 0;
Trajectories = 10; Trajectories = 10;
StartingType = "HotStart";
///////////////////////////////// /////////////////////////////////
} }
template <class ReaderClass >
HMCparameters(Reader<ReaderClass> & TheReader){
initialize(TheReader);
}
template < class ReaderClass >
void initialize(Reader<ReaderClass> &TheReader){
std::cout << "Reading HMC\n";
read(TheReader, "HMC", *this);
}
void print_parameters() const { void print_parameters() const {
std::cout << GridLogMessage << "[HMC parameters] Trajectories : " << Trajectories << "\n"; std::cout << GridLogMessage << "[HMC parameters] Trajectories : " << Trajectories << "\n";
std::cout << GridLogMessage << "[HMC parameters] Start trajectory : " << StartTrajectory << "\n"; std::cout << GridLogMessage << "[HMC parameters] Start trajectory : " << StartTrajectory << "\n";
std::cout << GridLogMessage << "[HMC parameters] Metropolis test (on/off): " << std::boolalpha << MetropolisTest << "\n"; std::cout << GridLogMessage << "[HMC parameters] Metropolis test (on/off): " << std::boolalpha << MetropolisTest << "\n";
std::cout << GridLogMessage << "[HMC parameters] Thermalization trajs : " << NoMetropolisUntil << "\n"; std::cout << GridLogMessage << "[HMC parameters] Thermalization trajs : " << NoMetropolisUntil << "\n";
std::cout << GridLogMessage << "[HMC parameters] Starting type : " << StartingType << "\n";
MD.print_parameters();
} }
}; };
@ -209,7 +224,6 @@ class HybridMonteCarlo {
Field Ucopy(Ucur._grid); Field Ucopy(Ucur._grid);
Params.print_parameters(); Params.print_parameters();
TheIntegrator.print_parameters();
TheIntegrator.print_actions(); TheIntegrator.print_actions();
// Actual updates (evolve a copy Ucopy then copy back eventually) // Actual updates (evolve a copy Ucopy then copy back eventually)

View File

@ -54,10 +54,12 @@ public:
template <class ReaderClass, typename std::enable_if<isReader<ReaderClass>::value, int >::type = 0 > template <class ReaderClass, typename std::enable_if<isReader<ReaderClass>::value, int >::type = 0 >
IntegratorParameters(ReaderClass & Reader){ IntegratorParameters(ReaderClass & Reader){
std::cout << "Reading integrator\n";
read(Reader, "Integrator", *this); read(Reader, "Integrator", *this);
} }
void print_parameters() { void print_parameters() const {
std::cout << GridLogMessage << "[Integrator] Type : " << name << std::endl;
std::cout << GridLogMessage << "[Integrator] Trajectory length : " << trajL << std::endl; std::cout << GridLogMessage << "[Integrator] Trajectory length : " << trajL << std::endl;
std::cout << GridLogMessage << "[Integrator] Number of MD steps : " << MDsteps << std::endl; std::cout << GridLogMessage << "[Integrator] Number of MD steps : " << MDsteps << std::endl;
std::cout << GridLogMessage << "[Integrator] Step size : " << trajL/MDsteps << std::endl; std::cout << GridLogMessage << "[Integrator] Step size : " << trajL/MDsteps << std::endl;

View File

@ -43,11 +43,12 @@ int main(int argc, char **argv) {
typedef GenericHMCRunner<MinimumNorm2> HMCWrapper; // Uses the default minimum norm typedef GenericHMCRunner<MinimumNorm2> HMCWrapper; // Uses the default minimum norm
typedef Grid::XmlReader InputFileReader; typedef Grid::XmlReader InputFileReader;
// Reader // Reader, file should come from command line
InputFileReader Reader("input.wilson_gauge.params.xml"); InputFileReader Reader("input.wilson_gauge.params.xml");
HMCWrapper TheHMC; HMCWrapper TheHMC;
TheHMC.Parameters.initialize(Reader);
TheHMC.Resources.initialize(Reader); TheHMC.Resources.initialize(Reader);
// Construct observables // Construct observables
@ -67,20 +68,13 @@ int main(int argc, char **argv) {
ActionLevel<HMCWrapper::Field> Level1(1); ActionLevel<HMCWrapper::Field> Level1(1);
Level1.push_back(&Waction); Level1.push_back(&Waction);
//Level1.push_back(WGMod.getPtr());
TheHMC.TheAction.push_back(Level1); TheHMC.TheAction.push_back(Level1);
///////////////////////////////////////////////////////////// /////////////////////////////////////////////////////////////
// Nest MDparameters in the HMCparameters->HMCPayload
// make it serializable
TheHMC.MDparameters.MDsteps = 20;
TheHMC.MDparameters.trajL = 1.0;
// eventually smearing here // eventually smearing here
// ... // ...
//////////////////////////////////////////////////////////////// ////////////////////////////////////////////////////////////////
TheHMC.ReadCommandLine(argc, argv); // these must be parameters from file
TheHMC.Run(); // no smearing TheHMC.Run(); // no smearing
Grid_finalize(); Grid_finalize();

View File

@ -83,10 +83,9 @@ int main(int argc, char **argv) {
TheHMC.TheAction.push_back(Level1); TheHMC.TheAction.push_back(Level1);
///////////////////////////////////////////////////////////// /////////////////////////////////////////////////////////////
// Nest MDparameters in the HMCparameters->HMCPayload // HMC parameters are serialisable
// make it serializable TheHMC.Parameters.MD.MDsteps = 20;
TheHMC.MDparameters.MDsteps = 20; TheHMC.Parameters.MD.trajL = 1.0;
TheHMC.MDparameters.trajL = 1.0;
TheHMC.ReadCommandLine(argc, argv); // these must be parameters from file TheHMC.ReadCommandLine(argc, argv); // these must be parameters from file
TheHMC.Run(); // no smearing TheHMC.Run(); // no smearing