mirror of
https://github.com/paboyle/Grid.git
synced 2025-08-02 20:57:06 +01:00
Reproducibility checks for inner product
This commit is contained in:
@@ -34,21 +34,24 @@ namespace Grid {
|
||||
|
||||
template <class T>
|
||||
struct ReproducibilityState {
|
||||
int n_call;
|
||||
typedef typename T::vector_type vector_type;
|
||||
unsigned int n_call;
|
||||
bool do_check;
|
||||
std::vector<std::vector<T, alignedAllocator<T> > > th_states;
|
||||
bool enable_reprocheck;
|
||||
std::vector<std::vector<vector_type, alignedAllocator<vector_type> > >
|
||||
th_states;
|
||||
|
||||
void reset(){
|
||||
void reset() {
|
||||
th_states.clear();
|
||||
do_check = false;
|
||||
enable_reprocheck = false;
|
||||
n_call = 0;
|
||||
}
|
||||
|
||||
ReproducibilityState(){
|
||||
reset();
|
||||
}
|
||||
};
|
||||
void reset_counter() { n_call = 0; }
|
||||
|
||||
ReproducibilityState() { reset(); }
|
||||
};
|
||||
|
||||
#ifdef GRID_WARN_SUBOPTIMAL
|
||||
#warning "Optimisation alert all these reduction loops are NOT threaded "
|
||||
@@ -58,15 +61,27 @@ struct ReproducibilityState {
|
||||
// Deterministic Reduction operations
|
||||
////////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
template<class vobj> inline RealD norm2(const Lattice<vobj> &arg){
|
||||
ComplexD nrm = innerProduct(arg,arg);
|
||||
return std::real(nrm);
|
||||
ReproducibilityState<vobj> repr;
|
||||
ComplexD nrm = innerProduct(arg, arg, repr);
|
||||
return std::real(nrm);
|
||||
}
|
||||
|
||||
|
||||
|
||||
template <class vobj>
|
||||
inline RealD norm2(const Lattice<vobj> &arg, ReproducibilityState<vobj>& rep) {
|
||||
ComplexD nrm = innerProduct(arg, arg, rep);
|
||||
return std::real(nrm);
|
||||
}
|
||||
|
||||
template<class vobj>
|
||||
inline ComplexD innerProduct(const Lattice<vobj> &left,const Lattice<vobj> &right)
|
||||
inline ComplexD innerProduct(const Lattice<vobj> &left,const Lattice<vobj> &right){
|
||||
ReproducibilityState<vobj> repr;
|
||||
return innerProduct(left, right, repr);
|
||||
}
|
||||
|
||||
|
||||
|
||||
template<class vobj>
|
||||
inline ComplexD innerProduct(const Lattice<vobj> &left,const Lattice<vobj> &right, ReproducibilityState<vobj>& repr)
|
||||
{
|
||||
typedef typename vobj::scalar_type scalar_type;
|
||||
typedef typename vobj::vector_type vector_type;
|
||||
@@ -81,30 +96,45 @@ struct ReproducibilityState {
|
||||
|
||||
// accumulation done in the same precision ad vobj...
|
||||
// may need to froce higher precision
|
||||
PARALLEL_FOR_LOOP
|
||||
PARALLEL_FOR_LOOP_STATIC //request statically scheduled threads for reproducibility
|
||||
for(int thr=0;thr<grid->SumArraySize();thr++){
|
||||
int nwork, mywork, myoff;
|
||||
GridThread::GetWork(left._grid->oSites(),thr,mywork,myoff);
|
||||
|
||||
decltype(innerProduct(left._odata[0],right._odata[0])) vnrm=zero; // private to thread; sub summation
|
||||
GridThread::GetWork(left._grid->oSites(),thr,mywork,myoff);
|
||||
|
||||
decltype(innerProduct(left._odata[0],right._odata[0])) vnrm = zero; // private to thread; sub summation
|
||||
for(int ss=myoff;ss<mywork+myoff; ss++){
|
||||
vnrm = vnrm + innerProduct(left._odata[ss],right._odata[ss]);
|
||||
}
|
||||
sumarray[thr]=TensorRemove(vnrm) ;
|
||||
vnrm = vnrm + innerProduct(left._odata[ss],right._odata[ss]);// accumulate here in higher precision
|
||||
}
|
||||
sumarray[thr]=TensorRemove(vnrm) ;
|
||||
}
|
||||
|
||||
|
||||
/*
|
||||
if (repr.do_check){
|
||||
if (sumarray!=repr.th_states[n_call]){
|
||||
std::cout << GridLogMessage << "Reproducibility failure on node " << grid->ThisRank() << std::endl;
|
||||
exit(1);
|
||||
/////////////////////// Reproducibility section
|
||||
if (repr.enable_reprocheck) {
|
||||
if (repr.do_check) {
|
||||
//std::cout << GridLogMessage << "Checking thread state for inner product. Call n. " << repr.n_call << std::endl;
|
||||
for (int thread = 0; thread < sumarray.size(); thread++) {
|
||||
if (sumarray[thread] != repr.th_states[repr.n_call][thread]) {
|
||||
std::cout << GridLogMessage << "Reproducibility failure on node " << grid->ThisRank() << std::endl;
|
||||
std::cout << GridLogMessage << "Call: "<< repr.n_call << " Thread: " << thread << std::endl;
|
||||
std::cout << GridLogMessage << "Size of states: " << repr.th_states.size() << std::endl;
|
||||
std::cout << GridLogMessage << sumarray[thread] << std::endl;
|
||||
std::cout << GridLogMessage << repr.th_states[repr.n_call][thread] << std::endl;
|
||||
//exit(1);
|
||||
}
|
||||
}
|
||||
repr.n_call++;
|
||||
} else
|
||||
{
|
||||
//std::cout << GridLogMessage << "Saving thread state for inner product. Call n. " << repr.n_call << std::endl;
|
||||
repr.th_states.resize(repr.n_call+1);
|
||||
repr.th_states[repr.n_call].resize(grid->SumArraySize());
|
||||
repr.th_states[repr.n_call] = sumarray; // save threads state
|
||||
repr.n_call++;
|
||||
}
|
||||
}
|
||||
} else {
|
||||
repr.th_states.push_back(sumarray);//save threads state
|
||||
repr.n_call +=1;
|
||||
}
|
||||
*/
|
||||
////////////////////////////////////////////////////////
|
||||
|
||||
|
||||
vector_type vvnrm; vvnrm=zero; // sum across threads
|
||||
for(int i=0;i<grid->SumArraySize();i++){
|
||||
@@ -154,13 +184,13 @@ struct ReproducibilityState {
|
||||
PARALLEL_FOR_LOOP
|
||||
for(int thr=0;thr<grid->SumArraySize();thr++){
|
||||
int nwork, mywork, myoff;
|
||||
GridThread::GetWork(grid->oSites(),thr,mywork,myoff);
|
||||
GridThread::GetWork(grid->oSites(),thr,mywork,myoff);
|
||||
|
||||
vobj vvsum=zero;
|
||||
vobj vvsum=zero;
|
||||
for(int ss=myoff;ss<mywork+myoff; ss++){
|
||||
vvsum = vvsum + arg._odata[ss];
|
||||
}
|
||||
sumarray[thr]=vvsum;
|
||||
}
|
||||
sumarray[thr]=vvsum;
|
||||
}
|
||||
|
||||
vobj vsum=zero; // sum across threads
|
||||
|
Reference in New Issue
Block a user