/************************************************************************************* Grid physics library, www.github.com/paboyle/Grid Source file: ./lib/algorithms/iterative/ConjugateGradientMixedPrecBatched.h Copyright (C) 2015 Author: Raoul Hodgson This program is free software; you can redistribute it and/or modify it under the terms of the GNU General Public License as published by the Free Software Foundation; either version 2 of the License, or (at your option) any later version. This program is distributed in the hope that it will be useful, but WITHOUT ANY WARRANTY; without even the implied warranty of MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the GNU General Public License for more details. You should have received a copy of the GNU General Public License along with this program; if not, write to the Free Software Foundation, Inc., 51 Franklin Street, Fifth Floor, Boston, MA 02110-1301 USA. See the full license in the file "LICENSE" in the top level distribution directory *************************************************************************************/ /* END LEGAL */ #ifndef GRID_CONJUGATE_GRADIENT_MIXED_PREC_BATCHED_H #define GRID_CONJUGATE_GRADIENT_MIXED_PREC_BATCHED_H NAMESPACE_BEGIN(Grid); //Mixed precision restarted defect correction CG template::value == 2, int>::type = 0, typename std::enable_if< getPrecision::value == 1, int>::type = 0> class MixedPrecisionConjugateGradientBatched : public LinearFunction { public: using LinearFunction::operator(); RealD Tolerance; RealD InnerTolerance; //Initial tolerance for inner CG. Defaults to Tolerance but can be changed Integer MaxInnerIterations; Integer MaxOuterIterations; Integer MaxPatchupIterations; GridBase* SinglePrecGrid; //Grid for single-precision fields RealD OuterLoopNormMult; //Stop the outer loop and move to a final double prec solve when the residual is OuterLoopNormMult * Tolerance LinearOperatorBase &Linop_f; LinearOperatorBase &Linop_d; //Option to speed up *inner single precision* solves using a LinearFunction that produces a guess LinearFunction *guesser; bool updateResidual; MixedPrecisionConjugateGradientBatched(RealD tol, Integer maxinnerit, Integer maxouterit, Integer maxpatchit, GridBase* _sp_grid, LinearOperatorBase &_Linop_f, LinearOperatorBase &_Linop_d, bool _updateResidual=true) : Linop_f(_Linop_f), Linop_d(_Linop_d), Tolerance(tol), InnerTolerance(tol), MaxInnerIterations(maxinnerit), MaxOuterIterations(maxouterit), MaxPatchupIterations(maxpatchit), SinglePrecGrid(_sp_grid), OuterLoopNormMult(100.), guesser(NULL), updateResidual(_updateResidual) { }; void useGuesser(LinearFunction &g){ guesser = &g; } void operator() (const FieldD &src_d_in, FieldD &sol_d){ std::vector srcs_d_in{src_d_in}; std::vector sols_d{sol_d}; (*this)(srcs_d_in,sols_d); sol_d = sols_d[0]; } void operator() (const std::vector &src_d_in, std::vector &sol_d){ assert(src_d_in.size() == sol_d.size()); int NBatch = src_d_in.size(); std::cout << GridLogMessage << "NBatch = " << NBatch << std::endl; Integer TotalOuterIterations = 0; //Number of restarts std::vector TotalInnerIterations(NBatch,0); //Number of inner CG iterations std::vector TotalFinalStepIterations(NBatch,0); //Number of CG iterations in final patch-up step GridStopWatch TotalTimer; TotalTimer.Start(); GridStopWatch InnerCGtimer; GridStopWatch PrecChangeTimer; int cb = src_d_in[0].Checkerboard(); std::vector src_norm; std::vector norm; std::vector stop; GridBase* DoublePrecGrid = src_d_in[0].Grid(); FieldD tmp_d(DoublePrecGrid); tmp_d.Checkerboard() = cb; FieldD tmp2_d(DoublePrecGrid); tmp2_d.Checkerboard() = cb; std::vector src_d; std::vector src_f; std::vector sol_f; for (int i=0; i CG_f(inner_tol, MaxInnerIterations); CG_f.ErrorOnNoConverge = false; Integer &outer_iter = TotalOuterIterations; //so it will be equal to the final iteration count for(outer_iter = 0; outer_iter < MaxOuterIterations; outer_iter++){ std::cout << GridLogMessage << std::endl; std::cout << GridLogMessage << "Outer iteration " << outer_iter << std::endl; bool allConverged = true; for (int i=0; i OuterLoopNormMult * stop[i]) { allConverged = false; } } if (allConverged) break; if (updateResidual) { RealD normMax = *std::max_element(std::begin(norm), std::end(norm)); RealD stopMax = *std::max_element(std::begin(stop), std::end(stop)); while( normMax * inner_tol * inner_tol < stopMax) inner_tol *= 2; // inner_tol = sqrt(stop/norm) ?? CG_f.Tolerance = inner_tol; } //Optionally improve inner solver guess (eg using known eigenvectors) if(guesser != NULL) { (*guesser)(src_f, sol_f); } for (int i=0; i CG_d(Tolerance, MaxPatchupIterations); CG_d(Linop_d, src_d_in[i], sol_d[i]); TotalFinalStepIterations[i] += CG_d.IterationsToComplete; } TotalTimer.Stop(); std::cout << GridLogMessage << std::endl; for (int i=0; i