mirror of
				https://github.com/paboyle/Grid.git
				synced 2025-11-04 05:54:32 +00:00 
			
		
		
		
	Added module for checkpointers
This commit is contained in:
		@@ -84,9 +84,10 @@ Author: paboyle <paboyle@ph.ed.ac.uk>
 | 
			
		||||
#include <Grid/parallelIO/BinaryIO.h>
 | 
			
		||||
#include <Grid/parallelIO/IldgIO.h>
 | 
			
		||||
#include <Grid/parallelIO/NerscIO.h>
 | 
			
		||||
#include <Grid/qcd/hmc/NerscCheckpointer.h>
 | 
			
		||||
#include <Grid/qcd/hmc/BinaryCheckpointer.h>
 | 
			
		||||
#include <Grid/qcd/hmc/ILDGCheckpointer.h>
 | 
			
		||||
#include <Grid/qcd/hmc/checkpointers/BaseCheckpointer.h>
 | 
			
		||||
#include <Grid/qcd/hmc/checkpointers/NerscCheckpointer.h>
 | 
			
		||||
#include <Grid/qcd/hmc/checkpointers/BinaryCheckpointer.h>
 | 
			
		||||
#include <Grid/qcd/hmc/checkpointers/ILDGCheckpointer.h>
 | 
			
		||||
#include <Grid/qcd/hmc/HMCModules.h>
 | 
			
		||||
#include <Grid/qcd/hmc/HMCResourceManager.h>
 | 
			
		||||
#include <Grid/qcd/hmc/HmcRunner.h>
 | 
			
		||||
 
 | 
			
		||||
@@ -30,165 +30,163 @@ with this program; if not, write to the Free Software Foundation, Inc.,
 | 
			
		||||
#ifndef GRID_GENERIC_HMC_RUNNER
 | 
			
		||||
#define GRID_GENERIC_HMC_RUNNER
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
#include <unordered_map>
 | 
			
		||||
 | 
			
		||||
namespace Grid {
 | 
			
		||||
namespace QCD {
 | 
			
		||||
 | 
			
		||||
template <class Implementation, 
 | 
			
		||||
          template < typename, typename, typename > class Integrator, 
 | 
			
		||||
          class RepresentationsPolicy = NoHirep >
 | 
			
		||||
class BinaryHmcRunnerTemplate {
 | 
			
		||||
public:
 | 
			
		||||
 INHERIT_FIELD_TYPES(Implementation);
 | 
			
		||||
	    typedef Implementation ImplPolicy; // visible from outside
 | 
			
		||||
      template < typename S = NoSmearing<Implementation> > 
 | 
			
		||||
      using IntegratorType = Integrator<Implementation,S,RepresentationsPolicy>;
 | 
			
		||||
template <class Implementation,
 | 
			
		||||
          template <typename, typename, typename> class Integrator,
 | 
			
		||||
          class RepresentationsPolicy = NoHirep>
 | 
			
		||||
class HMCWrapperTemplate {
 | 
			
		||||
 public:
 | 
			
		||||
  INHERIT_FIELD_TYPES(Implementation);
 | 
			
		||||
  typedef Implementation ImplPolicy;  // visible from outside
 | 
			
		||||
  template <typename S = NoSmearing<Implementation> >
 | 
			
		||||
  using IntegratorType = Integrator<Implementation, S, RepresentationsPolicy>;
 | 
			
		||||
 | 
			
		||||
      enum StartType_t 
 | 
			
		||||
      { 
 | 
			
		||||
        ColdStart,
 | 
			
		||||
        HotStart,
 | 
			
		||||
        TepidStart,
 | 
			
		||||
        CheckpointStart,
 | 
			
		||||
        FilenameStart 
 | 
			
		||||
      };
 | 
			
		||||
  enum StartType_t {
 | 
			
		||||
    ColdStart,
 | 
			
		||||
    HotStart,
 | 
			
		||||
    TepidStart,
 | 
			
		||||
    CheckpointStart,
 | 
			
		||||
    FilenameStart
 | 
			
		||||
  };
 | 
			
		||||
 | 
			
		||||
      struct HMCPayload
 | 
			
		||||
      {
 | 
			
		||||
        StartType_t StartType;
 | 
			
		||||
        HMCparameters Parameters;
 | 
			
		||||
  struct HMCPayload {
 | 
			
		||||
    StartType_t StartType;
 | 
			
		||||
    HMCparameters Parameters;
 | 
			
		||||
 | 
			
		||||
        HMCPayload() { StartType = HotStart; }
 | 
			
		||||
      };
 | 
			
		||||
    HMCPayload() { StartType = HotStart; }
 | 
			
		||||
  };
 | 
			
		||||
 | 
			
		||||
     // These can be rationalised, some private
 | 
			
		||||
	  HMCPayload Payload; // Parameters 
 | 
			
		||||
	  HMCResourceManager Resources;
 | 
			
		||||
	  IntegratorParameters MDparameters;
 | 
			
		||||
  // These can be rationalised, some private
 | 
			
		||||
  HMCPayload Payload;  // Parameters
 | 
			
		||||
  HMCResourceManager<Implementation> Resources;
 | 
			
		||||
  IntegratorParameters MDparameters;
 | 
			
		||||
 | 
			
		||||
	  ActionSet<Field, RepresentationsPolicy> TheAction;
 | 
			
		||||
  ActionSet<Field, RepresentationsPolicy> TheAction;
 | 
			
		||||
 | 
			
		||||
    // A vector of HmcObservable that can be injected from outside
 | 
			
		||||
	  std::vector<HmcObservable<typename Implementation::Field> *> ObservablesList;
 | 
			
		||||
  // A vector of HmcObservable that can be injected from outside
 | 
			
		||||
  std::vector<HmcObservable<typename Implementation::Field> *> ObservablesList;
 | 
			
		||||
 | 
			
		||||
	  //GridCartesian *        UGrid;
 | 
			
