From 9720c9ba3f06db688e2165135c79ff98a5ede821 Mon Sep 17 00:00:00 2001 From: Guido Cossu Date: Sun, 6 Nov 2016 11:13:29 +0000 Subject: [PATCH] First implementation of the CG reproducibility test --- lib/algorithms/iterative/ConjugateGradient.h | 61 +++++++++++++++++--- tests/solver/Test_wilson_cg_unprec.cc | 2 +- 2 files changed, 54 insertions(+), 9 deletions(-) diff --git a/lib/algorithms/iterative/ConjugateGradient.h b/lib/algorithms/iterative/ConjugateGradient.h index f340eb38..bd90b864 100644 --- a/lib/algorithms/iterative/ConjugateGradient.h +++ b/lib/algorithms/iterative/ConjugateGradient.h @@ -33,6 +33,17 @@ directory namespace Grid { +struct CG_state{ + bool do_repro; + std::vector residuals; + + CG_state(){ + do_repro = false; + residuals.clear();} + +}; + + ///////////////////////////////////////////////////////////// // Base classes for iterative processes based on operators // single input vec, single output vec. @@ -45,10 +56,16 @@ class ConjugateGradient : public OperatorFunction { // Defaults true. RealD Tolerance; Integer MaxIterations; - ConjugateGradient(RealD tol, Integer maxit, bool err_on_no_conv = true) + bool ReproTest; + CG_state CGState; + + ConjugateGradient(RealD tol, Integer maxit, bool err_on_no_conv = true, + bool ReproducibilityTest = false) : Tolerance(tol), MaxIterations(maxit), - ErrorOnNoConverge(err_on_no_conv){}; + ErrorOnNoConverge(err_on_no_conv), + ReproTest(ReproducibilityTest){}; + void operator()(LinearOperatorBase &Linop, const Field &src, Field &psi) { @@ -60,6 +77,10 @@ class ConjugateGradient : public OperatorFunction { Field p(src); Field mmp(src); Field r(src); + Field psi_start(psi);// save for the repro test + + if (CGState.do_repro) + std::cout << GridLogMessage << "Starting reproducibility test" << std::endl; // Initial residual computation & set up RealD guess = norm2(psi); @@ -107,10 +128,10 @@ class ConjugateGradient : public OperatorFunction { SolverTimer.Start(); int k; for (k = 1; k <= MaxIterations; k++) { - c = cp; + c = cp;// old residual MatrixTimer.Start(); - Linop.HermOpAndNorm(p, mmp, d, qq); + Linop.HermOpAndNorm(p, mmp, d, qq);// mmp = Ap, d=pAp MatrixTimer.Stop(); LinalgTimer.Start(); @@ -118,14 +139,30 @@ class ConjugateGradient : public OperatorFunction { // ComplexD dck = innerProduct(p,mmp); a = c / d; - b_pred = a * (a * qq - d) / c; + b_pred = a * (a * qq - d) / c;// a check - cp = axpy_norm(r, -a, mmp, r); + + cp = axpy_norm(r, -a, mmp, r);// new residual r = r_old - a * Ap + if (ReproTest && !CGState.do_repro) { + CGState.residuals.push_back(cp); // save residuals state + std::cout << GridLogIterative << "ReproTest: Saving state" << std::endl; + } + if (ReproTest && CGState.do_repro){ + // check that the residual agrees with the previous run + std::cout << GridLogIterative << "ReproTest: Checking state k=" << k << std::endl; + if (cp != CGState.residuals[k-1]){ + std::cout << GridLogMessage << "Failing reproducibility test"; + std::cout << GridLogMessage << " at k=" << k << std::endl; + std::cout << GridLogMessage << "saved residual = " << CGState.residuals[k-1] + << " cp = " << cp << std::endl; + exit(-1); + } + } b = cp / c; // Fuse these loops ; should be really easy - psi = a * p + psi; - p = p * b + r; + psi = a * p + psi; // update solution + p = p * b + r; // update search direction LinalgTimer.Stop(); std::cout << GridLogIterative << "ConjugateGradient: Iteration " << k @@ -156,6 +193,14 @@ class ConjugateGradient : public OperatorFunction { if (ErrorOnNoConverge) assert(true_residual / Tolerance < 1000.0); + if (!CGState.do_repro && ReproTest){ + CGState.do_repro = true; + this->operator()(Linop, src, psi_start);// run the repro test + } + + // Clear state + CGState.residuals.clear(); + CGState.do_repro = false; return; } } diff --git a/tests/solver/Test_wilson_cg_unprec.cc b/tests/solver/Test_wilson_cg_unprec.cc index 34b0a687..abc2e86e 100644 --- a/tests/solver/Test_wilson_cg_unprec.cc +++ b/tests/solver/Test_wilson_cg_unprec.cc @@ -71,7 +71,7 @@ int main (int argc, char ** argv) WilsonFermionR Dw(Umu,Grid,RBGrid,mass); MdagMLinearOperator HermOp(Dw); - ConjugateGradient CG(1.0e-8,10000); + ConjugateGradient CG(1.0e-8,10000,true, true); CG(HermOp,src,result); Grid_finalize();