mirror of
https://github.com/paboyle/Grid.git
synced 2024-11-10 07:55:35 +00:00
HMC checkpointing .
Need a general HMC framework to work in restart.
This commit is contained in:
parent
5710966324
commit
31ca609d12
@ -44,9 +44,10 @@
|
||||
#include <Cshift.h>
|
||||
#include <Stencil.h>
|
||||
#include <Algorithms.h>
|
||||
#include <qcd/QCD.h>
|
||||
#include <parallelIO/BinaryIO.h>
|
||||
#include <qcd/QCD.h>
|
||||
#include <parallelIO/NerscIO.h>
|
||||
#include <qcd/hmc/NerscCheckpointer.h>
|
||||
|
||||
#include <Init.h>
|
||||
|
||||
|
@ -73,7 +73,6 @@ public:
|
||||
typedef typename vobj::scalar_type scalar_type;
|
||||
typedef typename vobj::vector_type vector_type;
|
||||
typedef vobj vector_object;
|
||||
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////
|
||||
// Expression Template closure support
|
||||
@ -213,8 +212,8 @@ PARALLEL_FOR_LOOP
|
||||
// what about a default grid?
|
||||
//////////////////////////////////////////////////////////////////
|
||||
Lattice(GridBase *grid) : _grid(grid), _odata(_grid->oSites()) {
|
||||
// _odata.reserve(_grid->oSites());
|
||||
// _odata.resize(_grid->oSites());
|
||||
// _odata.reserve(_grid->oSites());
|
||||
// _odata.resize(_grid->oSites());
|
||||
// std::cout << "Constructing lattice object with Grid pointer "<<_grid<<std::endl;
|
||||
assert((((uint64_t)&_odata[0])&0xF) ==0);
|
||||
checkerboard=0;
|
||||
|
@ -107,7 +107,7 @@ class BinaryIO {
|
||||
}
|
||||
}
|
||||
|
||||
template<class vobj,class fobj,class munger> static inline void Uint32Checksum(Lattice<vobj> lat,munger munge,uint32_t &csum)
|
||||
template<class vobj,class fobj,class munger> static inline void Uint32Checksum(Lattice<vobj> &lat,munger munge,uint32_t &csum)
|
||||
{
|
||||
typedef typename vobj::scalar_object sobj;
|
||||
GridBase *grid = lat._grid ;
|
||||
|
@ -3,19 +3,5 @@
|
||||
namespace Grid{
|
||||
namespace QCD{
|
||||
|
||||
HMCparameters::HMCparameters(){
|
||||
// FIXME fill this constructor now just default values
|
||||
|
||||
////////////////////////////// Default values
|
||||
Nsweeps = 200;
|
||||
TotalSweeps = 240;
|
||||
ThermalizationSteps = 40;
|
||||
StartingConfig = 0;
|
||||
SaveInterval = 1;
|
||||
Filename_prefix = "Conf_";
|
||||
/////////////////////////////////
|
||||
|
||||
}
|
||||
|
||||
}
|
||||
}
|
||||
|
@ -17,15 +17,28 @@ namespace Grid{
|
||||
|
||||
|
||||
struct HMCparameters{
|
||||
Integer Nsweeps; /* @brief Number of sweeps in this run */
|
||||
Integer TotalSweeps; /* @brief If provided, the total number of sweeps */
|
||||
Integer ThermalizationSteps;
|
||||
Integer StartingConfig;
|
||||
Integer SaveInterval; //Setting to 0 does not save configurations
|
||||
std::string Filename_prefix; // To save configurations and rng seed
|
||||
|
||||
HMCparameters();
|
||||
|
||||
Integer StartTrajectory;
|
||||
Integer Trajectories; /* @brief Number of sweeps in this run */
|
||||
bool MetropolisTest;
|
||||
Integer NoMetropolisUntil;
|
||||
|
||||
HMCparameters(){
|
||||
////////////////////////////// Default values
|
||||
MetropolisTest = true;
|
||||
NoMetropolisUntil = 10;
|
||||
StartTrajectory = 0;
|
||||
Trajectories = 200;
|
||||
/////////////////////////////////
|
||||
}
|
||||
};
|
||||
|
||||
template<class GaugeField>
|
||||
class HmcObservable {
|
||||
public:
|
||||
virtual void TrajectoryComplete (int traj, GaugeField &U, GridSerialRNG &sRNG, GridParallelRNG & pRNG )=0;
|
||||
};
|
||||
|
||||
|
||||
// template <class GaugeField, class Integrator, class Smearer, class Boundary>
|
||||
template <class GaugeField, class IntegratorType>
|
||||
@ -34,10 +47,12 @@ namespace Grid{
|
||||
|
||||
const HMCparameters Params;
|
||||
|
||||
GridSerialRNG &sRNG; // Fixme: need a RNG management strategy.
|
||||
GridSerialRNG &sRNG; // Fixme: need a RNG management strategy.
|
||||
GridParallelRNG &pRNG; // Fixme: need a RNG management strategy.
|
||||
GaugeField & Ucur;
|
||||
|
||||
IntegratorType &TheIntegrator;
|
||||
std::vector<HmcObservable<GaugeField> *> Observables;
|
||||
|
||||
/////////////////////////////////////////////////////////
|
||||
// Metropolis step
|
||||
@ -89,38 +104,47 @@ namespace Grid{
|
||||
/////////////////////////////////////////
|
||||
// Constructor
|
||||
/////////////////////////////////////////
|
||||
HybridMonteCarlo(HMCparameters Pms, IntegratorType &_Int, GridSerialRNG &_sRNG, GridParallelRNG &_pRNG ) :
|
||||
HybridMonteCarlo(HMCparameters Pms, IntegratorType &_Int, GridSerialRNG &_sRNG, GridParallelRNG &_pRNG, GaugeField &_U ) :
|
||||
Params(Pms),
|
||||
TheIntegrator(_Int),
|
||||
sRNG(_sRNG),
|
||||
pRNG(_pRNG)
|
||||
pRNG(_pRNG),
|
||||
Ucur(_U)
|
||||
{
|
||||
}
|
||||
~HybridMonteCarlo(){};
|
||||
|
||||
void AddObservable(HmcObservable<GaugeField> *obs) {
|
||||
Observables.push_back(obs);
|
||||
}
|
||||
|
||||
void evolve(GaugeField& Uin){
|
||||
void evolve(void){
|
||||
|
||||
Real DeltaH;
|
||||
|
||||
// Thermalizations
|
||||
for(int iter=1; iter <= Params.ThermalizationSteps; ++iter){
|
||||
std::cout<<GridLogMessage << "-- # Thermalization step = "<< iter << "\n";
|
||||
|
||||
DeltaH = evolve_step(Uin);
|
||||
std::cout<<GridLogMessage<< "dH = "<< DeltaH << "\n";
|
||||
}
|
||||
|
||||
GaugeField Ucopy(Ucur._grid);
|
||||
|
||||
// Actual updates (evolve a copy Ucopy then copy back eventually)
|
||||
GaugeField Ucopy(Uin._grid);
|
||||
for(int iter=Params.StartingConfig; iter < Params.Nsweeps+Params.StartingConfig; ++iter){
|
||||
std::cout<<GridLogMessage << "-- # Sweep = "<< iter << "\n";
|
||||
for(int traj=Params.StartTrajectory; traj < Params.Trajectories+Params.StartTrajectory; ++traj){
|
||||
|
||||
std::cout<<GridLogMessage << "-- # Trajectory = "<< traj << "\n";
|
||||
|
||||
Ucopy = Uin;
|
||||
Ucopy = Ucur;
|
||||
|
||||
DeltaH = evolve_step(Ucopy);
|
||||
|
||||
if(metropolis_test(DeltaH)) Uin = Ucopy;
|
||||
|
||||
bool accept = true;
|
||||
if ( traj > Params.NoMetropolisUntil) {
|
||||
accept = metropolis_test(DeltaH);
|
||||
}
|
||||
|
||||
if ( accept ) {
|
||||
Ucur = Ucopy;
|
||||
}
|
||||
|
||||
for(int obs = 0;obs<Observables.size();obs++){
|
||||
Observables[obs]->TrajectoryComplete (traj,Ucur,sRNG,pRNG);
|
||||
}
|
||||
|
||||
}
|
||||
}
|
||||
|
50
lib/qcd/hmc/NerscCheckpointer.h
Normal file
50
lib/qcd/hmc/NerscCheckpointer.h
Normal file
@ -0,0 +1,50 @@
|
||||
#ifndef NERSC_CHECKPOINTER
|
||||
#define NERSC_CHECKPOINTER
|
||||
|
||||
#include <string>
|
||||
#include <iostream>
|
||||
#include <sstream>
|
||||
|
||||
|
||||
namespace Grid{
|
||||
namespace QCD{
|
||||
|
||||
|
||||
template<class GaugeField>
|
||||
class NerscHmcCheckpointer : public HmcObservable<GaugeField> {
|
||||
private:
|
||||
std::string configStem;
|
||||
std::string rngStem;
|
||||
int SaveInterval;
|
||||
public:
|
||||
NerscHmcCheckpointer(std::string cf, std::string rn,int savemodulo) {
|
||||
configStem = cf;
|
||||
rngStem = rn;
|
||||
SaveInterval= savemodulo;
|
||||
};
|
||||
|
||||
void TrajectoryComplete(int traj, GaugeField &U, GridSerialRNG &sRNG, GridParallelRNG & pRNG )
|
||||
{
|
||||
if ( (traj % SaveInterval)== 0 ) {
|
||||
std::string rng; { std::ostringstream os; os << rngStem << traj; rng = os.str(); }
|
||||
std::string config;{ std::ostringstream os; os << configStem << traj; config = os.str();}
|
||||
|
||||
int precision32=1;
|
||||
int tworow =0;
|
||||
NerscIO::writeRNGState(sRNG,pRNG,rng);
|
||||
NerscIO::writeConfiguration(U,config,tworow,precision32);
|
||||
}
|
||||
};
|
||||
|
||||
void CheckpointRestore(int traj, GaugeField &U, GridSerialRNG &sRNG, GridParallelRNG & pRNG ){
|
||||
std::string rng; { std::ostringstream os; os << rngStem << "rng." << traj; rng = os.str(); }
|
||||
std::string config;{ std::ostringstream os; os << configStem << "rng." << traj; rng = os.str();}
|
||||
|
||||
NerscField header;
|
||||
NerscIO::readRNGState(sRNG,pRNG,header,rng);
|
||||
NerscIO::readConfiguration(U,header,config);
|
||||
};
|
||||
|
||||
};
|
||||
}}
|
||||
#endif
|
@ -1,6 +1,5 @@
|
||||
#include "Grid.h"
|
||||
|
||||
|
||||
using namespace std;
|
||||
using namespace Grid;
|
||||
using namespace Grid::QCD;
|
||||
@ -52,10 +51,12 @@ int main (int argc, char ** argv)
|
||||
IntegratorType MDynamics(UGrid,MDpar, FullSet);
|
||||
|
||||
// Create HMC
|
||||
NerscHmcCheckpointer<LatticeGaugeField> Checkpoint(std::string("ckpoint_lat"),std::string("ckpoint_rng"),1);
|
||||
HMCparameters HMCpar;
|
||||
HybridMonteCarlo<LatticeGaugeField,IntegratorType> HMC(HMCpar, MDynamics,sRNG,pRNG);
|
||||
HybridMonteCarlo<LatticeGaugeField,IntegratorType> HMC(HMCpar, MDynamics,sRNG,pRNG,U);
|
||||
HMC.AddObservable(&Checkpoint);
|
||||
|
||||
// Run it
|
||||
HMC.evolve(U);
|
||||
HMC.evolve();
|
||||
|
||||
}
|
||||
|
@ -61,10 +61,13 @@ int main (int argc, char ** argv)
|
||||
IntegratorType MDynamics(UGrid,MDpar, FullSet);
|
||||
|
||||
// Create HMC
|
||||
NerscHmcCheckpointer<LatticeGaugeField> Checkpoint(std::string("ckpoint_lat"),std::string("ckpoint_rng"),1);
|
||||
HMCparameters HMCpar;
|
||||
HybridMonteCarlo<LatticeGaugeField,IntegratorType> HMC(HMCpar, MDynamics,sRNG,pRNG);
|
||||
HybridMonteCarlo<LatticeGaugeField,IntegratorType> HMC(HMCpar, MDynamics,sRNG,pRNG,U);
|
||||
HMC.AddObservable(&Checkpoint);
|
||||
|
||||
// Run it
|
||||
HMC.evolve(U);
|
||||
HMC.evolve();
|
||||
|
||||
|
||||
}
|
||||
|
@ -56,10 +56,14 @@ int main (int argc, char ** argv)
|
||||
IntegratorParameters MDpar(16);
|
||||
IntegratorAlgorithm MDynamics(&Fine,MDpar, FullSet);
|
||||
|
||||
// Create HMC
|
||||
HMCparameters HMCpar;
|
||||
HybridMonteCarlo<LatticeGaugeField,IntegratorAlgorithm> HMC(HMCpar, MDynamics,sRNG,pRNG);
|
||||
|
||||
HMC.evolve(U);
|
||||
// Create HMC
|
||||
NerscHmcCheckpointer<LatticeGaugeField> Checkpoint(std::string("ckpoint_lat"),std::string("ckpoint_rng"),1);
|
||||
HMCparameters HMCpar;
|
||||
HybridMonteCarlo<LatticeGaugeField,IntegratorAlgorithm> HMC(HMCpar, MDynamics,sRNG,pRNG,U);
|
||||
HMC.AddObservable(&Checkpoint);
|
||||
|
||||
// Run it
|
||||
HMC.evolve();
|
||||
|
||||
}
|
||||
|
@ -54,10 +54,14 @@ int main (int argc, char ** argv)
|
||||
IntegratorParameters MDpar(20);
|
||||
IntegratorAlgorithm MDynamics(&Fine,MDpar, FullSet);
|
||||
|
||||
// Create HMC
|
||||
HMCparameters HMCpar;
|
||||
HybridMonteCarlo<LatticeGaugeField,IntegratorAlgorithm> HMC(HMCpar, MDynamics,sRNG,pRNG);
|
||||
|
||||
HMC.evolve(U);
|
||||
// Create HMC
|
||||
NerscHmcCheckpointer<LatticeGaugeField> Checkpoint(std::string("ckpoint_lat"),std::string("ckpoint_rng"),1);
|
||||
HMCparameters HMCpar;
|
||||
HybridMonteCarlo<LatticeGaugeField,IntegratorAlgorithm> HMC(HMCpar, MDynamics,sRNG,pRNG,U);
|
||||
HMC.AddObservable(&Checkpoint);
|
||||
|
||||
// Run it
|
||||
HMC.evolve();
|
||||
|
||||
}
|
||||
|
@ -55,10 +55,14 @@ int main (int argc, char ** argv)
|
||||
IntegratorParameters MDpar(20);
|
||||
IntegratorAlgorithm MDynamics(&Fine,MDpar, FullSet);
|
||||
|
||||
// Create HMC
|
||||
HMCparameters HMCpar;
|
||||
HybridMonteCarlo<LatticeGaugeField,IntegratorAlgorithm> HMC(HMCpar, MDynamics,sRNG,pRNG);
|
||||
|
||||
HMC.evolve(U);
|
||||
// Create HMC
|
||||
NerscHmcCheckpointer<LatticeGaugeField> Checkpoint(std::string("ckpoint_lat"),std::string("ckpoint_rng"),1);
|
||||
HMCparameters HMCpar;
|
||||
HybridMonteCarlo<LatticeGaugeField,IntegratorAlgorithm> HMC(HMCpar, MDynamics,sRNG,pRNG,U);
|
||||
HMC.AddObservable(&Checkpoint);
|
||||
|
||||
// Run it
|
||||
HMC.evolve();
|
||||
|
||||
}
|
||||
|
@ -51,9 +51,12 @@ int main (int argc, char ** argv)
|
||||
IntegratorAlgorithm MDynamics(&Fine,MDpar, FullSet);
|
||||
|
||||
// Create HMC
|
||||
NerscHmcCheckpointer<LatticeGaugeField> Checkpoint(std::string("ckpoint_lat"),std::string("ckpoint_rng"),1);
|
||||
HMCparameters HMCpar;
|
||||
HybridMonteCarlo<LatticeGaugeField,IntegratorAlgorithm> HMC(HMCpar, MDynamics, sRNG, pRNG);
|
||||
HybridMonteCarlo<LatticeGaugeField,IntegratorAlgorithm> HMC(HMCpar, MDynamics,sRNG,pRNG,U);
|
||||
HMC.AddObservable(&Checkpoint);
|
||||
|
||||
HMC.evolve(U);
|
||||
// Run it
|
||||
HMC.evolve();
|
||||
|
||||
}
|
||||
|
@ -55,9 +55,12 @@ int main (int argc, char ** argv)
|
||||
IntegratorAlgorithm MDynamics(&Fine,MDpar, FullSet);
|
||||
|
||||
// Create HMC
|
||||
NerscHmcCheckpointer<LatticeGaugeField> Checkpoint(std::string("ckpoint_lat"),std::string("ckpoint_rng"),1);
|
||||
HMCparameters HMCpar;
|
||||
HybridMonteCarlo<LatticeGaugeField,IntegratorAlgorithm> HMC(HMCpar, MDynamics, sRNG, pRNG);
|
||||
HybridMonteCarlo<LatticeGaugeField,IntegratorAlgorithm> HMC(HMCpar, MDynamics,sRNG,pRNG,U);
|
||||
HMC.AddObservable(&Checkpoint);
|
||||
|
||||
HMC.evolve(U);
|
||||
// Run it
|
||||
HMC.evolve();
|
||||
|
||||
}
|
||||
|
@ -116,8 +116,9 @@ int main (int argc, char ** argv)
|
||||
|
||||
std::string clone2x3("./ckpoint_clone2x3.4000");
|
||||
std::string clone3x3("./ckpoint_clone3x3.4000");
|
||||
int precision32 = 0;
|
||||
|
||||
int precision32 = 1;
|
||||
int tworow = 1;
|
||||
NerscIO::writeConfiguration(Umu,clone3x3,0,precision32);
|
||||
NerscIO::writeConfiguration(Umu,clone2x3,1,precision32);
|
||||
|
||||
|
@ -56,8 +56,12 @@ int main (int argc, char ** argv)
|
||||
IntegratorAlgorithm MDynamics(&Fine,MDpar, FullSet);
|
||||
|
||||
// Create HMC
|
||||
NerscHmcCheckpointer<LatticeGaugeField> Checkpoint(std::string("ckpoint_lat"),std::string("ckpoint_rng"),1);
|
||||
HMCparameters HMCpar;
|
||||
HybridMonteCarlo<LatticeGaugeField,IntegratorAlgorithm> HMC(HMCpar, MDynamics, sRNG, pRNG);
|
||||
HMC.evolve(U);
|
||||
HybridMonteCarlo<LatticeGaugeField,IntegratorAlgorithm> HMC(HMCpar, MDynamics,sRNG,pRNG,U);
|
||||
HMC.AddObservable(&Checkpoint);
|
||||
|
||||
// Run it
|
||||
HMC.evolve();
|
||||
|
||||
}
|
||||
|
@ -58,9 +58,13 @@ int main (int argc, char ** argv)
|
||||
IntegratorAlgorithm MDynamics(&Fine,MDpar, FullSet);
|
||||
|
||||
// Create HMC
|
||||
NerscHmcCheckpointer<LatticeGaugeField> Checkpoint(std::string("ckpoint_lat"),std::string("ckpoint_rng"),1);
|
||||
HMCparameters HMCpar;
|
||||
HybridMonteCarlo<LatticeGaugeField,IntegratorAlgorithm> HMC(HMCpar, MDynamics, sRNG, pRNG);
|
||||
HybridMonteCarlo<LatticeGaugeField,IntegratorAlgorithm> HMC(HMCpar, MDynamics,sRNG,pRNG,U);
|
||||
HMC.AddObservable(&Checkpoint);
|
||||
|
||||
// Run it
|
||||
HMC.evolve();
|
||||
|
||||
HMC.evolve(U);
|
||||
|
||||
}
|
||||
|
@ -55,10 +55,14 @@ int main (int argc, char ** argv)
|
||||
IntegratorParameters MDpar(20);
|
||||
IntegratorAlgorithm MDynamics(&Fine,MDpar, FullSet);
|
||||
|
||||
// Create HMC
|
||||
HMCparameters HMCpar;
|
||||
HybridMonteCarlo<LatticeGaugeField,IntegratorAlgorithm> HMC(HMCpar, MDynamics, sRNG, pRNG);
|
||||
|
||||
HMC.evolve(U);
|
||||
// Create HMC
|
||||
NerscHmcCheckpointer<LatticeGaugeField> Checkpoint(std::string("ckpoint_lat"),std::string("ckpoint_rng"),1);
|
||||
HMCparameters HMCpar;
|
||||
HybridMonteCarlo<LatticeGaugeField,IntegratorAlgorithm> HMC(HMCpar, MDynamics,sRNG,pRNG,U);
|
||||
HMC.AddObservable(&Checkpoint);
|
||||
|
||||
// Run it
|
||||
HMC.evolve();
|
||||
|
||||
}
|
||||
|
@ -58,9 +58,12 @@ int main (int argc, char ** argv)
|
||||
IntegratorAlgorithm MDynamics(&Fine,MDpar, FullSet);
|
||||
|
||||
// Create HMC
|
||||
NerscHmcCheckpointer<LatticeGaugeField> Checkpoint(std::string("ckpoint_lat"),std::string("ckpoint_rng"),1);
|
||||
HMCparameters HMCpar;
|
||||
HybridMonteCarlo<LatticeGaugeField,IntegratorAlgorithm> HMC(HMCpar, MDynamics, sRNG, pRNG);
|
||||
HybridMonteCarlo<LatticeGaugeField,IntegratorAlgorithm> HMC(HMCpar, MDynamics,sRNG,pRNG,U);
|
||||
HMC.AddObservable(&Checkpoint);
|
||||
|
||||
HMC.evolve(U);
|
||||
// Run it
|
||||
HMC.evolve();
|
||||
|
||||
}
|
||||
|
Loading…
Reference in New Issue
Block a user