mirror of
				https://github.com/paboyle/Grid.git
				synced 2025-11-03 21:44:33 +00:00 
			
		
		
		
	HMC checkpointing .
Need a general HMC framework to work in restart.
This commit is contained in:
		@@ -44,9 +44,10 @@
 | 
			
		||||
#include <Cshift.h>       
 | 
			
		||||
#include <Stencil.h>      
 | 
			
		||||
#include <Algorithms.h>   
 | 
			
		||||
#include <qcd/QCD.h>
 | 
			
		||||
#include <parallelIO/BinaryIO.h>
 | 
			
		||||
#include <qcd/QCD.h>
 | 
			
		||||
#include <parallelIO/NerscIO.h>
 | 
			
		||||
#include <qcd/hmc/NerscCheckpointer.h>
 | 
			
		||||
 | 
			
		||||
#include <Init.h>
 | 
			
		||||
 | 
			
		||||
 
 | 
			
		||||
@@ -73,7 +73,6 @@ public:
 | 
			
		||||
    typedef typename vobj::scalar_type scalar_type;
 | 
			
		||||
    typedef typename vobj::vector_type vector_type;
 | 
			
		||||
    typedef vobj vector_object;
 | 
			
		||||
 
 | 
			
		||||
   
 | 
			
		||||
  ////////////////////////////////////////////////////////////////////////////////
 | 
			
		||||
  // Expression Template closure support
 | 
			
		||||
