mirror of
https://github.com/paboyle/Grid.git
synced 2025-06-16 06:47:06 +01:00
Reproducibility checks for inner product
This commit is contained in:
@ -62,9 +62,10 @@ class ConjugateGradient : public OperatorFunction<Field> {
|
||||
Integer MaxIterations;
|
||||
bool ReproTest;
|
||||
CG_state CGState;//to check reproducibility by repeating the CG
|
||||
ReproducibilityState<typename Field::vector_object> ReprTest;
|
||||
|
||||
ConjugateGradient(RealD tol, Integer maxit, bool err_on_no_conv = true,
|
||||
bool ReproducibilityTest = false)
|
||||
bool ReproducibilityTest = false)
|
||||
: Tolerance(tol),
|
||||
MaxIterations(maxit),
|
||||
ErrorOnNoConverge(err_on_no_conv),
|
||||
@ -84,7 +85,7 @@ class ConjugateGradient : public OperatorFunction<Field> {
|
||||
Field psi_start(psi);// save for the repro test
|
||||
|
||||
if (CGState.do_repro)
|
||||
std::cout << GridLogMessage << "Starting reproducibility test" << std::endl;
|
||||
std::cout << GridLogMessage << "Starting reproducibility test" << std::endl;
|
||||
|
||||
// Initial residual computation & set up
|
||||
RealD guess = norm2(psi);
|
||||
@ -93,6 +94,10 @@ class ConjugateGradient : public OperatorFunction<Field> {
|
||||
|
||||
Linop.HermOpAndNorm(psi, mmp, d, b);
|
||||
|
||||
if(!ReprTest.do_check)
|
||||
ReprTest.reset();
|
||||
ReprTest.enable_reprocheck=ReproTest;
|
||||
|
||||
|
||||
r = src - mmp;
|
||||
p = r;
|
||||
@ -146,21 +151,22 @@ class ConjugateGradient : public OperatorFunction<Field> {
|
||||
b_pred = a * (a * qq - d) / c;// a check
|
||||
|
||||
|
||||
cp = axpy_norm(r, -a, mmp, r);// new residual r = r_old - a * Ap
|
||||
axpy(r, -a, mmp, r);// new residual r = r_old - a * Ap
|
||||
cp = norm2(r, ReprTest);//
|
||||
if (ReproTest && !CGState.do_repro) {
|
||||
CGState.residuals.push_back(cp); // save residuals state
|
||||
std::cout << GridLogIterative << "ReproTest: Saving state" << std::endl;
|
||||
}
|
||||
CGState.residuals.push_back(cp); // save residuals state
|
||||
std::cout << GridLogIterative << "ReproTest: Saving state" << std::endl;
|
||||
}
|
||||
if (ReproTest && CGState.do_repro){
|
||||
// check that the residual agrees with the previous run
|
||||
std::cout << GridLogIterative << "ReproTest: Checking state k=" << k << std::endl;
|
||||
if (cp != CGState.residuals[k-1]){
|
||||
std::cout << GridLogMessage << "Failing reproducibility test";
|
||||
std::cout << GridLogMessage << " at k=" << k << std::endl;
|
||||
std::cout << GridLogMessage << "saved residual = " << CGState.residuals[k-1]
|
||||
<< " cp = " << cp << std::endl;
|
||||
exit(-1);
|
||||
}
|
||||
// check that the residual agrees with the previous run
|
||||
std::cout << GridLogIterative << "ReproTest: Checking state k=" << k << std::endl;
|
||||
if (cp != CGState.residuals[k-1]){
|
||||
std::cout << GridLogMessage << "Failing reproducibility test";
|
||||
std::cout << GridLogMessage << " at k=" << k << std::endl;
|
||||
std::cout << GridLogMessage << "saved residual = " << CGState.residuals[k-1]
|
||||
<< " cp = " << cp << std::endl;
|
||||
exit(-1);
|
||||
}
|
||||
}
|
||||
b = cp / c;
|
||||
|
||||
@ -197,13 +203,16 @@ class ConjugateGradient : public OperatorFunction<Field> {
|
||||
|
||||
if (ErrorOnNoConverge) assert(true_residual / Tolerance < 10000.0);
|
||||
|
||||
if (!CGState.do_repro && ReproTest){
|
||||
CGState.do_repro = true;
|
||||
this->operator()(Linop, src, psi_start);// run the repro test
|
||||
}
|
||||
if (!CGState.do_repro && ReproTest){
|
||||
CGState.do_repro = true;
|
||||
ReprTest.do_check = true;
|
||||
ReprTest.reset_counter();
|
||||
this->operator()(Linop, src, psi_start);// run the repro test
|
||||
}
|
||||
|
||||
// Clear state
|
||||
CGState.reset();
|
||||
// Clear state
|
||||
CGState.reset();
|
||||
ReprTest.reset();
|
||||
return;
|
||||
}
|
||||
}
|
||||
|
Reference in New Issue
Block a user