1
0
mirror of https://github.com/paboyle/Grid.git synced 2025-06-12 20:27:06 +01:00

HMC checkpointing .

Need a general HMC framework to work in restart.
This commit is contained in:
paboyle
2015-12-20 02:29:51 +00:00
parent 5710966324
commit 31ca609d12
18 changed files with 175 additions and 77 deletions

View File

@ -3,19 +3,5 @@
namespace Grid{
namespace QCD{
HMCparameters::HMCparameters(){
// FIXME fill this constructor now just default values
////////////////////////////// Default values
Nsweeps = 200;
TotalSweeps = 240;
ThermalizationSteps = 40;
StartingConfig = 0;
SaveInterval = 1;
Filename_prefix = "Conf_";
/////////////////////////////////
}
}
}

View File

@ -17,15 +17,28 @@ namespace Grid{
struct HMCparameters{
Integer Nsweeps; /* @brief Number of sweeps in this run */
Integer TotalSweeps; /* @brief If provided, the total number of sweeps */
Integer ThermalizationSteps;
Integer StartingConfig;
Integer SaveInterval; //Setting to 0 does not save configurations
std::string Filename_prefix; // To save configurations and rng seed
HMCparameters();
Integer StartTrajectory;
Integer Trajectories; /* @brief Number of sweeps in this run */
bool MetropolisTest;
Integer NoMetropolisUntil;
HMCparameters(){
////////////////////////////// Default values
MetropolisTest = true;
NoMetropolisUntil = 10;
StartTrajectory = 0;
Trajectories = 200;
/////////////////////////////////
}
};
template<class GaugeField>
class HmcObservable {
public:
virtual void TrajectoryComplete (int traj, GaugeField &U, GridSerialRNG &sRNG, GridParallelRNG & pRNG )=0;
};
// template <class GaugeField, class Integrator, class Smearer, class Boundary>
template <class GaugeField, class IntegratorType>
@ -34,10 +47,12 @@ namespace Grid{
const HMCparameters Params;
GridSerialRNG &sRNG; // Fixme: need a RNG management strategy.
GridSerialRNG &sRNG; // Fixme: need a RNG management strategy.
GridParallelRNG &pRNG; // Fixme: need a RNG management strategy.
GaugeField & Ucur;
IntegratorType &TheIntegrator;
std::vector<HmcObservable<GaugeField> *> Observables;
/////////////////////////////////////////////////////////
// Metropolis step
@ -89,38 +104,47 @@ namespace Grid{
/////////////////////////////////////////
// Constructor
/////////////////////////////////////////
HybridMonteCarlo(HMCparameters Pms, IntegratorType &_Int, GridSerialRNG &_sRNG, GridParallelRNG &_pRNG ) :
HybridMonteCarlo(HMCparameters Pms, IntegratorType &_Int, GridSerialRNG &_sRNG, GridParallelRNG &_pRNG, GaugeField &_U ) :
Params(Pms),
TheIntegrator(_Int),
sRNG(_sRNG),
pRNG(_pRNG)
pRNG(_pRNG),
Ucur(_U)
{
}
~HybridMonteCarlo(){};
void AddObservable(HmcObservable<GaugeField> *obs) {
Observables.push_back(obs);
}
void evolve(GaugeField& Uin){
void evolve(void){
Real DeltaH;
// Thermalizations
for(int iter=1; iter <= Params.ThermalizationSteps; ++iter){
std::cout<<GridLogMessage << "-- # Thermalization step = "<< iter << "\n";
DeltaH = evolve_step(Uin);
std::cout<<GridLogMessage<< "dH = "<< DeltaH << "\n";
}
GaugeField Ucopy(Ucur._grid);
// Actual updates (evolve a copy Ucopy then copy back eventually)
GaugeField Ucopy(Uin._grid);
for(int iter=Params.StartingConfig; iter < Params.Nsweeps+Params.StartingConfig; ++iter){
std::cout<<GridLogMessage << "-- # Sweep = "<< iter << "\n";
for(int traj=Params.StartTrajectory; traj < Params.Trajectories+Params.StartTrajectory; ++traj){
std::cout<<GridLogMessage << "-- # Trajectory = "<< traj << "\n";
Ucopy = Uin;
Ucopy = Ucur;
DeltaH = evolve_step(Ucopy);
if(metropolis_test(DeltaH)) Uin = Ucopy;
bool accept = true;
if ( traj > Params.NoMetropolisUntil) {
accept = metropolis_test(DeltaH);
}
if ( accept ) {
Ucur = Ucopy;
}
for(int obs = 0;obs<Observables.size();obs++){
Observables[obs]->TrajectoryComplete (traj,Ucur,sRNG,pRNG);
}
}
}

View File

@ -0,0 +1,50 @@
#ifndef NERSC_CHECKPOINTER
#define NERSC_CHECKPOINTER
#include <string>
#include <iostream>
#include <sstream>
namespace Grid{
namespace QCD{
template<class GaugeField>
class NerscHmcCheckpointer : public HmcObservable<GaugeField> {
private:
std::string configStem;
std::string rngStem;
int SaveInterval;
public:
NerscHmcCheckpointer(std::string cf, std::string rn,int savemodulo) {
configStem = cf;
rngStem = rn;
SaveInterval= savemodulo;
};
void TrajectoryComplete(int traj, GaugeField &U, GridSerialRNG &sRNG, GridParallelRNG & pRNG )
{
if ( (traj % SaveInterval)== 0 ) {
std::string rng; { std::ostringstream os; os << rngStem << traj; rng = os.str(); }
std::string config;{ std::ostringstream os; os << configStem << traj; config = os.str();}
int precision32=1;
int tworow =0;
NerscIO::writeRNGState(sRNG,pRNG,rng);
NerscIO::writeConfiguration(U,config,tworow,precision32);
}
};
void CheckpointRestore(int traj, GaugeField &U, GridSerialRNG &sRNG, GridParallelRNG & pRNG ){
std::string rng; { std::ostringstream os; os << rngStem << "rng." << traj; rng = os.str(); }
std::string config;{ std::ostringstream os; os << configStem << "rng." << traj; rng = os.str();}
NerscField header;
NerscIO::readRNGState(sRNG,pRNG,header,rng);
NerscIO::readConfiguration(U,header,config);
};
};
}}
#endif