1
0
mirror of https://github.com/paboyle/Grid.git synced 2025-06-16 06:47:06 +01:00

Reproducibility checks for inner product

This commit is contained in:
Guido Cossu
2016-11-23 11:42:04 +00:00
parent f1908c7bc9
commit 7144ee7ae8
6 changed files with 117 additions and 56 deletions

View File

@ -62,9 +62,10 @@ class ConjugateGradient : public OperatorFunction<Field> {
Integer MaxIterations;
bool ReproTest;
CG_state CGState;//to check reproducibility by repeating the CG
ReproducibilityState<typename Field::vector_object> ReprTest;
ConjugateGradient(RealD tol, Integer maxit, bool err_on_no_conv = true,
bool ReproducibilityTest = false)
bool ReproducibilityTest = false)
: Tolerance(tol),
MaxIterations(maxit),
ErrorOnNoConverge(err_on_no_conv),
@ -84,7 +85,7 @@ class ConjugateGradient : public OperatorFunction<Field> {
Field psi_start(psi);// save for the repro test
if (CGState.do_repro)
std::cout << GridLogMessage << "Starting reproducibility test" << std::endl;
std::cout << GridLogMessage << "Starting reproducibility test" << std::endl;
// Initial residual computation & set up
RealD guess = norm2(psi);
@ -93,6 +94,10 @@ class ConjugateGradient : public OperatorFunction<Field> {
Linop.HermOpAndNorm(psi, mmp, d, b);
if(!ReprTest.do_check)
ReprTest.reset();
ReprTest.enable_reprocheck=ReproTest;
r = src - mmp;
p = r;
@ -146,21 +151,22 @@ class ConjugateGradient : public OperatorFunction<Field> {
b_pred = a * (a * qq - d) / c;// a check
cp = axpy_norm(r, -a, mmp, r);// new residual r = r_old - a * Ap
axpy(r, -a, mmp, r);// new residual r = r_old - a * Ap
cp = norm2(r, ReprTest);//
if (ReproTest && !CGState.do_repro) {
CGState.residuals.push_back(cp); // save residuals state
std::cout << GridLogIterative << "ReproTest: Saving state" << std::endl;
}
CGState.residuals.push_back(cp); // save residuals state
std::cout << GridLogIterative << "ReproTest: Saving state" << std::endl;
}
if (ReproTest && CGState.do_repro){
// check that the residual agrees with the previous run
std::cout << GridLogIterative << "ReproTest: Checking state k=" << k << std::endl;
if (cp != CGState.residuals[k-1]){
std::cout << GridLogMessage << "Failing reproducibility test";
std::cout << GridLogMessage << " at k=" << k << std::endl;
std::cout << GridLogMessage << "saved residual = " << CGState.residuals[k-1]
<< " cp = " << cp << std::endl;
exit(-1);
}
// check that the residual agrees with the previous run
std::cout << GridLogIterative << "ReproTest: Checking state k=" << k << std::endl;
if (cp != CGState.residuals[k-1]){
std::cout << GridLogMessage << "Failing reproducibility test";
std::cout << GridLogMessage << " at k=" << k << std::endl;
std::cout << GridLogMessage << "saved residual = " << CGState.residuals[k-1]
<< " cp = " << cp << std::endl;
exit(-1);
}
}
b = cp / c;
@ -197,13 +203,16 @@ class ConjugateGradient : public OperatorFunction<Field> {
if (ErrorOnNoConverge) assert(true_residual / Tolerance < 10000.0);
if (!CGState.do_repro && ReproTest){
CGState.do_repro = true;
this->operator()(Linop, src, psi_start);// run the repro test
}
if (!CGState.do_repro && ReproTest){
CGState.do_repro = true;
ReprTest.do_check = true;
ReprTest.reset_counter();
this->operator()(Linop, src, psi_start);// run the repro test
}
// Clear state
CGState.reset();
// Clear state
CGState.reset();
ReprTest.reset();
return;
}
}