		||||
  void ReadCommandLine(int argc, char **argv) {
 | 
			
		||||
    std::string arg;
 | 
			
		||||
 | 
			
		||||
    // These two are unnecessary, eliminate
 | 
			
		||||
	  // GridRedBlackCartesian *UrbGrid;
 | 
			
		||||
	  // GridCartesian *        FGrid;
 | 
			
		||||
	  // GridRedBlackCartesian *FrbGrid;
 | 
			
		||||
    if (GridCmdOptionExists(argv, argv + argc, "--StartType")) {
 | 
			
		||||
      arg = GridCmdOptionPayload(argv, argv + argc, "--StartType");
 | 
			
		||||
      if (arg == "HotStart") {
 | 
			
		||||
        Payload.StartType = HotStart;
 | 
			
		||||
      } else if (arg == "ColdStart") {
 | 
			
		||||
        Payload.StartType = ColdStart;
 | 
			
		||||
      } else if (arg == "TepidStart") {
 | 
			
		||||
        Payload.StartType = TepidStart;
 | 
			
		||||
      } else if (arg == "CheckpointStart") {
 | 
			
		||||
        Payload.StartType = CheckpointStart;
 | 
			
		||||
      } else {
 | 
			
		||||
        std::cout << GridLogError << "Unrecognized option in --StartType\n";
 | 
			
		||||
        std::cout
 | 
			
		||||
            << GridLogError
 | 
			
		||||
            << "Valid [HotStart, ColdStart, TepidStart, CheckpointStart]\n";
 | 
			
		||||
        assert(0);
 | 
			
		||||
      }
 | 
			
		||||
    }
 | 
			
		||||
 | 
			
		||||
	  void ReadCommandLine(int argc, char ** argv) {
 | 
			
		||||
	  	std::string arg;
 | 
			
		||||
    if (GridCmdOptionExists(argv, argv + argc, "--StartTrajectory")) {
 | 
			
		||||
      arg = GridCmdOptionPayload(argv, argv + argc, "--StartTrajectory");
 | 
			
		||||
      std::vector<int> ivec(0);
 | 
			
		||||
      GridCmdOptionIntVector(arg, ivec);
 | 
			
		||||
      Payload.Parameters.StartTrajectory = ivec[0];
 | 
			
		||||
    }
 | 
			
		||||
 | 
			
		||||
	  	if (GridCmdOptionExists(argv, argv + argc, "--StartType")) {
 | 
			
		||||
	  		arg = GridCmdOptionPayload(argv, argv + argc, "--StartType");
 | 
			
		||||
	  		if (arg == "HotStart") {
 | 
			
		||||
	  			Payload.StartType = HotStart;
 | 
			
		||||
	  		} else if (arg == "ColdStart") {
 | 
			
		||||
	  			Payload.StartType = ColdStart;
 | 
			
		||||
	  		} else if (arg == "TepidStart") {
 | 
			
		||||
	  			Payload.StartType = TepidStart;
 | 
			
		||||
	  		} else if (arg == "CheckpointStart") {
 | 
			
		||||
	  			Payload.StartType = CheckpointStart;
 | 
			
		||||
	  		} else {
 | 
			
		||||
	  			std::cout << GridLogError << "Unrecognized option in --StartType\n";
 | 
			
		||||
	  			std::cout
 | 
			
		||||
	  			<< GridLogError
 | 
			
		||||
	  			<< "Valid [HotStart, ColdStart, TepidStart, CheckpointStart]\n";
 | 
			
		||||
	  			assert(0);
 | 
			
		||||
	  		}
 | 
			
		||||
	  	}
 | 
			
		||||
    if (GridCmdOptionExists(argv, argv + argc, "--Trajectories")) {
 | 
			
		||||
      arg = GridCmdOptionPayload(argv, argv + argc, "--Trajectories");
 | 
			
		||||
      std::vector<int> ivec(0);
 | 
			
		||||
      GridCmdOptionIntVector(arg, ivec);
 | 
			
		||||
      Payload.Parameters.Trajectories = ivec[0];
 | 
			
		||||
    }
 | 
			
		||||
 | 
			
		||||
	  	if (GridCmdOptionExists(argv, argv + argc, "--StartTrajectory")) {
 | 
			
		||||
	  		arg = GridCmdOptionPayload(argv, argv + argc, "--StartTrajectory");
 | 
			
		||||
	  		std::vector<int> ivec(0);
 | 
			
		||||
	  		GridCmdOptionIntVector(arg, ivec);
 | 
			
		||||
	  		Payload.Parameters.StartTrajectory = ivec[0];
 | 
			
		||||
	  	}
 | 
			
		||||
 | 
			
		||||
	  	if (GridCmdOptionExists(argv, argv + argc, "--Trajectories")) {
 | 
			
		||||
	  		arg = GridCmdOptionPayload(argv, argv + argc, "--Trajectories");
 | 
			
		||||
	  		std::vector<int> ivec(0);
 | 
			
		||||
	  		GridCmdOptionIntVector(arg, ivec);
 | 
			
		||||
	  		Payload.Parameters.Trajectories = ivec[0];
 | 
			
		||||
	  	}
 | 
			
		||||
 | 
			
		||||
	  	if (GridCmdOptionExists(argv, argv + argc, "--Thermalizations")) {
 | 
			
		||||
	  		arg = GridCmdOptionPayload(argv, argv + argc, "--Thermalizations");
 | 
			
		||||
	  		std::vector<int> ivec(0);
 | 
			
		||||
	  		GridCmdOptionIntVector(arg, ivec);
 | 
			
		||||
	  		Payload.Parameters.NoMetropolisUntil = ivec[0];
 | 
			
		||||
	  	}
 | 
			
		||||
 | 
			
		||||
	  }
 | 
			
		||||
 | 
			
		||||
  // A couple of wrapper functions
 | 
			
		||||
  template <class IOCheckpointer> void Run(IOCheckpointer &CP)  {
 | 
			
		||||
    NoSmearing<Implementation> S;
 | 
			
		||||
    Runner(CP, S);
 | 
			
		||||
    if (GridCmdOptionExists(argv, argv + argc, "--Thermalizations")) {
 | 
			
		||||
      arg = GridCmdOptionPayload(argv, argv + argc, "--Thermalizations");
 | 
			
		||||
      std::vector<int> ivec(0);
 | 
			
		||||
      GridCmdOptionIntVector(arg, ivec);
 | 
			
		||||
      Payload.Parameters.NoMetropolisUntil = ivec[0];
 | 
			
		||||
    }
 | 
			
		||||
  }
 | 
			
		||||
 | 
			
		||||
  template <class IOCheckpointer, class SmearingPolicy> void Run(IOCheckpointer &CP, SmearingPolicy &S) {
 | 
			
		||||
    Runner(CP, S);
 | 
			
		||||
 | 
			
		||||
  template <class SmearingPolicy>
 | 
			
		||||
  void Run(SmearingPolicy &S) {
 | 
			
		||||
    Runner(S);
 | 
			
		||||
  }
 | 
			
		||||
 | 
			
		||||
  void Run(){
 | 
			
		||||
    NoSmearing<Implementation> S;
 | 
			
		||||
    Runner(S);
 | 
			
		||||
  }
 | 
			
		||||
 | 
			
		||||
  //////////////////////////////////////////////////////////////////
 | 
			
		||||
 | 
			
		||||
private:
 | 
			
		||||
  template <class SmearingPolicy, class IOCheckpointer>
 | 
			
		||||
    void Runner(IOCheckpointer &Checkpoint,	SmearingPolicy &Smearing) {
 | 
			
		||||
  auto UGrid = Resources.GetCartesian();
 | 
			
		||||
  Resources.AddRNGs();
 | 
			
		||||
  Field U(UGrid);
 | 
			
		||||
 private:
 | 
			
		||||
  template <class SmearingPolicy>
 | 
			
		||||
  void Runner(SmearingPolicy &Smearing) {
 | 
			
		||||
    auto UGrid = Resources.GetCartesian();
 | 
			
		||||
    Resources.AddRNGs();
 | 
			
		||||
    Field U(UGrid);
 | 
			
		||||
 | 
			
		||||
  typedef IntegratorType<SmearingPolicy> TheIntegrator;
 | 
			
		||||
  TheIntegrator MDynamics(UGrid, MDparameters, TheAction, Smearing);
 | 
			
		||||
    typedef IntegratorType<SmearingPolicy> TheIntegrator;
 | 
			
		||||
    TheIntegrator MDynamics(UGrid, MDparameters, TheAction, Smearing);
 | 
			
		||||
 | 
			
		||||
  if (Payload.StartType == HotStart) {
 | 
			
		||||
        // Hot start
 | 
			
		||||
   Payload.Parameters.MetropolisTest = true;
 | 
			
		||||
   Resources.SeedFixedIntegers();
 | 
			
		||||
   Implementation::HotConfiguration(Resources.GetParallelRNG(), U);
 | 
			
		||||
 } else if (Payload.StartType == ColdStart) {
 | 
			
		||||
        // Cold start
 | 
			
		||||
   Payload.Parameters.MetropolisTest = true;
 | 
			
		||||
   Resources.SeedFixedIntegers();
 | 
			
		||||
   Implementation::ColdConfiguration(Resources.GetParallelRNG(), U);
 | 
			
		||||
 } else if (Payload.StartType == TepidStart) {
 | 
			
		||||
        // Tepid start
 | 
			
		||||
   Payload.Parameters.MetropolisTest = true;
 | 
			
		||||
   Resources.SeedFixedIntegers();
 | 
			
		||||
   Implementation::TepidConfiguration(Resources.GetParallelRNG(), U);
 | 
			
		||||
 } else if (Payload.StartType == CheckpointStart) {
 | 
			
		||||
   Payload.Parameters.MetropolisTest = true;
 | 
			
		||||
        // CheckpointRestart
 | 
			
		||||
   Checkpoint.CheckpointRestore(Payload.Parameters.StartTrajectory, U, Resources.GetSerialRNG(), Resources.GetParallelRNG());
 | 
			
		||||
 }
 | 
			
		||||
    if (Payload.StartType == HotStart) {
 | 
			
		||||
      // Hot start
 | 
			
		||||
      Resources.SeedFixedIntegers();
 | 
			
		||||
      Implementation::HotConfiguration(Resources.GetParallelRNG(), U);
 | 
			
		||||
    } else if (Payload.StartType == ColdStart) {
 | 
			
		||||
      // Cold start
 | 
			
		||||
      Resources.SeedFixedIntegers();
 | 
			
		||||
      Implementation::ColdConfiguration(Resources.GetParallelRNG(), U);
 | 
			
		||||
    } else if (Payload.StartType == TepidStart) {
 | 
			
		||||
      // Tepid start
 | 
			
		||||
      Resources.SeedFixedIntegers();
 | 
			
		||||
      Implementation::TepidConfiguration(Resources.GetParallelRNG(), U);
 | 
			
		||||
    } else if (Payload.StartType == CheckpointStart) {
 | 
			
		||||
      // CheckpointRestart
 | 
			
		||||
      Resources.get_CheckPointer()->CheckpointRestore(Payload.Parameters.StartTrajectory, U,
 | 
			
		||||
                                   Resources.GetSerialRNG(),
 | 
			
		||||
                                   Resources.GetParallelRNG());
 | 
			
		||||
    }
 | 
			
		||||
 | 
			
		||||
 Smearing.set_Field(U);
 | 
			
		||||
    Smearing.set_Field(U);
 | 
			
		||||
 | 
			
		||||
 HybridMonteCarlo<TheIntegrator> HMC(Payload.Parameters, MDynamics, Resources.GetSerialRNG(), Resources.GetParallelRNG(), U);
 | 
			
		||||
    HybridMonteCarlo<TheIntegrator> HMC(Payload.Parameters, MDynamics,
 | 
			
		||||
                                        Resources.GetSerialRNG(),
 | 
			
		||||
                                        Resources.GetParallelRNG(), U);
 | 
			
		||||
 | 
			
		||||
 for (int obs = 0; obs < ObservablesList.size(); obs++)
 | 
			
		||||
   HMC.AddObservable(ObservablesList[obs]);
 | 
			
		||||
    for (int obs = 0; obs < ObservablesList.size(); obs++)
 | 
			
		||||
      HMC.AddObservable(ObservablesList[obs]);
 | 
			
		||||
    HMC.AddObservable(Resources.get_CheckPointer());
 | 
			
		||||
 | 
			
		||||
      // Run it
 | 
			
		||||
 HMC.evolve();
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
    // Run it
 | 
			
		||||
    HMC.evolve();
 | 
			
		||||
  }
 | 
			
		||||
};
 | 
			
		||||
 | 
			
		||||
// These are for gauge fields, default integrator MinimumNorm2
 | 
			
		||||
template <template <typename, typename, typename> class Integrator >  using BinaryHmcRunner = BinaryHmcRunnerTemplate<PeriodicGimplR, Integrator > ;
 | 
			
		||||
template <template <typename, typename, typename> class Integrator >  using BinaryHmcRunnerF = BinaryHmcRunnerTemplate<PeriodicGimplF, Integrator > ;
 | 
			
		||||
template <template <typename, typename, typename> class Integrator >  using BinaryHmcRunnerD = BinaryHmcRunnerTemplate<PeriodicGimplD, Integrator > ;
 | 
			
		||||
template <template <typename, typename, typename> class Integrator>
 | 
			
		||||
using GenericHMCRunner = HMCWrapperTemplate<PeriodicGimplR, Integrator>;
 | 
			
		||||
template <template <typename, typename, typename> class Integrator>
 | 
			
		||||
using GenericHMCRunnerF = HMCWrapperTemplate<PeriodicGimplF, Integrator>;
 | 
			
		||||
template <template <typename, typename, typename> class Integrator>
 | 
			
		||||
using GenericHMCRunnerD = HMCWrapperTemplate<PeriodicGimplD, Integrator>;
 | 
			
		||||
 | 
			
		||||
template <class RepresentationsPolicy, template <typename, typename, typename> class Integrator >
 | 
			
		||||
using BinaryHmcRunnerTemplateHirep = BinaryHmcRunnerTemplate<PeriodicGimplR, Integrator, RepresentationsPolicy>;
 | 
			
		||||
template <class RepresentationsPolicy,
 | 
			
		||||
          template <typename, typename, typename> class Integrator>
 | 
			
		||||
using GenericHMCRunnerHirep =
 | 
			
		||||
    HMCWrapperTemplate<PeriodicGimplR, Integrator, RepresentationsPolicy>;
 | 
			
		||||
 | 
			
		||||
typedef BinaryHmcRunnerTemplate<ScalarImplR, MinimumNorm2, ScalarFields> ScalarBinaryHmcRunner;
 | 
			
		||||
typedef HMCWrapperTemplate<ScalarImplR, MinimumNorm2, ScalarFields>
 | 
			
		||||
    ScalarGenericHMCRunner;
 | 
			
		||||
 | 
			
		||||
}  // namespace QCD
 | 
			
		||||
}  // namespace Grid
 | 
			
