From 094c3d091afb3f29e7e370562cb0def29b3b26f0 Mon Sep 17 00:00:00 2001 From: paboyle Date: Fri, 2 Jun 2017 00:38:58 +0100 Subject: [PATCH] Improved and RNG's now survive checkpoint --- lib/parallelIO/BinaryIO.h | 175 ++++++++++++------ .../hmc/checkpointers/BinaryCheckpointer.h | 12 +- tests/IO/Test_nersc_io.cc | 4 +- 3 files changed, 124 insertions(+), 67 deletions(-) diff --git a/lib/parallelIO/BinaryIO.h b/lib/parallelIO/BinaryIO.h index 13341927..e427a25b 100644 --- a/lib/parallelIO/BinaryIO.h +++ b/lib/parallelIO/BinaryIO.h @@ -133,7 +133,6 @@ class BinaryIO { } #pragma omp critical csum = csum + csum_thr; - } } // Network is big endian @@ -227,13 +226,20 @@ class BinaryIO { // Real action: // Read or Write distributed lexico array of ANY object to a specific location in file ////////////////////////////////////////////////////////////////////////////////////// + + static const int BINARYIO_MASTER_APPEND = 0x10; + static const int BINARYIO_UNORDERED = 0x08; + static const int BINARYIO_LEXICOGRAPHIC = 0x04; + static const int BINARYIO_READ = 0x02; + static const int BINARYIO_WRITE = 0x01; + template - static inline uint32_t IOobject(word w, - GridBase *grid, - std::vector &iodata, - std::string file, - int offset, - const std::string &format, int doread) + static inline uint32_t IOobject(word w, + GridBase *grid, + std::vector &iodata, + std::string file, + int offset, + const std::string &format, int control) { grid->Barrier(); GridStopWatch timer; @@ -250,21 +256,24 @@ class BinaryIO { std::vector gLattice= grid->GlobalDimensions(); std::vector lLattice= grid->LocalDimensions(); - std::vector distribs(ndim,MPI_DISTRIBUTE_BLOCK); - std::vector dargs (ndim,MPI_DISTRIBUTE_DFLT_DARG); - std::vector lStart(ndim); std::vector gStart(ndim); // Flatten the file uint64_t lsites = grid->lSites(); - iodata.resize(lsites); - + if ( control & BINARYIO_MASTER_APPEND ) { + assert(iodata.size()==1); + } else { + assert(lsites==iodata.size()); + } for(int d=0;d distribs(ndim,MPI_DISTRIBUTE_BLOCK); + std::vector dargs (ndim,MPI_DISTRIBUTE_DFLT_DARG); MPI_Datatype mpiObject; MPI_Datatype fileArray; MPI_Datatype localArray; @@ -281,7 +290,6 @@ class BinaryIO { numword = sizeof(fobj)/sizeof(double); mpiword = MPI_DOUBLE; } - ////////////////////////////////////////////////////////////////////////////// // Sobj in MPI phrasing @@ -301,6 +309,7 @@ class BinaryIO { ////////////////////////////////////////////////////////////////////////////// ierr=MPI_Type_create_subarray(ndim,&lLattice[0],&lLattice[0],&lStart[0],MPI_ORDER_FORTRAN, mpiObject,&localArray); assert(ierr==0); ierr=MPI_Type_commit(&localArray); assert(ierr==0); +#endif ////////////////////////////////////////////////////////////////////////////// // Byte order @@ -311,55 +320,91 @@ class BinaryIO { int ieee64 = (format == std::string("IEEE64")); ////////////////////////////////////////////////////////////////////////////// - // Do the MPI I/O read + // Do the I/O ////////////////////////////////////////////////////////////////////////////// - if ( doread ) { - std::cout<< GridLogMessage<< "MPI read I/O "<< file<< std::endl; + if ( control & BINARYIO_READ ) { + timer.Start(); - ierr=MPI_File_open(grid->communicator, file.c_str(), MPI_MODE_RDONLY, MPI_INFO_NULL, &fh); assert(ierr==0); - ierr=MPI_File_set_view(fh, disp, mpiObject, fileArray, "native", MPI_INFO_NULL); assert(ierr==0); - ierr=MPI_File_read_all(fh, &iodata[0], 1, localArray, &status); assert(ierr==0); + + if ( (control & BINARYIO_LEXICOGRAPHIC) && (nrank > 1) ) { +#ifdef USE_MPI_IO + std::cout<< GridLogMessage<< "MPI read I/O "<< file<< std::endl; + ierr=MPI_File_open(grid->communicator, file.c_str(), MPI_MODE_RDONLY, MPI_INFO_NULL, &fh); assert(ierr==0); + ierr=MPI_File_set_view(fh, disp, mpiObject, fileArray, "native", MPI_INFO_NULL); assert(ierr==0); + ierr=MPI_File_read_all(fh, &iodata[0], 1, localArray, &status); assert(ierr==0); + MPI_File_close(&fh); + MPI_Type_free(&fileArray); + MPI_Type_free(&localArray); +#else + assert(0); +#endif + } else { + std::cout<< GridLogMessage<< "C++ read I/O "<< file<< std::endl; + std::ifstream fin; + fin.open(file,std::ios::binary|std::ios::in); + if ( control & BINARYIO_MASTER_APPEND ) { + fin.seekg(-sizeof(fobj),fin.end); + } else { + fin.seekg(offset+myrank*lsites*sizeof(fobj)); + } + fin.read((char *)&iodata[0],iodata.size()*sizeof(fobj));assert( fin.fail()==0); + fin.close(); + } timer.Stop(); grid->Barrier(); bstimer.Start(); - if (ieee32big) be32toh_v((void *)&iodata[0], sizeof(fobj)*lsites,csum); - if (ieee32) le32toh_v((void *)&iodata[0], sizeof(fobj)*lsites,csum); - if (ieee64big) be64toh_v((void *)&iodata[0], sizeof(fobj)*lsites,csum); - if (ieee64) le64toh_v((void *)&iodata[0], sizeof(fobj)*lsites,csum); + if (ieee32big) be32toh_v((void *)&iodata[0], sizeof(fobj)*iodata.size(),csum); + if (ieee32) le32toh_v((void *)&iodata[0], sizeof(fobj)*iodata.size(),csum); + if (ieee64big) be64toh_v((void *)&iodata[0], sizeof(fobj)*iodata.size(),csum); + if (ieee64) le64toh_v((void *)&iodata[0], sizeof(fobj)*iodata.size(),csum); bstimer.Stop(); - - } else { - std::cout<< GridLogMessage<< "MPI write I/O "<< file<< std::endl; - bstimer.Start(); - if (ieee32big) htobe32_v((void *)&iodata[0], sizeof(fobj)*lsites,csum); - if (ieee32) htole32_v((void *)&iodata[0], sizeof(fobj)*lsites,csum); - if (ieee64big) htobe64_v((void *)&iodata[0], sizeof(fobj)*lsites,csum); - if (ieee64) htole64_v((void *)&iodata[0], sizeof(fobj)*lsites,csum); - bstimer.Stop(); - - grid->Barrier(); - - timer.Start(); - ierr=MPI_File_open(grid->communicator, file.c_str(), MPI_MODE_RDWR|MPI_MODE_CREATE,MPI_INFO_NULL, &fh); assert(ierr==0); - ierr=MPI_File_set_view(fh, disp, mpiObject, fileArray, "native", MPI_INFO_NULL); assert(ierr==0); - ierr=MPI_File_write_all(fh, &iodata[0], 1, localArray, &status); assert(ierr==0); - timer.Stop(); - } - - ////////////////////////////////////////////////////////////////////////////// - // Finish up MPI I/O - ////////////////////////////////////////////////////////////////////////////// - MPI_File_close(&fh); - MPI_Type_free(&fileArray); - MPI_Type_free(&localArray); + + if ( control & BINARYIO_WRITE ) { + + bstimer.Start(); + if (ieee32big) htobe32_v((void *)&iodata[0], sizeof(fobj)*iodata.size(),csum); + if (ieee32) htole32_v((void *)&iodata[0], sizeof(fobj)*iodata.size(),csum); + if (ieee64big) htobe64_v((void *)&iodata[0], sizeof(fobj)*iodata.size(),csum); + if (ieee64) htole64_v((void *)&iodata[0], sizeof(fobj)*iodata.size(),csum); + bstimer.Stop(); + + grid->Barrier(); + + timer.Start(); + if ( (control & BINARYIO_LEXICOGRAPHIC) && (nrank > 1) ) { +#ifdef USE_MPI_IO + std::cout<< GridLogMessage<< "MPI write I/O "<< file<< std::endl; + ierr=MPI_File_open(grid->communicator, file.c_str(), MPI_MODE_RDWR|MPI_MODE_CREATE,MPI_INFO_NULL, &fh); assert(ierr==0); + ierr=MPI_File_set_view(fh, disp, mpiObject, fileArray, "native", MPI_INFO_NULL); assert(ierr==0); + ierr=MPI_File_write_all(fh, &iodata[0], 1, localArray, &status); assert(ierr==0); + MPI_File_close(&fh); + MPI_Type_free(&fileArray); + MPI_Type_free(&localArray); +#else + assert(0); +#endif + } else { + std::cout<< GridLogMessage<< "C++ write I/O "<< file<< std::endl; + std::ofstream fout; + fout.open(file,std::ios::binary|std::ios::out|std::ios::in); + if ( control & BINARYIO_MASTER_APPEND ) { + fout.seekp(0,fout.end); + } else { + fout.seekp(offset+myrank*lsites*sizeof(fobj)); + } + fout.write((char *)&iodata[0],iodata.size()*sizeof(fobj));assert( fout.fail()==0); + fout.close(); + } + timer.Stop(); + } std::cout< scalardata(lsites); std::vector iodata(lsites); // Munge, checksum, byte order in here - int doread=1; - uint32_t csum= IOobject(w,grid,iodata,file,offset,format,doread); + uint32_t csum= IOobject(w,grid,iodata,file,offset,format,BINARYIO_READ|BINARYIO_LEXICOGRAPHIC); GridStopWatch timer; timer.Start(); @@ -432,8 +476,7 @@ class BinaryIO { grid->Barrier(); timer.Stop(); - int dowrite=0; - uint32_t csum= IOobject(w,grid,iodata,file,offset,format,dowrite); + uint32_t csum= IOobject(w,grid,iodata,file,offset,format,BINARYIO_WRITE|BINARYIO_LEXICOGRAPHIC); std::cout< iodata(lsites); - csum= IOobject(w,grid,iodata,file,offset,format,doread); + csum= IOobject(w,grid,iodata,file,offset,format,BINARYIO_READ|BINARYIO_LEXICOGRAPHIC); timer.Start(); parallel_for(int lidx=0;lidx tmp(RngStateCount); + std::copy(iodata[0].begin(),iodata[0].end(),tmp.begin()); + serial.SetState(tmp,0); + } + std::cout << GridLogMessage << "RNG file checksum " << std::hex << csum << std::dec << std::endl; std::cout << GridLogMessage << "RNG state overhead " << timer.Elapsed() << std::endl; return csum; @@ -507,9 +557,16 @@ class BinaryIO { } timer.Stop(); - int dowrite=0; - csum= IOobject(w,grid,iodata,file,offset,format,dowrite); + csum= IOobject(w,grid,iodata,file,offset,format,BINARYIO_WRITE|BINARYIO_LEXICOGRAPHIC); + iodata.resize(1); + { + std::vector tmp(RngStateCount); + serial.GetState(tmp,0); + std::copy(tmp.begin(),tmp.end(),iodata[0].begin()); + } + csum+= IOobject(w,grid,iodata,file,offset,format,BINARYIO_WRITE|BINARYIO_MASTER_APPEND); + std::cout << GridLogMessage << "RNG file checksum " << std::hex << csum << std::dec << std::endl; std::cout << GridLogMessage << "RNG state overhead " << timer.Elapsed() << std::endl; return csum; diff --git a/lib/qcd/hmc/checkpointers/BinaryCheckpointer.h b/lib/qcd/hmc/checkpointers/BinaryCheckpointer.h index 251ed042..6116a46c 100644 --- a/lib/qcd/hmc/checkpointers/BinaryCheckpointer.h +++ b/lib/qcd/hmc/checkpointers/BinaryCheckpointer.h @@ -68,11 +68,11 @@ class BinaryHmcCheckpointer : public BaseHmcCheckpointer { std::string config, rng; this->build_filenames(traj, Params, config, rng); - BinaryIO::BinarySimpleUnmunger munge; + BinarySimpleUnmunger munge; truncate(rng); - BinaryIO::writeRNGSerial(sRNG, pRNG, rng, 0); + BinaryIO::writeRNG(sRNG, pRNG, rng, 0); truncate(config); - uint32_t csum = BinaryIO::writeObjectParallel( + uint32_t csum = BinaryIO::writeLatticeObject( U, config, munge, 0, Params.format); std::cout << GridLogMessage << "Written Binary Configuration " << config @@ -85,9 +85,9 @@ class BinaryHmcCheckpointer : public BaseHmcCheckpointer { std::string config, rng; this->build_filenames(traj, Params, config, rng); - BinaryIO::BinarySimpleMunger munge; - BinaryIO::readRNGSerial(sRNG, pRNG, rng, 0); - uint32_t csum = BinaryIO::readObjectParallel( + BinarySimpleMunger munge; + BinaryIO::readRNG(sRNG, pRNG, rng, 0); + uint32_t csum = BinaryIO::readLatticeObject( U, config, munge, 0, Params.format); std::cout << GridLogMessage << "Read Binary Configuration " << config diff --git a/tests/IO/Test_nersc_io.cc b/tests/IO/Test_nersc_io.cc index 0a0f8977..14c6080d 100644 --- a/tests/IO/Test_nersc_io.cc +++ b/tests/IO/Test_nersc_io.cc @@ -42,9 +42,9 @@ int main (int argc, char ** argv) std::vector simd_layout = GridDefaultSimd(4,vComplex::Nsimd()); std::vector mpi_layout = GridDefaultMpi(); - std::vector latt_size ({48,48,48,96}); + //std::vector latt_size ({48,48,48,96}); //std::vector latt_size ({32,32,32,32}); - //std::vector latt_size ({16,16,16,32}); + std::vector latt_size ({16,16,16,32}); std::vector clatt_size ({4,4,4,8}); int orthodir=3; int orthosz =latt_size[orthodir];