2018-10-17 20:26:48 +01:00
|
|
|
/*************************************************************************************
|
|
|
|
|
|
|
|
Grid physics library, www.github.com/paboyle/Grid
|
|
|
|
|
|
|
|
Source file: Hadrons/Utilities/Contractor.cc
|
|
|
|
|
|
|
|
Copyright (C) 2015-2018
|
|
|
|
|
|
|
|
|
|
|
|
This program is free software; you can redistribute it and/or modify
|
|
|
|
it under the terms of the GNU General Public License as published by
|
|
|
|
the Free Software Foundation; either version 2 of the License, or
|
|
|
|
(at your option) any later version.
|
|
|
|
|
|
|
|
This program is distributed in the hope that it will be useful,
|
|
|
|
but WITHOUT ANY WARRANTY; without even the implied warranty of
|
|
|
|
MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
|
|
|
|
GNU General Public License for more details.
|
|
|
|
|
|
|
|
You should have received a copy of the GNU General Public License along
|
|
|
|
with this program; if not, write to the Free Software Foundation, Inc.,
|
|
|
|
51 Franklin Street, Fifth Floor, Boston, MA 02110-1301 USA.
|
|
|
|
|
|
|
|
See the full license in the file "LICENSE" in the top level distribution directory
|
|
|
|
*************************************************************************************/
|
|
|
|
/* END LEGAL */
|
|
|
|
#include <Hadrons/Global.hpp>
|
|
|
|
#include <Hadrons/A2AMatrix.hpp>
|
|
|
|
#include <Hadrons/DiskVector.hpp>
|
2018-11-09 16:23:53 +00:00
|
|
|
#include <Hadrons/TimerArray.hpp>
|
2018-10-17 20:26:48 +01:00
|
|
|
|
|
|
|
using namespace Grid;
|
|
|
|
using namespace QCD;
|
|
|
|
using namespace Hadrons;
|
|
|
|
|
2018-11-07 19:16:55 +00:00
|
|
|
#define TIME_MOD(t) (((t) + par.global.nt) % par.global.nt)
|
|
|
|
|
2018-10-17 20:26:48 +01:00
|
|
|
namespace Contractor
|
|
|
|
{
|
|
|
|
class GlobalPar: Serializable
|
|
|
|
{
|
|
|
|
public:
|
|
|
|
GRID_SERIALIZABLE_CLASS_MEMBERS(GlobalPar,
|
|
|
|
unsigned int, nt,
|
|
|
|
std::string, diskVectorDir,
|
|
|
|
std::string, output);
|
|
|
|
};
|
|
|
|
|
|
|
|
class A2AMatrixPar: Serializable
|
|
|
|
{
|
|
|
|
public:
|
|
|
|
GRID_SERIALIZABLE_CLASS_MEMBERS(A2AMatrixPar,
|
|
|
|
std::string, file,
|
|
|
|
std::string, dataset,
|
|
|
|
unsigned int, cacheSize,
|
|
|
|
std::string, name);
|
|
|
|
};
|
|
|
|
|
|
|
|
class ProductPar: Serializable
|
|
|
|
{
|
|
|
|
public:
|
|
|
|
GRID_SERIALIZABLE_CLASS_MEMBERS(ProductPar,
|
|
|
|
std::string, terms,
|
2018-11-07 19:16:55 +00:00
|
|
|
std::vector<std::string>, times,
|
|
|
|
std::string, translations);
|
2018-10-17 20:26:48 +01:00
|
|
|
};
|
|
|
|
}
|
|
|
|
|
|
|
|
struct ContractorPar
|
|
|
|
{
|
|
|
|
Contractor::GlobalPar global;
|
|
|
|
std::vector<Contractor::A2AMatrixPar> a2aMatrix;
|
|
|
|
std::vector<Contractor::ProductPar> product;
|
|
|
|
};
|
|
|
|
|
2018-11-07 19:16:55 +00:00
|
|
|
void makeTimeSeq(std::vector<std::vector<unsigned int>> &timeSeq,
|
|
|
|
const std::vector<std::set<unsigned int>> ×,
|
|
|
|
std::vector<unsigned int> ¤t,
|
|
|
|
const unsigned int depth)
|
|
|
|
{
|
|
|
|
if (depth > 0)
|
|
|
|
{
|
|
|
|
for (auto t: times[times.size() - depth])
|
|
|
|
{
|
|
|
|
current[times.size() - depth] = t;
|
|
|
|
makeTimeSeq(timeSeq, times, current, depth - 1);
|
|
|
|
}
|
|
|
|
}
|
|
|
|
else
|
|
|
|
{
|
|
|
|
timeSeq.push_back(current);
|
|
|
|
}
|
|
|
|
}
|
|
|
|
|
|
|
|
void makeTimeSeq(std::vector<std::vector<unsigned int>> &timeSeq,
|
|
|
|
const std::vector<std::set<unsigned int>> ×)
|
|
|
|
{
|
|
|
|
std::vector<unsigned int> current(times.size());
|
|
|
|
|
|
|
|
makeTimeSeq(timeSeq, times, current, times.size());
|
|
|
|
}
|
|
|
|
|
|
|
|
|
2018-10-18 17:50:35 +01:00
|
|
|
std::set<unsigned int> parseTimeRange(const std::string str, const unsigned int nt)
|
2018-10-17 20:26:48 +01:00
|
|
|
{
|
|
|
|
std::regex rex("([0-9]+)|(([0-9]+)\\.\\.([0-9]+))");
|
|
|
|
std::smatch sm;
|
|
|
|
std::vector<std::string> rstr = strToVec<std::string>(str);
|
2018-10-18 17:50:35 +01:00
|
|
|
std::set<unsigned int> tSet;
|
|
|
|
|
2018-10-17 20:26:48 +01:00
|
|
|
|
|
|
|
for (auto &s: rstr)
|
|
|
|
{
|
|
|
|
std::regex_match(s, sm, rex);
|
|
|
|
if (sm[1].matched)
|
|
|
|
{
|
2018-10-18 17:50:35 +01:00
|
|
|
unsigned int t;
|
|
|
|
|
|
|
|
t = std::stoi(sm[1].str());
|
|
|
|
if (t >= nt)
|
|
|
|
{
|
|
|
|
HADRONS_ERROR(Range, "time out of range (from expression '" + str + "')");
|
|
|
|
}
|
|
|
|
tSet.insert(t);
|
2018-10-17 20:26:48 +01:00
|
|
|
}
|
|
|
|
else if (sm[2].matched)
|
|
|
|
{
|
|
|
|
unsigned int ta, tb;
|
|
|
|
|
|
|
|
ta = std::stoi(sm[3].str());
|
|
|
|
tb = std::stoi(sm[4].str());
|
2018-10-18 17:50:35 +01:00
|
|
|
if ((ta >= nt) or (tb >= nt))
|
|
|
|
{
|
|
|
|
HADRONS_ERROR(Range, "time out of range (from expression '" + str + "')");
|
|
|
|
}
|
2018-10-17 20:26:48 +01:00
|
|
|
for (unsigned int ti = ta; ti <= tb; ++ti)
|
|
|
|
{
|
2018-10-18 17:50:35 +01:00
|
|
|
tSet.insert(ti);
|
2018-10-17 20:26:48 +01:00
|
|
|
}
|
|
|
|
}
|
|
|
|
}
|
|
|
|
|
2018-10-18 17:50:35 +01:00
|
|
|
return tSet;
|
2018-10-17 20:26:48 +01:00
|
|
|
}
|
|
|
|
|
2018-11-09 16:23:53 +00:00
|
|
|
struct Sec
|
|
|
|
{
|
|
|
|
Sec(const double usec)
|
|
|
|
{
|
|
|
|
seconds = usec/1.0e6;
|
|
|
|
}
|
|
|
|
|
|
|
|
double seconds;
|
|
|
|
};
|
|
|
|
|
|
|
|
inline std::ostream & operator<< (std::ostream& s, const Sec &&sec)
|
|
|
|
{
|
|
|
|
s << std::setw(10) << sec.seconds << " sec";
|
|
|
|
|
|
|
|
return s;
|
|
|
|
}
|
|
|
|
|
2018-11-08 18:46:28 +00:00
|
|
|
struct Flops
|
2018-11-07 19:59:11 +00:00
|
|
|
{
|
2018-11-08 18:46:28 +00:00
|
|
|
Flops(const double flops, const double fusec)
|
|
|
|
{
|
|
|
|
gFlopsPerSec = flops/fusec/1.0e3;
|
|
|
|
}
|
|
|
|
|
|
|
|
double gFlopsPerSec;
|
|
|
|
};
|
|
|
|
|
|
|
|
inline std::ostream & operator<< (std::ostream& s, const Flops &&f)
|
|
|
|
{
|
|
|
|
s << std::setw(10) << f.gFlopsPerSec << " GFlop/s";
|
|
|
|
|
|
|
|
return s;
|
|
|
|
}
|
|
|
|
|
|
|
|
struct Bytes
|
|
|
|
{
|
|
|
|
Bytes(const double bytes, const double busec)
|
|
|
|
{
|
|
|
|
gBytesPerSec = bytes/busec*1.0e6/1024/1024/1024;
|
|
|
|
}
|
|
|
|
|
|
|
|
double gBytesPerSec;
|
|
|
|
};
|
|
|
|
|
|
|
|
inline std::ostream & operator<< (std::ostream& s, const Bytes &&b)
|
|
|
|
{
|
|
|
|
s << std::setw(10) << b.gBytesPerSec << " GB/s";
|
|
|
|
|
|
|
|
return s;
|
2018-11-07 19:59:11 +00:00
|
|
|
}
|
|
|
|
|
2018-10-17 20:26:48 +01:00
|
|
|
int main(int argc, char* argv[])
|
|
|
|
{
|
|
|
|
// parse command line
|
|
|
|
std::string parFilename;
|
|
|
|
|
|
|
|
if (argc != 2)
|
|
|
|
{
|
|
|
|
std::cerr << "usage: " << argv[0] << " <parameter file>";
|
|
|
|
std::cerr << std::endl;
|
|
|
|
|
|
|
|
return EXIT_FAILURE;
|
|
|
|
}
|
|
|
|
parFilename = argv[1];
|
|
|
|
|
|
|
|
// parse parameter file
|
|
|
|
ContractorPar par;
|
2018-10-18 17:50:35 +01:00
|
|
|
unsigned int nMat, nCont;
|
|
|
|
XmlReader reader(parFilename);
|
|
|
|
|
|
|
|
read(reader, "global", par.global);
|
|
|
|
read(reader, "a2aMatrix", par.a2aMatrix);
|
|
|
|
read(reader, "product", par.product);
|
2018-10-17 20:26:48 +01:00
|
|
|
nMat = par.a2aMatrix.size();
|
|
|
|
nCont = par.product.size();
|
|
|
|
|
|
|
|
// create diskvectors
|
|
|
|
std::map<std::string, EigenDiskVector<ComplexD>> a2aMat;
|
|
|
|
unsigned int cacheSize;
|
|
|
|
|
|
|
|
for (auto &p: par.a2aMatrix)
|
|
|
|
{
|
|
|
|
std::string dirName = par.global.diskVectorDir + "/" + p.name;
|
|
|
|
|
|
|
|
a2aMat.emplace(p.name, EigenDiskVector<ComplexD>(dirName, par.global.nt, p.cacheSize));
|
|
|
|
}
|
|
|
|
|
|
|
|
// load data
|
|
|
|
for (unsigned int i = 0; i < a2aMat.size(); ++i)
|
|
|
|
{
|
|
|
|
auto &p = par.a2aMatrix[i];
|
|
|
|
double t, size;
|
|
|
|
|
2018-11-07 19:16:55 +00:00
|
|
|
std::cout << "======== Loading '" << p.file << "'" << std::endl;
|
2018-10-17 20:26:48 +01:00
|
|
|
|
|
|
|
A2AMatrixIo<HADRONS_A2AM_IO_TYPE> a2aIo(p.file, p.dataset, par.global.nt);
|
|
|
|
|
|
|
|
a2aIo.load(a2aMat.at(p.name), &t);
|
2018-11-09 16:23:53 +00:00
|
|
|
std::cout << "Read " << a2aIo.getSize() << " bytes in " << t/1.0e6
|
|
|
|
<< " sec, " << a2aIo.getSize()/t*1.0e6/1024/1024 << " MB/s" << std::endl;
|
2018-10-17 20:26:48 +01:00
|
|
|
}
|
|
|
|
|
|
|
|
// contract
|
|
|
|
EigenDiskVector<ComplexD>::Matrix buf;
|
|
|
|
|
|
|
|
for (auto &p: par.product)
|
|
|
|
{
|
2018-11-07 19:16:55 +00:00
|
|
|
std::vector<std::string> term = strToVec<std::string>(p.terms);
|
|
|
|
std::vector<std::set<unsigned int>> times;
|
|
|
|
std::vector<std::vector<unsigned int>> timeSeq;
|
|
|
|
std::set<unsigned int> translations;
|
|
|
|
std::vector<ComplexD> corr(par.global.nt);
|
|
|
|
std::vector<A2AMatrixTr<ComplexD>> lastTerm(par.global.nt);
|
|
|
|
A2AMatrix<ComplexD> prod, buf, tmp;
|
2018-11-09 16:23:53 +00:00
|
|
|
TimerArray tAr;
|
|
|
|
double fusec, busec, flops, bytes, tusec;
|
2018-11-07 19:16:55 +00:00
|
|
|
|
2018-11-09 16:23:53 +00:00
|
|
|
std::cout << "======== Contraction tr(";
|
2018-11-07 19:16:55 +00:00
|
|
|
for (unsigned int g = 0; g < term.size(); ++g)
|
2018-10-17 20:26:48 +01:00
|
|
|
{
|
2018-11-07 19:16:55 +00:00
|
|
|
std::cout << term[g] << ((g == term.size() - 1) ? ')' : '*');
|
2018-10-17 20:26:48 +01:00
|
|
|
}
|
2018-11-07 19:16:55 +00:00
|
|
|
std::cout << std::endl;
|
|
|
|
if (term.size() != p.times.size() + 1)
|
2018-10-17 20:26:48 +01:00
|
|
|
{
|
2018-11-07 19:16:55 +00:00
|
|
|
HADRONS_ERROR(Size, "number of terms (" + std::to_string(term.size())
|
|
|
|
+ ") different from number of times ("
|
|
|
|
+ std::to_string(p.times.size() + 1) + ")");
|
2018-10-17 20:26:48 +01:00
|
|
|
}
|
2018-11-07 19:16:55 +00:00
|
|
|
for (auto &s: p.times)
|
2018-10-17 20:26:48 +01:00
|
|
|
{
|
2018-11-07 19:16:55 +00:00
|
|
|
times.push_back(parseTimeRange(s, par.global.nt));
|
2018-10-17 20:26:48 +01:00
|
|
|
}
|
2018-11-07 19:16:55 +00:00
|
|
|
translations = parseTimeRange(p.translations, par.global.nt);
|
|
|
|
makeTimeSeq(timeSeq, times);
|
|
|
|
std::cout << timeSeq.size()*translations.size()*(term.size() - 2) << " A*B, "
|
|
|
|
<< timeSeq.size()*translations.size()*par.global.nt << " tr(A*B)"
|
|
|
|
<< std::endl;
|
2018-10-17 20:26:48 +01:00
|
|
|
|
2018-11-09 16:23:53 +00:00
|
|
|
std::cout << "* Caching transposed last term" << std::endl;
|
2018-10-17 20:26:48 +01:00
|
|
|
for (unsigned int t = 0; t < par.global.nt; ++t)
|
|
|
|
{
|
2018-11-08 18:46:28 +00:00
|
|
|
const A2AMatrix<ComplexD> &ref = a2aMat.at(term.back())[t];
|
|
|
|
|
2018-11-09 16:23:53 +00:00
|
|
|
tAr.startTimer("Transpose caching");
|
2018-11-08 18:46:28 +00:00
|
|
|
lastTerm[t] = ref;
|
2018-11-09 16:23:53 +00:00
|
|
|
tAr.stopTimer("Transpose caching");
|
2018-10-17 20:26:48 +01:00
|
|
|
}
|
2018-11-09 16:23:53 +00:00
|
|
|
bytes = par.global.nt*lastTerm[0].rows()*lastTerm[0].cols()*sizeof(ComplexD);
|
|
|
|
std::cout << Sec(tAr.getDTimer("Transpose caching")) << " "
|
|
|
|
<< Bytes(bytes, tAr.getDTimer("Transpose caching")) << std::endl;
|
2018-11-08 19:24:29 +00:00
|
|
|
for (unsigned int i = 0; i < timeSeq.size(); ++i)
|
2018-10-17 20:26:48 +01:00
|
|
|
{
|
2018-11-08 19:24:29 +00:00
|
|
|
unsigned int dti = 0;
|
|
|
|
auto &t = timeSeq[i];
|
|
|
|
|
2018-11-07 19:16:55 +00:00
|
|
|
for (unsigned int tLast = 0; tLast < par.global.nt; ++tLast)
|
|
|
|
{
|
|
|
|
corr[tLast] = 0.;
|
|
|
|
}
|
|
|
|
for (auto &dt: translations)
|
|
|
|
{
|
2018-11-08 19:24:29 +00:00
|
|
|
std::cout << "* Step " << i*translations.size() + dti + 1
|
|
|
|
<< "/" << timeSeq.size()*translations.size()
|
|
|
|
<< " -- positions= " << t << ", dt= " << dt << std::endl;
|
2018-11-07 19:59:11 +00:00
|
|
|
if (term.size() > 2)
|
|
|
|
{
|
2018-11-08 19:24:29 +00:00
|
|
|
std::cout << std::setw(8) << "products";
|
2018-11-07 19:59:11 +00:00
|
|
|
}
|
|
|
|
flops = 0.;
|
|
|
|
bytes = 0.;
|
2018-11-09 16:23:53 +00:00
|
|
|
fusec = tAr.getDTimer("A*B algebra");
|
|
|
|
busec = tAr.getDTimer("A*B total");
|
|
|
|
tAr.startTimer("Linear algebra");
|
|
|
|
tAr.startTimer("Disk vector overhead");
|
2018-11-07 19:16:55 +00:00
|
|
|
prod = a2aMat.at(term[0])[TIME_MOD(t[0] + dt)];
|
2018-11-09 16:23:53 +00:00
|
|
|
tAr.stopTimer("Disk vector overhead");
|
2018-11-08 19:24:29 +00:00
|
|
|
for (unsigned int j = 1; j < term.size() - 1; ++j)
|
2018-11-07 19:16:55 +00:00
|
|
|
{
|
2018-11-09 16:23:53 +00:00
|
|
|
tAr.startTimer("Disk vector overhead");
|
2018-11-08 19:24:29 +00:00
|
|
|
const A2AMatrix<ComplexD> &ref = a2aMat.at(term[j])[TIME_MOD(t[j] + dt)];
|
2018-11-09 16:23:53 +00:00
|
|
|
tAr.stopTimer("Disk vector overhead");
|
|
|
|
|
|
|
|
tAr.startTimer("A*B total");
|
|
|
|
tAr.startTimer("A*B algebra");
|
2018-11-07 19:16:55 +00:00
|
|
|
A2AContraction::mul(tmp, prod, ref);
|
2018-11-09 16:23:53 +00:00
|
|
|
tAr.stopTimer("A*B algebra");
|
2018-11-08 18:46:28 +00:00
|
|
|
flops += A2AContraction::mulFlops(prod, ref);
|
2018-11-07 19:59:11 +00:00
|
|
|
prod = tmp;
|
2018-11-09 16:23:53 +00:00
|
|
|
tAr.stopTimer("A*B total");
|
2018-11-07 19:59:11 +00:00
|
|
|
bytes += 3.*tmp.rows()*tmp.cols()*sizeof(ComplexD);
|
|
|
|
}
|
|
|
|
if (term.size() > 2)
|
|
|
|
{
|
2018-11-09 16:23:53 +00:00
|
|
|
std::cout << Sec(tAr.getDTimer("A*B total") - busec) << " "
|
|
|
|
<< Flops(flops, tAr.getDTimer("A*B algebra") - fusec) << " "
|
|
|
|
<< Bytes(bytes, tAr.getDTimer("A*B total") - busec) << std::endl;
|
2018-11-07 19:16:55 +00:00
|
|
|
}
|
2018-11-08 19:24:29 +00:00
|
|
|
std::cout << std::setw(8) << "traces";
|
2018-11-07 19:59:11 +00:00
|
|
|
flops = 0.;
|
|
|
|
bytes = 0.;
|
2018-11-09 16:23:53 +00:00
|
|
|
fusec = tAr.getDTimer("tr(A*B)");
|
|
|
|
busec = tAr.getDTimer("tr(A*B)");
|
2018-11-07 19:16:55 +00:00
|
|
|
for (unsigned int tLast = 0; tLast < par.global.nt; ++tLast)
|
|
|
|
{
|
2018-11-09 16:23:53 +00:00
|
|
|
tAr.startTimer("tr(A*B)");
|
2018-11-08 18:46:28 +00:00
|
|
|
A2AContraction::accTrMul(corr[TIME_MOD(tLast - dt)], prod, lastTerm[tLast]);
|
2018-11-09 16:23:53 +00:00
|
|
|
tAr.stopTimer("tr(A*B)");
|
2018-11-08 18:46:28 +00:00
|
|
|
flops += A2AContraction::accTrMulFlops(prod, lastTerm[tLast]);
|
2018-11-07 19:59:11 +00:00
|
|
|
bytes += 2.*prod.rows()*prod.cols()*sizeof(ComplexD);
|
2018-11-07 19:16:55 +00:00
|
|
|
}
|
2018-11-09 16:23:53 +00:00
|
|
|
tAr.stopTimer("Linear algebra");
|
|
|
|
std::cout << Sec(tAr.getDTimer("tr(A*B)") - busec) << " "
|
|
|
|
<< Flops(flops, tAr.getDTimer("tr(A*B)") - fusec) << " "
|
|
|
|
<< Bytes(bytes, tAr.getDTimer("tr(A*B)") - busec) << std::endl;
|
2018-11-08 19:24:29 +00:00
|
|
|
dti++;
|
2018-11-07 19:16:55 +00:00
|
|
|
}
|
|
|
|
for (unsigned int tLast = 0; tLast < par.global.nt; ++tLast)
|
|
|
|
{
|
|
|
|
std::cout << tLast << " " << corr[tLast] << std::endl;
|
|
|
|
}
|
2018-10-17 20:26:48 +01:00
|
|
|
}
|
2018-11-07 19:16:55 +00:00
|
|
|
}
|
2018-10-17 20:26:48 +01:00
|
|
|
|
|
|
|
return EXIT_SUCCESS;
|
|
|
|
}
|