		||||
 
 | 
			
		||||
@@ -52,17 +52,17 @@ struct HMCparameters {
 | 
			
		||||
 | 
			
		||||
  HMCparameters() {
 | 
			
		||||
    ////////////////////////////// Default values
 | 
			
		||||
    MetropolisTest = true;
 | 
			
		||||
    MetropolisTest    = true;
 | 
			
		||||
    NoMetropolisUntil = 10;
 | 
			
		||||
    StartTrajectory = 0;
 | 
			
		||||
    Trajectories = 10;
 | 
			
		||||
    StartTrajectory   = 0;
 | 
			
		||||
    Trajectories      = 10;
 | 
			
		||||
    /////////////////////////////////
 | 
			
		||||
  }
 | 
			
		||||
 | 
			
		||||
  void print_parameters() const {
 | 
			
		||||
    std::cout << GridLogMessage << "[HMC parameters] Trajectories            : " << Trajectories << "\n";
 | 
			
		||||
    std::cout << GridLogMessage << "[HMC parameters] Start trajectory        : " << StartTrajectory << "\n";
 | 
			
		||||
    std::cout << GridLogMessage << "[HMC parameters] Metropolis test (on/off): " << MetropolisTest << "\n";
 | 
			
		||||
    std::cout << GridLogMessage << "[HMC parameters] Metropolis test (on/off): " << std::boolalpha << MetropolisTest << "\n";
 | 
			
		||||
    std::cout << GridLogMessage << "[HMC parameters] Thermalization trajs    : " << NoMetropolisUntil << "\n";
 | 
			
		||||
  }
 | 
			
		||||
  
 | 
			
		||||
@@ -209,24 +209,39 @@ class HybridMonteCarlo {
 | 
			
		||||
    TheIntegrator.print_actions();
 | 
			
		||||
 | 
			
		||||
    // Actual updates (evolve a copy Ucopy then copy back eventually)
 | 
			
		||||
    for (int traj = Params.StartTrajectory; traj < Params.Trajectories + Params.StartTrajectory; ++traj) {
 | 
			
		||||
    unsigned int FinalTrajectory = Params.Trajectories + Params.NoMetropolisUntil + Params.StartTrajectory;
 | 
			
		||||
    for (int traj = Params.StartTrajectory; traj < FinalTrajectory; ++traj) {
 | 
			
		||||
      std::cout << GridLogMessage << "-- # Trajectory = " << traj << "\n";
 | 
			
		||||
      if (traj < Params.StartTrajectory + Params.NoMetropolisUntil) {
 | 
			
		||||
      	std::cout << GridLogMessage << "-- Thermalization" << std::endl;
 | 
			
		||||
    	}
 | 
			
		||||
 | 
			
		||||
    	double t0=usecond();
 | 
			
		||||
      Ucopy = Ucur;
 | 
			
		||||
 | 
			
		||||
      DeltaH = evolve_step(Ucopy);
 | 
			
		||||
 | 
			
		||||
      bool accept = true;
 | 
			
		||||
      if (traj >= Params.NoMetropolisUntil) {
 | 
			
		||||
      if (traj >= Params.StartTrajectory + Params.NoMetropolisUntil) {
 | 
			
		||||
        accept = metropolis_test(DeltaH);
 | 
			
		||||
      } else {
 | 
			
		||||
      	std::cout << GridLogMessage << "Skipping Metropolis test" << std::endl;
 | 
			
		||||
      }
 | 
			
		||||
 | 
			
		||||
      if (accept) {
 | 
			
		||||
        Ucur = Ucopy;
 | 
			
		||||
      }
 | 
			
		||||
      double t1=usecond();
 | 
			
		||||
      std::cout << GridLogMessage << "Total time for trajectory (s): " << (t1-t0)/1e6 << std::endl;
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
      for (int obs = 0; obs < Observables.size(); obs++) {
 | 
			
		||||
      	std::cout << GridLogDebug << "Observables # " << obs << std::endl;
 | 
			
		||||
      	std::cout << GridLogDebug << "Observables total " << Observables.size() << std::endl;
 | 
			
		||||
      	std::cout << GridLogDebug << "Observables pointer " << Observables[obs] << std::endl;
 | 
			
		||||
        Observables[obs]->TrajectoryComplete(traj + 1, Ucur, sRNG, pRNG);
 | 
			
		||||
      }
 | 
			
		||||
      std::cout << GridLogMessage << ":::::::::::::::::::::::::::::::::::::::::::" << std::endl;
 | 
			
		||||
    }
 | 
			
		||||
  }
 | 
			
		||||
};
 | 
			
		||||
 
 | 
			
		||||
@@ -85,8 +85,37 @@ public:
 | 
			
		||||
  }
 | 
			
		||||
};
 | 
			