@@ -213,8 +212,8 @@ PARALLEL_FOR_LOOP
 | 
			
		||||
    // what about a default grid?
 | 
			
		||||
    //////////////////////////////////////////////////////////////////
 | 
			
		||||
    Lattice(GridBase *grid) : _grid(grid), _odata(_grid->oSites()) {
 | 
			
		||||
      //        _odata.reserve(_grid->oSites());
 | 
			
		||||
      //        _odata.resize(_grid->oSites());
 | 
			
		||||
    //        _odata.reserve(_grid->oSites());
 | 
			
		||||
    //        _odata.resize(_grid->oSites());
 | 
			
		||||
    //      std::cout << "Constructing lattice object with Grid pointer "<<_grid<<std::endl;
 | 
			
		||||
        assert((((uint64_t)&_odata[0])&0xF) ==0);
 | 
			
		||||
        checkerboard=0;
 | 
			
		||||
 
 | 
			
		||||
@@ -107,7 +107,7 @@ class BinaryIO {
 | 
			
		||||
    }
 | 
			
		||||
  }
 | 
			
		||||
 | 
			
		||||
  template<class vobj,class fobj,class munger> static inline void Uint32Checksum(Lattice<vobj> lat,munger munge,uint32_t &csum)
 | 
			
		||||
  template<class vobj,class fobj,class munger> static inline void Uint32Checksum(Lattice<vobj> &lat,munger munge,uint32_t &csum)
 | 
			
		||||
  {
 | 
			
		||||
    typedef typename vobj::scalar_object sobj;
 | 
			
		||||
    GridBase *grid = lat._grid ;
 | 
			
		||||
 
 | 
			
		||||
@@ -3,19 +3,5 @@
 | 
			
		||||
namespace Grid{
 | 
			
		||||
  namespace QCD{
 | 
			
		||||
 | 
			
		||||
    HMCparameters::HMCparameters(){
 | 
			
		||||
	// FIXME fill this constructor  now just default values
 | 
			
		||||
	  
 | 
			
		||||
	////////////////////////////// Default values
 | 
			
		||||
	Nsweeps             = 200;
 | 
			
		||||
	TotalSweeps         = 240;
 | 
			
		||||
	ThermalizationSteps = 40;
 | 
			
		||||
	StartingConfig      = 0;
 | 
			
		||||
	SaveInterval        = 1;
 | 
			
		||||
	Filename_prefix     = "Conf_";
 | 
			
		||||
	/////////////////////////////////
 | 
			
		||||
	  
 | 
			
		||||
      }
 | 
			
		||||
 | 
			
		||||
  }
 | 
			
		||||
}
 | 
			
		||||
 
 | 
			
		||||
@@ -17,15 +17,28 @@ namespace Grid{
 | 
			
		||||
    
 | 
			
		||||
 | 
			
		||||
    struct HMCparameters{
 | 
			
		||||
      Integer Nsweeps; /* @brief Number of sweeps in this run */
 | 
			
		||||
      Integer TotalSweeps; /* @brief If provided, the total number of sweeps */
 | 
			
		||||
      Integer ThermalizationSteps;
 | 
			
		||||
      Integer StartingConfig;
 | 
			
		||||
      Integer SaveInterval; //Setting to 0 does not save configurations
 | 
			
		||||
      std::string Filename_prefix; // To save configurations and rng seed
 | 
			
		||||
      
 | 
			
		||||
      HMCparameters();
 | 
			
		||||
 | 
			
		||||
      Integer StartTrajectory;
 | 
			
		||||
      Integer Trajectories; /* @brief Number of sweeps in this run */
 | 
			
		||||
      bool    MetropolisTest;
 | 
			
		||||
      Integer NoMetropolisUntil;
 | 
			
		||||
 | 
			
		||||
      HMCparameters(){
 | 
			
		||||
	////////////////////////////// Default values
 | 
			
		||||
	MetropolisTest      = true;
 | 
			
		||||
	NoMetropolisUntil   = 10;
 | 
			
		||||
	StartTrajectory     = 0;
 | 
			
		||||
	Trajectories        = 200;
 | 
			
		||||
	/////////////////////////////////
 | 
			
		||||
      }
 | 
			
		||||
    };
 | 
			
		||||
 | 
			
		||||
    template<class GaugeField> 
 | 
			
		||||
    class HmcObservable {
 | 
			
		||||
    public:
 | 
			
		||||
      virtual void TrajectoryComplete (int traj, GaugeField &U, GridSerialRNG &sRNG, GridParallelRNG & pRNG )=0;
 | 
			
		||||
    };
 | 
			
		||||
 | 
			
		||||
    
 | 
			
		||||
    //    template <class GaugeField, class Integrator, class Smearer, class Boundary> 
 | 
			
		||||
    template <class GaugeField, class IntegratorType>
 | 
			
		||||
@@ -34,10 +47,12 @@ namespace Grid{
 | 
			
		||||
 | 
			
		||||
      const HMCparameters Params;
 | 
			
		||||
      
 | 
			
		||||
      GridSerialRNG   &sRNG;   // Fixme: need a RNG management strategy.
 | 
			
		||||
      GridSerialRNG   &sRNG; // Fixme: need a RNG management strategy.
 | 
			
		||||
      GridParallelRNG &pRNG; // Fixme: need a RNG management strategy.
 | 
			
		||||
      GaugeField      & Ucur;
 | 
			
		||||
 | 
			
		||||
      IntegratorType &TheIntegrator;
 | 
			
		||||
      std::vector<HmcObservable<GaugeField> *> Observables;
 | 
			
		||||
 | 
			
		||||
      /////////////////////////////////////////////////////////
 | 
			
		||||
      // Metropolis step
 | 
			
		||||
@@ -89,38 +104,47 @@ namespace Grid{
 | 
			
		||||
      /////////////////////////////////////////
 | 
			
		||||
      // Constructor
 | 
			
		||||
      /////////////////////////////////////////
 | 
			
		||||
     HybridMonteCarlo(HMCparameters Pms,  IntegratorType &_Int, GridSerialRNG &_sRNG, GridParallelRNG &_pRNG ) :
 | 
			
		||||
      HybridMonteCarlo(HMCparameters Pms,  IntegratorType &_Int, GridSerialRNG &_sRNG, GridParallelRNG &_pRNG, GaugeField &_U ) :
 | 
			
		||||
        Params(Pms), 
 | 
			
		||||
	TheIntegrator(_Int), 
 | 
			
		||||
	sRNG(_sRNG),
 | 
			
		||||
	pRNG(_pRNG)
 | 
			
		||||
	pRNG(_pRNG),
 | 
			
		||||
	Ucur(_U)
 | 
			
		||||
      {
 | 
			
		||||
      }
 | 
			
		||||
      ~HybridMonteCarlo(){};
 | 
			
		||||
 | 
			
		||||
      void AddObservable(HmcObservable<GaugeField> *obs) {
 | 
			
		||||
	Observables.push_back(obs);
 | 
			
		||||
      }
 | 
			
		||||
 | 
			
		||||
      void evolve(GaugeField& Uin){
 | 
			
		||||
      void evolve(void){
 | 
			
		||||
 | 
			
		||||
	Real DeltaH;
 | 
			
		||||
	
 | 
			
		||||
	// Thermalizations
 | 
			
		||||
	for(int iter=1; iter <= Params.ThermalizationSteps; ++iter){
 | 
			
		||||
	  std::cout<<GridLogMessage << "-- # Thermalization step = "<< iter <<  "\n";
 | 
			
		||||
	
 | 
			
		||||
	  DeltaH = evolve_step(Uin);
 | 
			
		||||
	  std::cout<<GridLogMessage<< "dH = "<< DeltaH << "\n";
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	GaugeField Ucopy(Ucur._grid);
 | 
			
		||||
	
 | 
			
		||||
	// Actual updates (evolve a copy Ucopy then copy back eventually)
 | 
			
		||||
	GaugeField Ucopy(Uin._grid);
 | 
			
		||||
	for(int iter=Params.StartingConfig; iter < Params.Nsweeps+Params.StartingConfig; ++iter){
 | 
			
		||||
	  std::cout<<GridLogMessage << "-- # Sweep = "<< iter <<  "\n";
 | 
			
		||||
	for(int traj=Params.StartTrajectory; traj < Params.Trajectories+Params.StartTrajectory; ++traj){
 | 
			
		||||
 | 
			
		||||
	  std::cout<<GridLogMessage << "-- # Trajectory = "<< traj <<  "\n";
 | 
			
		||||
	  
 | 
			
		||||
	  Ucopy = Uin;
 | 
			
		||||
	  Ucopy = Ucur;
 | 
			
		||||
 | 
			
		||||
	  DeltaH = evolve_step(Ucopy);
 | 
			
		||||
		
 | 
			
		||||
	  if(metropolis_test(DeltaH)) Uin = Ucopy;
 | 
			
		||||
 | 
			
		||||
	  bool accept = true;
 | 
			
		||||
	  if ( traj > Params.NoMetropolisUntil) { 
 | 
			
		||||
	    accept = metropolis_test(DeltaH);
 | 
			
		||||
	  }
 | 
			
		||||
	  
 | 
			
		||||
	  if ( accept ) {
 | 
			
		||||
	    Ucur = Ucopy;
 | 
			
		||||
	  }
 | 
			
		||||
 | 
			
		||||
	  for(int obs = 0;obs<Observables.size();obs++){
 | 
			
		||||
	    Observables[obs]->TrajectoryComplete (traj,Ucur,sRNG,pRNG);
 | 
			
		||||
	  }
 | 
			
		||||
 | 
			
		||||
	}
 | 
			
		||||
      }
 | 
			
		||||
 
 | 
			
		||||
							
								
								
									
										50
									
								
								lib/qcd/hmc/NerscCheckpointer.h
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										50
									
								
								lib/qcd/hmc/NerscCheckpointer.h
									
									
									
									
									
										Normal file
									
								
							@@ -0,0 +1,50 @@
 | 
			
		||||
#ifndef NERSC_CHECKPOINTER
 | 
			
		||||
#define NERSC_CHECKPOINTER
 | 
			
		||||
 | 
			
		||||
#include <string>
 | 
			
		||||
#include <iostream>
 | 
			
		||||
#include <sstream>
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
namespace Grid{
 | 
			
		||||
  namespace QCD{
 | 
			
		||||
    
 | 
			
		||||
 | 
			
		||||
    template<class GaugeField> 
 | 
			
		||||
    class NerscHmcCheckpointer : public HmcObservable<GaugeField> {
 | 
			
		||||
    private:
 | 
			
		||||
      std::string configStem;
 | 
			
		||||
      std::string rngStem;
 | 
			
		||||
      int SaveInterval;
 | 
			
		||||
    public:
 | 
			
		||||
      NerscHmcCheckpointer(std::string cf, std::string rn,int savemodulo) {
 | 
			
		||||
        configStem  = cf;
 | 
			
		||||
        rngStem     = rn;
 | 
			
		||||
        SaveInterval= savemodulo;
 | 
			
		||||
      };
 | 
			
		||||
 | 
			
		||||
      void TrajectoryComplete(int traj, GaugeField &U, GridSerialRNG &sRNG, GridParallelRNG & pRNG )
 | 
			
		||||
      {
 | 
			
		||||
	if ( (traj % SaveInterval)== 0 ) {
 | 
			
		||||
	  std::string rng;   { std::ostringstream os; os << rngStem     << traj; rng = os.str(); }
 | 
			
		||||
	  std::string config;{ std::ostringstream os; os << configStem  << traj; config = os.str();}
 | 
			
		||||
 | 
			
		||||
	  int precision32=1;
 | 
			
		||||
	  int tworow     =0;
 | 
			
		||||
	  NerscIO::writeRNGState(sRNG,pRNG,rng);
 | 
			
		||||
	  NerscIO::writeConfiguration(U,config,tworow,precision32);
 | 
			
		||||
	}
 | 
			
		||||
      };
 | 
			
		||||
 | 
			
		||||
      void CheckpointRestore(int traj, GaugeField &U, GridSerialRNG &sRNG, GridParallelRNG & pRNG ){
 | 
			
		||||
	std::string rng;   { std::ostringstream os; os << rngStem << "rng." << traj; rng = os.str(); }
 | 
			
		||||
	std::string config;{ std::ostringstream os; os << configStem << "rng." << traj; rng = os.str();}
 | 
			
		||||
 | 
			
		||||
	NerscField header;
 | 
			
		||||
	NerscIO::readRNGState(sRNG,pRNG,header,rng);
 | 
			
		||||
	NerscIO::readConfiguration(U,header,config);
 | 
			
		||||
      };
 | 
			
		||||
 | 
			
		||||
    };
 | 
			
		||||
}}
 | 
			
		||||
#endif
 | 
			
		||||
		Reference in New Issue
	
	Block a user