/*************************************************************************************

    Grid physics library, www.github.com/paboyle/Grid 

    Source file: ./lib/parallelIO/NerscIO.h

    Copyright (C) 2015


    Author: Peter Boyle <paboyle@ph.ed.ac.uk>

    This program is free software; you can redistribute it and/or modify
    it under the terms of the GNU General Public License as published by
    the Free Software Foundation; either version 2 of the License, or
    (at your option) any later version.

    This program is distributed in the hope that it will be useful,
    but WITHOUT ANY WARRANTY; without even the implied warranty of
    MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
    GNU General Public License for more details.

    You should have received a copy of the GNU General Public License along
    with this program; if not, write to the Free Software Foundation, Inc.,
    51 Franklin Street, Fifth Floor, Boston, MA 02110-1301 USA.

    See the full license in the file "LICENSE" in the top level distribution directory
*************************************************************************************/
/*  END LEGAL */

#include <algorithm>
#include <iostream>
#include <iomanip>
#include <fstream>
#include <map>
#include <unistd.h>
#include <sys/utsname.h>
#include <pwd.h>

NAMESPACE_BEGIN(Grid);

///////////////////////////////////////////////////////
// Precision mapping
///////////////////////////////////////////////////////
template<class vobj> static std::string getFormatString (void)
{
  std::string format;
  typedef typename getPrecision<vobj>::real_scalar_type stype;
  if ( sizeof(stype) == sizeof(float) ) {
    format = std::string("IEEE32BIG");
  }
  if ( sizeof(stype) == sizeof(double) ) {
    format = std::string("IEEE64BIG");
  }
  return format;
}
////////////////////////////////////////////////////////////////////////////////
// header specification/interpretation
////////////////////////////////////////////////////////////////////////////////
class FieldMetaData : Serializable {
public:

  GRID_SERIALIZABLE_CLASS_MEMBERS(FieldMetaData,
				  int, nd,
				  std::vector<int>, dimension,
				  std::vector<std::string>, boundary,
				  int, data_start,
				  std::string, hdr_version,
				  std::string, storage_format,
				  double, link_trace,
				  double, plaquette,
				  uint32_t, checksum,
				  uint32_t, scidac_checksuma,
				  uint32_t, scidac_checksumb,
				  unsigned int, sequence_number,
				  std::string, data_type,
				  std::string, ensemble_id,
				  std::string, ensemble_label,
				  std::string, ildg_lfn,
				  std::string, creator,
				  std::string, creator_hardware,
				  std::string, creation_date,
				  std::string, archive_date,
				  std::string, floating_point);
  FieldMetaData(void) { 
    nd=4;
    dimension.resize(4);
    boundary.resize(4);
    scidac_checksuma=0;
    scidac_checksumb=0;
    checksum=0;
  }
};



namespace QCD {

  using namespace Grid;


  //////////////////////////////////////////////////////////////////////
  // Bit and Physical Checksumming and QA of data
  //////////////////////////////////////////////////////////////////////
  inline void GridMetaData(GridBase *grid,FieldMetaData &header)
  {
    int nd = grid->_ndimension;
    header.nd = nd;
    header.dimension.resize(nd);
    header.boundary.resize(nd);
    header.data_start = 0;
    for(int d=0;d<nd;d++) {
      header.dimension[d] = grid->_fdimensions[d];
    }
    for(int d=0;d<nd;d++) {
      header.boundary[d] = std::string("PERIODIC");
    }
  }