		||||
 | 
			
		||||
/// Smearing module
 | 
			
		||||
template <class ImplementationPolicy>
 | 
			
		||||
class SmearingModule{
 | 
			
		||||
   virtual void get_smearing();
 | 
			
		||||
};
 | 
			
		||||
 | 
			
		||||
template <class ImplementationPolicy>
 | 
			
		||||
class StoutSmearingModule: public SmearingModule<ImplementationPolicy>{
 | 
			
		||||
   SmearedConfiguration<ImplementationPolicy> SmearingPolicy;
 | 
			
		||||
};
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
// Checkpoint module, owns the Checkpointer
 | 
			
		||||
template <class ImplementationPolicy>
 | 
			
		||||
class CheckPointModule{
 | 
			
		||||
   std::unique_ptr< BaseHmcCheckpointer<ImplementationPolicy> > cp_;
 | 
			
		||||
 | 
			
		||||
public:
 | 
			
		||||
   void set_Checkpointer(BaseHmcCheckpointer<ImplementationPolicy> *cp){
 | 
			
		||||
      cp_.reset(cp);
 | 
			
		||||
   };
 | 
			
		||||
   BaseHmcCheckpointer<ImplementationPolicy>* get_CheckPointer(){
 | 
			
		||||
      std::cout << "Checkpointer Pointer requested : " << cp_.get() << std::endl;
 | 
			
		||||
      return cp_.get();
 | 
			
		||||
   }
 | 
			
		||||
 | 
			
		||||
   void initialize(CheckpointerParameters& P){
 | 
			
		||||
      cp_.initialize(P);
 | 
			
		||||
   }
 | 
			
		||||
 | 
			
		||||
};
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
 
 | 
			
		||||
@@ -5,8 +5,8 @@ Grid physics library, www.github.com/paboyle/Grid
 | 
			
		||||
Source file: ./lib/qcd/hmc/GenericHmcRunner.h
 | 
			
		||||
 | 
			
		||||
Copyright (C) 2015
 | 
			
		||||
Copyright (C) 2016
 | 
			
		||||
 | 
			
		||||
Author: paboyle <paboyle@ph.ed.ac.uk>
 | 
			
		||||
Author: Guido Cossu <guido.cossu@ed.ac.uk>
 | 
			
		||||
 | 
			
		||||
This program is free software; you can redistribute it and/or modify
 | 
			
		||||
@@ -30,28 +30,51 @@ with this program; if not, write to the Free Software Foundation, Inc.,
 | 
			
		||||
#ifndef HMC_RESOURCE_MANAGER_H
 | 
			
		||||
#define HMC_RESOURCE_MANAGER_H
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
#include <unordered_map>
 | 
			
		||||
 | 
			
		||||
// One function per Checkpointer, use a macro to simplify
 | 
			
		||||
  #define RegisterLoadCheckPointerFunction(NAME)                                       \
 | 
			
		||||
  void Load##NAME##Checkpointer(CheckpointerParameters& Params_) {   \
 | 
			
		||||
    if (!have_CheckPointer) {                                        \
 | 
			
		||||
      std::cout << GridLogDebug << "Loading Checkpointer " << #NAME  \
 | 
			
		||||
                << std::endl;                                        \
 | 
			
		||||
      CP.set_Checkpointer(                                           \
 | 
			
		||||
          new NAME##HmcCheckpointer<ImplementationPolicy>(Params_)); \
 | 
			
		||||
      have_CheckPointer = true;                                      \
 | 
			
		||||
    } else {                                                         \
 | 
			
		||||
      std::cout << GridLogError << "Checkpointer already loaded "    \
 | 
			
		||||
                << std::endl;                                        \
 | 
			
		||||
      exit(1);                                                       \
 | 
			
		||||
    }                                                                \
 | 
			
		||||
  }
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
namespace Grid {
 | 
			
		||||
namespace QCD {
 | 
			
		||||
 | 
			
		||||
  // HMC Resource manager
 | 
			
		||||
  class HMCResourceManager {
 | 
			
		||||
	// Storage for grid pairs (std + red-black)
 | 
			
		||||
   std::unordered_map<std::string, GridModule> Grids;
 | 
			
		||||
   RNGModule RNGs; 
 | 
			
		||||
// HMC Resource manager
 | 
			
		||||
  template <class ImplementationPolicy>
 | 
			
		||||
class HMCResourceManager{
 | 
			
		||||
  // Storage for grid pairs (std + red-black)
 | 
			
		||||
  std::unordered_map<std::string, GridModule> Grids;
 | 
			
		||||
  RNGModule RNGs;
 | 
			
		||||
 | 
			
		||||
   bool have_RNG;
 | 
			
		||||
  //SmearingModule<ImplementationPolicy> Smearing;
 | 
			
		||||
  CheckPointModule<ImplementationPolicy> CP;
 | 
			
		||||
 | 
			
		||||
 public:	
 | 
			
		||||
   HMCResourceManager():have_RNG(false){}
 | 
			
