1
0
mirror of https://github.com/paboyle/Grid.git synced 2025-06-17 07:17:06 +01:00

Improvement in the CG interface for Repro

This commit is contained in:
Guido Cossu
2016-12-09 05:20:38 +00:00
parent 6ceee102e8
commit ec0c53fa68
3 changed files with 27 additions and 32 deletions

View File

@ -47,7 +47,7 @@ struct CG_state {
}; };
enum CGexec_modes{ Default, ReproducibilityTest }; enum CGexec_mode{ Default, ReproducibilityTest };
///////////////////////////////////////////////////////////// /////////////////////////////////////////////////////////////
// Base classes for iterative processes based on operators // Base classes for iterative processes based on operators
@ -67,21 +67,22 @@ class ConjugateGradient : public OperatorFunction<Field> {
CG_state CGState; //to check reproducibility by repeating the CG CG_state CGState; //to check reproducibility by repeating the CG
ReproducibilityState<typename Field::vector_object> ReprTest; // for the inner proucts ReproducibilityState<typename Field::vector_object> ReprTest; // for the inner proucts
ConjugateGradient(RealD tol, Integer maxit, bool err_on_no_conv = true, // Constructor
bool ReproducibilityTest = false) ConjugateGradient(RealD tol, Integer maxit, CGexec_mode Mode = Default)
: Tolerance(tol), : Tolerance(tol),MaxIterations(maxit){
MaxIterations(maxit), switch(Mode)
ErrorOnNoConverge(err_on_no_conv), {
ReproTest(ReproducibilityTest){ case Default :
if(ReproducibilityTest == true && err_on_no_conv == true){ ErrorOnNoConverge = true;
std::cout << GridLogMessage << "CG: Reproducibility test ON "<< ReproTest = false;
"and error on convergence ON are incompatible options" << std::endl; case ReproducibilityTest :
exit(1); ErrorOnNoConverge = false;
ReproTest = true;
} }
}; };
void operator()(LinearOperatorBase<Field> &Linop, const Field &src, void operator()(LinearOperatorBase<Field> &Linop, const Field &src,
Field &psi) { Field &psi) {
psi.checkerboard = src.checkerboard; psi.checkerboard = src.checkerboard;
@ -116,18 +117,12 @@ class ConjugateGradient : public OperatorFunction<Field> {
cp = a; cp = a;
ssq = norm2(src); ssq = norm2(src);
std::cout << GridLogIterative << std::setprecision(4) std::cout << GridLogIterative << "ConjugateGradient: guess " << guess << std::endl;
<< "ConjugateGradient: guess " << guess << std::endl; std::cout << GridLogIterative << "ConjugateGradient: src " << ssq << std::endl;
std::cout << GridLogIterative << std::setprecision(4) std::cout << GridLogIterative << "ConjugateGradient: mp " << d << std::endl;
<< "ConjugateGradient: src " << ssq << std::endl; std::cout << GridLogIterative << "ConjugateGradient: mmp " << b << std::endl;
std::cout << GridLogIterative << std::setprecision(4) std::cout << GridLogIterative << "ConjugateGradient: cp,r " << cp << std::endl;
<< "ConjugateGradient: mp " << d << std::endl; std::cout << GridLogIterative << "ConjugateGradient: p " << a << std::endl;
std::cout << GridLogIterative << std::setprecision(4)
<< "ConjugateGradient: mmp " << b << std::endl;
std::cout << GridLogIterative << std::setprecision(4)
<< "ConjugateGradient: cp,r " << cp << std::endl;
std::cout << GridLogIterative << std::setprecision(4)
<< "ConjugateGradient: p " << a << std::endl;
RealD rsq = Tolerance * Tolerance * ssq; RealD rsq = Tolerance * Tolerance * ssq;
@ -162,7 +157,7 @@ class ConjugateGradient : public OperatorFunction<Field> {
axpy(r, -a, mmp, r);// new residual r = r_old - a * Ap axpy(r, -a, mmp, r);// new residual r = r_old - a * Ap
cp = norm2(r, ReprTest); // cp = norm2(r, ReprTest); // bookmarking this norm
if (ReproTest && !CGState.do_repro) { if (ReproTest && !CGState.do_repro) {
CGState.residuals.push_back(cp); // save residuals state CGState.residuals.push_back(cp); // save residuals state
std::cout << GridLogIterative << "ReproTest: Saving state" << std::endl; std::cout << GridLogIterative << "ReproTest: Saving state" << std::endl;

View File

@ -168,7 +168,7 @@ void TestCGunprec(What & Ddwf,
LatticeFermion result(FGrid); result=zero; LatticeFermion result(FGrid); result=zero;
MdagMLinearOperator<What,LatticeFermion> HermOp(Ddwf); MdagMLinearOperator<What,LatticeFermion> HermOp(Ddwf);
ConjugateGradient<LatticeFermion> CG(1.0e-8,10000, false, true); ConjugateGradient<LatticeFermion> CG(1.0e-8,10000, ReproducibilityTest);
CG(HermOp,src,result); CG(HermOp,src,result);
} }
@ -187,7 +187,7 @@ void TestCGprec(What & Ddwf,
result_o=zero; result_o=zero;
SchurDiagMooeeOperator<What,LatticeFermion> HermOpEO(Ddwf); SchurDiagMooeeOperator<What,LatticeFermion> HermOpEO(Ddwf);
ConjugateGradient<LatticeFermion> CG(1.0e-8,10000, false, true); ConjugateGradient<LatticeFermion> CG(1.0e-8,10000, ReproducibilityTest);
CG(HermOpEO,src_o,result_o); CG(HermOpEO,src_o,result_o);
} }
@ -203,7 +203,7 @@ void TestCGschur(What & Ddwf,
LatticeFermion src (FGrid); random(*RNG5,src); LatticeFermion src (FGrid); random(*RNG5,src);
LatticeFermion result(FGrid); result=zero; LatticeFermion result(FGrid); result=zero;
ConjugateGradient<LatticeFermion> CG(1.0e-8,10000, false, true); ConjugateGradient<LatticeFermion> CG(1.0e-8,10000, ReproducibilityTest);
SchurRedBlackDiagMooeeSolve<LatticeFermion> SchurSolver(CG); SchurRedBlackDiagMooeeSolve<LatticeFermion> SchurSolver(CG);
SchurSolver(Ddwf,src,result); SchurSolver(Ddwf,src,result);
} }

View File

@ -64,9 +64,9 @@ class HmcRunner : public NerscHmcRunner {
FermionAction FermOp(U, *FGrid, *FrbGrid, mass); FermionAction FermOp(U, *FGrid, *FrbGrid, mass);
// To enable the CG reproducibility tests use // To enable the CG reproducibility tests use
ConjugateGradient<FermionField> CG(1.0e-8, 10000, false, true); //ConjugateGradient<FermionField> CG(1.0e-8, 10000, ReproducibilityTest);
// This is the plain version // This is the plain version
//ConjugateGradient<FermionField> CG(1.0e-8, 10000); ConjugateGradient<FermionField> CG(1.0e-8, 10000);
TwoFlavourPseudoFermionAction<ImplPolicy> Nf2(FermOp, CG, CG); TwoFlavourPseudoFermionAction<ImplPolicy> Nf2(FermOp, CG, CG);