mirror of
https://github.com/paboyle/Grid.git
synced 2024-11-10 07:55:35 +00:00
Improved and RNG's now survive checkpoint
This commit is contained in:
parent
4b98e524a0
commit
094c3d091a
@ -133,7 +133,6 @@ class BinaryIO {
|
||||
}
|
||||
#pragma omp critical
|
||||
csum = csum + csum_thr;
|
||||
|
||||
}
|
||||
}
|
||||
// Network is big endian
|
||||
@ -227,13 +226,20 @@ class BinaryIO {
|
||||
// Real action:
|
||||
// Read or Write distributed lexico array of ANY object to a specific location in file
|
||||
//////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
static const int BINARYIO_MASTER_APPEND = 0x10;
|
||||
static const int BINARYIO_UNORDERED = 0x08;
|
||||
static const int BINARYIO_LEXICOGRAPHIC = 0x04;
|
||||
static const int BINARYIO_READ = 0x02;
|
||||
static const int BINARYIO_WRITE = 0x01;
|
||||
|
||||
template<class word,class fobj>
|
||||
static inline uint32_t IOobject(word w,
|
||||
GridBase *grid,
|
||||
std::vector<fobj> &iodata,
|
||||
std::string file,
|
||||
int offset,
|
||||
const std::string &format, int doread)
|
||||
static inline uint32_t IOobject(word w,
|
||||
GridBase *grid,
|
||||
std::vector<fobj> &iodata,
|
||||
std::string file,
|
||||
int offset,
|
||||
const std::string &format, int control)
|
||||
{
|
||||
grid->Barrier();
|
||||
GridStopWatch timer;
|
||||
@ -250,21 +256,24 @@ class BinaryIO {
|
||||
std::vector<int> gLattice= grid->GlobalDimensions();
|
||||
std::vector<int> lLattice= grid->LocalDimensions();
|
||||
|
||||
std::vector<int> distribs(ndim,MPI_DISTRIBUTE_BLOCK);
|
||||
std::vector<int> dargs (ndim,MPI_DISTRIBUTE_DFLT_DARG);
|
||||
|
||||
std::vector<int> lStart(ndim);
|
||||
std::vector<int> gStart(ndim);
|
||||
|
||||
// Flatten the file
|
||||
uint64_t lsites = grid->lSites();
|
||||
iodata.resize(lsites);
|
||||
|
||||
if ( control & BINARYIO_MASTER_APPEND ) {
|
||||
assert(iodata.size()==1);
|
||||
} else {
|
||||
assert(lsites==iodata.size());
|
||||
}
|
||||
for(int d=0;d<ndim;d++){
|
||||
gStart[d] = lLattice[d]*pcoor[d];
|
||||
lStart[d] = 0;
|
||||
}
|
||||
|
||||
#ifdef USE_MPI_IO
|
||||
std::vector<int> distribs(ndim,MPI_DISTRIBUTE_BLOCK);
|
||||
std::vector<int> dargs (ndim,MPI_DISTRIBUTE_DFLT_DARG);
|
||||
MPI_Datatype mpiObject;
|
||||
MPI_Datatype fileArray;
|
||||
MPI_Datatype localArray;
|
||||
@ -281,7 +290,6 @@ class BinaryIO {
|
||||
numword = sizeof(fobj)/sizeof(double);
|
||||
mpiword = MPI_DOUBLE;
|
||||
}
|
||||
|
||||
|
||||
//////////////////////////////////////////////////////////////////////////////
|
||||
// Sobj in MPI phrasing
|
||||
@ -301,6 +309,7 @@ class BinaryIO {
|
||||
//////////////////////////////////////////////////////////////////////////////
|
||||
ierr=MPI_Type_create_subarray(ndim,&lLattice[0],&lLattice[0],&lStart[0],MPI_ORDER_FORTRAN, mpiObject,&localArray); assert(ierr==0);
|
||||
ierr=MPI_Type_commit(&localArray); assert(ierr==0);
|
||||
#endif
|
||||
|
||||
//////////////////////////////////////////////////////////////////////////////
|
||||
// Byte order
|
||||
@ -311,55 +320,91 @@ class BinaryIO {
|
||||
int ieee64 = (format == std::string("IEEE64"));
|
||||
|
||||
//////////////////////////////////////////////////////////////////////////////
|
||||
// Do the MPI I/O read
|
||||
// Do the I/O
|
||||
//////////////////////////////////////////////////////////////////////////////
|
||||
if ( doread ) {
|
||||
std::cout<< GridLogMessage<< "MPI read I/O "<< file<< std::endl;
|
||||
if ( control & BINARYIO_READ ) {
|
||||
|
||||
timer.Start();
|
||||
ierr=MPI_File_open(grid->communicator, file.c_str(), MPI_MODE_RDONLY, MPI_INFO_NULL, &fh); assert(ierr==0);
|
||||
ierr=MPI_File_set_view(fh, disp, mpiObject, fileArray, "native", MPI_INFO_NULL); assert(ierr==0);
|
||||
ierr=MPI_File_read_all(fh, &iodata[0], 1, localArray, &status); assert(ierr==0);
|
||||
|
||||
if ( (control & BINARYIO_LEXICOGRAPHIC) && (nrank > 1) ) {
|
||||
#ifdef USE_MPI_IO
|
||||
std::cout<< GridLogMessage<< "MPI read I/O "<< file<< std::endl;
|
||||
ierr=MPI_File_open(grid->communicator, file.c_str(), MPI_MODE_RDONLY, MPI_INFO_NULL, &fh); assert(ierr==0);
|
||||
ierr=MPI_File_set_view(fh, disp, mpiObject, fileArray, "native", MPI_INFO_NULL); assert(ierr==0);
|
||||
ierr=MPI_File_read_all(fh, &iodata[0], 1, localArray, &status); assert(ierr==0);
|
||||
MPI_File_close(&fh);
|
||||
MPI_Type_free(&fileArray);
|
||||
MPI_Type_free(&localArray);
|
||||
#else
|
||||
assert(0);
|
||||
#endif
|
||||
} else {
|
||||
std::cout<< GridLogMessage<< "C++ read I/O "<< file<< std::endl;
|
||||
std::ifstream fin;
|
||||
fin.open(file,std::ios::binary|std::ios::in);
|
||||
if ( control & BINARYIO_MASTER_APPEND ) {
|
||||
fin.seekg(-sizeof(fobj),fin.end);
|
||||
} else {
|
||||
fin.seekg(offset+myrank*lsites*sizeof(fobj));
|
||||
}
|
||||
fin.read((char *)&iodata[0],iodata.size()*sizeof(fobj));assert( fin.fail()==0);
|
||||
fin.close();
|
||||
}
|
||||
timer.Stop();
|
||||
|
||||
grid->Barrier();
|
||||
|
||||
bstimer.Start();
|
||||
if (ieee32big) be32toh_v((void *)&iodata[0], sizeof(fobj)*lsites,csum);
|
||||
if (ieee32) le32toh_v((void *)&iodata[0], sizeof(fobj)*lsites,csum);
|
||||
if (ieee64big) be64toh_v((void *)&iodata[0], sizeof(fobj)*lsites,csum);
|
||||
if (ieee64) le64toh_v((void *)&iodata[0], sizeof(fobj)*lsites,csum);
|
||||
if (ieee32big) be32toh_v((void *)&iodata[0], sizeof(fobj)*iodata.size(),csum);
|
||||
if (ieee32) le32toh_v((void *)&iodata[0], sizeof(fobj)*iodata.size(),csum);
|
||||
if (ieee64big) be64toh_v((void *)&iodata[0], sizeof(fobj)*iodata.size(),csum);
|
||||
if (ieee64) le64toh_v((void *)&iodata[0], sizeof(fobj)*iodata.size(),csum);
|
||||
bstimer.Stop();
|
||||
|
||||
} else {
|
||||
std::cout<< GridLogMessage<< "MPI write I/O "<< file<< std::endl;
|
||||
bstimer.Start();
|
||||
if (ieee32big) htobe32_v((void *)&iodata[0], sizeof(fobj)*lsites,csum);
|
||||
if (ieee32) htole32_v((void *)&iodata[0], sizeof(fobj)*lsites,csum);
|
||||
if (ieee64big) htobe64_v((void *)&iodata[0], sizeof(fobj)*lsites,csum);
|
||||
if (ieee64) htole64_v((void *)&iodata[0], sizeof(fobj)*lsites,csum);
|
||||
bstimer.Stop();
|
||||
|
||||
grid->Barrier();
|
||||
|
||||
timer.Start();
|
||||
ierr=MPI_File_open(grid->communicator, file.c_str(), MPI_MODE_RDWR|MPI_MODE_CREATE,MPI_INFO_NULL, &fh); assert(ierr==0);
|
||||
ierr=MPI_File_set_view(fh, disp, mpiObject, fileArray, "native", MPI_INFO_NULL); assert(ierr==0);
|
||||
ierr=MPI_File_write_all(fh, &iodata[0], 1, localArray, &status); assert(ierr==0);
|
||||
timer.Stop();
|
||||
|
||||
}
|
||||
|
||||
//////////////////////////////////////////////////////////////////////////////
|
||||
// Finish up MPI I/O
|
||||
//////////////////////////////////////////////////////////////////////////////
|
||||
MPI_File_close(&fh);
|
||||
MPI_Type_free(&fileArray);
|
||||
MPI_Type_free(&localArray);
|
||||
|
||||
if ( control & BINARYIO_WRITE ) {
|
||||
|
||||
bstimer.Start();
|
||||
if (ieee32big) htobe32_v((void *)&iodata[0], sizeof(fobj)*iodata.size(),csum);
|
||||
if (ieee32) htole32_v((void *)&iodata[0], sizeof(fobj)*iodata.size(),csum);
|
||||
if (ieee64big) htobe64_v((void *)&iodata[0], sizeof(fobj)*iodata.size(),csum);
|
||||
if (ieee64) htole64_v((void *)&iodata[0], sizeof(fobj)*iodata.size(),csum);
|
||||
bstimer.Stop();
|
||||
|
||||
grid->Barrier();
|
||||
|
||||
timer.Start();
|
||||
if ( (control & BINARYIO_LEXICOGRAPHIC) && (nrank > 1) ) {
|
||||
#ifdef USE_MPI_IO
|
||||
std::cout<< GridLogMessage<< "MPI write I/O "<< file<< std::endl;
|
||||
ierr=MPI_File_open(grid->communicator, file.c_str(), MPI_MODE_RDWR|MPI_MODE_CREATE,MPI_INFO_NULL, &fh); assert(ierr==0);
|
||||
ierr=MPI_File_set_view(fh, disp, mpiObject, fileArray, "native", MPI_INFO_NULL); assert(ierr==0);
|
||||
ierr=MPI_File_write_all(fh, &iodata[0], 1, localArray, &status); assert(ierr==0);
|
||||
MPI_File_close(&fh);
|
||||
MPI_Type_free(&fileArray);
|
||||
MPI_Type_free(&localArray);
|
||||
#else
|
||||
assert(0);
|
||||
#endif
|
||||
} else {
|
||||
std::cout<< GridLogMessage<< "C++ write I/O "<< file<< std::endl;
|
||||
std::ofstream fout;
|
||||
fout.open(file,std::ios::binary|std::ios::out|std::ios::in);
|
||||
if ( control & BINARYIO_MASTER_APPEND ) {
|
||||
fout.seekp(0,fout.end);
|
||||
} else {
|
||||
fout.seekp(offset+myrank*lsites*sizeof(fobj));
|
||||
}
|
||||
fout.write((char *)&iodata[0],iodata.size()*sizeof(fobj));assert( fout.fail()==0);
|
||||
fout.close();
|
||||
}
|
||||
timer.Stop();
|
||||
}
|
||||
|
||||
std::cout<<GridLogMessage<<"IOobject: ";
|
||||
if ( doread) std::cout << " read ";
|
||||
else std::cout << " write ";
|
||||
uint64_t bytes = sizeof(fobj)*lsites*nrank;
|
||||
if ( control & BINARYIO_READ) std::cout << " read ";
|
||||
else std::cout << " write ";
|
||||
uint64_t bytes = sizeof(fobj)*iodata.size()*nrank;
|
||||
std::cout<< bytes <<" bytes in "<<timer.Elapsed() <<" "
|
||||
<< (double)bytes/ (double)timer.useconds() <<" MB/s "<<std::endl;
|
||||
|
||||
@ -390,8 +435,7 @@ class BinaryIO {
|
||||
std::vector<sobj> scalardata(lsites);
|
||||
std::vector<fobj> iodata(lsites); // Munge, checksum, byte order in here
|
||||
|
||||
int doread=1;
|
||||
uint32_t csum= IOobject(w,grid,iodata,file,offset,format,doread);
|
||||
uint32_t csum= IOobject(w,grid,iodata,file,offset,format,BINARYIO_READ|BINARYIO_LEXICOGRAPHIC);
|
||||
|
||||
GridStopWatch timer;
|
||||
timer.Start();
|
||||
@ -432,8 +476,7 @@ class BinaryIO {
|
||||
grid->Barrier();
|
||||
timer.Stop();
|
||||
|
||||
int dowrite=0;
|
||||
uint32_t csum= IOobject(w,grid,iodata,file,offset,format,dowrite);
|
||||
uint32_t csum= IOobject(w,grid,iodata,file,offset,format,BINARYIO_WRITE|BINARYIO_LEXICOGRAPHIC);
|
||||
|
||||
std::cout<<GridLogMessage<<"writeLatticeObject: unvectorize overhead "<<timer.Elapsed() <<std::endl;
|
||||
|
||||
@ -461,9 +504,8 @@ class BinaryIO {
|
||||
|
||||
std::cout << GridLogMessage << "RNG read I/O on file " << file << std::endl;
|
||||
|
||||
int doread=1;
|
||||
std::vector<RNGstate> iodata(lsites);
|
||||
csum= IOobject(w,grid,iodata,file,offset,format,doread);
|
||||
csum= IOobject(w,grid,iodata,file,offset,format,BINARYIO_READ|BINARYIO_LEXICOGRAPHIC);
|
||||
|
||||
timer.Start();
|
||||
parallel_for(int lidx=0;lidx<lsites;lidx++){
|
||||
@ -473,6 +515,14 @@ class BinaryIO {
|
||||
}
|
||||
timer.Stop();
|
||||
|
||||
iodata.resize(1);
|
||||
csum+= IOobject(w,grid,iodata,file,offset,format,BINARYIO_READ|BINARYIO_MASTER_APPEND);
|
||||
{
|
||||
std::vector<RngStateType> tmp(RngStateCount);
|
||||
std::copy(iodata[0].begin(),iodata[0].end(),tmp.begin());
|
||||
serial.SetState(tmp,0);
|
||||
}
|
||||
|
||||
std::cout << GridLogMessage << "RNG file checksum " << std::hex << csum << std::dec << std::endl;
|
||||
std::cout << GridLogMessage << "RNG state overhead " << timer.Elapsed() << std::endl;
|
||||
return csum;
|
||||
@ -507,9 +557,16 @@ class BinaryIO {
|
||||
}
|
||||
timer.Stop();
|
||||
|
||||
int dowrite=0;
|
||||
csum= IOobject(w,grid,iodata,file,offset,format,dowrite);
|
||||
csum= IOobject(w,grid,iodata,file,offset,format,BINARYIO_WRITE|BINARYIO_LEXICOGRAPHIC);
|
||||
|
||||
iodata.resize(1);
|
||||
{
|
||||
std::vector<RngStateType> tmp(RngStateCount);
|
||||
serial.GetState(tmp,0);
|
||||
std::copy(tmp.begin(),tmp.end(),iodata[0].begin());
|
||||
}
|
||||
csum+= IOobject(w,grid,iodata,file,offset,format,BINARYIO_WRITE|BINARYIO_MASTER_APPEND);
|
||||
|
||||
std::cout << GridLogMessage << "RNG file checksum " << std::hex << csum << std::dec << std::endl;
|
||||
std::cout << GridLogMessage << "RNG state overhead " << timer.Elapsed() << std::endl;
|
||||
return csum;
|
||||
|
@ -68,11 +68,11 @@ class BinaryHmcCheckpointer : public BaseHmcCheckpointer<Impl> {
|
||||
std::string config, rng;
|
||||
this->build_filenames(traj, Params, config, rng);
|
||||
|
||||
BinaryIO::BinarySimpleUnmunger<sobj_double, sobj> munge;
|
||||
BinarySimpleUnmunger<sobj_double, sobj> munge;
|
||||
truncate(rng);
|
||||
BinaryIO::writeRNGSerial(sRNG, pRNG, rng, 0);
|
||||
BinaryIO::writeRNG(sRNG, pRNG, rng, 0);
|
||||
truncate(config);
|
||||
uint32_t csum = BinaryIO::writeObjectParallel<vobj, sobj_double>(
|
||||
uint32_t csum = BinaryIO::writeLatticeObject<vobj, sobj_double>(
|
||||
U, config, munge, 0, Params.format);
|
||||
|
||||
std::cout << GridLogMessage << "Written Binary Configuration " << config
|
||||
@ -85,9 +85,9 @@ class BinaryHmcCheckpointer : public BaseHmcCheckpointer<Impl> {
|
||||
std::string config, rng;
|
||||
this->build_filenames(traj, Params, config, rng);
|
||||
|
||||
BinaryIO::BinarySimpleMunger<sobj_double, sobj> munge;
|
||||
BinaryIO::readRNGSerial(sRNG, pRNG, rng, 0);
|
||||
uint32_t csum = BinaryIO::readObjectParallel<vobj, sobj_double>(
|
||||
BinarySimpleMunger<sobj_double, sobj> munge;
|
||||
BinaryIO::readRNG(sRNG, pRNG, rng, 0);
|
||||
uint32_t csum = BinaryIO::readLatticeObject<vobj, sobj_double>(
|
||||
U, config, munge, 0, Params.format);
|
||||
|
||||
std::cout << GridLogMessage << "Read Binary Configuration " << config
|
||||
|
@ -42,9 +42,9 @@ int main (int argc, char ** argv)
|
||||
|
||||
std::vector<int> simd_layout = GridDefaultSimd(4,vComplex::Nsimd());
|
||||
std::vector<int> mpi_layout = GridDefaultMpi();
|
||||
std::vector<int> latt_size ({48,48,48,96});
|
||||
//std::vector<int> latt_size ({48,48,48,96});
|
||||
//std::vector<int> latt_size ({32,32,32,32});
|
||||
//std::vector<int> latt_size ({16,16,16,32});
|
||||
std::vector<int> latt_size ({16,16,16,32});
|
||||
std::vector<int> clatt_size ({4,4,4,8});
|
||||
int orthodir=3;
|
||||
int orthosz =latt_size[orthodir];
|
||||
|
Loading…
Reference in New Issue
Block a user