  inline void MachineCharacteristics(FieldMetaData &header)
  {
    // Who
    struct passwd *pw = getpwuid (getuid());
    if (pw) header.creator = std::string(pw->pw_name); 

    // When
    std::time_t t = std::time(nullptr);
    std::tm tm_ = *std::localtime(&t);
    std::ostringstream oss; 
    //      oss << std::put_time(&tm_, "%c %Z");
    header.creation_date = oss.str();
    header.archive_date  = header.creation_date;

    // What
    struct utsname name;  uname(&name);
    header.creator_hardware = std::string(name.nodename)+"-";
    header.creator_hardware+= std::string(name.machine)+"-";
    header.creator_hardware+= std::string(name.sysname)+"-";
    header.creator_hardware+= std::string(name.release);
  }

#define dump_meta_data(field, s)					\
  s << "BEGIN_HEADER"      << std::endl;				\
  s << "HDR_VERSION = "    << field.hdr_version    << std::endl;	\
  s << "DATATYPE = "       << field.data_type      << std::endl;	\
  s << "STORAGE_FORMAT = " << field.storage_format << std::endl;	\
  for(int i=0;i<4;i++){							\
    s << "DIMENSION_" << i+1 << " = " << field.dimension[i] << std::endl ; \
  }									\
  s << "LINK_TRACE = " << std::setprecision(10) << field.link_trace << std::endl; \
  s << "PLAQUETTE  = " << std::setprecision(10) << field.plaquette  << std::endl; \
  for(int i=0;i<4;i++){							\
    s << "BOUNDARY_"<<i+1<<" = " << field.boundary[i] << std::endl;	\
  }									\
									\
  s << "CHECKSUM = "<< std::hex << std::setw(10) << field.checksum << std::dec<<std::endl; \
  s << "SCIDAC_CHECKSUMA = "<< std::hex << std::setw(10) << field.scidac_checksuma << std::dec<<std::endl; \
  s << "SCIDAC_CHECKSUMB = "<< std::hex << std::setw(10) << field.scidac_checksumb << std::dec<<std::endl; \
  s << "ENSEMBLE_ID = "     << field.ensemble_id      << std::endl;	\
  s << "ENSEMBLE_LABEL = "  << field.ensemble_label   << std::endl;	\
  s << "SEQUENCE_NUMBER = " << field.sequence_number  << std::endl;	\
  s << "CREATOR = "         << field.creator          << std::endl;	\
  s << "CREATOR_HARDWARE = "<< field.creator_hardware << std::endl;	\
  s << "CREATION_DATE = "   << field.creation_date    << std::endl;	\
  s << "ARCHIVE_DATE = "    << field.archive_date     << std::endl;	\
  s << "FLOATING_POINT = "  << field.floating_point   << std::endl;	\
  s << "END_HEADER"         << std::endl;

  template<class vobj> inline void PrepareMetaData(Lattice<vobj> & field, FieldMetaData &header)
  {
    GridBase *grid = field._grid;
    std::string format = getFormatString<vobj>();
    header.floating_point = format;
    header.checksum = 0x0; // Nersc checksum unused in ILDG, Scidac
    GridMetaData(grid,header); 
    MachineCharacteristics(header);
  }
  inline void GaugeStatistics(Lattice<vLorentzColourMatrixF> & data,FieldMetaData &header)
  {
    // How to convert data precision etc...
    header.link_trace=Grid::QCD::WilsonLoops<PeriodicGimplF>::linkTrace(data);
    header.plaquette =Grid::QCD::WilsonLoops<PeriodicGimplF>::avgPlaquette(data);
  }
  inline void GaugeStatistics(Lattice<vLorentzColourMatrixD> & data,FieldMetaData &header)
  {
    // How to convert data precision etc...
    header.link_trace=Grid::QCD::WilsonLoops<PeriodicGimplD>::linkTrace(data);
    header.plaquette =Grid::QCD::WilsonLoops<PeriodicGimplD>::avgPlaquette(data);
  }
  template<> inline void PrepareMetaData<vLorentzColourMatrixF>(Lattice<vLorentzColourMatrixF> & field, FieldMetaData &header)
  {
   
    GridBase *grid = field._grid;
    std::string format = getFormatString<vLorentzColourMatrixF>();
    header.floating_point = format;
    header.checksum = 0x0; // Nersc checksum unused in ILDG, Scidac
    GridMetaData(grid,header); 
    GaugeStatistics(field,header);
    MachineCharacteristics(header);
  }
  template<> inline void PrepareMetaData<vLorentzColourMatrixD>(Lattice<vLorentzColourMatrixD> & field, FieldMetaData &header)
  {
    GridBase *grid = field._grid;
    std::string format = getFormatString<vLorentzColourMatrixD>();
    header.floating_point = format;
    header.checksum = 0x0; // Nersc checksum unused in ILDG, Scidac
    GridMetaData(grid,header); 
    GaugeStatistics(field,header);
    MachineCharacteristics(header);
  }