		||||
   void AddGrid(std::string s, GridModule& M){
 | 
			
		||||
  bool have_RNG;
 | 
			
		||||
  bool have_CheckPointer;
 | 
			
		||||
 | 
			
		||||
 public:
 | 
			
		||||
  HMCResourceManager() : have_RNG(false), have_CheckPointer(false) {}
 | 
			
		||||
  void AddGrid(std::string s, GridModule& M) {
 | 
			
		||||
    // Check for name clashes
 | 
			
		||||
    auto search = Grids.find(s);
 | 
			
		||||
    if(search != Grids.end()) {
 | 
			
		||||
      std::cout << GridLogError << "Grid with name \"" << search->first << "\" already present. Terminating\n" ;
 | 
			
		||||
      exit(1);        
 | 
			
		||||
    if (search != Grids.end()) {
 | 
			
		||||
      std::cout << GridLogError << "Grid with name \"" << search->first
 | 
			
		||||
                << "\" already present. Terminating\n";
 | 
			
		||||
      exit(1);
 | 
			
		||||
    }
 | 
			
		||||
    Grids[s] = std::move(M);
 | 
			
		||||
  }
 | 
			
		||||
@@ -59,52 +82,69 @@ namespace QCD {
 | 
			
		||||
  // Add a named grid set
 | 
			
		||||
  void AddFourDimGrid(std::string s) {
 | 
			
		||||
    GridFourDimModule Mod;
 | 
			
		||||
    AddGrid(s,Mod);
 | 
			
		||||
    AddGrid(s, Mod);
 | 
			
		||||
  }
 | 
			
		||||
 | 
			
		||||
  GridCartesian* GetCartesian(std::string s="") {
 | 
			
		||||
  GridCartesian* GetCartesian(std::string s = "") {
 | 
			
		||||
    if (s.empty()) s = Grids.begin()->first;
 | 
			
		||||
    std::cout << GridLogDebug << "Getting cartesian grid from: "<< s << std::endl;
 | 
			
		||||
    std::cout << GridLogDebug << "Getting cartesian grid from: " << s
 | 
			
		||||
              << std::endl;
 | 
			
		||||
    return Grids[s].get_full();
 | 
			
		||||
  }
 | 
			
		||||
 | 
			
		||||
  GridRedBlackCartesian* GetRBCartesian(std::string s="") {
 | 
			
		||||
  GridRedBlackCartesian* GetRBCartesian(std::string s = "") {
 | 
			
		||||
    if (s.empty()) s = Grids.begin()->first;
 | 
			
		||||
    std::cout << GridLogDebug << "Getting rb-cartesian grid from: "<< s << std::endl;
 | 
			
		||||
    std::cout << GridLogDebug << "Getting rb-cartesian grid from: " << s
 | 
			
		||||
              << std::endl;
 | 
			
		||||
    return Grids[s].get_rb();
 | 
			
		||||
  }
 | 
			
		||||
 | 
			
		||||
  void AddRNGs(std::string s="") {
 | 
			
		||||
		// Couple the RNGs to the GridModule tagged by s
 | 
			
		||||
    // default is the first grid set
 | 
			
		||||
    assert(Grids.size()>0 && !have_RNG );
 | 
			
		||||
  void AddRNGs(std::string s = "") {
 | 
			
		||||
    // Couple the RNGs to the GridModule tagged by s
 | 
			
		||||
    // the default is the first grid registered
 | 
			
		||||
    assert(Grids.size() > 0 && !have_RNG);
 | 
			
		||||
    if (s.empty()) s = Grids.begin()->first;
 | 
			
		||||
    std::cout << GridLogDebug << "Adding RNG to grid: "<< s << std::endl;
 | 
			
		||||
    std::cout << GridLogDebug << "Adding RNG to grid: " << s << std::endl;
 | 
			
		||||
    RNGs.set_pRNG(new GridParallelRNG(GetCartesian(s)));
 | 
			
		||||
    //pRNG.reset(new GridParallelRNG(GetCartesian(s)));
 | 
			
		||||
    have_RNG = true;
 | 
			
		||||
  }
 | 
			
		||||
 | 
			
		||||
  void AddRNGSeeds(const std::vector<int> S, const std::vector<int> P) {
 | 
			
		||||
    RNGs.set_RNGSeeds(S,P);
 | 
			
		||||
    //SerialSeed   = S;
 | 
			
		||||
    //ParallelSeed = P;
 | 
			
		||||
    RNGs.set_RNGSeeds(S, P);
 | 
			
		||||
  }
 | 
			
		||||
 | 
			
		||||
  GridSerialRNG& GetSerialRNG() {return RNGs.get_sRNG();}
 | 
			
		||||
  GridSerialRNG& GetSerialRNG() { return RNGs.get_sRNG(); }
 | 
			
		||||
  GridParallelRNG& GetParallelRNG() {
 | 
			
		||||
    assert(have_RNG);
 | 
			
		||||
    return RNGs.get_pRNG();
 | 
			
		||||
  }
 | 
			
		||||
  
 | 
			
		||||
  void SeedFixedIntegers() {
 | 
			
		||||
    assert(have_RNG);
 | 
			
		||||
    RNGs.seed();
 | 
			
		||||
    //sRNG.SeedFixedIntegers(SerialSeed);
 | 
			
		||||
    //pRNG->SeedFixedIntegers(ParallelSeed);
 | 
			
		||||
  }
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
  //////////////////////////////////////////////////////
 | 
			
		||||
  // Checkpointers
 | 
			
		||||
  //////////////////////////////////////////////////////
 | 
			
		||||
 | 
			
		||||
  BaseHmcCheckpointer<ImplementationPolicy>* get_CheckPointer(){
 | 
			
		||||
    if (have_CheckPointer)
 | 
			
		||||
    return CP.get_CheckPointer();
 | 
			
		||||
    else{
 | 
			
		||||
      std::cout << GridLogError << "Error: no checkpointer defined" << std::endl;
 | 
			
		||||
      exit(1);
 | 
			
		||||
    }
 | 
			
		||||
  }
 | 
			
		||||
 | 
			
		||||
  RegisterLoadCheckPointerFunction (Binary);
 | 
			
		||||
  RegisterLoadCheckPointerFunction (Nersc);
 | 
			
		||||
  RegisterLoadCheckPointerFunction (ILDG)
 | 
			
		||||
 | 
			
		||||
};
 | 
			
		||||
 | 
			
		||||
}
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
#endif // HMC_RESOURCE_MANAGER_H
 | 
			
		||||
#endif  // HMC_RESOURCE_MANAGER_H
 | 
			
		||||
@@ -1,81 +0,0 @@
 | 
			
		||||
    /*************************************************************************************
 | 
			
		||||
 | 
			
		||||
    Grid physics library, www.github.com/paboyle/Grid 
 | 
			
		||||
 | 
			
		||||
    Source file: ./lib/qcd/hmc/NerscCheckpointer.h
 | 
			
		||||
 | 
			
		||||
    Copyright (C) 2015
 | 
			
		||||
 | 
			
		||||
Author: paboyle <paboyle@ph.ed.ac.uk>
 | 
			
		||||
 | 
			
		||||
    This program is free software; you can redistribute it and/or modify
 | 
			
		||||
    it under the terms of the GNU General Public License as published by
 | 
			
		||||
    the Free Software Foundation; either version 2 of the License, or
 | 
			
		||||
    (at your option) any later version.
 | 
			
		||||
 | 
			
		||||
    This program is distributed in the hope that it will be useful,
 | 
			
		||||
    but WITHOUT ANY WARRANTY; without even the implied warranty of
 | 
			
		||||
    MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
 | 
			
		||||
    GNU General Public License for more details.
 | 
			
		||||
 | 
			
		||||
    You should have received a copy of the GNU General Public License along
 | 
			
		||||
    with this program; if not, write to the Free Software Foundation, Inc.,
 | 
			
		||||
    51 Franklin Street, Fifth Floor, Boston, MA 02110-1301 USA.
 | 
			
		||||
 | 
			
		||||
    See the full license in the file "LICENSE" in the top level distribution directory
 | 
			
		||||
    *************************************************************************************/
 | 
			
		||||
    /*  END LEGAL */
 | 
			
		||||
#ifndef NERSC_CHECKPOINTER
 | 
			
		||||
#define NERSC_CHECKPOINTER
 | 
			
		||||
 | 
			
		||||
#include <string>
 | 
			
		||||
#include <iostream>
 | 
			
		||||
#include <sstream>
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
namespace Grid{
 | 
			
		||||
  namespace QCD{
 | 
			
		||||
    
 | 
			
		||||
    // Only for Gauge fields
 | 
			
		||||
    template<class Gimpl> 
 | 
			
		||||
    class NerscHmcCheckpointer : public HmcObservable<typename Gimpl::GaugeField> {
 | 
			
		||||
    private:
 | 
			
		||||
      std::string configStem;
 | 
			
		||||
      std::string rngStem;
 | 
			
		||||
      int SaveInterval;
 | 
			
		||||
    public:
 | 
			
		||||
      INHERIT_GIMPL_TYPES(Gimpl);// 
 | 
			
		||||
 | 
			
		||||
      NerscHmcCheckpointer(std::string cf, std::string rn,int savemodulo, std::string format = "") {
 | 
			
		||||
        configStem  = cf;
 | 
			
		||||
        rngStem     = rn;
 | 
			
		||||
        SaveInterval= savemodulo;
 | 
			
		||||
        // format is fixed to IEEE64BIG for NERSC
 | 
			
		||||
      };
 | 
			
		||||
 | 
			
		||||
      void TrajectoryComplete(int traj, GaugeField &U, GridSerialRNG &sRNG, GridParallelRNG & pRNG )
 | 
			
		||||
      {
 | 
			
		||||
	if ( (traj % SaveInterval)== 0 ) {
 | 
			
		||||
	  std::string rng;   { std::ostringstream os; os << rngStem     <<"."<< traj; rng = os.str(); }
 | 
			
		||||
	  std::string config;{ std::ostringstream os; os << configStem  <<"."<< traj; config = os.str();}
 | 
			
		||||
 | 
			
		||||
	  int precision32=1;
 | 
			
		||||
	  int tworow     =0;
 | 
			
		||||
	  NerscIO::writeRNGState(sRNG,pRNG,rng);
 | 
			
		||||
	  NerscIO::writeConfiguration(U,config,tworow,precision32);
 | 
			
		||||
	}
 | 
			
		||||
      };
 | 
			
		||||
 | 
			
		||||
      void CheckpointRestore(int traj, GaugeField &U, GridSerialRNG &sRNG, GridParallelRNG & pRNG ){
 | 
			
		||||
 | 
			
		||||
	std::string rng;   { std::ostringstream os; os << rngStem     <<"."<< traj; rng = os.str(); }
 | 
			
		||||
	std::string config;{ std::ostringstream os; os << configStem  <<"."<< traj; config = os.str();}
 | 
			
		||||
 | 
			
		||||
	NerscField header;
 | 
			
		||||
	NerscIO::readRNGState(sRNG,pRNG,header,rng);
 | 
			
		||||
	NerscIO::readConfiguration(U,header,config);
 | 
			
		||||
      };
 | 
			
		||||
 | 
			
		||||
    };
 | 
			
		||||
}}
 | 
			
		||||
#endif
 | 
			
		||||
							
								
								
									
										61
									
								
								lib/qcd/hmc/checkpointers/BaseCheckpointer.h
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										61
									
								
								lib/qcd/hmc/checkpointers/BaseCheckpointer.h
									
									
									
									
									
										Normal file
									
								
							@@ -0,0 +1,61 @@
 | 
			
		||||
/*************************************************************************************
 | 
			
		||||
 | 
			
		||||
Grid physics library, www.github.com/paboyle/Grid
 | 
			
		||||
 | 
			
		||||
Source file: ./lib/qcd/hmc/BaseCheckpointer.h
 | 
			
		||||
 | 
			
		||||
Copyright (C) 2015
 | 
			
		||||
 | 
			
		||||
Author: Guido Cossu <guido.cossu@ed.ac.uk>
 | 
			
		||||
 | 
			
		||||
This program is free software; you can redistribute it and/or modify
 | 
			
		||||
it under the terms of the GNU General Public License as published by
 | 
			
		||||
the Free Software Foundation; either version 2 of the License, or
 | 
			
		||||
(at your option) any later version.
 | 
			
		||||
 | 
			
		||||
This program is distributed in the hope that it will be useful,
 | 
			
		||||
but WITHOUT ANY WARRANTY; without even the implied warranty of
 | 
			
		||||
MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
 | 
			
		||||
GNU General Public License for more details.
 | 
			
		||||
 | 
			
		||||
You should have received a copy of the GNU General Public License along
 | 
			
		||||
with this program; if not, write to the Free Software Foundation, Inc.,
 | 
			
		||||
51 Franklin Street, Fifth Floor, Boston, MA 02110-1301 USA.
 | 
			
		||||
 | 
			
		||||
See the full license in the file "LICENSE" in the top level distribution
 | 
			
		||||
directory
 | 
			
		||||
*************************************************************************************/
 | 
			
		||||
/*  END LEGAL */
 | 
			
		||||
#ifndef BASE_CHECKPOINTER
 | 
			
		||||
#define BASE_CHECKPOINTER
 | 
			
		||||
 | 
			
		||||
namespace Grid {
 | 
			
		||||
namespace QCD {
 | 
			
		||||
 | 
			
		||||
class CheckpointerParameters : Serializable {
 | 
			
		||||
 public:
 | 
			
		||||
  GRID_SERIALIZABLE_CLASS_MEMBERS(CheckpointerParameters, std::string,
 | 
			
		||||
                                  configStem, std::string, rngStem, int,
 | 
			
		||||
                                  SaveInterval, std::string, format, );
 | 
			
		||||
 | 
			
		||||
  CheckpointerParameters(std::string cf = "cfg", std::string rn = "rng",
 | 
			
		||||
                         int savemodulo = 1, const std::string &f = "IEEE64BIG")
 | 
			
		||||
      : configStem(cf), rngStem(rn), SaveInterval(savemodulo), format(f){};
 | 
			
		||||
};
 | 
			
		||||
 | 
			
		||||
//////////////////////////////////////////////////////////////////////////////
 | 
			
		||||
// Base class for checkpointers
 | 
			
		||||
template <class Impl>
 | 
			
		||||
class BaseHmcCheckpointer : public HmcObservable<typename Impl::Field> {
 | 
			
		||||
 public:
 | 
			
		||||
  virtual void initialize(CheckpointerParameters &Params) = 0;
 | 
			
		||||
 | 
			
		||||
  virtual void CheckpointRestore(int traj, typename Impl::Field &U,
 | 
			
		||||
                                 GridSerialRNG &sRNG,
 | 
			
		||||
                                 GridParallelRNG &pRNG) = 0;
 | 
			
		||||
 | 
			
		||||
};  // class BaseHmcCheckpointer
 | 
			
		||||
///////////////////////////////////////////////////////////////////////////////
 | 
			
		||||
}
 | 
			
		||||
}
 | 
			
		||||
#endif
 | 
			
		||||
@@ -36,14 +36,12 @@ directory
 | 
			
		||||
namespace Grid {
 | 
			
		||||
namespace QCD {
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
// Simple checkpointer, only binary file
 | 
			
		||||
template <class Impl>
 | 
			
		||||
class BinaryHmcCheckpointer : public HmcObservable<typename Impl::Field> {
 | 
			
		||||
class BinaryHmcCheckpointer : public BaseHmcCheckpointer<Impl> {
 | 
			
		||||
 private:
 | 
			
		||||
  std::string configStem;
 | 
			
		||||
  std::string rngStem;
 | 
			
		||||
  int SaveInterval;
 | 
			
		||||
  std::string format;
 | 
			
		||||
  CheckpointerParameters Params;
 | 
			
		||||
 | 
			
		||||
 public:
 | 
			
		||||
  INHERIT_FIELD_TYPES(Impl);  // Gets the Field type, a Lattice object
 | 
			
		||||
@@ -54,9 +52,11 @@ class BinaryHmcCheckpointer : public HmcObservable<typename Impl::Field> {
 | 
			
		||||
  typedef typename getPrecision<sobj>::real_scalar_type sobj_stype;
 | 
			
		||||
  typedef typename sobj::DoublePrecision sobj_double;
 | 
			
		||||
 | 
			
		||||
  BinaryHmcCheckpointer(std::string cf, std::string rn, int savemodulo,
 | 
			
		||||
                        const std::string &f)
 | 
			
		||||
      : configStem(cf), rngStem(rn), SaveInterval(savemodulo), format(f){};
 | 
			
		||||
  BinaryHmcCheckpointer(CheckpointerParameters& Params_){
 | 
			
		||||
  	initialize(Params_);
 | 
			
		||||
  }
 | 
			
		||||
 | 
			
		||||
  void initialize(CheckpointerParameters& Params_){ Params = Params_; }
 | 
			
		||||
 | 
			
		||||
  void truncate(std::string file) {
 | 
			
		||||
    std::ofstream fout(file, std::ios::out);
 | 
			
		||||
@@ -65,17 +65,17 @@ class BinaryHmcCheckpointer : public HmcObservable<typename Impl::Field> {
 | 
			
		||||
 | 
			
		||||
  void TrajectoryComplete(int traj, Field &U, GridSerialRNG &sRNG,
 | 
			
		||||
                          GridParallelRNG &pRNG) {
 | 
			
		||||
    if ((traj % SaveInterval) == 0) {
 | 
			
		||||
    if ((traj % Params.SaveInterval) == 0) {
 | 
			
		||||
      std::string rng;
 | 
			
		||||
      {
 | 
			
		||||
        std::ostringstream os;
 | 
			
		||||
        os << rngStem << "." << traj;
 | 
			
		||||
        os << Params.rngStem << "." << traj;
 | 
			
		||||
        rng = os.str();
 | 
			
		||||
      }
 | 
			
		||||
      std::string config;
 | 
			
		||||
      {
 | 
			
		||||
        std::ostringstream os;
 | 
			
		||||
        os << configStem << "." << traj;
 | 
			
		||||
        os << Params.configStem << "." << traj;
 | 
			
		||||
        config = os.str();
 | 
			
		||||
      }
 | 
			
		||||
 | 
			
		||||
@@ -84,7 +84,7 @@ class BinaryHmcCheckpointer : public HmcObservable<typename Impl::Field> {
 | 
			
		||||
      BinaryIO::writeRNGSerial(sRNG, pRNG, rng, 0);
 | 
			
		||||
      truncate(config);
 | 
			
		||||
      uint32_t csum = BinaryIO::writeObjectParallel<vobj, sobj_double>(
 | 
			
		||||
          U, config, munge, 0, format);
 | 
			
		||||
          U, config, munge, 0, Params.format);
 | 
			
		||||
 | 
			
		||||
      std::cout << GridLogMessage << "Written Binary Configuration " << config
 | 
			
		||||
                << " checksum " << std::hex << csum << std::dec << std::endl;
 | 
			
		||||
@@ -96,20 +96,20 @@ class BinaryHmcCheckpointer : public HmcObservable<typename Impl::Field> {
 | 
			
		||||
    std::string rng;
 | 
			
		||||
    {
 | 
			
		||||
      std::ostringstream os;
 | 
			
		||||
      os << rngStem << "." << traj;
 | 
			
		||||
      os << Params.rngStem << "." << traj;
 | 
			
		||||
      rng = os.str();
 | 
			
		||||
    }
 | 
			
		||||
    std::string config;
 | 
			
		||||
    {
 | 
			
		||||
      std::ostringstream os;
 | 
			
		||||
      os << configStem << "." << traj;
 | 
			
		||||
      os << Params.configStem << "." << traj;
 | 
			
		||||
      config = os.str();
 | 
			
		||||
    }
 | 
			
		||||
 | 
			
		||||
    BinaryIO::BinarySimpleMunger<sobj_double, sobj> munge;
 | 
			
		||||
    BinaryIO::readRNGSerial(sRNG, pRNG, rng, 0);
 | 
			
		||||
    uint32_t csum = BinaryIO::readObjectParallel<vobj, sobj_double>(
 | 
			
		||||
        U, config, munge, 0, format);
 | 
			
		||||
        U, config, munge, 0, Params.format);
 | 
			
		||||
 | 
			
		||||
    std::cout << GridLogMessage << "Read Binary Configuration " << config
 | 
			
		||||
              << " checksum " << std::hex << csum << std::dec << std::endl;
 | 
			
		||||
@@ -4,9 +4,9 @@ Grid physics library, www.github.com/paboyle/Grid
 | 
			
		||||
 | 
			
		||||
Source file: ./lib/qcd/hmc/ILDGCheckpointer.h
 | 
			
		||||
 | 
			
		||||
Copyright (C) 2015
 | 
			
		||||
Copyright (C) 2016
 | 
			
		||||
 | 
			
		||||
Author: Guido Cossu
 | 
			
		||||
Author: Guido Cossu <guido.cossu@ed.ac.uk>
 | 
			
		||||
 | 
			
		||||
This program is free software; you can redistribute it and/or modify
 | 
			
		||||
it under the terms of the GNU General Public License as published by
 | 
			
		||||
@@ -41,59 +41,60 @@ namespace QCD {
 | 
			
		||||
// Only for Gauge fields
 | 
			
		||||
template <class Implementation>
 | 
			
		||||
class ILDGHmcCheckpointer
 | 
			
		||||
    : public HmcObservable<typename Implementation::GaugeField> {
 | 
			
		||||
    : public BaseHmcCheckpointer<Implementation> {
 | 
			
		||||
 private:
 | 
			
		||||
 	CheckpointerParameters Params;
 | 
			
		||||
/*
 | 
			
		||||
  std::string configStem;
 | 
			
		||||
  std::string rngStem;
 | 
			
		||||
  int SaveInterval;
 | 
			
		||||
  std::string format;
 | 
			
		||||
*/
 | 
			
		||||
 | 
			
		||||
 public:
 | 
			
		||||
  INHERIT_GIMPL_TYPES(Implementation);  
 | 
			
		||||
  INHERIT_GIMPL_TYPES(Implementation);
 | 
			
		||||
 | 
			
		||||
  ILDGHmcCheckpointer(std::string cf, std::string rn, int savemodulo,
 | 
			
		||||
                       std::string form = "IEEE64BIG") {
 | 
			
		||||
    configStem = cf;
 | 
			
		||||
    rngStem = rn;
 | 
			
		||||
    SaveInterval = savemodulo;
 | 
			
		||||
    format = form;
 | 
			
		||||
  ILDGHmcCheckpointer(CheckpointerParameters &Params_) { initialize(Params_); }
 | 
			
		||||
 | 
			
		||||
  void initialize(CheckpointerParameters &Params_) {
 | 
			
		||||
    Params = Params_;
 | 
			
		||||
 | 
			
		||||
    // check here that the format is valid
 | 
			
		||||
    int ieee32big = (format == std::string("IEEE32BIG"));
 | 
			
		||||
    int ieee32    = (format == std::string("IEEE32"));
 | 
			
		||||
    int ieee64big = (format == std::string("IEEE64BIG"));
 | 
			
		||||
    int ieee64    = (format == std::string("IEEE64"));
 | 
			
		||||
    int ieee32big = (Params.format == std::string("IEEE32BIG"));
 | 
			
		||||
    int ieee32    = (Params.format == std::string("IEEE32"));
 | 
			
		||||
    int ieee64big = (Params.format == std::string("IEEE64BIG"));
 | 
			
		||||
    int ieee64    = (Params.format == std::string("IEEE64"));
 | 
			
		||||
 | 
			
		||||
    if (!(ieee64big || ieee32 || ieee32big || ieee64)) {
 | 
			
		||||
      std::cout << GridLogError << "Unrecognized file format " << format
 | 
			
		||||
      std::cout << GridLogError << "Unrecognized file format " << Params.format
 | 
			
		||||
                << std::endl;
 | 
			
		||||
      std::cout << GridLogError
 | 
			
		||||
                << "Allowed: IEEE32BIG | IEEE32 | IEEE64BIG | IEEE64"
 | 
			
		||||
                << std::endl;
 | 
			
		||||
 | 
			
		||||
      exit(0);
 | 
			
		||||
      exit(1);
 | 
			
		||||
    }
 | 
			
		||||
  };
 | 
			
		||||
  }
 | 
			
		||||
 | 
			
		||||
  void TrajectoryComplete(int traj, GaugeField &U, GridSerialRNG &sRNG,
 | 
			
		||||
                          GridParallelRNG &pRNG) {
 | 
			
		||||
    if ((traj % SaveInterval) == 0) {
 | 
			
		||||
    if ((traj % Params.SaveInterval) == 0) {
 | 
			
		||||
      std::string rng;
 | 
			
		||||
      {
 | 
			
		||||
        std::ostringstream os;
 | 
			
		||||
        os << rngStem << "." << traj;
 | 
			
		||||
        os << Params.rngStem << "." << traj;
 | 
			
		||||
        rng = os.str();
 | 
			
		||||
      }
 | 
			
		||||
      std::string config;
 | 
			
		||||
      {
 | 
			
		||||
        std::ostringstream os;
 | 
			
		||||
        os << configStem << "." << traj;
 | 
			
		||||
        os << Params.configStem << "." << traj;
 | 
			
		||||
        config = os.str();
 | 
			
		||||
      }
 | 
			
		||||
 | 
			
		||||
      ILDGIO IO(config, ILDGwrite);
 | 
			
		||||
      BinaryIO::writeRNGSerial(sRNG, pRNG, rng, 0);
 | 
			
		||||
      uint32_t csum  = IO.writeConfiguration(U, format);
 | 
			
		||||
      uint32_t csum  = IO.writeConfiguration(U, Params.format);
 | 
			
		||||
 | 
			
		||||
      std::cout << GridLogMessage << "Written ILDG Configuration on " << config
 | 
			
		||||
                << " checksum " << std::hex << csum << std::dec << std::endl;
 | 
			
		||||
@@ -105,13 +106,13 @@ class ILDGHmcCheckpointer
 | 
			
		||||
    std::string rng;
 | 
			
		||||
    {
 | 
			
		||||
      std::ostringstream os;
 | 
			
		||||
      os << rngStem << "." << traj;
 | 
			
		||||
      os << Params.rngStem << "." << traj;
 | 
			
		||||
      rng = os.str();
 | 
			
		||||
    }
 | 
			
		||||
    std::string config;
 | 
			
		||||
    {
 | 
			
		||||
      std::ostringstream os;
 | 
			
		||||
      os << configStem << "." << traj;
 | 
			
		||||
      os << Params.configStem << "." << traj;
 | 
			
		||||
      config = os.str();
 | 
			
		||||
    }
 | 
			
		||||
 | 
			
		||||
@@ -126,5 +127,5 @@ class ILDGHmcCheckpointer
 | 
			
		||||
}
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
#endif
 | 
			
		||||
#endif
 | 
			
		||||
#endif // HAVE_LIME
 | 
			
		||||
#endif // ILDG_CHECKPOINTER
 | 
			
		||||
							
								
								
									
										102
									
								
								lib/qcd/hmc/checkpointers/NerscCheckpointer.h
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										102
									
								
								lib/qcd/hmc/checkpointers/NerscCheckpointer.h
									
									
									
									
									
										Normal file
									
								
							@@ -0,0 +1,102 @@
 | 
			
		||||
    /*************************************************************************************
 | 
			
		||||
 | 
			
		||||
    Grid physics library, www.github.com/paboyle/Grid 
 | 
			
		||||
 | 
			
		||||
    Source file: ./lib/qcd/hmc/NerscCheckpointer.h
 | 
			
		||||
 | 
			
		||||
    Copyright (C) 2015
 | 
			
		||||
 | 
			
		||||
Author: paboyle <paboyle@ph.ed.ac.uk>
 | 
			
		||||
 | 
			
		||||
    This program is free software; you can redistribute it and/or modify
 | 
			
		||||
    it under the terms of the GNU General Public License as published by
 | 
			
		||||
    the Free Software Foundation; either version 2 of the License, or
 | 
			
		||||
    (at your option) any later version.
 | 
			
		||||
 | 
			
		||||
    This program is distributed in the hope that it will be useful,
 | 
			
		||||
    but WITHOUT ANY WARRANTY; without even the implied warranty of
 | 
			
		||||
    MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
 | 
			
		||||
    GNU General Public License for more details.
 | 
			
		||||
 | 
			
		||||
    You should have received a copy of the GNU General Public License along
 | 
			
		||||
    with this program; if not, write to the Free Software Foundation, Inc.,
 | 
			
		||||
    51 Franklin Street, Fifth Floor, Boston, MA 02110-1301 USA.
 | 
			
		||||
 | 
			
		||||
    See the full license in the file "LICENSE" in the top level distribution directory
 | 
			
		||||
    *************************************************************************************/
 | 
			
		||||
    /*  END LEGAL */
 | 
			
		||||
#ifndef NERSC_CHECKPOINTER
 | 
			
		||||
#define NERSC_CHECKPOINTER
 | 
			
		||||
 | 
			
		||||
#include <string>
 | 
			
		||||
#include <iostream>
 | 
			
		||||
#include <sstream>
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
namespace Grid{
 | 
			
		||||
  namespace QCD{
 | 
			
		||||
    
 | 
			
		||||
    // Only for Gauge fields
 | 
			
		||||
    template<class Gimpl> 
 | 
			
		||||
    class NerscHmcCheckpointer : public BaseHmcCheckpointer<Gimpl> {
 | 
			
		||||
    private:
 | 
			
		||||
     CheckpointerParameters Params;
 | 
			
		||||
 | 
			
		||||
    public:
 | 
			
		||||
      INHERIT_GIMPL_TYPES(Gimpl);//
 | 
			
		||||
 | 
			
		||||
        NerscHmcCheckpointer(CheckpointerParameters& Params_){
 | 
			
		||||
           initialize(Params_);
 | 
			
		||||
        }
 | 
			
		||||
 | 
			
		||||
        void initialize(CheckpointerParameters &Params_) { 
 | 
			
		||||
        	Params = Params_; 
 | 
			
		||||
        	Params.format = "IEEE64BIG"; // fixed, overwrite any other choice
 | 
			
		||||
        }
 | 
			
		||||
 | 
			
		||||
        void TrajectoryComplete(int traj, GaugeField &U, GridSerialRNG &sRNG,
 | 
			
		||||
                                GridParallelRNG &pRNG) {
 | 
			
		||||
          if ((traj % Params.SaveInterval) == 0) {
 | 
			
		||||
            std::string rng;
 | 
			
		||||
            {
 | 
			
		||||
              std::ostringstream os;
 | 
			
		||||
              os << Params.rngStem << "." << traj;
 | 
			
		||||
              rng = os.str();
 | 
			
		||||
            }
 | 
			
		||||
            std::string config;
 | 
			
		||||
            {
 | 
			
		||||
              std::ostringstream os;
 | 
			
		||||
              os << Params.configStem << "." << traj;
 | 
			
		||||
              config = os.str();
 | 
			
		||||
            }
 | 
			
		||||
 | 
			
		||||
            int precision32 = 1;
 | 
			
		||||
            int tworow = 0;
 | 
			
		||||
            NerscIO::writeRNGState(sRNG, pRNG, rng);
 | 
			
		||||
            NerscIO::writeConfiguration(U, config, tworow, precision32);
 | 
			
		||||
          }
 | 
			
		||||
        };
 | 
			
		||||
 | 
			
		||||
        void CheckpointRestore(int traj, GaugeField &U, GridSerialRNG &sRNG,
 | 
			
		||||
                               GridParallelRNG &pRNG) {
 | 
			
		||||
          std::string rng;
 | 
			
		||||
          {
 | 
			
		||||
            std::ostringstream os;
 | 
			
		||||
            os << Params.rngStem << "." << traj;
 | 
			
		||||
            rng = os.str();
 | 
			
		||||
          }
 | 
			
		||||
          std::string config;
 | 
			
		||||
          {
 | 
			
		||||
            std::ostringstream os;
 | 
			
		||||
            os << Params.configStem << "." << traj;
 | 
			
		||||
            config = os.str();
 | 
			
		||||
          }
 | 
			
		||||
 | 
			
		||||
          NerscField header;
 | 
			
		||||
          NerscIO::readRNGState(sRNG, pRNG, header, rng);
 | 
			
		||||
          NerscIO::readConfiguration(U, header, config);
 | 
			
		||||
        };
 | 
			
		||||
 | 
			
		||||
    };
 | 
			
		||||
}}
 | 
			
		||||
#endif
 | 
			
		||||
		Reference in New Issue
	
	Block a user