From 38d8cd228e7fd31e3cacfdff379483ac5054cfb6 Mon Sep 17 00:00:00 2001 From: Quadro Date: Tue, 1 Jun 2021 13:31:18 -0400 Subject: [PATCH] Reusable mixed precision wrapper --- .../utils/MixedPrecisionOperatorFunction.h | 110 ++++++++++++++++++ 1 file changed, 110 insertions(+) create mode 100644 Grid/qcd/utils/MixedPrecisionOperatorFunction.h diff --git a/Grid/qcd/utils/MixedPrecisionOperatorFunction.h b/Grid/qcd/utils/MixedPrecisionOperatorFunction.h new file mode 100644 index 00000000..7ef26c85 --- /dev/null +++ b/Grid/qcd/utils/MixedPrecisionOperatorFunction.h @@ -0,0 +1,110 @@ +/************************************************************************************* + +Grid physics library, www.github.com/paboyle/Grid + +Source file: + +Copyright (C) 2015-2016 + +Author: Peter Boyle + +This program is free software; you can redistribute it and/or modify +it under the terms of the GNU General Public License as published by +the Free Software Foundation; either version 2 of the License, or +(at your option) any later version. + +This program is distributed in the hope that it will be useful, +but WITHOUT ANY WARRANTY; without even the implied warranty of +MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the +GNU General Public License for more details. + +You should have received a copy of the GNU General Public License along +with this program; if not, write to the Free Software Foundation, Inc., +51 Franklin Street, Fifth Floor, Boston, MA 02110-1301 USA. + +See the full license in the file "LICENSE" in the top level distribution +directory +*************************************************************************************/ +/* END LEGAL */ + +#pragma once + +NAMESPACE_BEGIN(Grid); + +template +class MixedPrecisionConjugateGradientOperatorFunction : public OperatorFunction { + + public: + typedef typename FermionOperatorD::FermionField FieldD; + typedef typename FermionOperatorF::FermionField FieldF; + + using OperatorFunction::operator(); + + RealD Tolerance; + RealD InnerTolerance; //Initial tolerance for inner CG. Defaults to Tolerance but can be changed + Integer MaxInnerIterations; + Integer MaxOuterIterations; + GridBase* SinglePrecGrid; + RealD OuterLoopNormMult; //Stop the outer loop and move to a final double prec solve when the residual is OuterLoopNormMult * Tolerance + + FermionOperatorF &FermOpF; + FermionOperatorD &FermOpD;; + SchurOperatorF &LinOpF; + SchurOperatorD &LinOpD; + + Integer TotalInnerIterations; //Number of inner CG iterations + Integer TotalOuterIterations; //Number of restarts + Integer TotalFinalStepIterations; //Number of CG iterations in final patch-up step + + MixedPrecisionConjugateGradientOperatorFunction(RealD tol, + Integer maxinnerit, + Integer maxouterit, + GridBase *_SinglePrecGrid, + FermionOperatorF &_FermOpF, + FermionOperatorD &_FermOpD, + SchurOperatorF &_LinOpF, + SchurOperatorD &_LinOpD) : + LinOpF(_LinOpF), + LinOpD(_LinOpD), + FermOpF(_FermOpF), + FermOpD(_FermOpD), + Tolerance(tol), + InnerTolerance(tol), + MaxInnerIterations(maxinnerit), + MaxOuterIterations(maxouterit), + SinglePrecGrid(_SinglePrecGrid), + OuterLoopNormMult(100.) + { }; + + void operator()(LinearOperatorBase &LinOpU, const FieldD &src, FieldD &psi) + { + + SchurOperatorD * SchurOpU = static_cast(&LinOpU); + + // Assumption made in code to extract gauge field + // We could avoid storing LinopD reference alltogether ? + assert(&(SchurOpU->_Mat)==&(LinOpD._Mat)); + + //////////////////////////////////////////////////////////////////////////////////// + // Moving this to a Clone method of fermion operator would allow to duplicate the + // physics parameters and decrease gauge field copies + //////////////////////////////////////////////////////////////////////////////////// + auto &Umu_d = FermOpD.GetDoubledGaugeField(); + auto &Umu_f = FermOpF.GetDoubledGaugeField(); + auto &Umu_fe= FermOpF.GetDoubledGaugeFieldE(); + auto &Umu_fo= FermOpF.GetDoubledGaugeFieldO(); + precisionChange(Umu_f,Umu_d); + pickCheckerboard(Even,Umu_fe,Umu_f); + pickCheckerboard(Odd ,Umu_fo,Umu_f); + + ////////////////////////////////////////////////////////////////////////////////////////// + // Make a mixed precision conjugate gradient + ////////////////////////////////////////////////////////////////////////////////////////// + // Could assume red black solver here and remove the SinglePrecGrid parameter??? + MixedPrecisionConjugateGradient MPCG(Tolerance,MaxInnerIterations,MaxOuterIterations,SinglePrecGrid,LinOpF,LinOpD); + std::cout << GridLogMessage << "Calling mixed precision Conjugate Gradient src "<