/************************************************************************************* Grid physics library, www.github.com/paboyle/Grid Source file: ./lib/algorithms/iterative/ConjugateGradientMultiShift.h Copyright (C) 2015 Author: Azusa Yamaguchi Author: Peter Boyle Author: Christopher Kelly This program is free software; you can redistribute it and/or modify it under the terms of the GNU General Public License as published by the Free Software Foundation; either version 2 of the License, or (at your option) any later version. This program is distributed in the hope that it will be useful, but WITHOUT ANY WARRANTY; without even the implied warranty of MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the GNU General Public License for more details. You should have received a copy of the GNU General Public License along with this program; if not, write to the Free Software Foundation, Inc., 51 Franklin Street, Fifth Floor, Boston, MA 02110-1301 USA. See the full license in the file "LICENSE" in the top level distribution directory *************************************************************************************/ /* END LEGAL */ #ifndef GRID_CONJUGATE_GRADIENT_MULTI_SHIFT_MIXEDPREC_H #define GRID_CONJUGATE_GRADIENT_MULTI_SHIFT_MIXEDPREC_H NAMESPACE_BEGIN(Grid); //CK 2020: A variant of the multi-shift conjugate gradient with the matrix multiplication in single precision. //The residual is stored in single precision, but the search directions and solution are stored in double precision. //Every update_freq iterations the residual is corrected in double precision. //For safety the a final regular CG is applied to clean up if necessary //Linop to add shift to input linop, used in cleanup CG namespace ConjugateGradientMultiShiftMixedPrecSupport{ template class ShiftedLinop: public LinearOperatorBase{ public: LinearOperatorBase &linop_base; RealD shift; ShiftedLinop(LinearOperatorBase &_linop_base, RealD _shift): linop_base(_linop_base), shift(_shift){} void OpDiag (const Field &in, Field &out){ assert(0); } void OpDir (const Field &in, Field &out,int dir,int disp){ assert(0); } void OpDirAll (const Field &in, std::vector &out){ assert(0); } void Op (const Field &in, Field &out){ assert(0); } void AdjOp (const Field &in, Field &out){ assert(0); } void HermOp(const Field &in, Field &out){ linop_base.HermOp(in, out); axpy(out, shift, in, out); } void HermOpAndNorm(const Field &in, Field &out,RealD &n1,RealD &n2){ HermOp(in,out); ComplexD dot = innerProduct(in,out); n1=real(dot); n2=norm2(out); } }; }; template::value == 2, int>::type = 0, typename std::enable_if< getPrecision::value == 1, int>::type = 0> class ConjugateGradientMultiShiftMixedPrec : public OperatorMultiFunction, public OperatorFunction { public: using OperatorFunction::operator(); RealD Tolerance; Integer MaxIterationsMshift; Integer MaxIterations; Integer IterationsToComplete; //Number of iterations the CG took to finish. Filled in upon completion std::vector IterationsToCompleteShift; // Iterations for this shift int verbose; MultiShiftFunction shifts; std::vector TrueResidualShift; int ReliableUpdateFreq; //number of iterations between reliable updates GridBase* SinglePrecGrid; //Grid for single-precision fields LinearOperatorBase &Linop_f; //single precision ConjugateGradientMultiShiftMixedPrec(Integer maxit, const MultiShiftFunction &_shifts, GridBase* _SinglePrecGrid, LinearOperatorBase &_Linop_f, int _ReliableUpdateFreq) : MaxIterationsMshift(maxit), shifts(_shifts), SinglePrecGrid(_SinglePrecGrid), Linop_f(_Linop_f), ReliableUpdateFreq(_ReliableUpdateFreq), MaxIterations(20000) { verbose=1; IterationsToCompleteShift.resize(_shifts.order); TrueResidualShift.resize(_shifts.order); } void operator() (LinearOperatorBase &Linop, const FieldD &src, FieldD &psi) { GridBase *grid = src.Grid(); int nshift = shifts.order; std::vector results(nshift,grid); (*this)(Linop,src,results,psi); } void operator() (LinearOperatorBase &Linop, const FieldD &src, std::vector &results, FieldD &psi) { int nshift = shifts.order; (*this)(Linop,src,results); psi = shifts.norm*src; for(int i=0;i &Linop_d, const FieldD &src_d, std::vector &psi_d) { GRID_TRACE("ConjugateGradientMultiShiftMixedPrec"); GridBase *DoublePrecGrid = src_d.Grid(); //////////////////////////////////////////////////////////////////////// // Convenience references to the info stored in "MultiShiftFunction" //////////////////////////////////////////////////////////////////////// int nshift = shifts.order; std::vector &mass(shifts.poles); // Make references to array in "shifts" std::vector &mresidual(shifts.tolerances); std::vector alpha(nshift,1.0); //Double precision search directions FieldD p_d(DoublePrecGrid); std::vector ps_d(nshift, DoublePrecGrid);// Search directions (double precision) FieldD tmp_d(DoublePrecGrid); FieldD r_d(DoublePrecGrid); FieldD mmp_d(DoublePrecGrid); assert(psi_d.size()==nshift); assert(mass.size()==nshift); assert(mresidual.size()==nshift); // dynamic sized arrays on stack; 2d is a pain with vector RealD bs[nshift]; RealD rsq[nshift]; RealD rsqf[nshift]; RealD z[nshift][2]; int converged[nshift]; const int primary =0; //Primary shift fields CG iteration RealD a,b,c,d; RealD cp,bp,qq; //prev // Matrix mult fields FieldF p_f(SinglePrecGrid); FieldF mmp_f(SinglePrecGrid); // Check lightest mass for(int s=0;s= mass[primary] ); converged[s]=0; } // Wire guess to zero // Residuals "r" are src // First search direction "p" is also src cp = norm2(src_d); // Handle trivial case of zero src. if( cp == 0. ){ for(int s=0;s= rsq[s]){ CleanupTimer.Start(); std::cout< Linop_shift_d(Linop_d, mass[s]); ConjugateGradientMultiShiftMixedPrecSupport::ShiftedLinop Linop_shift_f(Linop_f, mass[s]); MixedPrecisionConjugateGradient cg(mresidual[s], MaxIterations, MaxIterations, SinglePrecGrid, Linop_shift_f, Linop_shift_d); cg(src_d, psi_d[s]); TrueResidualShift[s] = cg.TrueResidual; CleanupTimer.Stop(); } } std::cout << GridLogMessage << "ConjugateGradientMultiShiftMixedPrec: Time Breakdown for body"<