From 4f5ad735015e2eaff469df297e0a5ad5225b29ed Mon Sep 17 00:00:00 2001 From: Quadro Date: Wed, 9 Jun 2021 16:33:02 -0400 Subject: [PATCH] Mixed prec update --- HMC/Mobius2p1fEOFA.cc | 150 +++--------------------------------------- 1 file changed, 8 insertions(+), 142 deletions(-) diff --git a/HMC/Mobius2p1fEOFA.cc b/HMC/Mobius2p1fEOFA.cc index 68a3a5b5..68ecc355 100644 --- a/HMC/Mobius2p1fEOFA.cc +++ b/HMC/Mobius2p1fEOFA.cc @@ -33,137 +33,8 @@ directory #ifdef GRID_DEFAULT_PRECISION_DOUBLE #define MIXED_PRECISION #endif +#include -NAMESPACE_BEGIN(Grid); - - /* - * Need a plan for gauge field update for mixed precision in HMC (2x speed up) - * -- Store the single prec action operator. - * -- Clone the gauge field from the operator function argument. - * -- Build the mixed precision operator dynamically from the passed operator and single prec clone. - */ - - template - class MixedPrecisionConjugateGradientOperatorFunction : public OperatorFunction { - public: - typedef typename FermionOperatorD::FermionField FieldD; - typedef typename FermionOperatorF::FermionField FieldF; - - using OperatorFunction::operator(); - - RealD Tolerance; - RealD InnerTolerance; //Initial tolerance for inner CG. Defaults to Tolerance but can be changed - Integer MaxInnerIterations; - Integer MaxOuterIterations; - GridBase* SinglePrecGrid4; //Grid for single-precision fields - GridBase* SinglePrecGrid5; //Grid for single-precision fields - RealD OuterLoopNormMult; //Stop the outer loop and move to a final double prec solve when the residual is OuterLoopNormMult * Tolerance - - FermionOperatorF &FermOpF; - FermionOperatorD &FermOpD;; - SchurOperatorF &LinOpF; - SchurOperatorD &LinOpD; - - Integer TotalInnerIterations; //Number of inner CG iterations - Integer TotalOuterIterations; //Number of restarts - Integer TotalFinalStepIterations; //Number of CG iterations in final patch-up step - - MixedPrecisionConjugateGradientOperatorFunction(RealD tol, - Integer maxinnerit, - Integer maxouterit, - GridBase* _sp_grid4, - GridBase* _sp_grid5, - FermionOperatorF &_FermOpF, - FermionOperatorD &_FermOpD, - SchurOperatorF &_LinOpF, - SchurOperatorD &_LinOpD): - LinOpF(_LinOpF), - LinOpD(_LinOpD), - FermOpF(_FermOpF), - FermOpD(_FermOpD), - Tolerance(tol), - InnerTolerance(tol), - MaxInnerIterations(maxinnerit), - MaxOuterIterations(maxouterit), - SinglePrecGrid4(_sp_grid4), - SinglePrecGrid5(_sp_grid5), - OuterLoopNormMult(100.) - { - /* Debugging instances of objects; references are stored - std::cout << GridLogMessage << " Mixed precision CG wrapper LinOpF " < &LinOpU, const FieldD &src, FieldD &psi) { - - std::cout << GridLogMessage << " Mixed precision CG wrapper operator() "<(&LinOpU); - - // std::cout << GridLogMessage << " Mixed precision CG wrapper operator() FermOpU " <_Mat)<_Mat)==&(LinOpD._Mat)); - - //////////////////////////////////////////////////////////////////////////////////// - // Must snarf a single precision copy of the gauge field in Linop_d argument - //////////////////////////////////////////////////////////////////////////////////// - typedef typename FermionOperatorF::GaugeField GaugeFieldF; - typedef typename FermionOperatorF::GaugeLinkField GaugeLinkFieldF; - typedef typename FermionOperatorD::GaugeField GaugeFieldD; - typedef typename FermionOperatorD::GaugeLinkField GaugeLinkFieldD; - - GridBase * GridPtrF = SinglePrecGrid4; - GridBase * GridPtrD = FermOpD.Umu.Grid(); - GaugeFieldF U_f (GridPtrF); - GaugeLinkFieldF Umu_f(GridPtrF); - // std::cout << " Dim gauge field "<Nd()<Nd()<(FermOpD.Umu, mu); - precisionChange(Umu_f,Umu_d); - PokeIndex(FermOpF.Umu, Umu_f, mu); - } - pickCheckerboard(Even,FermOpF.UmuEven,FermOpF.Umu); - pickCheckerboard(Odd ,FermOpF.UmuOdd ,FermOpF.Umu); - - //////////////////////////////////////////////////////////////////////////////////// - // Could test to make sure that LinOpF and LinOpD agree to single prec? - //////////////////////////////////////////////////////////////////////////////////// - /* - GridBase *Fgrid = psi._grid; - FieldD tmp2(Fgrid); - FieldD tmp1(Fgrid); - LinOpU.Op(src,tmp1); - LinOpD.Op(src,tmp2); - std::cout << " Double gauge field "<< norm2(FermOpD.Umu)< MPCG(Tolerance,MaxInnerIterations,MaxOuterIterations,SinglePrecGrid5,LinOpF,LinOpD); - std::cout << GridLogMessage << "Calling mixed precision Conjugate Gradient" < DerivativeCG(DerivativeStoppingCondition,MaxCGIterations); #ifdef MIXED_PRECISION const int MX_inner = 1000; + const RealD MX_tol = 1.0e-6; // Mixed precision EOFA LinearOperatorEOFAD Strange_LinOp_L (Strange_Op_L); LinearOperatorEOFAD Strange_LinOp_R (Strange_Op_R); LinearOperatorEOFAF Strange_LinOp_LF(Strange_Op_LF); LinearOperatorEOFAF Strange_LinOp_RF(Strange_Op_RF); - MxPCG_EOFA ActionCGL(ActionStoppingCondition, + MxPCG_EOFA ActionCGL(ActionStoppingCondition,MX_tol, MX_inner, MaxCGIterations, - GridPtrF, FrbGridF, Strange_Op_LF,Strange_Op_L, Strange_LinOp_LF,Strange_LinOp_L); - MxPCG_EOFA DerivativeCGL(DerivativeStoppingCondition, + MxPCG_EOFA DerivativeCGL(DerivativeStoppingCondition,MX_tol, MX_inner, MaxCGIterations, - GridPtrF, FrbGridF, Strange_Op_LF,Strange_Op_L, Strange_LinOp_LF,Strange_LinOp_L); - MxPCG_EOFA ActionCGR(ActionStoppingCondition, + MxPCG_EOFA ActionCGR(ActionStoppingCondition,MX_tol, MX_inner, MaxCGIterations, - GridPtrF, FrbGridF, Strange_Op_RF,Strange_Op_R, Strange_LinOp_RF,Strange_LinOp_R); - MxPCG_EOFA DerivativeCGR(DerivativeStoppingCondition, + MxPCG_EOFA DerivativeCGR(DerivativeStoppingCondition,MX_tol, MX_inner, MaxCGIterations, - GridPtrF, FrbGridF, Strange_Op_RF,Strange_Op_R, Strange_LinOp_RF,Strange_LinOp_R); @@ -401,18 +269,16 @@ int main(int argc, char **argv) { LinOpD.push_back(new LinearOperatorD(*Denominators[h])); LinOpF.push_back(new LinearOperatorF(*DenominatorsF[h])); - MPCG.push_back(new MxPCG(DerivativeStoppingCondition, + MPCG.push_back(new MxPCG(DerivativeStoppingCondition,MX_tol, MX_inner, MaxCGIterations, - GridPtrF, FrbGridF, *DenominatorsF[h],*Denominators[h], *LinOpF[h], *LinOpD[h]) ); - ActionMPCG.push_back(new MxPCG(ActionStoppingCondition, + ActionMPCG.push_back(new MxPCG(ActionStoppingCondition,MX_tol, MX_inner, MaxCGIterations, - GridPtrF, FrbGridF, *DenominatorsF[h],*Denominators[h], *LinOpF[h], *LinOpD[h]) );