From 426d2365d14744159bff2ef260f0ad148cf8506c Mon Sep 17 00:00:00 2001 From: Quadro Date: Tue, 1 Jun 2021 13:35:38 -0400 Subject: [PATCH] Schur factored matrix --- .../fermion/SchurFactoredFermionOperator.h | 172 +++++++++++++----- 1 file changed, 125 insertions(+), 47 deletions(-) diff --git a/Grid/qcd/action/fermion/SchurFactoredFermionOperator.h b/Grid/qcd/action/fermion/SchurFactoredFermionOperator.h index 84fb0bd1..7c971600 100644 --- a/Grid/qcd/action/fermion/SchurFactoredFermionOperator.h +++ b/Grid/qcd/action/fermion/SchurFactoredFermionOperator.h @@ -27,6 +27,8 @@ Author: Peter Boyle /* END LEGAL */ #pragma once +#include + NAMESPACE_BEGIN(Grid); //////////////////////////////////////////////////////// @@ -63,47 +65,91 @@ NAMESPACE_BEGIN(Grid); // //////////////////////////////////////////////////////// -template -class SchurFactoredFermionOperator : public Impl + +template +class SchurFactoredFermionOperator : public ImplD { - INHERIT_IMPL_TYPES(Impl); + INHERIT_IMPL_TYPES(ImplD); + typedef typename ImplF::FermionField FermionFieldF; + typedef typename ImplD::FermionField FermionFieldD; + + typedef SchurDiagMooeeOperator,FermionFieldD> LinearOperatorD; + typedef SchurDiagMooeeOperator,FermionFieldF> LinearOperatorF; + typedef SchurDiagMooeeDagOperator,FermionFieldD> LinearOperatorDagD; + typedef SchurDiagMooeeDagOperator,FermionFieldF> LinearOperatorDagF; + + typedef MixedPrecisionConjugateGradientOperatorFunction, + FermionOperator, + LinearOperatorD, + LinearOperatorF> MxPCG; + + typedef MixedPrecisionConjugateGradientOperatorFunction, + FermionOperator, + LinearOperatorDagD, + LinearOperatorDagF> MxDagPCG; public: - - FermionOperator & DirichletFermOp; - FermionOperator & FermOp; - OperatorFunction &OmegaSolver; - OperatorFunction &OmegaDagSolver; - OperatorFunction &DSolver; - OperatorFunction &DdagSolver; - Coordinate Block; - SchurFactoredFermionOperator(FermionOperator & _FermOp, - FermionOperator & _DirichletFermOp, - OperatorFunction &_OmegaSolver, - OperatorFunction &_OmegaDagSolver, - OperatorFunction &_DSolver, - OperatorFunction &_DdagSolver, + GridBase *FermionGrid(void) { return PeriodicFermOpD.FermionGrid(); }; + GridBase *GaugeGrid(void) { return PeriodicFermOpD.GaugeGrid(); }; + + FermionOperator & DirichletFermOpD; + FermionOperator & DirichletFermOpF; + FermionOperator & PeriodicFermOpD; + FermionOperator & PeriodicFermOpF; + + LinearOperatorD DirichletLinOpD; + LinearOperatorF DirichletLinOpF; + LinearOperatorD PeriodicLinOpD; + LinearOperatorF PeriodicLinOpF; + + LinearOperatorDagD DirichletLinOpDagD; + LinearOperatorDagF DirichletLinOpDagF; + LinearOperatorDagD PeriodicLinOpDagD; + LinearOperatorDagF PeriodicLinOpDagF; + + // Can tinker with these in the pseudofermion for force vs. action solves + Integer maxinnerit; + Integer maxouterit; + RealD tol; + + Coordinate Block; + Coordinate InnerBlock; + + SchurFactoredFermionOperator(FermionOperator & _PeriodicFermOpD, + FermionOperator & _PeriodicFermOpF, + FermionOperator & _DirichletFermOpD, + FermionOperator & _DirichletFermOpF, Coordinate &_Block) : Block(_Block), - FermOp(_FermOp), - DirichletFermOp(_DirichletFermOp), - OmegaSolver(_OmegaSolver), - OmegaDagSolver(_OmegaDagSolver), - DSolver(_DSolver), - DdagSolver(_DdagSolver) + PeriodicFermOpD(_PeriodicFermOpD), + PeriodicFermOpF(_PeriodicFermOpF), + DirichletFermOpD(_DirichletFermOpD), + DirichletFermOpF(_DirichletFermOpF), + DirichletLinOpD(DirichletFermOpD), + DirichletLinOpF(DirichletFermOpF), + PeriodicLinOpD(PeriodicFermOpD), + PeriodicLinOpF(PeriodicFermOpF), + DirichletLinOpDagD(DirichletFermOpD), + DirichletLinOpDagF(DirichletFermOpF), + PeriodicLinOpDagD(PeriodicFermOpD), + PeriodicLinOpDagF(PeriodicFermOpF) { - // Pass in Dirichlet FermOp because we really need two dirac operators - // as double stored gauge fields differ and they will otherwise overwrite - assert(_FermOp.FermionGrid() == _DirichletFermOp.FermionGrid()); // May not be true in future if change communicator scheme + InnerBlock = Coordinate({16,16,16,20}); + tol=1.0e-10; + maxinnerit=1000; + maxouterit=10; + assert(PeriodicFermOpD.FermionGrid() == DirichletFermOpD.FermionGrid()); + assert(PeriodicFermOpF.FermionGrid() == DirichletFermOpF.FermionGrid()); }; enum Domain { Omega=0, OmegaBar=1 }; void ImportGauge(const GaugeField &Umu) { - FermOp.ImportGauge(Umu); - DirichletFermOp.ImportGauge(Umu); + PeriodicFermOpD.ImportGauge(Umu); + DirichletFermOpD.ImportGauge(Umu); + // Single precision will update in the mixed prec CG } void ProjectBoundaryBothDomains (FermionField &f,int sgn) { @@ -202,56 +248,56 @@ public: { FermionField tmp(in); ProjectOmegaBar(tmp); - FermOp.M(tmp,out); + PeriodicFermOpD.M(tmp,out); ProjectOmega(out); }; void dBoundaryDag (FermionField &in,FermionField &out) { FermionField tmp(in); ProjectOmega(tmp); - FermOp.Mdag(tmp,out); + PeriodicFermOpD.Mdag(tmp,out); ProjectOmegaBar(out); }; void dBoundaryBar (FermionField &in,FermionField &out) { FermionField tmp(in); ProjectOmega(tmp); - FermOp.M(tmp,out); + PeriodicFermOpD.M(tmp,out); ProjectOmegaBar(out); }; void dBoundaryBarDag (FermionField &in,FermionField &out) { FermionField tmp(in); ProjectOmegaBar(tmp); - FermOp.Mdag(tmp,out); + PeriodicFermOpD.Mdag(tmp,out); ProjectOmega(out); }; void dOmega (FermionField &in,FermionField &out) { FermionField tmp(in); ProjectOmega(tmp); - DirichletFermOp.M(tmp,out); + DirichletFermOpD.M(tmp,out); ProjectOmega(out); }; void dOmegaBar (FermionField &in,FermionField &out) { FermionField tmp(in); ProjectOmegaBar(tmp); - DirichletFermOp.M(tmp,out); + DirichletFermOpD.M(tmp,out); ProjectOmegaBar(out); }; void dOmegaDag (FermionField &in,FermionField &out) { FermionField tmp(in); ProjectOmega(tmp); - DirichletFermOp.Mdag(tmp,out); + DirichletFermOpD.Mdag(tmp,out); ProjectOmega(out); }; void dOmegaBarDag (FermionField &in,FermionField &out) { FermionField tmp(in); ProjectOmegaBar(tmp); - DirichletFermOp.Mdag(tmp,out); + DirichletFermOpD.Mdag(tmp,out); ProjectOmegaBar(out); }; void dOmegaInv (FermionField &in,FermionField &out) @@ -284,20 +330,36 @@ public: }; void dOmegaInvAndOmegaBarInv(FermionField &in,FermionField &out) { + MxPCG OmegaSolver(tol, + maxinnerit, + maxouterit, + DirichletFermOpF.FermionRedBlackGrid(), + DirichletFermOpF, + DirichletFermOpD, + DirichletLinOpF, + DirichletLinOpD); SchurRedBlackDiagMooeeSolve PrecSolve(OmegaSolver); - PrecSolve(DirichletFermOp,in,out); + PrecSolve(DirichletFermOpD,in,out); }; void dOmegaDagInvAndOmegaBarDagInv(FermionField &in,FermionField &out) { + MxDagPCG OmegaDagSolver(tol, + maxinnerit, + maxouterit, + DirichletFermOpF.FermionRedBlackGrid(), + DirichletFermOpF, + DirichletFermOpD, + DirichletLinOpDagF, + DirichletLinOpDagD); SchurRedBlackDiagMooeeDagSolve PrecSolve(OmegaDagSolver); - PrecSolve(DirichletFermOp,in,out); + PrecSolve(DirichletFermOpD,in,out); }; // Rdag = Pdbar - DdbarDag DomegabarDagInv DdDag DomegaDagInv Pdbar void RDag(FermionField &in,FermionField &out) { - FermionField tmp1(FermOp.FermionGrid()); - FermionField tmp2(FermOp.FermionGrid()); + FermionField tmp1(PeriodicFermOpD.FermionGrid()); + FermionField tmp2(PeriodicFermOpD.FermionGrid()); out = in; ProjectBoundaryBar(out); dOmegaDagInv(out,tmp1); @@ -310,8 +372,8 @@ public: // R = Pdbar - Pdbar DomegaInv Dd DomegabarInv Ddbar void R(FermionField &in,FermionField &out) { - FermionField tmp1(FermOp.FermionGrid()); - FermionField tmp2(FermOp.FermionGrid()); + FermionField tmp1(PeriodicFermOpD.FermionGrid()); + FermionField tmp2(PeriodicFermOpD.FermionGrid()); out = in; ProjectBoundaryBar(out); dBoundaryBar(out,tmp1); @@ -326,7 +388,7 @@ public: // R = Pdbar - Pdbar Dinv Ddbar void RInv(FermionField &in,FermionField &out) { - FermionField tmp1(FermOp.FermionGrid()); + FermionField tmp1(PeriodicFermOpD.FermionGrid()); dBoundaryBar(in,out); Dinverse(out,tmp1); out =in -tmp1; @@ -335,8 +397,8 @@ public: // R = Pdbar - DdbarDag DinvDag Pdbar void RDagInv(FermionField &in,FermionField &out) { - FermionField tmp(FermOp.FermionGrid()); - FermionField Pin(FermOp.FermionGrid()); + FermionField tmp(PeriodicFermOpD.FermionGrid()); + FermionField Pin(PeriodicFermOpD.FermionGrid()); Pin = in; ProjectBoundaryBar(Pin); DinverseDag(Pin,out); dBoundaryBarDag(out,tmp); @@ -345,13 +407,29 @@ public: // Non-dirichlet inverter using red-black preconditioning void Dinverse(FermionField &in,FermionField &out) { + MxPCG DSolver(tol, + maxinnerit, + maxouterit, + PeriodicFermOpF.FermionRedBlackGrid(), + PeriodicFermOpF, + PeriodicFermOpD, + PeriodicLinOpF, + PeriodicLinOpD); SchurRedBlackDiagMooeeSolve Solve(DSolver); - Solve(FermOp,in,out); + Solve(PeriodicFermOpD,in,out); } void DinverseDag(FermionField &in,FermionField &out) { + MxDagPCG DdagSolver(tol, + maxinnerit, + maxouterit, + PeriodicFermOpF.FermionRedBlackGrid(), + PeriodicFermOpF, + PeriodicFermOpD, + PeriodicLinOpDagF, + PeriodicLinOpDagD); SchurRedBlackDiagMooeeDagSolve Solve(DdagSolver); - Solve(FermOp,in,out); + Solve(PeriodicFermOpD,in,out); } };