mirror of
				https://github.com/paboyle/Grid.git
				synced 2025-11-03 21:44:33 +00:00 
			
		
		
		
	Improved and RNG's now survive checkpoint
This commit is contained in:
		@@ -133,7 +133,6 @@ class BinaryIO {
 | 
			
		||||
      }
 | 
			
		||||
#pragma omp critical
 | 
			
		||||
      csum = csum + csum_thr;
 | 
			
		||||
 | 
			
		||||
    }
 | 
			
		||||
  }
 | 
			
		||||
  // Network is big endian
 | 
			
		||||
@@ -227,13 +226,20 @@ class BinaryIO {
 | 
			
		||||
  // Real action:
 | 
			
		||||
  // Read or Write distributed lexico array of ANY object to a specific location in file 
 | 
			
		||||
  //////////////////////////////////////////////////////////////////////////////////////
 | 
			
		||||
 | 
			
		||||
  static const int BINARYIO_MASTER_APPEND = 0x10;
 | 
			
		||||
  static const int BINARYIO_UNORDERED     = 0x08;
 | 
			
		||||
  static const int BINARYIO_LEXICOGRAPHIC = 0x04;
 | 
			
		||||
  static const int BINARYIO_READ          = 0x02;
 | 
			
		||||
  static const int BINARYIO_WRITE         = 0x01;
 | 
			
		||||
 | 
			
		||||
  template<class word,class fobj>
 | 
			
		||||
  static inline uint32_t IOobject(word w,
 | 
			
		||||
				  GridBase *grid,
 | 
			
		||||
				  std::vector<fobj> &iodata,
 | 
			
		||||
				  std::string file,
 | 
			
		||||
				  int offset,
 | 
			
		||||
				    const std::string &format, int doread)
 | 
			
		||||
				  const std::string &format, int control)
 | 
			
		||||
  {
 | 
			
		||||
    grid->Barrier();
 | 
			
		||||
    GridStopWatch timer; 
 | 
			
		||||
@@ -250,21 +256,24 @@ class BinaryIO {
 | 
			
		||||
    std::vector<int> gLattice= grid->GlobalDimensions();
 | 
			
		||||
    std::vector<int> lLattice= grid->LocalDimensions();
 | 
			
		||||
 | 
			
		||||
    std::vector<int> distribs(ndim,MPI_DISTRIBUTE_BLOCK);
 | 
			
		||||
    std::vector<int> dargs   (ndim,MPI_DISTRIBUTE_DFLT_DARG);
 | 
			
		||||
 | 
			
		||||
    std::vector<int> lStart(ndim);
 | 
			
		||||
    std::vector<int> gStart(ndim);
 | 
			
		||||
 | 
			
		||||
    // Flatten the file
 | 
			
		||||
    uint64_t lsites = grid->lSites();
 | 
			
		||||
    iodata.resize(lsites);
 | 
			
		||||
 | 
			
		||||
    if ( control & BINARYIO_MASTER_APPEND )  {
 | 
			
		||||
      assert(iodata.size()==1);
 | 
			
		||||
    } else {
 | 
			
		||||
      assert(lsites==iodata.size());
 | 
			
		||||
    }
 | 
			
		||||
    for(int d=0;d<ndim;d++){
 | 
			
		||||
      gStart[d] = lLattice[d]*pcoor[d];
 | 
			
		||||
      lStart[d] = 0;
 | 
			
		||||
    }
 | 
			
		||||
 | 
			
		||||
#ifdef USE_MPI_IO
 | 
			
		||||
    std::vector<int> distribs(ndim,MPI_DISTRIBUTE_BLOCK);
 | 
			
		||||
    std::vector<int> dargs   (ndim,MPI_DISTRIBUTE_DFLT_DARG);
 | 
			
		||||
    MPI_Datatype mpiObject;
 | 
			
		||||
    MPI_Datatype fileArray;
 | 
			
		||||
    MPI_Datatype localArray;
 | 
			
		||||
@@ -282,7 +291,6 @@ class BinaryIO {
 | 
			
		||||
      mpiword = MPI_DOUBLE;
 | 
			
		||||
    }
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
    //////////////////////////////////////////////////////////////////////////////
 | 
			
		||||
    // Sobj in MPI phrasing
 | 
			
		||||
    //////////////////////////////////////////////////////////////////////////////
 | 
			
		||||
@@ -301,6 +309,7 @@ class BinaryIO {
 | 
			
		||||
    //////////////////////////////////////////////////////////////////////////////
 | 
			
		||||
    ierr=MPI_Type_create_subarray(ndim,&lLattice[0],&lLattice[0],&lStart[0],MPI_ORDER_FORTRAN, mpiObject,&localArray);    assert(ierr==0);
 | 
			
		||||
    ierr=MPI_Type_commit(&localArray);    assert(ierr==0);
 | 
			
		||||
#endif
 | 
			
		||||
 | 
			
		||||
    //////////////////////////////////////////////////////////////////////////////
 | 
			
		||||
    // Byte order
 | 
			
		||||
@@ -311,55 +320,91 @@ class BinaryIO {
 | 
			
		||||
    int ieee64    = (format == std::string("IEEE64"));
 | 
			
		||||
 | 
			
		||||
    //////////////////////////////////////////////////////////////////////////////
 | 
			
		||||
    // Do the MPI I/O read
 | 
			
		||||
    // Do the I/O
 | 
			
		||||
    //////////////////////////////////////////////////////////////////////////////
 | 
			
		||||
    if ( doread ) { 
 | 
			
		||||
      std::cout<< GridLogMessage<< "MPI read I/O "<< file<< std::endl;
 | 
			
		||||
    if ( control & BINARYIO_READ ) { 
 | 
			
		||||
 | 
			
		||||
      timer.Start();
 | 
			
		||||
 | 
			
		||||
      if ( (control & BINARYIO_LEXICOGRAPHIC) && (nrank > 1) ) {
 | 
			
		||||
#ifdef USE_MPI_IO
 | 
			
		||||
	std::cout<< GridLogMessage<< "MPI read I/O "<< file<< std::endl;
 | 
			
		||||
	ierr=MPI_File_open(grid->communicator, file.c_str(), MPI_MODE_RDONLY, MPI_INFO_NULL, &fh);    assert(ierr==0);
 | 
			
		||||
	ierr=MPI_File_set_view(fh, disp, mpiObject, fileArray, "native", MPI_INFO_NULL);    assert(ierr==0);
 | 
			
		||||
	ierr=MPI_File_read_all(fh, &iodata[0], 1, localArray, &status);    assert(ierr==0);
 | 
			
		||||
	MPI_File_close(&fh);
 | 
			
		||||
	MPI_Type_free(&fileArray);
 | 
			
		||||
	MPI_Type_free(&localArray);
 | 
			
		||||
#else 
 | 
			
		||||
	assert(0);
 | 
			
		||||
#endif
 | 
			
		||||
      } else { 
 | 
			
		||||
	std::cout<< GridLogMessage<< "C++ read I/O "<< file<< std::endl;
 | 
			
		||||
	std::ifstream fin;
 | 
			
		||||
	fin.open(file,std::ios::binary|std::ios::in);
 | 
			
		||||
	if ( control & BINARYIO_MASTER_APPEND )  {
 | 
			
		||||
	  fin.seekg(-sizeof(fobj),fin.end);
 | 
			
		||||
	} else { 
 | 
			
		||||
	  fin.seekg(offset+myrank*lsites*sizeof(fobj));
 | 
			
		||||
	}
 | 
			
		||||
	fin.read((char *)&iodata[0],iodata.size()*sizeof(fobj));assert( fin.fail()==0);
 | 
			
		||||
	fin.close();
 | 
			
		||||
      }
 | 
			
		||||
      timer.Stop();
 | 
			
		||||
 | 
			
		||||
      grid->Barrier();
 | 
			
		||||
 | 
			
		||||
      bstimer.Start();
 | 
			
		||||
      if (ieee32big) be32toh_v((void *)&iodata[0], sizeof(fobj)*lsites,csum);
 | 
			
		||||
      if (ieee32)    le32toh_v((void *)&iodata[0], sizeof(fobj)*lsites,csum);
 | 
			
		||||
      if (ieee64big) be64toh_v((void *)&iodata[0], sizeof(fobj)*lsites,csum);
 | 
			
		||||
      if (ieee64)    le64toh_v((void *)&iodata[0], sizeof(fobj)*lsites,csum);
 | 
			
		||||
      if (ieee32big) be32toh_v((void *)&iodata[0], sizeof(fobj)*iodata.size(),csum);
 | 
			
		||||
      if (ieee32)    le32toh_v((void *)&iodata[0], sizeof(fobj)*iodata.size(),csum);
 | 
			
		||||
      if (ieee64big) be64toh_v((void *)&iodata[0], sizeof(fobj)*iodata.size(),csum);
 | 
			
		||||
      if (ieee64)    le64toh_v((void *)&iodata[0], sizeof(fobj)*iodata.size(),csum);
 | 
			
		||||
      bstimer.Stop();
 | 
			
		||||
    }
 | 
			
		||||
    
 | 
			
		||||
    if ( control & BINARYIO_WRITE ) { 
 | 
			
		||||
 | 
			
		||||
    } else { 
 | 
			
		||||
      std::cout<< GridLogMessage<< "MPI write I/O "<< file<< std::endl;
 | 
			
		||||
      bstimer.Start();
 | 
			
		||||
      if (ieee32big) htobe32_v((void *)&iodata[0], sizeof(fobj)*lsites,csum);
 | 
			
		||||
      if (ieee32)    htole32_v((void *)&iodata[0], sizeof(fobj)*lsites,csum);
 | 
			
		||||
      if (ieee64big) htobe64_v((void *)&iodata[0], sizeof(fobj)*lsites,csum);
 | 
			
		||||
      if (ieee64)    htole64_v((void *)&iodata[0], sizeof(fobj)*lsites,csum);
 | 
			
		||||
      if (ieee32big) htobe32_v((void *)&iodata[0], sizeof(fobj)*iodata.size(),csum);
 | 
			
		||||
      if (ieee32)    htole32_v((void *)&iodata[0], sizeof(fobj)*iodata.size(),csum);
 | 
			
		||||
      if (ieee64big) htobe64_v((void *)&iodata[0], sizeof(fobj)*iodata.size(),csum);
 | 
			
		||||
      if (ieee64)    htole64_v((void *)&iodata[0], sizeof(fobj)*iodata.size(),csum);
 | 
			
		||||
      bstimer.Stop();
 | 
			
		||||
 | 
			
		||||
      grid->Barrier();
 | 
			
		||||
 | 
			
		||||
      timer.Start();
 | 
			
		||||
      if ( (control & BINARYIO_LEXICOGRAPHIC) && (nrank > 1) ) {
 | 
			
		||||
#ifdef USE_MPI_IO
 | 
			
		||||
	std::cout<< GridLogMessage<< "MPI write I/O "<< file<< std::endl;
 | 
			
		||||
	ierr=MPI_File_open(grid->communicator, file.c_str(), MPI_MODE_RDWR|MPI_MODE_CREATE,MPI_INFO_NULL, &fh); assert(ierr==0);
 | 
			
		||||
	ierr=MPI_File_set_view(fh, disp, mpiObject, fileArray, "native", MPI_INFO_NULL);                        assert(ierr==0);
 | 
			
		||||
	ierr=MPI_File_write_all(fh, &iodata[0], 1, localArray, &status);                                        assert(ierr==0);
 | 
			
		||||
      timer.Stop();
 | 
			
		||||
    
 | 
			
		||||
    }
 | 
			
		||||
   
 | 
			
		||||
    //////////////////////////////////////////////////////////////////////////////
 | 
			
		||||
    // Finish up MPI I/O
 | 
			
		||||
    //////////////////////////////////////////////////////////////////////////////
 | 
			
		||||
	MPI_File_close(&fh);
 | 
			
		||||
	MPI_Type_free(&fileArray);
 | 
			
		||||
	MPI_Type_free(&localArray);
 | 
			
		||||
#else 
 | 
			
		||||
	assert(0);
 | 
			
		||||
#endif
 | 
			
		||||
      } else { 
 | 
			
		||||
	std::cout<< GridLogMessage<< "C++ write I/O "<< file<< std::endl;
 | 
			
		||||
	std::ofstream fout;
 | 
			
		||||
	fout.open(file,std::ios::binary|std::ios::out|std::ios::in);
 | 
			
		||||
	if ( control & BINARYIO_MASTER_APPEND )  {
 | 
			
		||||
	  fout.seekp(0,fout.end);
 | 
			
		||||
	} else {
 | 
			
		||||
	  fout.seekp(offset+myrank*lsites*sizeof(fobj));
 | 
			
		||||
	}
 | 
			
		||||
	fout.write((char *)&iodata[0],iodata.size()*sizeof(fobj));assert( fout.fail()==0);
 | 
			
		||||
	fout.close();
 | 
			
		||||
      }
 | 
			
		||||
      timer.Stop();
 | 
			
		||||
    }
 | 
			
		||||
 | 
			
		||||
    std::cout<<GridLogMessage<<"IOobject: ";
 | 
			
		||||
    if ( doread) std::cout << " read  ";
 | 
			
		||||
    if ( control & BINARYIO_READ) std::cout << " read  ";
 | 
			
		||||
    else                          std::cout << " write ";
 | 
			
		||||
    uint64_t bytes = sizeof(fobj)*lsites*nrank;
 | 
			
		||||
    uint64_t bytes = sizeof(fobj)*iodata.size()*nrank;
 | 
			
		||||
    std::cout<< bytes <<" bytes in "<<timer.Elapsed() <<" "
 | 
			
		||||
	     << (double)bytes/ (double)timer.useconds() <<" MB/s "<<std::endl;
 | 
			
		||||
 | 
			
		||||
@@ -390,8 +435,7 @@ class BinaryIO {
 | 
			
		||||
    std::vector<sobj> scalardata(lsites); 
 | 
			
		||||
    std::vector<fobj>     iodata(lsites); // Munge, checksum, byte order in here
 | 
			
		||||
    
 | 
			
		||||
    int doread=1;
 | 
			
		||||
    uint32_t csum= IOobject(w,grid,iodata,file,offset,format,doread);
 | 
			
		||||
    uint32_t csum= IOobject(w,grid,iodata,file,offset,format,BINARYIO_READ|BINARYIO_LEXICOGRAPHIC);
 | 
			
		||||
 | 
			
		||||
    GridStopWatch timer; 
 | 
			
		||||
    timer.Start();
 | 
			
		||||
@@ -432,8 +476,7 @@ class BinaryIO {
 | 
			
		||||
    grid->Barrier();
 | 
			
		||||
    timer.Stop();
 | 
			
		||||
 | 
			
		||||
    int dowrite=0;
 | 
			
		||||
    uint32_t csum= IOobject(w,grid,iodata,file,offset,format,dowrite);
 | 
			
		||||
    uint32_t csum= IOobject(w,grid,iodata,file,offset,format,BINARYIO_WRITE|BINARYIO_LEXICOGRAPHIC);
 | 
			
		||||
 | 
			
		||||
    std::cout<<GridLogMessage<<"writeLatticeObject: unvectorize overhead "<<timer.Elapsed()  <<std::endl;
 | 
			
		||||
 | 
			
		||||
@@ -461,9 +504,8 @@ class BinaryIO {
 | 
			
		||||
 | 
			
		||||
    std::cout << GridLogMessage << "RNG read I/O on file " << file << std::endl;
 | 
			
		||||
 | 
			
		||||
    int doread=1;
 | 
			
		||||
    std::vector<RNGstate> iodata(lsites);
 | 
			
		||||
    csum= IOobject(w,grid,iodata,file,offset,format,doread);
 | 
			
		||||
    csum= IOobject(w,grid,iodata,file,offset,format,BINARYIO_READ|BINARYIO_LEXICOGRAPHIC);
 | 
			
		||||
 | 
			
		||||
    timer.Start();
 | 
			
		||||
    parallel_for(int lidx=0;lidx<lsites;lidx++){
 | 
			
		||||
@@ -473,6 +515,14 @@ class BinaryIO {
 | 
			
		||||
    }
 | 
			
		||||
    timer.Stop();
 | 
			
		||||
 | 
			
		||||
    iodata.resize(1);
 | 
			
		||||
    csum+= IOobject(w,grid,iodata,file,offset,format,BINARYIO_READ|BINARYIO_MASTER_APPEND);
 | 
			
		||||
    {
 | 
			
		||||
      std::vector<RngStateType> tmp(RngStateCount);
 | 
			
		||||
      std::copy(iodata[0].begin(),iodata[0].end(),tmp.begin());
 | 
			
		||||
      serial.SetState(tmp,0);
 | 
			
		||||
    }
 | 
			
		||||
 | 
			
		||||
    std::cout << GridLogMessage << "RNG file checksum " << std::hex << csum << std::dec << std::endl;
 | 
			
		||||
    std::cout << GridLogMessage << "RNG state overhead " << timer.Elapsed() << std::endl;
 | 
			
		||||
    return csum;
 | 
			
		||||
@@ -507,8 +557,15 @@ class BinaryIO {
 | 
			
		||||
    }
 | 
			
		||||
    timer.Stop();
 | 
			
		||||
 | 
			
		||||
    int dowrite=0;
 | 
			
		||||
    csum= IOobject(w,grid,iodata,file,offset,format,dowrite);
 | 
			
		||||
    csum= IOobject(w,grid,iodata,file,offset,format,BINARYIO_WRITE|BINARYIO_LEXICOGRAPHIC);
 | 
			
		||||
 | 
			
		||||
    iodata.resize(1);
 | 
			
		||||
    {
 | 
			
		||||
      std::vector<RngStateType> tmp(RngStateCount);
 | 
			
		||||
      serial.GetState(tmp,0);
 | 
			
		||||
      std::copy(tmp.begin(),tmp.end(),iodata[0].begin());
 | 
			
		||||
    }
 | 
			
		||||
    csum+= IOobject(w,grid,iodata,file,offset,format,BINARYIO_WRITE|BINARYIO_MASTER_APPEND);
 | 
			
		||||
    
 | 
			
		||||
    std::cout << GridLogMessage << "RNG file checksum " << std::hex << csum << std::dec << std::endl;
 | 
			
		||||
    std::cout << GridLogMessage << "RNG state overhead " << timer.Elapsed() << std::endl;
 | 
			
		||||
 
 | 
			
		||||
@@ -68,11 +68,11 @@ class BinaryHmcCheckpointer : public BaseHmcCheckpointer<Impl> {
 | 
			
		||||
      std::string config, rng;
 | 
			
		||||
      this->build_filenames(traj, Params, config, rng);
 | 
			
		||||
 | 
			
		||||
      BinaryIO::BinarySimpleUnmunger<sobj_double, sobj> munge;
 | 
			
		||||
      BinarySimpleUnmunger<sobj_double, sobj> munge;
 | 
			
		||||
      truncate(rng);
 | 
			
		||||
      BinaryIO::writeRNGSerial(sRNG, pRNG, rng, 0);
 | 
			
		||||
      BinaryIO::writeRNG(sRNG, pRNG, rng, 0);
 | 
			
		||||
      truncate(config);
 | 
			
		||||
      uint32_t csum = BinaryIO::writeObjectParallel<vobj, sobj_double>(
 | 
			
		||||
      uint32_t csum = BinaryIO::writeLatticeObject<vobj, sobj_double>(
 | 
			
		||||
          U, config, munge, 0, Params.format);
 | 
			
		||||
 | 
			
		||||
      std::cout << GridLogMessage << "Written Binary Configuration " << config
 | 
			
		||||
@@ -85,9 +85,9 @@ class BinaryHmcCheckpointer : public BaseHmcCheckpointer<Impl> {
 | 
			
		||||
    std::string config, rng;
 | 
			
		||||
    this->build_filenames(traj, Params, config, rng);
 | 
			
		||||
 | 
			
		||||
    BinaryIO::BinarySimpleMunger<sobj_double, sobj> munge;
 | 
			
		||||
    BinaryIO::readRNGSerial(sRNG, pRNG, rng, 0);
 | 
			
		||||
    uint32_t csum = BinaryIO::readObjectParallel<vobj, sobj_double>(
 | 
			
		||||
    BinarySimpleMunger<sobj_double, sobj> munge;
 | 
			
		||||
    BinaryIO::readRNG(sRNG, pRNG, rng, 0);
 | 
			
		||||
    uint32_t csum = BinaryIO::readLatticeObject<vobj, sobj_double>(
 | 
			
		||||
        U, config, munge, 0, Params.format);
 | 
			
		||||
 | 
			
		||||
    std::cout << GridLogMessage << "Read Binary Configuration " << config
 | 
			
		||||
 
 | 
			
		||||
@@ -42,9 +42,9 @@ int main (int argc, char ** argv)
 | 
			
		||||
 | 
			
		||||
  std::vector<int> simd_layout = GridDefaultSimd(4,vComplex::Nsimd());
 | 
			
		||||
  std::vector<int> mpi_layout  = GridDefaultMpi();
 | 
			
		||||
  std::vector<int> latt_size  ({48,48,48,96});
 | 
			
		||||
  //std::vector<int> latt_size  ({48,48,48,96});
 | 
			
		||||
  //std::vector<int> latt_size  ({32,32,32,32});
 | 
			
		||||
  //std::vector<int> latt_size  ({16,16,16,32});
 | 
			
		||||
  std::vector<int> latt_size  ({16,16,16,32});
 | 
			
		||||
  std::vector<int> clatt_size  ({4,4,4,8});
 | 
			
		||||
  int orthodir=3;
 | 
			
		||||
  int orthosz =latt_size[orthodir];
 | 
			
		||||
 
 | 
			
		||||
		Reference in New Issue
	
	Block a user