From 9f280b82c4e56d8b32034dfbb83187e15add41e9 Mon Sep 17 00:00:00 2001 From: Christopher Kelly Date: Tue, 25 Jul 2017 11:30:41 -0400 Subject: [PATCH] Added mixed-precision CG with reliable updates --- lib/algorithms/Algorithms.h | 1 + .../ConjugateGradientReliableUpdate.h | 231 ++++++++++++++++++ tests/Test_dwf_mixedcg_prec_halfcomms.cc | 40 ++- 3 files changed, 260 insertions(+), 12 deletions(-) create mode 100644 lib/algorithms/iterative/ConjugateGradientReliableUpdate.h diff --git a/lib/algorithms/Algorithms.h b/lib/algorithms/Algorithms.h index 5123c7a1..f8dc2dc2 100644 --- a/lib/algorithms/Algorithms.h +++ b/lib/algorithms/Algorithms.h @@ -44,6 +44,7 @@ Author: Peter Boyle #include #include #include +#include // Lanczos support //#include diff --git a/lib/algorithms/iterative/ConjugateGradientReliableUpdate.h b/lib/algorithms/iterative/ConjugateGradientReliableUpdate.h new file mode 100644 index 00000000..1aab064d --- /dev/null +++ b/lib/algorithms/iterative/ConjugateGradientReliableUpdate.h @@ -0,0 +1,231 @@ + /************************************************************************************* + + Grid physics library, www.github.com/paboyle/Grid + + Source file: ./lib/algorithms/iterative/ConjugateGradientReliableUpdate.h + + Copyright (C) 2015 + +Author: Christopher Kelly + + This program is free software; you can redistribute it and/or modify + it under the terms of the GNU General Public License as published by + the Free Software Foundation; either version 2 of the License, or + (at your option) any later version. + + This program is distributed in the hope that it will be useful, + but WITHOUT ANY WARRANTY; without even the implied warranty of + MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the + GNU General Public License for more details. + + You should have received a copy of the GNU General Public License along + with this program; if not, write to the Free Software Foundation, Inc., + 51 Franklin Street, Fifth Floor, Boston, MA 02110-1301 USA. + + See the full license in the file "LICENSE" in the top level distribution directory + *************************************************************************************/ + /* END LEGAL */ +#ifndef GRID_CONJUGATE_GRADIENT_RELIABLE_UPDATE_H +#define GRID_CONJUGATE_GRADIENT_RELIABLE_UPDATE_H + +namespace Grid { + + template::value == 2, int>::type = 0,typename std::enable_if< getPrecision::value == 1, int>::type = 0> + class ConjugateGradientReliableUpdate : public LinearFunction { + public: + bool ErrorOnNoConverge; // throw an assert when the CG fails to converge. + // Defaults true. + RealD Tolerance; + Integer MaxIterations; + Integer IterationsToComplete; //Number of iterations the CG took to finish. Filled in upon completion + Integer ReliableUpdatesPerformed; + + bool DoFinalCleanup; //Final DP cleanup, defaults to true + Integer IterationsToCleanup; //Final DP cleanup step iterations + + LinearOperatorBase &Linop_f; + LinearOperatorBase &Linop_d; + GridBase* SinglePrecGrid; + RealD Delta; //reliable update parameter + + ConjugateGradientReliableUpdate(RealD tol, Integer maxit, RealD _delta, GridBase* _sp_grid, LinearOperatorBase &_Linop_f, LinearOperatorBase &_Linop_d, bool err_on_no_conv = true) + : Tolerance(tol), + MaxIterations(maxit), + Delta(_delta), + Linop_f(_Linop_f), + Linop_d(_Linop_d), + SinglePrecGrid(_sp_grid), + ErrorOnNoConverge(err_on_no_conv), + DoFinalCleanup(true) + {}; + + void operator()(const FieldD &src, FieldD &psi) { + psi.checkerboard = src.checkerboard; + conformable(psi, src); + + RealD cp, c, a, d, b, ssq, qq, b_pred; + + FieldD p(src); + FieldD mmp(src); + FieldD r(src); + + // Initial residual computation & set up + RealD guess = norm2(psi); + assert(std::isnan(guess) == 0); + + Linop_d.HermOpAndNorm(psi, mmp, d, b); + + r = src - mmp; + p = r; + + a = norm2(p); + cp = a; + ssq = norm2(src); + + std::cout << GridLogIterative << std::setprecision(4) << "ConjugateGradientReliableUpdate: guess " << guess << std::endl; + std::cout << GridLogIterative << std::setprecision(4) << "ConjugateGradientReliableUpdate: src " << ssq << std::endl; + std::cout << GridLogIterative << std::setprecision(4) << "ConjugateGradientReliableUpdate: mp " << d << std::endl; + std::cout << GridLogIterative << std::setprecision(4) << "ConjugateGradientReliableUpdate: mmp " << b << std::endl; + std::cout << GridLogIterative << std::setprecision(4) << "ConjugateGradientReliableUpdate: cp,r " << cp << std::endl; + std::cout << GridLogIterative << std::setprecision(4) << "ConjugateGradientReliableUpdate: p " << a << std::endl; + + RealD rsq = Tolerance * Tolerance * ssq; + + // Check if guess is really REALLY good :) + if (cp <= rsq) { + return; + } + + //Single prec initialization + FieldF r_f(SinglePrecGrid); + r_f.checkerboard = r.checkerboard; + precisionChange(r_f, r); + + FieldF psi_f(r_f); + psi_f = zero; + + FieldF p_f(r_f); + FieldF mmp_f(r_f); + + RealD MaxResidSinceLastRelUp = cp; //initial residual + + std::cout << GridLogIterative << std::setprecision(4) + << "ConjugateGradient: k=0 residual " << cp << " target " << rsq << std::endl; + + GridStopWatch LinalgTimer; + GridStopWatch MatrixTimer; + GridStopWatch SolverTimer; + + SolverTimer.Start(); + int k = 0; + int l = 0; + + for (k = 1; k <= MaxIterations; k++) { + c = cp; + + MatrixTimer.Start(); + Linop_f.HermOpAndNorm(p_f, mmp_f, d, qq); + MatrixTimer.Stop(); + + LinalgTimer.Start(); + + a = c / d; + b_pred = a * (a * qq - d) / c; + + cp = axpy_norm(r_f, -a, mmp_f, r_f); + b = cp / c; + + // Fuse these loops ; should be really easy + psi_f = a * p_f + psi_f; + //p_f = p_f * b + r_f; + + LinalgTimer.Stop(); + + std::cout << GridLogIterative << "ConjugateGradientReliableUpdate: Iteration " << k + << " residual " << cp << " target " << rsq << std::endl; + std::cout << GridLogDebug << "a = "<< a << " b_pred = "<< b_pred << " b = "<< b << std::endl; + std::cout << GridLogDebug << "qq = "<< qq << " d = "<< d << " c = "<< c << std::endl; + + if(cp > MaxResidSinceLastRelUp){ + std::cout << GridLogIterative << "ConjugateGradientReliableUpdate: updating MaxResidSinceLastRelUp : " << MaxResidSinceLastRelUp << " -> " << cp << std::endl; + MaxResidSinceLastRelUp = cp; + } + + // Stopping condition + if (cp <= rsq) { + //Although not written in the paper, I assume that I have to add on the final solution + precisionChange(mmp, psi_f); + psi = psi + mmp; + + + SolverTimer.Stop(); + Linop_d.HermOpAndNorm(psi, mmp, d, qq); + p = mmp - src; + + RealD srcnorm = sqrt(norm2(src)); + RealD resnorm = sqrt(norm2(p)); + RealD true_residual = resnorm / srcnorm; + + std::cout << GridLogMessage << "ConjugateGradientReliableUpdate Converged on iteration " << k << " after " << l << " reliable updates" << std::endl; + std::cout << GridLogMessage << "\tComputed residual " << sqrt(cp / ssq)< CG(Tolerance,MaxIterations); + CG.ErrorOnNoConverge = ErrorOnNoConverge; + CG(Linop_d,src,psi); + IterationsToCleanup = CG.IterationsToComplete; + } + else if (ErrorOnNoConverge) assert(true_residual / Tolerance < 10000.0); + + std::cout << GridLogMessage << "ConjugateGradientReliableUpdate complete.\n"; + return; + } + else if(cp < Delta * MaxResidSinceLastRelUp) { //reliable update + std::cout << GridLogMessage << "ConjugateGradientReliableUpdate " + << cp << "(residual) < " << Delta << "(Delta) * " << MaxResidSinceLastRelUp << "(MaxResidSinceLastRelUp) on iteration " << k << " : performing reliable update\n"; + precisionChange(mmp, psi_f); + psi = psi + mmp; + + Linop_d.HermOpAndNorm(psi, mmp, d, qq); + r = src - mmp; + + psi_f = zero; + precisionChange(r_f, r); + cp = norm2(r); + MaxResidSinceLastRelUp = cp; + + std::cout << GridLogMessage << "ConjugateGradientReliableUpdate new residual " << cp << std::endl; + + l = l+1; + } + + p_f = p_f * b + r_f; //update search vector after reliable update appears to help convergence + + } + std::cout << GridLogMessage << "ConjugateGradientReliableUpdate did NOT converge" + << std::endl; + + if (ErrorOnNoConverge) assert(0); + IterationsToComplete = k; + ReliableUpdatesPerformed = l; + } + }; + + +}; + + + +#endif diff --git a/tests/Test_dwf_mixedcg_prec_halfcomms.cc b/tests/Test_dwf_mixedcg_prec_halfcomms.cc index 9cc935d9..d6aaa21e 100644 --- a/tests/Test_dwf_mixedcg_prec_halfcomms.cc +++ b/tests/Test_dwf_mixedcg_prec_halfcomms.cc @@ -80,31 +80,47 @@ int main (int argc, char ** argv) LatticeFermionD src_o(FrbGrid); - LatticeFermionD result_o(FrbGrid); - LatticeFermionD result_o_2(FrbGrid); + LatticeFermionD result_cg(FrbGrid); pickCheckerboard(Odd,src_o,src); - result_o.checkerboard = Odd; - result_o = zero; - result_o_2.checkerboard = Odd; - result_o_2 = zero; + result_cg.checkerboard = Odd; + result_cg = zero; + LatticeFermionD result_mcg(result_cg); + LatticeFermionD result_rlcg(result_cg); SchurDiagMooeeOperator HermOpEO(Ddwf); SchurDiagMooeeOperator HermOpEO_f(Ddwf_f); + //#define DO_MIXED_CG +#define DO_RLUP_CG + +#ifdef DO_MIXED_CG std::cout << "Starting mixed CG" << std::endl; MixedPrecisionConjugateGradient mCG(1.0e-8, 10000, 50, FrbGrid_f, HermOpEO_f, HermOpEO); mCG.InnerTolerance = 3.0e-5; - mCG(src_o,result_o); + mCG(src_o,result_mcg); +#endif +#ifdef DO_RLUP_CG + std::cout << "Starting reliable update CG" << std::endl; + ConjugateGradientReliableUpdate rlCG(1.e-8, 10000, 0.1, FrbGrid_f, HermOpEO_f, HermOpEO); + rlCG(src_o,result_rlcg); +#endif + std::cout << "Starting regular CG" << std::endl; ConjugateGradient CG(1.0e-8,10000); - CG(HermOpEO,src_o,result_o_2); + CG(HermOpEO,src_o,result_cg); - LatticeFermionD diff_o(FrbGrid); - RealD diff = axpy_norm(diff_o, -1.0, result_o, result_o_2); - - std::cout << "Diff between mixed and regular CG: " << diff << std::endl; +#ifdef DO_MIXED_CG + LatticeFermionD diff_mcg(FrbGrid); + RealD vdiff_mcg = axpy_norm(diff_mcg, -1.0, result_cg, result_mcg); + std::cout << "Diff between mixed and regular CG: " << vdiff_mcg << std::endl; +#endif +#ifdef DO_RLUP_CG + LatticeFermionD diff_rlcg(FrbGrid); + RealD vdiff_rlcg = axpy_norm(diff_rlcg, -1.0, result_cg, result_rlcg); + std::cout << "Diff between reliable update and regular CG: " << vdiff_rlcg << std::endl; +#endif Grid_finalize(); }