mirror of
https://github.com/paboyle/Grid.git
synced 2025-06-14 13:57:07 +01:00
First implementation of the CG reproducibility test
This commit is contained in:
@ -33,6 +33,17 @@ directory
|
||||
|
||||
namespace Grid {
|
||||
|
||||
struct CG_state{
|
||||
bool do_repro;
|
||||
std::vector<RealD> residuals;
|
||||
|
||||
CG_state(){
|
||||
do_repro = false;
|
||||
residuals.clear();}
|
||||
|
||||
};
|
||||
|
||||
|
||||
/////////////////////////////////////////////////////////////
|
||||
// Base classes for iterative processes based on operators
|
||||
// single input vec, single output vec.
|
||||
@ -45,10 +56,16 @@ class ConjugateGradient : public OperatorFunction<Field> {
|
||||
// Defaults true.
|
||||
RealD Tolerance;
|
||||
Integer MaxIterations;
|
||||
ConjugateGradient(RealD tol, Integer maxit, bool err_on_no_conv = true)
|
||||
bool ReproTest;
|
||||
CG_state CGState;
|
||||
|
||||
ConjugateGradient(RealD tol, Integer maxit, bool err_on_no_conv = true,
|
||||
bool ReproducibilityTest = false)
|
||||
: Tolerance(tol),
|
||||
MaxIterations(maxit),
|
||||
ErrorOnNoConverge(err_on_no_conv){};
|
||||
ErrorOnNoConverge(err_on_no_conv),
|
||||
ReproTest(ReproducibilityTest){};
|
||||
|
||||
|
||||
void operator()(LinearOperatorBase<Field> &Linop, const Field &src,
|
||||
Field &psi) {
|
||||
@ -60,6 +77,10 @@ class ConjugateGradient : public OperatorFunction<Field> {
|
||||
Field p(src);
|
||||
Field mmp(src);
|
||||
Field r(src);
|
||||
Field psi_start(psi);// save for the repro test
|
||||
|
||||
if (CGState.do_repro)
|
||||
std::cout << GridLogMessage << "Starting reproducibility test" << std::endl;
|
||||
|
||||
// Initial residual computation & set up
|
||||
RealD guess = norm2(psi);
|
||||
@ -107,10 +128,10 @@ class ConjugateGradient : public OperatorFunction<Field> {
|
||||
SolverTimer.Start();
|
||||
int k;
|
||||
for (k = 1; k <= MaxIterations; k++) {
|
||||
c = cp;
|
||||
c = cp;// old residual
|
||||
|
||||
MatrixTimer.Start();
|
||||
Linop.HermOpAndNorm(p, mmp, d, qq);
|
||||
Linop.HermOpAndNorm(p, mmp, d, qq);// mmp = Ap, d=pAp
|
||||
MatrixTimer.Stop();
|
||||
|
||||
LinalgTimer.Start();
|
||||
@ -118,14 +139,30 @@ class ConjugateGradient : public OperatorFunction<Field> {
|
||||
// ComplexD dck = innerProduct(p,mmp);
|
||||
|
||||
a = c / d;
|
||||
b_pred = a * (a * qq - d) / c;
|
||||
b_pred = a * (a * qq - d) / c;// a check
|
||||
|
||||
cp = axpy_norm(r, -a, mmp, r);
|
||||
|
||||
cp = axpy_norm(r, -a, mmp, r);// new residual r = r_old - a * Ap
|
||||
if (ReproTest && !CGState.do_repro) {
|
||||
CGState.residuals.push_back(cp); // save residuals state
|
||||
std::cout << GridLogIterative << "ReproTest: Saving state" << std::endl;
|
||||
}
|
||||
if (ReproTest && CGState.do_repro){
|
||||
// check that the residual agrees with the previous run
|
||||
std::cout << GridLogIterative << "ReproTest: Checking state k=" << k << std::endl;
|
||||
if (cp != CGState.residuals[k-1]){
|
||||
std::cout << GridLogMessage << "Failing reproducibility test";
|
||||
std::cout << GridLogMessage << " at k=" << k << std::endl;
|
||||
std::cout << GridLogMessage << "saved residual = " << CGState.residuals[k-1]
|
||||
<< " cp = " << cp << std::endl;
|
||||
exit(-1);
|
||||
}
|
||||
}
|
||||
b = cp / c;
|
||||
|
||||
// Fuse these loops ; should be really easy
|
||||
psi = a * p + psi;
|
||||
p = p * b + r;
|
||||
psi = a * p + psi; // update solution
|
||||
p = p * b + r; // update search direction
|
||||
|
||||
LinalgTimer.Stop();
|
||||
std::cout << GridLogIterative << "ConjugateGradient: Iteration " << k
|
||||
@ -156,6 +193,14 @@ class ConjugateGradient : public OperatorFunction<Field> {
|
||||
|
||||
if (ErrorOnNoConverge) assert(true_residual / Tolerance < 1000.0);
|
||||
|
||||
if (!CGState.do_repro && ReproTest){
|
||||
CGState.do_repro = true;
|
||||
this->operator()(Linop, src, psi_start);// run the repro test
|
||||
}
|
||||
|
||||
// Clear state
|
||||
CGState.residuals.clear();
|
||||
CGState.do_repro = false;
|
||||
return;
|
||||
}
|
||||
}
|
||||
|
@ -71,7 +71,7 @@ int main (int argc, char ** argv)
|
||||
WilsonFermionR Dw(Umu,Grid,RBGrid,mass);
|
||||
|
||||
MdagMLinearOperator<WilsonFermionR,LatticeFermion> HermOp(Dw);
|
||||
ConjugateGradient<LatticeFermion> CG(1.0e-8,10000);
|
||||
ConjugateGradient<LatticeFermion> CG(1.0e-8,10000,true, true);
|
||||
CG(HermOp,src,result);
|
||||
|
||||
Grid_finalize();
|
||||
|
Reference in New Issue
Block a user