mirror of
https://github.com/paboyle/Grid.git
synced 2025-12-22 13:44:29 +00:00
Adding a resource manager
This commit is contained in:
@@ -30,176 +30,167 @@ with this program; if not, write to the Free Software Foundation, Inc.,
|
||||
#ifndef GRID_GENERIC_HMC_RUNNER
|
||||
#define GRID_GENERIC_HMC_RUNNER
|
||||
|
||||
|
||||
#include <unordered_map>
|
||||
|
||||
namespace Grid {
|
||||
namespace QCD {
|
||||
|
||||
// Virtual Class for HMC specific for gauge theories
|
||||
// implement a specific theory by defining the BuildTheAction
|
||||
template <class Implementation, class RepresentationsPolicy = NoHirep>
|
||||
template <class Implementation,
|
||||
template < typename, typename, typename > class Integrator,
|
||||
class RepresentationsPolicy = NoHirep >
|
||||
class BinaryHmcRunnerTemplate {
|
||||
public:
|
||||
INHERIT_FIELD_TYPES(Implementation);
|
||||
typedef Implementation ImplPolicy;
|
||||
INHERIT_FIELD_TYPES(Implementation);
|
||||
typedef Implementation ImplPolicy; // visible from outside
|
||||
template < typename S = NoSmearing<Implementation> >
|
||||
using IntegratorType = Integrator<Implementation,S,RepresentationsPolicy>;
|
||||
|
||||
enum StartType_t { ColdStart,
|
||||
HotStart,
|
||||
TepidStart,
|
||||
CheckpointStart };
|
||||
enum StartType_t
|
||||
{
|
||||
ColdStart,
|
||||
HotStart,
|
||||
TepidStart,
|
||||
CheckpointStart,
|
||||
FilenameStart
|
||||
};
|
||||
|
||||
ActionSet<Field, RepresentationsPolicy> TheAction;
|
||||
struct HMCPayload
|
||||
{
|
||||
StartType_t StartType;
|
||||
HMCparameters Parameters;
|
||||
|
||||
// A vector of HmcObservable
|
||||
// that can be injected from outside
|
||||
std::vector<HmcObservable<typename Implementation::Field> *>
|
||||
ObservablesList;
|
||||
HMCPayload() { StartType = HotStart; }
|
||||
};
|
||||
|
||||
IntegratorParameters MDparameters;
|
||||
// These can be rationalised, some private
|
||||
HMCPayload Payload; // Parameters
|
||||
HMCResourceManager Resources;
|
||||
IntegratorParameters MDparameters;
|
||||
|
||||
GridCartesian * UGrid;
|
||||
GridRedBlackCartesian *UrbGrid;
|
||||
ActionSet<Field, RepresentationsPolicy> TheAction;
|
||||
|
||||
// A vector of HmcObservable that can be injected from outside
|
||||
std::vector<HmcObservable<typename Implementation::Field> *> ObservablesList;
|
||||
|
||||
//GridCartesian * UGrid;
|
||||
|
||||
// These two are unnecessary, eliminate
|
||||
GridCartesian * FGrid;
|
||||
GridRedBlackCartesian *FrbGrid;
|
||||
// GridRedBlackCartesian *UrbGrid;
|
||||
// GridCartesian * FGrid;
|
||||
// GridRedBlackCartesian *FrbGrid;
|
||||
|
||||
std::vector<int> SerialSeed;
|
||||
std::vector<int> ParallelSeed;
|
||||
void ReadCommandLine(int argc, char ** argv) {
|
||||
std::string arg;
|
||||
|
||||
void RNGSeeds(std::vector<int> S, std::vector<int> P) {
|
||||
SerialSeed = S;
|
||||
ParallelSeed = P;
|
||||
}
|
||||
if (GridCmdOptionExists(argv, argv + argc, "--StartType")) {
|
||||
arg = GridCmdOptionPayload(argv, argv + argc, "--StartType");
|
||||
if (arg == "HotStart") {
|
||||
Payload.StartType = HotStart;
|
||||
} else if (arg == "ColdStart") {
|
||||
Payload.StartType = ColdStart;
|
||||
} else if (arg == "TepidStart") {
|
||||
Payload.StartType = TepidStart;
|
||||
} else if (arg == "CheckpointStart") {
|
||||
Payload.StartType = CheckpointStart;
|
||||
} else {
|
||||
std::cout << GridLogError << "Unrecognized option in --StartType\n";
|
||||
std::cout
|
||||
<< GridLogError
|
||||
<< "Valid [HotStart, ColdStart, TepidStart, CheckpointStart]\n";
|
||||
assert(0);
|
||||
}
|
||||
}
|
||||
|
||||
virtual void BuildTheAction(int argc, char **argv) = 0; // necessary?
|
||||
if (GridCmdOptionExists(argv, argv + argc, "--StartTrajectory")) {
|
||||
arg = GridCmdOptionPayload(argv, argv + argc, "--StartTrajectory");
|
||||
std::vector<int> ivec(0);
|
||||
GridCmdOptionIntVector(arg, ivec);
|
||||
Payload.Parameters.StartTrajectory = ivec[0];
|
||||
}
|
||||
|
||||
// A couple of wrapper classes
|
||||
template <class IOCheckpointer>
|
||||
void Run(int argc, char **argv, IOCheckpointer &Checkpoint) {
|
||||
NoSmearing<Implementation> S;
|
||||
Runner(argc, argv, Checkpoint, S);
|
||||
}
|
||||
if (GridCmdOptionExists(argv, argv + argc, "--Trajectories")) {
|
||||
arg = GridCmdOptionPayload(argv, argv + argc, "--Trajectories");
|
||||
std::vector<int> ivec(0);
|
||||
GridCmdOptionIntVector(arg, ivec);
|
||||
Payload.Parameters.Trajectories = ivec[0];
|
||||
}
|
||||
|
||||
template <class IOCheckpointer, class SmearingPolicy>
|
||||
void Run(int argc, char **argv, IOCheckpointer &CP, SmearingPolicy &S) {
|
||||
Runner(argc, argv, CP, S);
|
||||
}
|
||||
//////////////////////////////
|
||||
if (GridCmdOptionExists(argv, argv + argc, "--Thermalizations")) {
|
||||
arg = GridCmdOptionPayload(argv, argv + argc, "--Thermalizations");
|
||||
std::vector<int> ivec(0);
|
||||
GridCmdOptionIntVector(arg, ivec);
|
||||
Payload.Parameters.NoMetropolisUntil = ivec[0];
|
||||
}
|
||||
|
||||
|
||||
}
|
||||
|
||||
// A couple of wrapper functions
|
||||
template <class IOCheckpointer> void Run(IOCheckpointer &CP) {
|
||||
NoSmearing<Implementation> S;
|
||||
Runner(CP, S);
|
||||
}
|
||||
|
||||
template <class SmearingPolicy, class IOCheckpointer>
|
||||
void Runner(int argc,
|
||||
char ** argv,
|
||||
IOCheckpointer &Checkpoint,
|
||||
SmearingPolicy &Smearing) {
|
||||
StartType_t StartType = HotStart;
|
||||
template <class IOCheckpointer, class SmearingPolicy> void Run(IOCheckpointer &CP, SmearingPolicy &S) {
|
||||
Runner(CP, S);
|
||||
}
|
||||
|
||||
std::string arg;
|
||||
//////////////////////////////////////////////////////////////////
|
||||
|
||||
if (GridCmdOptionExists(argv, argv + argc, "--StartType")) {
|
||||
arg = GridCmdOptionPayload(argv, argv + argc, "--StartType");
|
||||
if (arg == "HotStart") {
|
||||
StartType = HotStart;
|
||||
} else if (arg == "ColdStart") {
|
||||
StartType = ColdStart;
|
||||
} else if (arg == "TepidStart") {
|
||||
StartType = TepidStart;
|
||||
} else if (arg == "CheckpointStart") {
|
||||
StartType = CheckpointStart;
|
||||
} else {
|
||||
std::cout << GridLogError << "Unrecognized option in --StartType\n";
|
||||
std::cout
|
||||
<< GridLogError
|
||||
<< "Valid [HotStart, ColdStart, TepidStart, CheckpointStart]\n";
|
||||
assert(0);
|
||||
}
|
||||
}
|
||||
private:
|
||||
template <class SmearingPolicy, class IOCheckpointer>
|
||||
void Runner(IOCheckpointer &Checkpoint, SmearingPolicy &Smearing) {
|
||||
auto UGrid = Resources.GetCartesian();
|
||||
Resources.AddRNGs();
|
||||
Field U(UGrid);
|
||||
|
||||
int StartTraj = 0;
|
||||
if (GridCmdOptionExists(argv, argv + argc, "--StartTrajectory")) {
|
||||
arg = GridCmdOptionPayload(argv, argv + argc, "--StartTrajectory");
|
||||
std::vector<int> ivec(0);
|
||||
GridCmdOptionIntVector(arg, ivec);
|
||||
StartTraj = ivec[0];
|
||||
}
|
||||
typedef IntegratorType<SmearingPolicy> TheIntegrator;
|
||||
TheIntegrator MDynamics(UGrid, MDparameters, TheAction, Smearing);
|
||||
|
||||
int NumTraj = 1;
|
||||
if (GridCmdOptionExists(argv, argv + argc, "--Trajectories")) {
|
||||
arg = GridCmdOptionPayload(argv, argv + argc, "--Trajectories");
|
||||
std::vector<int> ivec(0);
|
||||
GridCmdOptionIntVector(arg, ivec);
|
||||
NumTraj = ivec[0];
|
||||
}
|
||||
|
||||
int NumThermalizations = 10;
|
||||
if (GridCmdOptionExists(argv, argv + argc, "--Thermalizations")) {
|
||||
arg = GridCmdOptionPayload(argv, argv + argc, "--Thermalizations");
|
||||
std::vector<int> ivec(0);
|
||||
GridCmdOptionIntVector(arg, ivec);
|
||||
NumThermalizations = ivec[0];
|
||||
}
|
||||
|
||||
GridSerialRNG sRNG;
|
||||
GridParallelRNG pRNG(UGrid);
|
||||
Field U(UGrid);
|
||||
|
||||
|
||||
typedef MinimumNorm2<Implementation, SmearingPolicy, RepresentationsPolicy> IntegratorType; // change here to change the algorithm
|
||||
IntegratorType MDynamics(UGrid, MDparameters, TheAction, Smearing);
|
||||
|
||||
HMCparameters HMCpar;
|
||||
HMCpar.StartTrajectory = StartTraj;
|
||||
HMCpar.Trajectories = NumTraj;
|
||||
HMCpar.NoMetropolisUntil = NumThermalizations;
|
||||
|
||||
if (StartType == HotStart) {
|
||||
if (Payload.StartType == HotStart) {
|
||||
// Hot start
|
||||
HMCpar.MetropolisTest = true;
|
||||
sRNG.SeedFixedIntegers(SerialSeed);
|
||||
pRNG.SeedFixedIntegers(ParallelSeed);
|
||||
Implementation::HotConfiguration(pRNG, U);
|
||||
} else if (StartType == ColdStart) {
|
||||
Payload.Parameters.MetropolisTest = true;
|
||||
Resources.SeedFixedIntegers();
|
||||
Implementation::HotConfiguration(Resources.GetParallelRNG(), U);
|
||||
} else if (Payload.StartType == ColdStart) {
|
||||
// Cold start
|
||||
HMCpar.MetropolisTest = true;
|
||||
sRNG.SeedFixedIntegers(SerialSeed);
|
||||
pRNG.SeedFixedIntegers(ParallelSeed);
|
||||
Implementation::ColdConfiguration(pRNG, U);
|
||||
} else if (StartType == TepidStart) {
|
||||
Payload.Parameters.MetropolisTest = true;
|
||||
Resources.SeedFixedIntegers();
|
||||
Implementation::ColdConfiguration(Resources.GetParallelRNG(), U);
|
||||
} else if (Payload.StartType == TepidStart) {
|
||||
// Tepid start
|
||||
HMCpar.MetropolisTest = true;
|
||||
sRNG.SeedFixedIntegers(SerialSeed);
|
||||
pRNG.SeedFixedIntegers(ParallelSeed);
|
||||
Implementation::TepidConfiguration(pRNG, U);
|
||||
} else if (StartType == CheckpointStart) {
|
||||
HMCpar.MetropolisTest = true;
|
||||
Payload.Parameters.MetropolisTest = true;
|
||||
Resources.SeedFixedIntegers();
|
||||
Implementation::TepidConfiguration(Resources.GetParallelRNG(), U);
|
||||
} else if (Payload.StartType == CheckpointStart) {
|
||||
Payload.Parameters.MetropolisTest = true;
|
||||
// CheckpointRestart
|
||||
Checkpoint.CheckpointRestore(StartTraj, U, sRNG, pRNG);
|
||||
}
|
||||
Checkpoint.CheckpointRestore(Payload.Parameters.StartTrajectory, U, Resources.GetSerialRNG(), Resources.GetParallelRNG());
|
||||
}
|
||||
|
||||
Smearing.set_Field(U);
|
||||
Smearing.set_Field(U);
|
||||
|
||||
HybridMonteCarlo<IntegratorType> HMC(HMCpar, MDynamics, sRNG, pRNG, U);
|
||||
HybridMonteCarlo<TheIntegrator> HMC(Payload.Parameters, MDynamics, Resources.GetSerialRNG(), Resources.GetParallelRNG(), U);
|
||||
|
||||
for (int obs = 0; obs < ObservablesList.size(); obs++)
|
||||
HMC.AddObservable(ObservablesList[obs]);
|
||||
for (int obs = 0; obs < ObservablesList.size(); obs++)
|
||||
HMC.AddObservable(ObservablesList[obs]);
|
||||
|
||||
// Run it
|
||||
HMC.evolve();
|
||||
}
|
||||
HMC.evolve();
|
||||
}
|
||||
};
|
||||
|
||||
// These are for gauge fields
|
||||
typedef BinaryHmcRunnerTemplate<PeriodicGimplR> BinaryHmcRunner;
|
||||
typedef BinaryHmcRunnerTemplate<PeriodicGimplF> BinaryHmcRunnerF;
|
||||
typedef BinaryHmcRunnerTemplate<PeriodicGimplD> BinaryHmcRunnerD;
|
||||
// These are for gauge fields, default integrator MinimumNorm2
|
||||
template <template <typename, typename, typename> class Integrator > using BinaryHmcRunner = BinaryHmcRunnerTemplate<PeriodicGimplR, Integrator > ;
|
||||
template <template <typename, typename, typename> class Integrator > using BinaryHmcRunnerF = BinaryHmcRunnerTemplate<PeriodicGimplF, Integrator > ;
|
||||
template <template <typename, typename, typename> class Integrator > using BinaryHmcRunnerD = BinaryHmcRunnerTemplate<PeriodicGimplD, Integrator > ;
|
||||
|
||||
template <class RepresentationsPolicy>
|
||||
using BinaryHmcRunnerTemplateHirep = BinaryHmcRunnerTemplate<PeriodicGimplR, RepresentationsPolicy>;
|
||||
template <class RepresentationsPolicy, template <typename, typename, typename> class Integrator >
|
||||
using BinaryHmcRunnerTemplateHirep = BinaryHmcRunnerTemplate<PeriodicGimplR, Integrator, RepresentationsPolicy>;
|
||||
|
||||
typedef BinaryHmcRunnerTemplate<ScalarImplR, ScalarFields>
|
||||
ScalarBinaryHmcRunner;
|
||||
typedef BinaryHmcRunnerTemplate<ScalarImplR, MinimumNorm2, ScalarFields> ScalarBinaryHmcRunner;
|
||||
|
||||
} // namespace QCD
|
||||
} // namespace Grid
|
||||
#endif
|
||||
|
||||
#endif // GRID_GENERIC_HMC_RUNNER
|
||||
|
||||
Reference in New Issue
Block a user