From 386a89c6680f88cdc4a8a0cc920766ccc9762e99 Mon Sep 17 00:00:00 2001 From: Quadro Date: Wed, 9 Jun 2021 17:14:24 -0400 Subject: [PATCH] Updated mixed prec --- HMC/Mobius2p1fEOFA_4dPseudoFermion.cc | 153 ++------------------------ 1 file changed, 9 insertions(+), 144 deletions(-) diff --git a/HMC/Mobius2p1fEOFA_4dPseudoFermion.cc b/HMC/Mobius2p1fEOFA_4dPseudoFermion.cc index 5b8731ae..49a0de75 100644 --- a/HMC/Mobius2p1fEOFA_4dPseudoFermion.cc +++ b/HMC/Mobius2p1fEOFA_4dPseudoFermion.cc @@ -35,137 +35,8 @@ directory #ifdef GRID_DEFAULT_PRECISION_DOUBLE #define MIXED_PRECISION #endif +#include -NAMESPACE_BEGIN(Grid); - - /* - * Need a plan for gauge field update for mixed precision in HMC (2x speed up) - * -- Store the single prec action operator. - * -- Clone the gauge field from the operator function argument. - * -- Build the mixed precision operator dynamically from the passed operator and single prec clone. - */ - - template - class MixedPrecisionConjugateGradientOperatorFunction : public OperatorFunction { - public: - typedef typename FermionOperatorD::FermionField FieldD; - typedef typename FermionOperatorF::FermionField FieldF; - - using OperatorFunction::operator(); - - RealD Tolerance; - RealD InnerTolerance; //Initial tolerance for inner CG. Defaults to Tolerance but can be changed - Integer MaxInnerIterations; - Integer MaxOuterIterations; - GridBase* SinglePrecGrid4; //Grid for single-precision fields - GridBase* SinglePrecGrid5; //Grid for single-precision fields - RealD OuterLoopNormMult; //Stop the outer loop and move to a final double prec solve when the residual is OuterLoopNormMult * Tolerance - - FermionOperatorF &FermOpF; - FermionOperatorD &FermOpD;; - SchurOperatorF &LinOpF; - SchurOperatorD &LinOpD; - - Integer TotalInnerIterations; //Number of inner CG iterations - Integer TotalOuterIterations; //Number of restarts - Integer TotalFinalStepIterations; //Number of CG iterations in final patch-up step - - MixedPrecisionConjugateGradientOperatorFunction(RealD tol, - Integer maxinnerit, - Integer maxouterit, - GridBase* _sp_grid4, - GridBase* _sp_grid5, - FermionOperatorF &_FermOpF, - FermionOperatorD &_FermOpD, - SchurOperatorF &_LinOpF, - SchurOperatorD &_LinOpD): - LinOpF(_LinOpF), - LinOpD(_LinOpD), - FermOpF(_FermOpF), - FermOpD(_FermOpD), - Tolerance(tol), - InnerTolerance(tol), - MaxInnerIterations(maxinnerit), - MaxOuterIterations(maxouterit), - SinglePrecGrid4(_sp_grid4), - SinglePrecGrid5(_sp_grid5), - OuterLoopNormMult(100.) - { - /* Debugging instances of objects; references are stored - std::cout << GridLogMessage << " Mixed precision CG wrapper LinOpF " < &LinOpU, const FieldD &src, FieldD &psi) { - - std::cout << GridLogMessage << " Mixed precision CG wrapper operator() "<(&LinOpU); - - // std::cout << GridLogMessage << " Mixed precision CG wrapper operator() FermOpU " <_Mat)<_Mat)==&(LinOpD._Mat)); - - //////////////////////////////////////////////////////////////////////////////////// - // Must snarf a single precision copy of the gauge field in Linop_d argument - //////////////////////////////////////////////////////////////////////////////////// - typedef typename FermionOperatorF::GaugeField GaugeFieldF; - typedef typename FermionOperatorF::GaugeLinkField GaugeLinkFieldF; - typedef typename FermionOperatorD::GaugeField GaugeFieldD; - typedef typename FermionOperatorD::GaugeLinkField GaugeLinkFieldD; - - GridBase * GridPtrF = SinglePrecGrid4; - GridBase * GridPtrD = FermOpD.Umu.Grid(); - GaugeFieldF U_f (GridPtrF); - GaugeLinkFieldF Umu_f(GridPtrF); - // std::cout << " Dim gauge field "<Nd()<Nd()<(FermOpD.Umu, mu); - precisionChange(Umu_f,Umu_d); - PokeIndex(FermOpF.Umu, Umu_f, mu); - } - pickCheckerboard(Even,FermOpF.UmuEven,FermOpF.Umu); - pickCheckerboard(Odd ,FermOpF.UmuOdd ,FermOpF.Umu); - - //////////////////////////////////////////////////////////////////////////////////// - // Could test to make sure that LinOpF and LinOpD agree to single prec? - //////////////////////////////////////////////////////////////////////////////////// - /* - GridBase *Fgrid = psi._grid; - FieldD tmp2(Fgrid); - FieldD tmp1(Fgrid); - LinOpU.Op(src,tmp1); - LinOpD.Op(src,tmp2); - std::cout << " Double gauge field "<< norm2(FermOpD.Umu)< MPCG(Tolerance,MaxInnerIterations,MaxOuterIterations,SinglePrecGrid5,LinOpF,LinOpD); - std::cout << GridLogMessage << "Calling mixed precision Conjugate Gradient" < DerivativeCG(DerivativeStoppingCondition,MaxCGIterations); #ifdef MIXED_PRECISION const int MX_inner = 1000; + const RealD MX_tol = 1.0e-4; // Mixed precision EOFA LinearOperatorEOFAD Strange_LinOp_L (Strange_Op_L); LinearOperatorEOFAD Strange_LinOp_R (Strange_Op_R); LinearOperatorEOFAF Strange_LinOp_LF(Strange_Op_LF); LinearOperatorEOFAF Strange_LinOp_RF(Strange_Op_RF); - MxPCG_EOFA ActionCGL(ActionStoppingCondition, + MxPCG_EOFA ActionCGL(ActionStoppingCondition,MX_tol, MX_inner, MaxCGIterations, - GridPtrF, FrbGridF, Strange_Op_LF,Strange_Op_L, Strange_LinOp_LF,Strange_LinOp_L); - MxPCG_EOFA DerivativeCGL(DerivativeStoppingCondition, + MxPCG_EOFA DerivativeCGL(DerivativeStoppingCondition,MX_tol, MX_inner, MaxCGIterations, - GridPtrF, FrbGridF, Strange_Op_LF,Strange_Op_L, Strange_LinOp_LF,Strange_LinOp_L); - MxPCG_EOFA ActionCGR(ActionStoppingCondition, + MxPCG_EOFA ActionCGR(ActionStoppingCondition,MX_tol, MX_inner, MaxCGIterations, - GridPtrF, FrbGridF, Strange_Op_RF,Strange_Op_R, Strange_LinOp_RF,Strange_LinOp_R); - MxPCG_EOFA DerivativeCGR(DerivativeStoppingCondition, + MxPCG_EOFA DerivativeCGR(DerivativeStoppingCondition,MX_tol, MX_inner, MaxCGIterations, - GridPtrF, FrbGridF, Strange_Op_RF,Strange_Op_R, Strange_LinOp_RF,Strange_LinOp_R); @@ -413,26 +281,23 @@ int main(int argc, char **argv) { LinOpDagD.push_back(new LinearOperatorDagD(*Denominators[h])); LinOpDagF.push_back(new LinearOperatorDagF(*DenominatorsF[h])); - MPCG.push_back(new MxPCG(DerivativeStoppingCondition, + MPCG.push_back(new MxPCG(DerivativeStoppingCondition,MX_tol, MX_inner, MaxCGIterations, - GridPtrF, FrbGridF, *DenominatorsF[h],*Denominators[h], *LinOpF[h], *LinOpD[h]) ); - MPCGdag.push_back(new MxDagPCG(DerivativeStoppingCondition, + MPCGdag.push_back(new MxDagPCG(DerivativeStoppingCondition,MX_tol, MX_inner, MaxCGIterations, - GridPtrF, FrbGridF, *DenominatorsF[h],*Denominators[h], *LinOpDagF[h], *LinOpDagD[h]) ); - ActionMPCG.push_back(new MxPCG(ActionStoppingCondition, + ActionMPCG.push_back(new MxPCG(ActionStoppingCondition,MX_tol, MX_inner, MaxCGIterations, - GridPtrF, FrbGridF, *DenominatorsF[h],*Denominators[h], *LinOpF[h], *LinOpD[h]) );