1
0
mirror of https://github.com/paboyle/Grid.git synced 2024-11-09 23:45:36 +00:00

CG improvements for smoother

This commit is contained in:
Peter Boyle 2024-05-16 10:55:18 -04:00
parent 0ac85fa70b
commit 14e9d8ed9f

View File

@ -54,11 +54,14 @@ public:
ConjugateGradient(RealD tol, Integer maxit, bool err_on_no_conv = true)
: Tolerance(tol),
MaxIterations(maxit),
ErrorOnNoConverge(err_on_no_conv){};
ErrorOnNoConverge(err_on_no_conv)
{};
void operator()(LinearOperatorBase<Field> &Linop, const Field &src, Field &psi) {
GRID_TRACE("ConjugateGradient");
GridStopWatch PreambleTimer;
PreambleTimer.Start();
psi.Checkerboard() = src.Checkerboard();
conformable(psi, src);
@ -66,22 +69,26 @@ public:
RealD cp, c, a, d, b, ssq, qq;
//RealD b_pred;
Field p(src);
Field mmp(src);
Field r(src);
// Was doing copies
Field p(src.Grid());
Field mmp(src.Grid());
Field r(src.Grid());
// Initial residual computation & set up
ssq = norm2(src);
RealD guess = norm2(psi);
assert(std::isnan(guess) == 0);
Linop.HermOpAndNorm(psi, mmp, d, b);
r = src - mmp;
p = r;
a = norm2(p);
if ( guess == 0.0 ) {
r = src;
p = r;
a = ssq;
} else {
Linop.HermOpAndNorm(psi, mmp, d, b);
r = src - mmp;
p = r;
a = norm2(p);
}
cp = a;
ssq = norm2(src);
// Handle trivial case of zero src
if (ssq == 0.){
@ -103,7 +110,7 @@ public:
// Check if guess is really REALLY good :)
if (cp <= rsq) {
TrueResidual = std::sqrt(a/ssq);
std::cout << GridLogMessage << "ConjugateGradient guess is converged already : cp " << cp <<" rsq "<<rsq <<" ssq "<<ssq<< std::endl;
std::cout << GridLogMessage << "ConjugateGradient guess is converged already " << std::endl;
IterationsToComplete = 0;
return;
}
@ -111,6 +118,7 @@ public:
std::cout << GridLogIterative << std::setprecision(8)
<< "ConjugateGradient: k=0 residual " << cp << " target " << rsq << std::endl;
PreambleTimer.Stop();
GridStopWatch LinalgTimer;
GridStopWatch InnerTimer;
GridStopWatch AxpyNormTimer;
@ -183,7 +191,8 @@ public:
<< "\tTrue residual " << true_residual
<< "\tTarget " << Tolerance << std::endl;
std::cout << GridLogMessage << "\tElapsed " << SolverTimer.Elapsed() <<std::endl;
// std::cout << GridLogMessage << "\tPreamble " << PreambleTimer.Elapsed() <<std::endl;
std::cout << GridLogMessage << "\tSolver Elapsed " << SolverTimer.Elapsed() <<std::endl;
std::cout << GridLogPerformance << "Time breakdown "<<std::endl;
std::cout << GridLogPerformance << "\tMatrix " << MatrixTimer.Elapsed() <<std::endl;
std::cout << GridLogPerformance << "\tLinalg " << LinalgTimer.Elapsed() <<std::endl;
@ -202,13 +211,22 @@ public:
}
}
// Failed. Calculate true residual before giving up
Linop.HermOpAndNorm(psi, mmp, d, qq);
p = mmp - src;
TrueResidual = sqrt(norm2(p)/ssq);
// Linop.HermOpAndNorm(psi, mmp, d, qq);
// p = mmp - src;
//TrueResidual = sqrt(norm2(p)/ssq);
// TrueResidual = 1;
std::cout << GridLogMessage << "ConjugateGradient did NOT converge "<<k<<" / "<< MaxIterations
<<" residual "<< TrueResidual<< std::endl;
<<" residual "<< std::sqrt(cp / ssq)<< std::endl;
SolverTimer.Stop();
std::cout << GridLogMessage << "\tPreamble " << PreambleTimer.Elapsed() <<std::endl;
std::cout << GridLogMessage << "\tSolver " << SolverTimer.Elapsed() <<std::endl;
std::cout << GridLogMessage << "Solver breakdown "<<std::endl;
std::cout << GridLogMessage << "\tMatrix " << MatrixTimer.Elapsed() <<std::endl;
std::cout << GridLogMessage<< "\tLinalg " << LinalgTimer.Elapsed() <<std::endl;
std::cout << GridLogPerformance << "\t\tInner " << InnerTimer.Elapsed() <<std::endl;
std::cout << GridLogPerformance << "\t\tAxpyNorm " << AxpyNormTimer.Elapsed() <<std::endl;
std::cout << GridLogPerformance << "\t\tLinearComb " << LinearCombTimer.Elapsed() <<std::endl;
if (ErrorOnNoConverge) assert(0);
IterationsToComplete = k;