/*************************************************************************************

Grid physics library, www.github.com/paboyle/Grid

Source file: ./lib/qcd/hmc/HMC.h

Copyright (C) 2015

Author: Azusa Yamaguchi <ayamaguc@staffmail.ed.ac.uk>
Author: Peter Boyle <paboyle@ph.ed.ac.uk>
Author: neo <cossu@post.kek.jp>
Author: paboyle <paboyle@ph.ed.ac.uk>

This program is free software; you can redistribute it and/or modify
it under the terms of the GNU General Public License as published by
the Free Software Foundation; either version 2 of the License, or
(at your option) any later version.

This program is distributed in the hope that it will be useful,
but WITHOUT ANY WARRANTY; without even the implied warranty of
MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
GNU General Public License for more details.

You should have received a copy of the GNU General Public License along
with this program; if not, write to the Free Software Foundation, Inc.,
51 Franklin Street, Fifth Floor, Boston, MA 02110-1301 USA.

See the full license in the file "LICENSE" in the top level distribution
directory
*************************************************************************************/
/*  END LEGAL */
//--------------------------------------------------------------------
/*! @file HMC.h
 * @brief Classes for Hybrid Monte Carlo update
 *
 * @author Guido Cossu
 * Time-stamp: <2015-07-30 16:58:26 neo>
 */
//--------------------------------------------------------------------
#ifndef HMC_INCLUDED
#define HMC_INCLUDED

#include <string>

namespace Grid {
namespace QCD {

struct HMCparameters {
  Integer StartTrajectory;
  Integer Trajectories; /* @brief Number of sweeps in this run */
  bool MetropolisTest;
  Integer NoMetropolisUntil;

  HMCparameters() {
    ////////////////////////////// Default values
    MetropolisTest = true;
    NoMetropolisUntil = 10;
    StartTrajectory = 0;
    Trajectories = 200;
    /////////////////////////////////
  }

  void print() const {
    std::cout << GridLogMessage << "[HMC parameter] Trajectories            : " << Trajectories << "\n";
    std::cout << GridLogMessage << "[HMC parameter] Start trajectory        : " << StartTrajectory << "\n";
    std::cout << GridLogMessage << "[HMC parameter] Metropolis test (on/off): " << MetropolisTest << "\n";
    std::cout << GridLogMessage << "[HMC parameter] Thermalization trajs    : " << NoMetropolisUntil << "\n";
  }
  
};

template <class GaugeField>
class HmcObservable {
 public:
  virtual void TrajectoryComplete(int traj, GaugeField &U, GridSerialRNG &sRNG,
                                  GridParallelRNG &pRNG) = 0;
};

template <class Gimpl>
class PlaquetteLogger : public HmcObservable<typename Gimpl::GaugeField> {
 private:
  std::string Stem;

 public:
  INHERIT_GIMPL_TYPES(Gimpl);
  PlaquetteLogger(std::string cf) { Stem = cf; };

  void TrajectoryComplete(int traj, GaugeField &U, GridSerialRNG &sRNG,
                          GridParallelRNG &pRNG) {
    std::string file;
    {
      std::ostringstream os;
      os << Stem << "." << traj;
      file = os.str();
    }
    std::ofstream of(file);

    RealD peri_plaq = WilsonLoops<PeriodicGimplR>::avgPlaquette(U);
    RealD peri_rect = WilsonLoops<PeriodicGimplR>::avgRectangle(U);

    RealD impl_plaq = WilsonLoops<Gimpl>::avgPlaquette(U);
    RealD impl_rect = WilsonLoops<Gimpl>::avgRectangle(U);

    of << traj << " " << impl_plaq << " " << impl_rect << "  " << peri_plaq
       << " " << peri_rect << std::endl;
    std::cout << GridLogMessage << "traj"
              << " "
              << "plaq "
              << " "
              << " rect  "
              << "  "
              << "peri_plaq"
              << " "
              << "peri_rect" << std::endl;
    std::cout << GridLogMessage << traj << " " << impl_plaq << " " << impl_rect
              << "  " << peri_plaq << " " << peri_rect << std::endl;
  }
};

//    template <class GaugeField, class Integrator, class Smearer, class
//    Boundary>
template <class GaugeField, class IntegratorType>
class HybridMonteCarlo {
 private:
  const HMCparameters Params;

