1
0
mirror of https://github.com/paboyle/Grid.git synced 2025-04-10 14:10:46 +01:00

Changes to A2Autils, A2AMatirx and DiskVector code that is needed for Hadrons 4 quark contraction module

This commit is contained in:
fionnoh 2019-06-27 13:45:20 +08:00
parent ac530636ca
commit 421a0a8a36
3 changed files with 290 additions and 142 deletions

View File

@ -68,8 +68,17 @@ public:
const std::vector<ComplexField> &emB1,
int orthogdim, double *t_kernel = nullptr, double *t_gsum = nullptr);
static void ContractWWVV(std::vector<PropagatorField> &WWVV,
const Eigen::Tensor<ComplexD,3> &WW_sd,
template <typename TensorType>
typename std::enable_if<std::is_same<Eigen::Tensor<ComplexD,3>, TensorType>::value, void>::type
static ContractWWVV(std::vector<PropagatorField> &WWVV,
const TensorType &WW_sd,
const FermionField *vs,
const FermionField *vd);
template <typename TensorType>
typename std::enable_if<!std::is_same<Eigen::Tensor<ComplexD,3>, TensorType>::value, void>::type
static ContractWWVV(std::vector<PropagatorField> &WWVV,
const TensorType &WW_sd,
const FermionField *vs,
const FermionField *vd);
@ -99,6 +108,11 @@ public:
const FermionField *vd,
int orthogdim);
#endif
private:
inline static void OuterProductWWVV(std::vector<PropagatorField> &WWVV,
const vobj &lhs,
const vobj &rhs,
const int Ns, const int ss, const int t);
};
template <class FImpl>
@ -962,11 +976,14 @@ void A2Autils<FImpl>::AslashField(TensorType &mat,
//
template <class FImpl>
void A2Autils<FImpl>::ContractWWVV(std::vector<PropagatorField> &WWVV,
const Eigen::Tensor<ComplexD,3> &WW_sd,
template <typename TensorType>
typename std::enable_if<std::is_same<Eigen::Tensor<ComplexD,3>, TensorType>::value, void>::type
A2Autils<FImpl>::ContractWWVV(std::vector<PropagatorField> &WWVV,
const TensorType &WW_sd,
const FermionField *vs,
const FermionField *vd)
{
std::cout << "Start contraction" << std::endl;
GridBase *grid = vs[0]._grid;
int nd = grid->_ndimension;
@ -998,25 +1015,82 @@ void A2Autils<FImpl>::ContractWWVV(std::vector<PropagatorField> &WWVV,
//////////////////////////
// Fast outer product of tmp1 with a sum of terms suppressed by d_unroll
//////////////////////////
OuterProductWWVV(WWVV, tmp1, tmp2, Ns, ss, t);
}}
}
}
}
template <class FImpl>
template <typename TensorType>
typename std::enable_if<!std::is_same<Eigen::Tensor<ComplexD,3>, TensorType>::value, void>::type
A2Autils<FImpl>::ContractWWVV(std::vector<PropagatorField> &WWVV,
const TensorType &WW_sd,
const FermionField *vs,
const FermionField *vd)
{
GridBase *grid = vs[0]._grid;
int nd = grid->_ndimension;
int Nsimd = grid->Nsimd();
int N_t = WW_sd.dimensions()[0];
int N_s = WW_sd.dimensions()[1];
int N_d = WW_sd.dimensions()[2];
int d_unroll = 32;// Empirical optimisation
Eigen::Matrix<Complex, -1, -1, Eigen::RowMajor> buf;
for(int t=0;t<N_t;t++){
WWVV[t] = zero;
}
for (int t = 0; t < N_t; t++){
std::cout << GridLogMessage << "Contraction t = " << t << std::endl;
buf = WW_sd[t];
parallel_for(int ss=0;ss<grid->oSites();ss++){
for(int d_o=0;d_o<N_d;d_o+=d_unroll){
for(int s=0;s<N_s;s++){
auto tmp1 = vs[s]._odata[ss];
vobj tmp2 = zero;
vobj tmp3 = zero;
for(int d=d_o;d<MIN(d_o+d_unroll,N_d);d++){
Scalar_v coeff = buf(s,d);
tmp3 = conjugate(vd[d]._odata[ss]);
mac(&tmp2 ,& coeff, &tmp3 );
}
//////////////////////////
// Fast outer product of tmp1 with a sum of terms suppressed by d_unroll
//////////////////////////
OuterProductWWVV(WWVV, tmp1, tmp2, Ns, ss, t);
}}
}
}
}
template <class FImpl>
inline void A2Autils<FImpl>::OuterProductWWVV(std::vector<PropagatorField> &WWVV,
const vobj &lhs,
const vobj &rhs,
const int Ns, const int ss, const int t)
{
for (int s1 = 0; s1 < Ns; s1++){
for (int s2 = 0; s2 < Ns; s2++){
WWVV[t]._odata[ss]()(s1,s2)(0,0) += tmp1()(s1)(0)*tmp2()(s2)(0);
WWVV[t]._odata[ss]()(s1,s2)(0,1) += tmp1()(s1)(0)*tmp2()(s2)(1);
WWVV[t]._odata[ss]()(s1,s2)(0,2) += tmp1()(s1)(0)*tmp2()(s2)(2);
WWVV[t]._odata[ss]()(s1,s2)(1,0) += tmp1()(s1)(1)*tmp2()(s2)(0);
WWVV[t]._odata[ss]()(s1,s2)(1,1) += tmp1()(s1)(1)*tmp2()(s2)(1);
WWVV[t]._odata[ss]()(s1,s2)(1,2) += tmp1()(s1)(1)*tmp2()(s2)(2);
WWVV[t]._odata[ss]()(s1,s2)(2,0) += tmp1()(s1)(2)*tmp2()(s2)(0);
WWVV[t]._odata[ss]()(s1,s2)(2,1) += tmp1()(s1)(2)*tmp2()(s2)(1);
WWVV[t]._odata[ss]()(s1,s2)(2,2) += tmp1()(s1)(2)*tmp2()(s2)(2);
}}
}}
WWVV[t]._odata[ss]()(s1, s2)(0, 0) += lhs()(s1)(0) * rhs()(s2)(0);
WWVV[t]._odata[ss]()(s1, s2)(0, 1) += lhs()(s1)(0) * rhs()(s2)(1);
WWVV[t]._odata[ss]()(s1, s2)(0, 2) += lhs()(s1)(0) * rhs()(s2)(2);
WWVV[t]._odata[ss]()(s1, s2)(1, 0) += lhs()(s1)(1) * rhs()(s2)(0);
WWVV[t]._odata[ss]()(s1, s2)(1, 1) += lhs()(s1)(1) * rhs()(s2)(1);
WWVV[t]._odata[ss]()(s1, s2)(1, 2) += lhs()(s1)(1) * rhs()(s2)(2);
WWVV[t]._odata[ss]()(s1, s2)(2, 0) += lhs()(s1)(2) * rhs()(s2)(0);
WWVV[t]._odata[ss]()(s1, s2)(2, 1) += lhs()(s1)(2) * rhs()(s2)(1);
WWVV[t]._odata[ss]()(s1, s2)(2, 2) += lhs()(s1)(2) * rhs()(s2)(2);
}
}
}
template<class FImpl>
void A2Autils<FImpl>::ContractFourQuarkColourDiagonal(const PropagatorField &WWVV0,
const PropagatorField &WWVV1,

View File

@ -108,7 +108,7 @@ public:
void saveBlock(const A2AMatrixSet<T> &m, const unsigned int ext, const unsigned int str,
const unsigned int i, const unsigned int j);
template <template <class> class Vec, typename VecT>
void load(Vec<VecT> &v, double *tRead = nullptr);
void load(Vec<VecT> &v, double *tRead = nullptr, GridBase *grid = nullptr);
private:
std::string filename_{""}, dataname_{""};
unsigned int nt_{0}, ni_{0}, nj_{0};
@ -495,15 +495,17 @@ void A2AMatrixIo<T>::saveBlock(const A2AMatrixSet<T> &m,
template <typename T>
template <template <class> class Vec, typename VecT>
void A2AMatrixIo<T>::load(Vec<VecT> &v, double *tRead)
void A2AMatrixIo<T>::load(Vec<VecT> &v, double *tRead, GridBase *grid)
{
#ifdef HAVE_HDF5
Hdf5Reader reader(filename_);
std::vector<hsize_t> hdim;
H5NS::DataSet dataset;
H5NS::DataSpace dataspace;
H5NS::CompType datatype;
if (!(grid) || grid->IsBoss())
{
Hdf5Reader reader(filename_);
push(reader, dataname_);
auto &group = reader.getGroup();
dataset = group.openDataSet(HADRONS_A2AM_NAME);
@ -531,8 +533,15 @@ void A2AMatrixIo<T>::load(Vec<VecT> &v, double *tRead)
ni_ = hdim[1];
nj_ = hdim[2];
}
}
if (grid)
{
grid->Broadcast(grid->BossRank(), &ni_, sizeof(unsigned int));
grid->Broadcast(grid->BossRank(), &nj_, sizeof(unsigned int));
}
A2AMatrix<T> buf(ni_, nj_);
int broadcastSize = sizeof(T) * buf.size();
std::vector<hsize_t> count = {1, static_cast<hsize_t>(ni_),
static_cast<hsize_t>(nj_)},
stride = {1, 1, 1},
@ -554,10 +563,20 @@ void A2AMatrixIo<T>::load(Vec<VecT> &v, double *tRead)
std::cout << " " << t;
std::cout.flush();
}
if (!(grid) || grid->IsBoss())
{
dataspace.selectHyperslab(H5S_SELECT_SET, count.data(), offset.data(),
stride.data(), block.data());
}
if (tRead) *tRead -= usecond();
if (!(grid) || grid->IsBoss())
{
dataset.read(buf.data(), datatype, memspace, dataspace);
}
if (grid)
{
grid->Broadcast(grid->BossRank(), buf.data(), broadcastSize);
}
if (tRead) *tRead += usecond();
v[t] = buf.template cast<VecT>();
}