  //////////////////////////////////////////////////////////////////////
  // Utilities ; these are QCD aware
  //////////////////////////////////////////////////////////////////////
  inline void reconstruct3(LorentzColourMatrix & cm)
  {
    const int x=0;
    const int y=1;
    const int z=2;
    for(int mu=0;mu<Nd;mu++){
      cm(mu)()(2,x) = adj(cm(mu)()(0,y)*cm(mu)()(1,z)-cm(mu)()(0,z)*cm(mu)()(1,y)); //x= yz-zy
      cm(mu)()(2,y) = adj(cm(mu)()(0,z)*cm(mu)()(1,x)-cm(mu)()(0,x)*cm(mu)()(1,z)); //y= zx-xz
      cm(mu)()(2,z) = adj(cm(mu)()(0,x)*cm(mu)()(1,y)-cm(mu)()(0,y)*cm(mu)()(1,x)); //z= xy-yx
    }
  }

  ////////////////////////////////////////////////////////////////////////////////
  // Some data types for intermediate storage
  ////////////////////////////////////////////////////////////////////////////////
  template<typename vtype> using iLorentzColour2x3 = iVector<iVector<iVector<vtype, Nc>, 2>, Nd >;

  typedef iLorentzColour2x3<Complex>  LorentzColour2x3;
  typedef iLorentzColour2x3<ComplexF> LorentzColour2x3F;
  typedef iLorentzColour2x3<ComplexD> LorentzColour2x3D;

  /////////////////////////////////////////////////////////////////////////////////
  // Simple classes for precision conversion
  /////////////////////////////////////////////////////////////////////////////////
  template <class fobj, class sobj>
  struct BinarySimpleUnmunger {
    typedef typename getPrecision<fobj>::real_scalar_type fobj_stype;
    typedef typename getPrecision<sobj>::real_scalar_type sobj_stype;
  
    void operator()(sobj &in, fobj &out) {
      // take word by word and transform accoding to the status
      fobj_stype *out_buffer = (fobj_stype *)&out;
      sobj_stype *in_buffer = (sobj_stype *)&in;
      size_t fobj_words = sizeof(out) / sizeof(fobj_stype);
      size_t sobj_words = sizeof(in) / sizeof(sobj_stype);
      assert(fobj_words == sobj_words);
    
      for (unsigned int word = 0; word < sobj_words; word++)
	out_buffer[word] = in_buffer[word];  // type conversion on the fly
    
    }
  };

  template <class fobj, class sobj>
  struct BinarySimpleMunger {
    typedef typename getPrecision<fobj>::real_scalar_type fobj_stype;
    typedef typename getPrecision<sobj>::real_scalar_type sobj_stype;

    void operator()(fobj &in, sobj &out) {
      // take word by word and transform accoding to the status
      fobj_stype *in_buffer = (fobj_stype *)&in;
      sobj_stype *out_buffer = (sobj_stype *)&out;
      size_t fobj_words = sizeof(in) / sizeof(fobj_stype);
      size_t sobj_words = sizeof(out) / sizeof(sobj_stype);
      assert(fobj_words == sobj_words);
    
      for (unsigned int word = 0; word < sobj_words; word++)
	out_buffer[word] = in_buffer[word];  // type conversion on the fly
    
    }
  };


  template<class fobj,class sobj>
  struct GaugeSimpleMunger{
    void operator()(fobj &in, sobj &out) {
      for (int mu = 0; mu < Nd; mu++) {
	for (int i = 0; i < Nc; i++) {
          for (int j = 0; j < Nc; j++) {
	    out(mu)()(i, j) = in(mu)()(i, j);
	  }}
      }
    };
  };

  template <class fobj, class sobj>
  struct GaugeSimpleUnmunger {

    void operator()(sobj &in, fobj &out) {
      for (int mu = 0; mu < Nd; mu++) {
	for (int i = 0; i < Nc; i++) {
          for (int j = 0; j < Nc; j++) {
	    out(mu)()(i, j) = in(mu)()(i, j);
	  }}
      }
    };
  };

  template<class fobj,class sobj>
  struct Gauge3x2munger{
    void operator() (fobj &in,sobj &out){
      for(int mu=0;mu<Nd;mu++){
	for(int i=0;i<2;i++){
	  for(int j=0;j<3;j++){
	    out(mu)()(i,j) = in(mu)(i)(j);
	  }}
      }
      reconstruct3(out);
    }
  };

  template<class fobj,class sobj>
  struct Gauge3x2unmunger{
    void operator() (sobj &in,fobj &out){
      for(int mu=0;mu<Nd;mu++){
	for(int i=0;i<2;i++){
	  for(int j=0;j<3;j++){
	    out(mu)(i)(j) = in(mu)()(i,j);
	  }}
      }
    }
  };
}

NAMESPACE_END(Grid);