diff --git a/HMC/Mobius2p1fEOFA_F1.cc b/HMC/Mobius2p1fEOFA_F1.cc index 3f0a7bf6..e0f12335 100644 --- a/HMC/Mobius2p1fEOFA_F1.cc +++ b/HMC/Mobius2p1fEOFA_F1.cc @@ -34,8 +34,6 @@ directory #define MIXED_PRECISION #endif -NAMESPACE_BEGIN(Grid); - /* * Need a plan for gauge field update for mixed precision in HMC (2x speed up) * -- Store the single prec action operator. @@ -43,111 +41,7 @@ NAMESPACE_BEGIN(Grid); * -- Build the mixed precision operator dynamically from the passed operator and single prec clone. */ - template - class MixedPrecisionConjugateGradientOperatorFunction : public OperatorFunction { - public: - typedef typename FermionOperatorD::FermionField FieldD; - typedef typename FermionOperatorF::FermionField FieldF; - - using OperatorFunction::operator(); - - RealD Tolerance; - RealD InnerTolerance; //Initial tolerance for inner CG. Defaults to Tolerance but can be changed - Integer MaxInnerIterations; - Integer MaxOuterIterations; - GridBase* SinglePrecGrid4; //Grid for single-precision fields - GridBase* SinglePrecGrid5; //Grid for single-precision fields - RealD OuterLoopNormMult; //Stop the outer loop and move to a final double prec solve when the residual is OuterLoopNormMult * Tolerance - - FermionOperatorF &FermOpF; - FermionOperatorD &FermOpD;; - SchurOperatorF &LinOpF; - SchurOperatorD &LinOpD; - - Integer TotalInnerIterations; //Number of inner CG iterations - Integer TotalOuterIterations; //Number of restarts - Integer TotalFinalStepIterations; //Number of CG iterations in final patch-up step - - MixedPrecisionConjugateGradientOperatorFunction(RealD tol, - Integer maxinnerit, - Integer maxouterit, - GridBase* _sp_grid4, - GridBase* _sp_grid5, - FermionOperatorF &_FermOpF, - FermionOperatorD &_FermOpD, - SchurOperatorF &_LinOpF, - SchurOperatorD &_LinOpD): - LinOpF(_LinOpF), - LinOpD(_LinOpD), - FermOpF(_FermOpF), - FermOpD(_FermOpD), - Tolerance(tol), - InnerTolerance(tol), - MaxInnerIterations(maxinnerit), - MaxOuterIterations(maxouterit), - SinglePrecGrid4(_sp_grid4), - SinglePrecGrid5(_sp_grid5), - OuterLoopNormMult(100.) - { - /* Debugging instances of objects; references are stored - std::cout << GridLogMessage << " Mixed precision CG wrapper LinOpF " < &LinOpU, const FieldD &src, FieldD &psi) { - - std::cout << GridLogMessage << " Mixed precision CG wrapper operator() "<(&LinOpU); - - // std::cout << GridLogMessage << " Mixed precision CG wrapper operator() FermOpU " <_Mat)<_Mat)==&(LinOpD._Mat)); - - //////////////////////////////////////////////////////////////////////////////////// - // Must snarf a single precision copy of the gauge field in Linop_d argument - //////////////////////////////////////////////////////////////////////////////////// - typedef typename FermionOperatorF::GaugeField GaugeFieldF; - typedef typename FermionOperatorF::GaugeLinkField GaugeLinkFieldF; - typedef typename FermionOperatorD::GaugeField GaugeFieldD; - typedef typename FermionOperatorD::GaugeLinkField GaugeLinkFieldD; - - GridBase * GridPtrF = SinglePrecGrid4; - GridBase * GridPtrD = FermOpD.Umu.Grid(); - GaugeFieldF U_f (GridPtrF); - GaugeLinkFieldF Umu_f(GridPtrF); - // std::cout << " Dim gauge field "<Nd()<Nd()<(FermOpD.Umu, mu); - precisionChange(Umu_f,Umu_d); - PokeIndex(FermOpF.Umu, Umu_f, mu); - } - pickCheckerboard(Even,FermOpF.UmuEven,FermOpF.Umu); - pickCheckerboard(Odd ,FermOpF.UmuOdd ,FermOpF.Umu); - - //////////////////////////////////////////////////////////////////////////////////// - // Make a mixed precision conjugate gradient - //////////////////////////////////////////////////////////////////////////////////// - MixedPrecisionConjugateGradient MPCG(Tolerance,MaxInnerIterations,MaxOuterIterations,SinglePrecGrid5,LinOpF,LinOpD); - std::cout << GridLogMessage << "Calling mixed precision Conjugate Gradient" < int main(int argc, char **argv) { using namespace Grid; @@ -290,6 +184,7 @@ int main(int argc, char **argv) { ConjugateGradient DerivativeCG(DerivativeStoppingCondition,MaxCGIterations); #ifdef MIXED_PRECISION const int MX_inner = 5000; + const RealD MX_tol = 1.0e-6; // Mixed precision EOFA LinearOperatorEOFAD Strange_LinOp_L (Strange_Op_L); @@ -297,34 +192,30 @@ int main(int argc, char **argv) { LinearOperatorEOFAF Strange_LinOp_LF(Strange_Op_LF); LinearOperatorEOFAF Strange_LinOp_RF(Strange_Op_RF); - MxPCG_EOFA ActionCGL(ActionStoppingCondition, + MxPCG_EOFA ActionCGL(ActionStoppingCondition, MX_tol, MX_inner, MaxCGIterations, - GridPtrF, FrbGridF, Strange_Op_LF,Strange_Op_L, Strange_LinOp_LF,Strange_LinOp_L); - MxPCG_EOFA DerivativeCGL(DerivativeStoppingCondition, + MxPCG_EOFA DerivativeCGL(DerivativeStoppingCondition, MX_tol, MX_inner, MaxCGIterations, - GridPtrF, FrbGridF, Strange_Op_LF,Strange_Op_L, Strange_LinOp_LF,Strange_LinOp_L); - MxPCG_EOFA ActionCGR(ActionStoppingCondition, + MxPCG_EOFA ActionCGR(ActionStoppingCondition, MX_tol, MX_inner, MaxCGIterations, - GridPtrF, FrbGridF, Strange_Op_RF,Strange_Op_R, Strange_LinOp_RF,Strange_LinOp_R); - MxPCG_EOFA DerivativeCGR(DerivativeStoppingCondition, + MxPCG_EOFA DerivativeCGR(DerivativeStoppingCondition, MX_tol, MX_inner, MaxCGIterations, - GridPtrF, FrbGridF, Strange_Op_RF,Strange_Op_R, Strange_LinOp_RF,Strange_LinOp_R); @@ -394,18 +285,16 @@ int main(int argc, char **argv) { double conv = DerivativeStoppingCondition; if (h<3) conv= DerivativeStoppingConditionLoose; // Relax on first two hasenbusch factors - MPCG.push_back(new MxPCG(conv, + MPCG.push_back(new MxPCG(conv,MX_tol, MX_inner, MaxCGIterations, - GridPtrF, FrbGridF, *DenominatorsF[h],*Denominators[h], *LinOpF[h], *LinOpD[h]) ); - ActionMPCG.push_back(new MxPCG(ActionStoppingCondition, + ActionMPCG.push_back(new MxPCG(ActionStoppingCondition,MX_tol, MX_inner, MaxCGIterations, - GridPtrF, FrbGridF, *DenominatorsF[h],*Denominators[h], *LinOpF[h], *LinOpD[h]) );