diff --git a/lib/algorithms/iterative/ConjugateGradient.h b/lib/algorithms/iterative/ConjugateGradient.h index b070d10b..00374f7b 100644 --- a/lib/algorithms/iterative/ConjugateGradient.h +++ b/lib/algorithms/iterative/ConjugateGradient.h @@ -47,7 +47,7 @@ struct CG_state { }; -enum CGexec_modes{ Default, ReproducibilityTest }; +enum CGexec_mode{ Default, ReproducibilityTest }; ///////////////////////////////////////////////////////////// // Base classes for iterative processes based on operators @@ -67,19 +67,20 @@ class ConjugateGradient : public OperatorFunction { CG_state CGState; //to check reproducibility by repeating the CG ReproducibilityState ReprTest; // for the inner proucts - ConjugateGradient(RealD tol, Integer maxit, bool err_on_no_conv = true, - bool ReproducibilityTest = false) - : Tolerance(tol), - MaxIterations(maxit), - ErrorOnNoConverge(err_on_no_conv), - ReproTest(ReproducibilityTest){ - if(ReproducibilityTest == true && err_on_no_conv == true){ - std::cout << GridLogMessage << "CG: Reproducibility test ON "<< - "and error on convergence ON are incompatible options" << std::endl; - exit(1); - } - - }; + // Constructor + ConjugateGradient(RealD tol, Integer maxit, CGexec_mode Mode = Default) + : Tolerance(tol),MaxIterations(maxit){ + switch(Mode) + { + case Default : + ErrorOnNoConverge = true; + ReproTest = false; + case ReproducibilityTest : + ErrorOnNoConverge = false; + ReproTest = true; + } + }; + void operator()(LinearOperatorBase &Linop, const Field &src, @@ -116,18 +117,12 @@ class ConjugateGradient : public OperatorFunction { cp = a; ssq = norm2(src); - std::cout << GridLogIterative << std::setprecision(4) - << "ConjugateGradient: guess " << guess << std::endl; - std::cout << GridLogIterative << std::setprecision(4) - << "ConjugateGradient: src " << ssq << std::endl; - std::cout << GridLogIterative << std::setprecision(4) - << "ConjugateGradient: mp " << d << std::endl; - std::cout << GridLogIterative << std::setprecision(4) - << "ConjugateGradient: mmp " << b << std::endl; - std::cout << GridLogIterative << std::setprecision(4) - << "ConjugateGradient: cp,r " << cp << std::endl; - std::cout << GridLogIterative << std::setprecision(4) - << "ConjugateGradient: p " << a << std::endl; + std::cout << GridLogIterative << "ConjugateGradient: guess " << guess << std::endl; + std::cout << GridLogIterative << "ConjugateGradient: src " << ssq << std::endl; + std::cout << GridLogIterative << "ConjugateGradient: mp " << d << std::endl; + std::cout << GridLogIterative << "ConjugateGradient: mmp " << b << std::endl; + std::cout << GridLogIterative << "ConjugateGradient: cp,r " << cp << std::endl; + std::cout << GridLogIterative << "ConjugateGradient: p " << a << std::endl; RealD rsq = Tolerance * Tolerance * ssq; @@ -162,7 +157,7 @@ class ConjugateGradient : public OperatorFunction { axpy(r, -a, mmp, r);// new residual r = r_old - a * Ap - cp = norm2(r, ReprTest); // + cp = norm2(r, ReprTest); // bookmarking this norm if (ReproTest && !CGState.do_repro) { CGState.residuals.push_back(cp); // save residuals state std::cout << GridLogIterative << "ReproTest: Saving state" << std::endl; diff --git a/tests/debug/Test_cayley_cg_reproducibility.cc b/tests/debug/Test_cayley_cg_reproducibility.cc index a0bb16f2..25a40fe6 100644 --- a/tests/debug/Test_cayley_cg_reproducibility.cc +++ b/tests/debug/Test_cayley_cg_reproducibility.cc @@ -168,7 +168,7 @@ void TestCGunprec(What & Ddwf, LatticeFermion result(FGrid); result=zero; MdagMLinearOperator HermOp(Ddwf); - ConjugateGradient CG(1.0e-8,10000, false, true); + ConjugateGradient CG(1.0e-8,10000, ReproducibilityTest); CG(HermOp,src,result); } @@ -187,7 +187,7 @@ void TestCGprec(What & Ddwf, result_o=zero; SchurDiagMooeeOperator HermOpEO(Ddwf); - ConjugateGradient CG(1.0e-8,10000, false, true); + ConjugateGradient CG(1.0e-8,10000, ReproducibilityTest); CG(HermOpEO,src_o,result_o); } @@ -203,7 +203,7 @@ void TestCGschur(What & Ddwf, LatticeFermion src (FGrid); random(*RNG5,src); LatticeFermion result(FGrid); result=zero; - ConjugateGradient CG(1.0e-8,10000, false, true); + ConjugateGradient CG(1.0e-8,10000, ReproducibilityTest); SchurRedBlackDiagMooeeSolve SchurSolver(CG); SchurSolver(Ddwf,src,result); } diff --git a/tests/hmc/Test_hmc_WilsonFermionGauge.cc b/tests/hmc/Test_hmc_WilsonFermionGauge.cc index dd6ab81b..b67b0222 100644 --- a/tests/hmc/Test_hmc_WilsonFermionGauge.cc +++ b/tests/hmc/Test_hmc_WilsonFermionGauge.cc @@ -64,9 +64,9 @@ class HmcRunner : public NerscHmcRunner { FermionAction FermOp(U, *FGrid, *FrbGrid, mass); // To enable the CG reproducibility tests use - ConjugateGradient CG(1.0e-8, 10000, false, true); + //ConjugateGradient CG(1.0e-8, 10000, ReproducibilityTest); // This is the plain version - //ConjugateGradient CG(1.0e-8, 10000); + ConjugateGradient CG(1.0e-8, 10000); TwoFlavourPseudoFermionAction Nf2(FermOp, CG, CG);