View File

@ -87,13 +87,20 @@ public:
};
public:
DiskVectorBase(const std::string dirname, const unsigned int size = 0,
const unsigned int cacheSize = 1, const bool clean = true);
const unsigned int cacheSize = 1, const bool clean = true,
GridBase *grid = nullptr);
DiskVectorBase(DiskVectorBase<T> &&v) = default;
virtual ~DiskVectorBase(void);
const T & operator[](const unsigned int i) const;
RwAccessHelper operator[](const unsigned int i);
double hitRatio(void) const;
void resetStat(void);
void setSize(unsigned int size_);
unsigned int getSize() const;
unsigned int dvSize;
void setGrid(GridBase *grid_);
GridBase *getGrid() const;
GridBase *dvGrid;
private:
virtual void load(T &obj, const std::string filename) const = 0;
virtual void save(const std::string filename, const T &obj) const = 0;
@ -107,6 +114,7 @@ private:
unsigned int size_, cacheSize_;
double access_{0.}, hit_{0.};
bool clean_;
GridBase *grid_;
// using pointers to allow modifications when class is const
// semantic: const means data unmodified, but cache modification allowed
std::unique_ptr<std::vector<T>> cachePtr_;
@ -158,8 +166,20 @@ public:
{
return (*this)[i](j, k);
}
std::vector<int> dimensions() const
{
std::vector<int> dims(3);
dims[0] = (*this).getSize();
dims[1] = (*this)[0].rows();
dims[2] = (*this)[0].cols();
return dims;
}
private:
virtual void load(EigenDiskVectorMat<T> &obj, const std::string filename) const
{
GridBase *loadGrid;
loadGrid = (*this).getGrid();
if (!(loadGrid) || loadGrid->IsBoss())
{
std::ifstream f(filename, std::ios::binary);
uint32_t crc, check;
@ -190,8 +210,20 @@ private:
HADRONS_ERROR(Io, "checksum failed")
}
}
int broadcastSize;
broadcastSize = sizeof(T)*obj.size();
if (loadGrid)
{
loadGrid->Broadcast(loadGrid->BossRank(), obj.data(), broadcastSize);
loadGrid->Barrier();
}
}
virtual void save(const std::string filename, const EigenDiskVectorMat<T> &obj) const
{
GridBase *saveGrid;
saveGrid = (*this).getGrid();
if (!(saveGrid) || saveGrid->IsBoss())
{
std::ofstream f(filename, std::ios::binary);
uint32_t crc;
@ -219,6 +251,8 @@ private:
DV_DEBUG_MSG(this, "Eigen crc32 " << std::hex << crc << std::dec
<< " " << tHash/1.0e6 << " sec " << matSize/tHash*1.0e6/1024/1024 << " MB/s");
}
if (saveGrid) saveGrid->Barrier();
}
};
/******************************************************************************
@ -228,8 +262,9 @@ template <typename T>
DiskVectorBase<T>::DiskVectorBase(const std::string dirname,
const unsigned int size,
const unsigned int cacheSize,
const bool clean)
: dirname_(dirname), size_(size), cacheSize_(cacheSize), clean_(clean)
const bool clean,
GridBase *grid)
: dirname_(dirname), size_(size), cacheSize_(cacheSize), clean_(clean), grid_(grid)
, cachePtr_(new std::vector<T>(size))
, modifiedPtr_(new std::vector<bool>(size, false))
, indexPtr_(new std::map<unsigned int, unsigned int>())
@ -238,15 +273,21 @@ DiskVectorBase<T>::DiskVectorBase(const std::string dirname,
{
struct stat s;
if (!(grid_) || grid_->IsBoss())
{
if(stat(dirname.c_str(), &s) == 0)
{
HADRONS_ERROR(Io, "directory '" + dirname + "' already exists")
}
mkdir(dirname);
}
if (grid_) grid_->Barrier();
for (unsigned int i = 0; i < cacheSize_; ++i)
{
freePtr_->push(i);
}
setSize(size_);
setGrid(grid_);
}
template <typename T>
@ -258,6 +299,30 @@ DiskVectorBase<T>::~DiskVectorBase(void)
}
}
template <typename T>
void DiskVectorBase<T>::setSize(unsigned int size_)
{
dvSize = size_;
}
template <typename T>
unsigned int DiskVectorBase<T>::getSize() const
{
return dvSize;
}
template <typename T>
void DiskVectorBase<T>::setGrid(GridBase *grid_)
{
dvGrid = grid_;
}
template <typename T>
GridBase *DiskVectorBase<T>::getGrid() const
{
return dvGrid;
}
template <typename T>
const T & DiskVectorBase<T>::operator[](const unsigned int i) const
{
@ -299,7 +364,7 @@ const T & DiskVectorBase<T>::operator[](const unsigned int i) const
}
DV_DEBUG_MSG(this, "in cache: " << msg);
#endif
if (grid_) grid_->Barrier();
return cache[index.at(i)];
}
@ -358,6 +423,7 @@ void DiskVectorBase<T>::evict(void) const
index.erase(i);
loads.pop_front();
}
if (grid_) grid_->Barrier();
}
template <typename T>
@ -395,27 +461,14 @@ void DiskVectorBase<T>::cacheInsert(const unsigned int i, const T &obj) const
auto &freeInd = *freePtr_;
auto &loads = *loadsPtr_;
// cache miss, evict and store
if (index.find(i) == index.end())
{
evict();
index[i] = freeInd.top();
freeInd.pop();
cache[index.at(i)] = obj;
loads.push_back(i);
modified[index.at(i)] = false;
}
// cache hit, modify current value
else
{
auto pos = std::find(loads.begin(), loads.end(), i);
cache[index.at(i)] = obj;
modified[index.at(i)] = true;
loads.erase(pos);
loads.push_back(i);
}
if (grid_) grid_->Barrier();
#ifdef DV_DEBUG
std::string msg;
@ -434,15 +487,15 @@ void DiskVectorBase<T>::cacheInsert(const unsigned int i, const T &obj) const
template <typename T>
void DiskVectorBase<T>::clean(void)
{
auto unlink = [](const char *fpath, const struct stat *sb,
int typeflag, struct FTW *ftwbuf)
if (!(grid_) || grid_->IsBoss())
{
auto unlink = [](const char *fpath, const struct stat *sb,
int typeflag, struct FTW *ftwbuf) {
int rv = remove(fpath);
if (rv)
{
HADRONS_ERROR(Io, "cannot remove '" + std::string(fpath) + "': "
+ std::string(std::strerror(errno)));
HADRONS_ERROR(Io, "cannot remove '" + std::string(fpath) + "': " + std::string(std::strerror(errno)));
}
return rv;
@ -450,6 +503,8 @@ void DiskVectorBase<T>::clean(void)
nftw(dirname_.c_str(), unlink, 64, FTW_DEPTH | FTW_PHYS);
}
if (grid_) grid_->Barrier();
}
END_HADRONS_NAMESPACE