diff --git a/Grid/algorithms/iterative/ConjugateGradient.h b/Grid/algorithms/iterative/ConjugateGradient.h index 8b4c8fc5..65a77d83 100644 --- a/Grid/algorithms/iterative/ConjugateGradient.h +++ b/Grid/algorithms/iterative/ConjugateGradient.h @@ -38,12 +38,13 @@ NAMESPACE_BEGIN(Grid); // single input vec, single output vec. ///////////////////////////////////////////////////////////// + template class ConjugateGradient : public OperatorFunction { public: using OperatorFunction::operator(); - + bool ErrorOnNoConverge; // throw an assert when the CG fails to converge. // Defaults true. RealD Tolerance; @@ -57,10 +58,22 @@ public: ErrorOnNoConverge(err_on_no_conv) {}; - void operator()(LinearOperatorBase &Linop, const Field &src, Field &psi) { + virtual void LogIteration(int k,RealD a,RealD b){ + // std::cout << "ConjugageGradient::LogIteration() "< &Linop, const Field &src, Field &psi) { + + this->LogBegin(); + + GRID_TRACE("ConjugateGradient"); GridStopWatch PreambleTimer; + GridStopWatch ConstructTimer; + GridStopWatch NormTimer; + GridStopWatch AssignTimer; PreambleTimer.Start(); psi.Checkerboard() = src.Checkerboard(); @@ -70,14 +83,19 @@ public: //RealD b_pred; // Was doing copies - Field p(src.Grid()); + ConstructTimer.Start(); + Field p (src.Grid()); Field mmp(src.Grid()); - Field r(src.Grid()); + Field r (src.Grid()); + ConstructTimer.Stop(); // Initial residual computation & set up + NormTimer.Start(); ssq = norm2(src); RealD guess = norm2(psi); + NormTimer.Stop(); assert(std::isnan(guess) == 0); + AssignTimer.Start(); if ( guess == 0.0 ) { r = src; p = r; @@ -89,6 +107,7 @@ public: a = norm2(p); } cp = a; + AssignTimer.Stop(); // Handle trivial case of zero src if (ssq == 0.){ @@ -164,6 +183,7 @@ public: } LinearCombTimer.Stop(); LinalgTimer.Stop(); + LogIteration(k,a,b); IterationTimer.Stop(); if ( (k % 500) == 0 ) { @@ -220,6 +240,9 @@ public: <<" residual "<< std::sqrt(cp / ssq)<< std::endl; SolverTimer.Stop(); std::cout << GridLogMessage << "\tPreamble " << PreambleTimer.Elapsed() < +class ConjugateGradientPolynomial : public ConjugateGradient { +public: + // Optionally record the CG polynomial + std::vector ak; + std::vector bk; + std::vector poly_p; + std::vector poly_r; + std::vector poly_Ap; + std::vector polynomial; + +public: + ConjugateGradientPolynomial(RealD tol, Integer maxit, bool err_on_no_conv = true) + : ConjugateGradient(tol,maxit,err_on_no_conv) + { }; + void PolyHermOp(LinearOperatorBase &Linop, const Field &src, Field &psi) + { + Field tmp(src.Grid()); + Field AtoN(src.Grid()); + AtoN = src; + psi=AtoN*polynomial[0]; + for(int n=1;n &Linop, const Field &src, Field &x) + { + Field Ap(src.Grid()); + Field r(src.Grid()); + Field p(src.Grid()); + p=src; + r=src; + x=Zero(); + x.Checkerboard()=src.Checkerboard(); + for(int k=0;k &Linop, const Field &src, Field &psi) + { + psi=Zero(); + this->operator ()(Linop,src,psi); + } + virtual void LogBegin(void) + { + std::cout << "ConjugageGradientPolynomial::LogBegin() "< shift poly P right by 1 and add 0. + // x = x + a p ==> add polynomials term by term + // r = r - a A p ==> add polynomials term by term + // p = r + b p ==> add polynomials term by term + // + std::cout << "ConjugageGradientPolynomial::LogIteration() "<