  GridSerialRNG &sRNG;    // Fixme: need a RNG management strategy.
  GridParallelRNG &pRNG;  // Fixme: need a RNG management strategy.
  GaugeField &Ucur;

  IntegratorType &TheIntegrator;
  std::vector<HmcObservable<GaugeField> *> Observables;

  /////////////////////////////////////////////////////////
  // Metropolis step
  /////////////////////////////////////////////////////////
  bool metropolis_test(const RealD DeltaH) {
    RealD rn_test;

    RealD prob = std::exp(-DeltaH);

    random(sRNG, rn_test);

    std::cout << GridLogMessage
              << "--------------------------------------------------\n";
    std::cout << GridLogMessage << "exp(-dH) = " << prob
              << "  Random = " << rn_test << "\n";
    std::cout << GridLogMessage
              << "Acc. Probability = " << ((prob < 1.0) ? prob : 1.0) << "\n";

    if ((prob > 1.0) || (rn_test <= prob)) {  // accepted
      std::cout << GridLogMessage << "Metropolis_test -- ACCEPTED\n";
      std::cout << GridLogMessage
                << "--------------------------------------------------\n";
      return true;
    } else {  // rejected
      std::cout << GridLogMessage << "Metropolis_test -- REJECTED\n";
      std::cout << GridLogMessage
                << "--------------------------------------------------\n";
      return false;
    }
  }

  /////////////////////////////////////////////////////////
  // Evolution
  /////////////////////////////////////////////////////////
  RealD evolve_step(GaugeField &U) {
    TheIntegrator.refresh(U, pRNG);  // set U and initialize P and phi's

    RealD H0 = TheIntegrator.S(U);  // initial state action

    std::streamsize current_precision = std::cout.precision();
    std::cout.precision(17);
    std::cout << GridLogMessage << "Total H before trajectory = " << H0 << "\n";
    std::cout.precision(current_precision);

    TheIntegrator.integrate(U);

    RealD H1 = TheIntegrator.S(U);  // updated state action

    std::cout.precision(17);
    std::cout << GridLogMessage << "Total H after trajectory  = " << H1
              << "  dH = " << H1 - H0 << "\n";
    std::cout.precision(current_precision);

    return (H1 - H0);
  }

 public:
  /////////////////////////////////////////
  // Constructor
  /////////////////////////////////////////
  HybridMonteCarlo(HMCparameters Pams, IntegratorType &_Int,
                   GridSerialRNG &_sRNG, GridParallelRNG &_pRNG, GaugeField &_U)
      : Params(Pams), TheIntegrator(_Int), sRNG(_sRNG), pRNG(_pRNG), Ucur(_U) {}
  ~HybridMonteCarlo(){};

  void AddObservable(HmcObservable<GaugeField> *obs) {
    Observables.push_back(obs);
  }

  void evolve(void) {
    Real DeltaH;

    GaugeField Ucopy(Ucur._grid);

    Params.print();

    // Actual updates (evolve a copy Ucopy then copy back eventually)
    for (int traj = Params.StartTrajectory;
         traj < Params.Trajectories + Params.StartTrajectory; ++traj) {
      std::cout << GridLogMessage << "-- # Trajectory = " << traj << "\n";
      Ucopy = Ucur;

      DeltaH = evolve_step(Ucopy);

      bool accept = true;
      if (traj >= Params.NoMetropolisUntil) {
        accept = metropolis_test(DeltaH);
      }

      if (accept) {
        Ucur = Ucopy;
      }

      for (int obs = 0; obs < Observables.size(); obs++) {
        Observables[obs]->TrajectoryComplete(traj + 1, Ucur, sRNG, pRNG);
      }
    }
  }
};

}  // QCD
}  // Grid

#endif