1
0
mirror of https://github.com/paboyle/Grid.git synced 2025-06-15 06:17:05 +01:00

First implementation of the CG reproducibility test

This commit is contained in:
Guido Cossu
2016-11-06 11:13:29 +00:00
parent 96ba42a297
commit 9720c9ba3f
2 changed files with 54 additions and 9 deletions

View File

@ -33,6 +33,17 @@ directory
namespace Grid { namespace Grid {
struct CG_state{
bool do_repro;
std::vector<RealD> residuals;
CG_state(){
do_repro = false;
residuals.clear();}
};
///////////////////////////////////////////////////////////// /////////////////////////////////////////////////////////////
// Base classes for iterative processes based on operators // Base classes for iterative processes based on operators
// single input vec, single output vec. // single input vec, single output vec.
@ -45,10 +56,16 @@ class ConjugateGradient : public OperatorFunction<Field> {
// Defaults true. // Defaults true.
RealD Tolerance; RealD Tolerance;
Integer MaxIterations; Integer MaxIterations;
ConjugateGradient(RealD tol, Integer maxit, bool err_on_no_conv = true) bool ReproTest;
CG_state CGState;
ConjugateGradient(RealD tol, Integer maxit, bool err_on_no_conv = true,
bool ReproducibilityTest = false)
: Tolerance(tol), : Tolerance(tol),
MaxIterations(maxit), MaxIterations(maxit),
ErrorOnNoConverge(err_on_no_conv){}; ErrorOnNoConverge(err_on_no_conv),
ReproTest(ReproducibilityTest){};
void operator()(LinearOperatorBase<Field> &Linop, const Field &src, void operator()(LinearOperatorBase<Field> &Linop, const Field &src,
Field &psi) { Field &psi) {
@ -60,6 +77,10 @@ class ConjugateGradient : public OperatorFunction<Field> {
Field p(src); Field p(src);
Field mmp(src); Field mmp(src);
Field r(src); Field r(src);
Field psi_start(psi);// save for the repro test
if (CGState.do_repro)
std::cout << GridLogMessage << "Starting reproducibility test" << std::endl;
// Initial residual computation & set up // Initial residual computation & set up
RealD guess = norm2(psi); RealD guess = norm2(psi);
@ -107,10 +128,10 @@ class ConjugateGradient : public OperatorFunction<Field> {
SolverTimer.Start(); SolverTimer.Start();
int k; int k;
for (k = 1; k <= MaxIterations; k++) { for (k = 1; k <= MaxIterations; k++) {
c = cp; c = cp;// old residual
MatrixTimer.Start(); MatrixTimer.Start();
Linop.HermOpAndNorm(p, mmp, d, qq); Linop.HermOpAndNorm(p, mmp, d, qq);// mmp = Ap, d=pAp
MatrixTimer.Stop(); MatrixTimer.Stop();
LinalgTimer.Start(); LinalgTimer.Start();
@ -118,14 +139,30 @@ class ConjugateGradient : public OperatorFunction<Field> {
// ComplexD dck = innerProduct(p,mmp); // ComplexD dck = innerProduct(p,mmp);
a = c / d; a = c / d;
b_pred = a * (a * qq - d) / c; b_pred = a * (a * qq - d) / c;// a check
cp = axpy_norm(r, -a, mmp, r);
cp = axpy_norm(r, -a, mmp, r);// new residual r = r_old - a * Ap
if (ReproTest && !CGState.do_repro) {
CGState.residuals.push_back(cp); // save residuals state
std::cout << GridLogIterative << "ReproTest: Saving state" << std::endl;
}
if (ReproTest && CGState.do_repro){
// check that the residual agrees with the previous run
std::cout << GridLogIterative << "ReproTest: Checking state k=" << k << std::endl;
if (cp != CGState.residuals[k-1]){
std::cout << GridLogMessage << "Failing reproducibility test";
std::cout << GridLogMessage << " at k=" << k << std::endl;
std::cout << GridLogMessage << "saved residual = " << CGState.residuals[k-1]
<< " cp = " << cp << std::endl;
exit(-1);
}
}
b = cp / c; b = cp / c;
// Fuse these loops ; should be really easy // Fuse these loops ; should be really easy
psi = a * p + psi; psi = a * p + psi; // update solution
p = p * b + r; p = p * b + r; // update search direction
LinalgTimer.Stop(); LinalgTimer.Stop();
std::cout << GridLogIterative << "ConjugateGradient: Iteration " << k std::cout << GridLogIterative << "ConjugateGradient: Iteration " << k
@ -156,6 +193,14 @@ class ConjugateGradient : public OperatorFunction<Field> {
if (ErrorOnNoConverge) assert(true_residual / Tolerance < 1000.0); if (ErrorOnNoConverge) assert(true_residual / Tolerance < 1000.0);
if (!CGState.do_repro && ReproTest){
CGState.do_repro = true;
this->operator()(Linop, src, psi_start);// run the repro test
}
// Clear state
CGState.residuals.clear();
CGState.do_repro = false;
return; return;
} }
} }

View File

@ -71,7 +71,7 @@ int main (int argc, char ** argv)
WilsonFermionR Dw(Umu,Grid,RBGrid,mass); WilsonFermionR Dw(Umu,Grid,RBGrid,mass);
MdagMLinearOperator<WilsonFermionR,LatticeFermion> HermOp(Dw); MdagMLinearOperator<WilsonFermionR,LatticeFermion> HermOp(Dw);
ConjugateGradient<LatticeFermion> CG(1.0e-8,10000); ConjugateGradient<LatticeFermion> CG(1.0e-8,10000,true, true);
CG(HermOp,src,result); CG(HermOp,src,result);
Grid_finalize(); Grid_finalize();