diff --git a/Grid/algorithms/iterative/ConjugateGradient.h b/Grid/algorithms/iterative/ConjugateGradient.h index dc812cb6..8b4c8fc5 100644 --- a/Grid/algorithms/iterative/ConjugateGradient.h +++ b/Grid/algorithms/iterative/ConjugateGradient.h @@ -54,11 +54,14 @@ public: ConjugateGradient(RealD tol, Integer maxit, bool err_on_no_conv = true) : Tolerance(tol), MaxIterations(maxit), - ErrorOnNoConverge(err_on_no_conv){}; + ErrorOnNoConverge(err_on_no_conv) + {}; void operator()(LinearOperatorBase &Linop, const Field &src, Field &psi) { GRID_TRACE("ConjugateGradient"); + GridStopWatch PreambleTimer; + PreambleTimer.Start(); psi.Checkerboard() = src.Checkerboard(); conformable(psi, src); @@ -66,22 +69,26 @@ public: RealD cp, c, a, d, b, ssq, qq; //RealD b_pred; - Field p(src); - Field mmp(src); - Field r(src); + // Was doing copies + Field p(src.Grid()); + Field mmp(src.Grid()); + Field r(src.Grid()); // Initial residual computation & set up + ssq = norm2(src); RealD guess = norm2(psi); assert(std::isnan(guess) == 0); - - Linop.HermOpAndNorm(psi, mmp, d, b); - - r = src - mmp; - p = r; - - a = norm2(p); + if ( guess == 0.0 ) { + r = src; + p = r; + a = ssq; + } else { + Linop.HermOpAndNorm(psi, mmp, d, b); + r = src - mmp; + p = r; + a = norm2(p); + } cp = a; - ssq = norm2(src); // Handle trivial case of zero src if (ssq == 0.){ @@ -103,7 +110,7 @@ public: // Check if guess is really REALLY good :) if (cp <= rsq) { TrueResidual = std::sqrt(a/ssq); - std::cout << GridLogMessage << "ConjugateGradient guess is converged already : cp " << cp <<" rsq "<