1
0
mirror of https://github.com/paboyle/Grid.git synced 2024-09-20 01:05:38 +01:00

Improved and RNG's now survive checkpoint

This commit is contained in:
paboyle 2017-06-02 00:38:58 +01:00
parent 4b98e524a0
commit 094c3d091a
3 changed files with 124 additions and 67 deletions

View File

@ -133,7 +133,6 @@ class BinaryIO {
}
#pragma omp critical
csum = csum + csum_thr;
}
}
// Network is big endian
@ -227,13 +226,20 @@ class BinaryIO {
// Real action:
// Read or Write distributed lexico array of ANY object to a specific location in file
//////////////////////////////////////////////////////////////////////////////////////
static const int BINARYIO_MASTER_APPEND = 0x10;
static const int BINARYIO_UNORDERED = 0x08;
static const int BINARYIO_LEXICOGRAPHIC = 0x04;
static const int BINARYIO_READ = 0x02;
static const int BINARYIO_WRITE = 0x01;
template<class word,class fobj>
static inline uint32_t IOobject(word w,
GridBase *grid,
std::vector<fobj> &iodata,
std::string file,
int offset,
const std::string &format, int doread)
static inline uint32_t IOobject(word w,
GridBase *grid,
std::vector<fobj> &iodata,
std::string file,
int offset,
const std::string &format, int control)
{
grid->Barrier();
GridStopWatch timer;
@ -250,21 +256,24 @@ class BinaryIO {
std::vector<int> gLattice= grid->GlobalDimensions();
std::vector<int> lLattice= grid->LocalDimensions();
std::vector<int> distribs(ndim,MPI_DISTRIBUTE_BLOCK);
std::vector<int> dargs (ndim,MPI_DISTRIBUTE_DFLT_DARG);
std::vector<int> lStart(ndim);
std::vector<int> gStart(ndim);
// Flatten the file
uint64_t lsites = grid->lSites();
iodata.resize(lsites);
if ( control & BINARYIO_MASTER_APPEND ) {
assert(iodata.size()==1);
} else {
assert(lsites==iodata.size());
}
for(int d=0;d<ndim;d++){
gStart[d] = lLattice[d]*pcoor[d];
lStart[d] = 0;
}
#ifdef USE_MPI_IO
std::vector<int> distribs(ndim,MPI_DISTRIBUTE_BLOCK);
std::vector<int> dargs (ndim,MPI_DISTRIBUTE_DFLT_DARG);
MPI_Datatype mpiObject;
MPI_Datatype fileArray;
MPI_Datatype localArray;
@ -281,7 +290,6 @@ class BinaryIO {
numword = sizeof(fobj)/sizeof(double);
mpiword = MPI_DOUBLE;
}
//////////////////////////////////////////////////////////////////////////////
// Sobj in MPI phrasing
@ -301,6 +309,7 @@ class BinaryIO {
//////////////////////////////////////////////////////////////////////////////
ierr=MPI_Type_create_subarray(ndim,&lLattice[0],&lLattice[0],&lStart[0],MPI_ORDER_FORTRAN, mpiObject,&localArray); assert(ierr==0);
ierr=MPI_Type_commit(&localArray); assert(ierr==0);
#endif
//////////////////////////////////////////////////////////////////////////////
// Byte order
@ -311,55 +320,91 @@ class BinaryIO {
int ieee64 = (format == std::string("IEEE64"));
//////////////////////////////////////////////////////////////////////////////
// Do the MPI I/O read
// Do the I/O
//////////////////////////////////////////////////////////////////////////////
if ( doread ) {
std::cout<< GridLogMessage<< "MPI read I/O "<< file<< std::endl;
if ( control & BINARYIO_READ ) {
timer.Start();
ierr=MPI_File_open(grid->communicator, file.c_str(), MPI_MODE_RDONLY, MPI_INFO_NULL, &fh); assert(ierr==0);
ierr=MPI_File_set_view(fh, disp, mpiObject, fileArray, "native", MPI_INFO_NULL); assert(ierr==0);
ierr=MPI_File_read_all(fh, &iodata[0], 1, localArray, &status); assert(ierr==0);
if ( (control & BINARYIO_LEXICOGRAPHIC) && (nrank > 1) ) {
#ifdef USE_MPI_IO
std::cout<< GridLogMessage<< "MPI read I/O "<< file<< std::endl;
ierr=MPI_File_open(grid->communicator, file.c_str(), MPI_MODE_RDONLY, MPI_INFO_NULL, &fh); assert(ierr==0);
ierr=MPI_File_set_view(fh, disp, mpiObject, fileArray, "native", MPI_INFO_NULL); assert(ierr==0);
ierr=MPI_File_read_all(fh, &iodata[0], 1, localArray, &status); assert(ierr==0);
MPI_File_close(&fh);
MPI_Type_free(&fileArray);
MPI_Type_free(&localArray);
#else
assert(0);
#endif
} else {
std::cout<< GridLogMessage<< "C++ read I/O "<< file<< std::endl;
std::ifstream fin;
fin.open(file,std::ios::binary|std::ios::in);
if ( control & BINARYIO_MASTER_APPEND ) {
fin.seekg(-sizeof(fobj),fin.end);
} else {
fin.seekg(offset+myrank*lsites*sizeof(fobj));
}
fin.read((char *)&iodata[0],iodata.size()*sizeof(fobj));assert( fin.fail()==0);
fin.close();
}
timer.Stop();
grid->Barrier();
bstimer.Start();
if (ieee32big) be32toh_v((void *)&iodata[0], sizeof(fobj)*lsites,csum);
if (ieee32) le32toh_v((void *)&iodata[0], sizeof(fobj)*lsites,csum);
if (ieee64big) be64toh_v((void *)&iodata[0], sizeof(fobj)*lsites,csum);
if (ieee64) le64toh_v((void *)&iodata[0], sizeof(fobj)*lsites,csum);
if (ieee32big) be32toh_v((void *)&iodata[0], sizeof(fobj)*iodata.size(),csum);
if (ieee32) le32toh_v((void *)&iodata[0], sizeof(fobj)*iodata.size(),csum);
if (ieee64big) be64toh_v((void *)&iodata[0], sizeof(fobj)*iodata.size(),csum);
if (ieee64) le64toh_v((void *)&iodata[0], sizeof(fobj)*iodata.size(),csum);
bstimer.Stop();
} else {
std::cout<< GridLogMessage<< "MPI write I/O "<< file<< std::endl;
bstimer.Start();
if (ieee32big) htobe32_v((void *)&iodata[0], sizeof(fobj)*lsites,csum);
if (ieee32) htole32_v((void *)&iodata[0], sizeof(fobj)*lsites,csum);
if (ieee64big) htobe64_v((void *)&iodata[0], sizeof(fobj)*lsites,csum);
if (ieee64) htole64_v((void *)&iodata[0], sizeof(fobj)*lsites,csum);
bstimer.Stop();
grid->Barrier();
timer.Start();
ierr=MPI_File_open(grid->communicator, file.c_str(), MPI_MODE_RDWR|MPI_MODE_CREATE,MPI_INFO_NULL, &fh); assert(ierr==0);
ierr=MPI_File_set_view(fh, disp, mpiObject, fileArray, "native", MPI_INFO_NULL); assert(ierr==0);
ierr=MPI_File_write_all(fh, &iodata[0], 1, localArray, &status); assert(ierr==0);
timer.Stop();
}
//////////////////////////////////////////////////////////////////////////////
// Finish up MPI I/O
//////////////////////////////////////////////////////////////////////////////
MPI_File_close(&fh);
MPI_Type_free(&fileArray);
MPI_Type_free(&localArray);
if ( control & BINARYIO_WRITE ) {
bstimer.Start();
if (ieee32big) htobe32_v((void *)&iodata[0], sizeof(fobj)*iodata.size(),csum);
if (ieee32) htole32_v((void *)&iodata[0], sizeof(fobj)*iodata.size(),csum);
if (ieee64big) htobe64_v((void *)&iodata[0], sizeof(fobj)*iodata.size(),csum);
if (ieee64) htole64_v((void *)&iodata[0], sizeof(fobj)*iodata.size(),csum);
bstimer.Stop();
grid->Barrier();
timer.Start();
if ( (control & BINARYIO_LEXICOGRAPHIC) && (nrank > 1) ) {
#ifdef USE_MPI_IO
std::cout<< GridLogMessage<< "MPI write I/O "<< file<< std::endl;
ierr=MPI_File_open(grid->communicator, file.c_str(), MPI_MODE_RDWR|MPI_MODE_CREATE,MPI_INFO_NULL, &fh); assert(ierr==0);
ierr=MPI_File_set_view(fh, disp, mpiObject, fileArray, "native", MPI_INFO_NULL); assert(ierr==0);
ierr=MPI_File_write_all(fh, &iodata[0], 1, localArray, &status); assert(ierr==0);
MPI_File_close(&fh);
MPI_Type_free(&fileArray);
MPI_Type_free(&localArray);
#else
assert(0);
#endif
} else {
std::cout<< GridLogMessage<< "C++ write I/O "<< file<< std::endl;
std::ofstream fout;
fout.open(file,std::ios::binary|std::ios::out|std::ios::in);
if ( control & BINARYIO_MASTER_APPEND ) {
fout.seekp(0,fout.end);
} else {
fout.seekp(offset+myrank*lsites*sizeof(fobj));
}
fout.write((char *)&iodata[0],iodata.size()*sizeof(fobj));assert( fout.fail()==0);
fout.close();
}
timer.Stop();
}
std::cout<<GridLogMessage<<"IOobject: ";
if ( doread) std::cout << " read ";
else std::cout << " write ";
uint64_t bytes = sizeof(fobj)*lsites*nrank;
if ( control & BINARYIO_READ) std::cout << " read ";
else std::cout << " write ";
uint64_t bytes = sizeof(fobj)*iodata.size()*nrank;
std::cout<< bytes <<" bytes in "<<timer.Elapsed() <<" "
<< (double)bytes/ (double)timer.useconds() <<" MB/s "<<std::endl;
@ -390,8 +435,7 @@ class BinaryIO {
std::vector<sobj> scalardata(lsites);
std::vector<fobj> iodata(lsites); // Munge, checksum, byte order in here
int doread=1;
uint32_t csum= IOobject(w,grid,iodata,file,offset,format,doread);
uint32_t csum= IOobject(w,grid,iodata,file,offset,format,BINARYIO_READ|BINARYIO_LEXICOGRAPHIC);
GridStopWatch timer;
timer.Start();
@ -432,8 +476,7 @@ class BinaryIO {
grid->Barrier();
timer.Stop();
int dowrite=0;
uint32_t csum= IOobject(w,grid,iodata,file,offset,format,dowrite);
uint32_t csum= IOobject(w,grid,iodata,file,offset,format,BINARYIO_WRITE|BINARYIO_LEXICOGRAPHIC);
std::cout<<GridLogMessage<<"writeLatticeObject: unvectorize overhead "<<timer.Elapsed() <<std::endl;
@ -461,9 +504,8 @@ class BinaryIO {
std::cout << GridLogMessage << "RNG read I/O on file " << file << std::endl;
int doread=1;
std::vector<RNGstate> iodata(lsites);
csum= IOobject(w,grid,iodata,file,offset,format,doread);
csum= IOobject(w,grid,iodata,file,offset,format,BINARYIO_READ|BINARYIO_LEXICOGRAPHIC);
timer.Start();
parallel_for(int lidx=0;lidx<lsites;lidx++){
@ -473,6 +515,14 @@ class BinaryIO {
}
timer.Stop();
iodata.resize(1);
csum+= IOobject(w,grid,iodata,file,offset,format,BINARYIO_READ|BINARYIO_MASTER_APPEND);
{
std::vector<RngStateType> tmp(RngStateCount);
std::copy(iodata[0].begin(),iodata[0].end(),tmp.begin());
serial.SetState(tmp,0);
}
std::cout << GridLogMessage << "RNG file checksum " << std::hex << csum << std::dec << std::endl;
std::cout << GridLogMessage << "RNG state overhead " << timer.Elapsed() << std::endl;
return csum;
@ -507,9 +557,16 @@ class BinaryIO {
}
timer.Stop();
int dowrite=0;
csum= IOobject(w,grid,iodata,file,offset,format,dowrite);
csum= IOobject(w,grid,iodata,file,offset,format,BINARYIO_WRITE|BINARYIO_LEXICOGRAPHIC);
iodata.resize(1);
{
std::vector<RngStateType> tmp(RngStateCount);
serial.GetState(tmp,0);
std::copy(tmp.begin(),tmp.end(),iodata[0].begin());
}
csum+= IOobject(w,grid,iodata,file,offset,format,BINARYIO_WRITE|BINARYIO_MASTER_APPEND);
std::cout << GridLogMessage << "RNG file checksum " << std::hex << csum << std::dec << std::endl;
std::cout << GridLogMessage << "RNG state overhead " << timer.Elapsed() << std::endl;
return csum;

View File

@ -68,11 +68,11 @@ class BinaryHmcCheckpointer : public BaseHmcCheckpointer<Impl> {
std::string config, rng;
this->build_filenames(traj, Params, config, rng);
BinaryIO::BinarySimpleUnmunger<sobj_double, sobj> munge;
BinarySimpleUnmunger<sobj_double, sobj> munge;
truncate(rng);
BinaryIO::writeRNGSerial(sRNG, pRNG, rng, 0);
BinaryIO::writeRNG(sRNG, pRNG, rng, 0);
truncate(config);
uint32_t csum = BinaryIO::writeObjectParallel<vobj, sobj_double>(
uint32_t csum = BinaryIO::writeLatticeObject<vobj, sobj_double>(
U, config, munge, 0, Params.format);
std::cout << GridLogMessage << "Written Binary Configuration " << config
@ -85,9 +85,9 @@ class BinaryHmcCheckpointer : public BaseHmcCheckpointer<Impl> {
std::string config, rng;
this->build_filenames(traj, Params, config, rng);
BinaryIO::BinarySimpleMunger<sobj_double, sobj> munge;
BinaryIO::readRNGSerial(sRNG, pRNG, rng, 0);
uint32_t csum = BinaryIO::readObjectParallel<vobj, sobj_double>(
BinarySimpleMunger<sobj_double, sobj> munge;
BinaryIO::readRNG(sRNG, pRNG, rng, 0);
uint32_t csum = BinaryIO::readLatticeObject<vobj, sobj_double>(
U, config, munge, 0, Params.format);
std::cout << GridLogMessage << "Read Binary Configuration " << config

View File

@ -42,9 +42,9 @@ int main (int argc, char ** argv)
std::vector<int> simd_layout = GridDefaultSimd(4,vComplex::Nsimd());
std::vector<int> mpi_layout = GridDefaultMpi();
std::vector<int> latt_size ({48,48,48,96});
//std::vector<int> latt_size ({48,48,48,96});
//std::vector<int> latt_size ({32,32,32,32});
//std::vector<int> latt_size ({16,16,16,32});
std::vector<int> latt_size ({16,16,16,32});
std::vector<int> clatt_size ({4,4,4,8});
int orthodir=3;
int orthosz =latt_size[orthodir];