From f3e60a9feb27087c88c16d011aa607ba10dbcac3 Mon Sep 17 00:00:00 2001 From: Peter Boyle Date: Fri, 5 Jun 2015 10:17:10 +0100 Subject: [PATCH] Rework the linop support to get different forms of red black schur solver Moo on diag, or MooInv Moe MeeInv Meo --- lib/algorithms/LinearOperator.h | 154 +++++++++++++------ lib/algorithms/SparseMatrix.h | 26 ---- lib/algorithms/iterative/ConjugateGradient.h | 21 +-- lib/algorithms/iterative/SchurRedBlack.h | 15 +- tests/Test_cayley_cg.cc | 6 +- tests/Test_cayley_even_odd.cc | 11 +- tests/Test_contfrac_cg.cc | 6 +- tests/Test_contfrac_even_odd.cc | 11 +- tests/Test_dwf_cg_prec.cc | 2 +- tests/Test_dwf_cg_schur.cc | 2 +- tests/Test_dwf_cg_unprec.cc | 2 +- tests/Test_dwf_even_odd.cc | 10 +- tests/Test_wilson_cg_prec.cc | 2 +- tests/Test_wilson_cg_schur.cc | 2 +- tests/Test_wilson_cg_unprec.cc | 2 +- tests/Test_wilson_even_odd.cc | 9 +- 16 files changed, 155 insertions(+), 126 deletions(-) diff --git a/lib/algorithms/LinearOperator.h b/lib/algorithms/LinearOperator.h index 5839f2c1..710fcde7 100644 --- a/lib/algorithms/LinearOperator.h +++ b/lib/algorithms/LinearOperator.h @@ -18,84 +18,144 @@ namespace Grid { public: virtual void Op (const Field &in, Field &out) = 0; // Abstract base virtual void AdjOp (const Field &in, Field &out) = 0; // Abstract base + virtual void HermOpAndNorm(const Field &in, Field &out,double &n1,double &n2)=0; }; - ///////////////////////////////////////////////////////////////////////////////////////////// - // Hermitian operators are self adjoint and only require Op to be defined, so refine the base - ///////////////////////////////////////////////////////////////////////////////////////////// - template class HermitianOperatorBase : public LinearOperatorBase { - public: - virtual void OpAndNorm(const Field &in, Field &out,double &n1,double &n2)=0; - void AdjOp(const Field &in, Field &out) { - Op(in,out); - }; - void Op(const Field &in, Field &out) { - double n1,n2; - OpAndNorm(in,out,n1,n2); - }; - }; ///////////////////////////////////////////////////////////////////////////////////////////// - // Whereas non hermitian takes a generic sparse matrix (e.g. lattice action) - // conforming to sparse matrix interface and builds the full checkerboard non-herm operator - // Op and AdjOp distinct. // By sharing the class for Sparse Matrix across multiple operator wrappers, we can share code // between RB and non-RB variants. Sparse matrix is like the fermion action def, and then // the wrappers implement the specialisation of "Op" and "AdjOp" to the cases minimising // replication of code. + // + // I'm not entirely happy with implementation; to share the Schur code between herm and non-herm + // while still having a "OpAndNorm" in the abstract base I had to implement it in both cases + // with an assert trap in the non-herm. This isn't right; there must be a better C++ way to + // do it, but I fear it required multiple inheritance and mixed in abstract base classes ///////////////////////////////////////////////////////////////////////////////////////////// + + //////////////////////////////////////////////////////////////////// + // Construct herm op from non-herm matrix + //////////////////////////////////////////////////////////////////// template - class NonHermitianOperator : public LinearOperatorBase { + class MdagMLinearOperator : public LinearOperatorBase { Matrix &_Mat; public: - NonHermitianOperator(Matrix &Mat): _Mat(Mat){}; + MdagMLinearOperator(Matrix &Mat): _Mat(Mat){}; void Op (const Field &in, Field &out){ _Mat.M(in,out); } void AdjOp (const Field &in, Field &out){ _Mat.Mdag(in,out); } + void HermOpAndNorm(const Field &in, Field &out,double &n1,double &n2){ + _Mat.MdagM(in,out,n1,n2); + } }; - - //////////////////////////////////////////////////////////////////////////////////// - // Redblack Non hermitian wrapper - //////////////////////////////////////////////////////////////////////////////////// + + //////////////////////////////////////////////////////////////////// + // Wrap an already herm matrix + //////////////////////////////////////////////////////////////////// template - class NonHermitianCheckerBoardedOperator : public LinearOperatorBase { + class HermitianLinearOperator : public LinearOperatorBase { Matrix &_Mat; public: - NonHermitianCheckerBoardedOperator(Matrix &Mat): _Mat(Mat){}; + HermitianLinearOperator(Matrix &Mat): _Mat(Mat){}; void Op (const Field &in, Field &out){ - _Mat.Mpc(in,out); + _Mat.M(in,out); } - void AdjOp (const Field &in, Field &out){ // - _Mat.MpcDag(in,out); + void AdjOp (const Field &in, Field &out){ + _Mat.M(in,out); + } + void HermOpAndNorm(const Field &in, Field &out,double &n1,double &n2){ + ComplexD dot; + + _Mat.M(in,out); + + dot= innerProduct(in,out); + n1=real(dot); + + dot = innerProduct(out,out); + n2=real(dot); } }; - //////////////////////////////////////////////////////////////////////////////////// - // Hermitian wrapper - //////////////////////////////////////////////////////////////////////////////////// - template - class HermitianOperator : public HermitianOperatorBase { - Matrix &_Mat; + ////////////////////////////////////////////////////////// + // Even Odd Schur decomp operators; there are several + // ways to introduce the even odd checkerboarding + ////////////////////////////////////////////////////////// + + template + class SchurOperatorBase : public LinearOperatorBase { public: - HermitianOperator(Matrix &Mat): _Mat(Mat) {}; - void OpAndNorm(const Field &in, Field &out,double &n1,double &n2){ - return _Mat.MdagM(in,out,n1,n2); + virtual RealD Mpc (const Field &in, Field &out) =0; + virtual RealD MpcDag (const Field &in, Field &out) =0; + virtual void MpcDagMpc(const Field &in, Field &out,RealD &ni,RealD &no) { + Field tmp(in._grid); + ni=Mpc(in,tmp); + no=MpcDag(tmp,out); + } + void HermOpAndNorm(const Field &in, Field &out,RealD &n1,RealD &n2){ + MpcDagMpc(in,out,n1,n2); + } + void Op (const Field &in, Field &out){ + Mpc(in,out); + } + void AdjOp (const Field &in, Field &out){ + MpcDag(in,out); } }; - - //////////////////////////////////////////////////////////////////////////////////// - // Hermitian CheckerBoarded wrapper - //////////////////////////////////////////////////////////////////////////////////// template - class HermitianCheckerBoardedOperator : public HermitianOperatorBase { + class SchurDiagMooeeOperator : public SchurOperatorBase { Matrix &_Mat; public: - HermitianCheckerBoardedOperator(Matrix &Mat): _Mat(Mat) {}; - void OpAndNorm(const Field &in, Field &out,RealD &n1,RealD &n2){ - _Mat.MpcDagMpc(in,out,n1,n2); + SchurDiagMooeeOperator (Matrix &Mat): _Mat(Mat){}; + virtual RealD Mpc (const Field &in, Field &out) { + Field tmp(in._grid); + + _Mat.Meooe(in,tmp); + _Mat.MooeeInv(tmp,out); + _Mat.Meooe(out,tmp); + + _Mat.Mooee(in,out); + return axpy_norm(out,-1.0,tmp,out); + } + virtual RealD MpcDag (const Field &in, Field &out){ + Field tmp(in._grid); + + _Mat.MeooeDag(in,tmp); + _Mat.MooeeInvDag(tmp,out); + _Mat.MeooeDag(out,tmp); + + _Mat.MooeeDag(in,out); + return axpy_norm(out,-1.0,tmp,out); + } + }; + template + class SchurDiagOneOperator : public SchurOperatorBase { + Matrix &_Mat; + public: + SchurDiagOneOperator (Matrix &Mat): _Mat(Mat){}; + + virtual RealD Mpc (const Field &in, Field &out) { + Field tmp(in._grid); + + _Mat.Meooe(in,tmp); + _Mat.MooeeInv(tmp,out); + _Mat.Meooe(out,tmp); + _Mat.MooeeInv(tmp,out); + + return axpy_norm(out,-1.0,tmp,in); + } + virtual RealD MpcDag (const Field &in, Field &out){ + Field tmp(in._grid); + + _Mat.MooeeInvDag(in,out); + _Mat.MeooeDag(out,tmp); + _Mat.MooeeInvDag(tmp,out); + _Mat.MeooeDag(out,tmp); + + return axpy_norm(out,-1.0,tmp,in); } }; @@ -106,10 +166,6 @@ namespace Grid { public: virtual void operator() (LinearOperatorBase &Linop, const Field &in, Field &out) = 0; }; - template class HermitianOperatorFunction { - public: - virtual void operator() (HermitianOperatorBase &Linop, const Field &in, Field &out) = 0; - }; // FIXME : To think about diff --git a/lib/algorithms/SparseMatrix.h b/lib/algorithms/SparseMatrix.h index 9c955e9a..9146648f 100644 --- a/lib/algorithms/SparseMatrix.h +++ b/lib/algorithms/SparseMatrix.h @@ -36,32 +36,6 @@ namespace Grid { virtual void MooeeDag (const Field &in, Field &out)=0; virtual void MooeeInvDag (const Field &in, Field &out)=0; - // Schur decomp operators - virtual RealD Mpc (const Field &in, Field &out) { - Field tmp(in._grid); - - Meooe(in,tmp); - MooeeInv(tmp,out); - Meooe(out,tmp); - - Mooee(in,out); - return axpy_norm(out,-1.0,tmp,out); - } - virtual RealD MpcDag (const Field &in, Field &out){ - Field tmp(in._grid); - - MeooeDag(in,tmp); - MooeeInvDag(tmp,out); - MeooeDag(out,tmp); - - MooeeDag(in,out); - return axpy_norm(out,-1.0,tmp,out); - } - virtual void MpcDagMpc(const Field &in, Field &out,RealD &ni,RealD &no) { - Field tmp(in._grid); - ni=Mpc(in,tmp); - no=MpcDag(tmp,out); - } }; } diff --git a/lib/algorithms/iterative/ConjugateGradient.h b/lib/algorithms/iterative/ConjugateGradient.h index d4f89662..e6915d83 100644 --- a/lib/algorithms/iterative/ConjugateGradient.h +++ b/lib/algorithms/iterative/ConjugateGradient.h @@ -9,17 +9,17 @@ namespace Grid { ///////////////////////////////////////////////////////////// template - class ConjugateGradient : public HermitianOperatorFunction { + class ConjugateGradient : public OperatorFunction { public: RealD Tolerance; Integer MaxIterations; int verbose; ConjugateGradient(RealD tol,Integer maxit) : Tolerance(tol), MaxIterations(maxit) { - verbose=0; + verbose=1; }; - void operator() (HermitianOperatorBase &Linop,const Field &src, Field &psi){ + void operator() (LinearOperatorBase &Linop,const Field &src, Field &psi){ psi.checkerboard = src.checkerboard; conformable(psi,src); @@ -33,7 +33,7 @@ public: //Initial residual computation & set up RealD guess = norm2(psi); - Linop.OpAndNorm(psi,mmp,d,b); + Linop.HermOpAndNorm(psi,mmp,d,b); r= src-mmp; p= r; @@ -65,7 +65,7 @@ public: c=cp; - Linop.OpAndNorm(p,mmp,d,qq); + Linop.HermOpAndNorm(p,mmp,d,qq); RealD qqck = norm2(mmp); ComplexD dck = innerProduct(p,mmp); @@ -86,19 +86,10 @@ public: if (verbose) std::cout<<"ConjugateGradient: Iteration " < class SchurRedBlackSolve { + template class SchurRedBlackDiagMooeeSolve { private: - HermitianOperatorFunction & _HermitianRBSolver; + OperatorFunction & _HermitianRBSolver; int CBfactorise; public: ///////////////////////////////////////////////////// // Wrap the usual normal equations Schur trick ///////////////////////////////////////////////////// - SchurRedBlackSolve(HermitianOperatorFunction &HermitianRBSolver) : + SchurRedBlackDiagMooeeSolve(OperatorFunction &HermitianRBSolver) : _HermitianRBSolver(HermitianRBSolver) { CBfactorise=0; @@ -62,6 +62,8 @@ namespace Grid { // FIXME use CBfactorise to control schur decomp GridBase *grid = _Matrix.RedBlackGrid(); GridBase *fgrid= _Matrix.Grid(); + + SchurDiagMooeeOperator _HermOpEO(_Matrix); Field src_e(grid); Field src_o(grid); @@ -80,12 +82,13 @@ namespace Grid { _Matrix.MooeeInv(src_e,tmp); assert( tmp.checkerboard ==Even); _Matrix.Meooe (tmp,Mtmp); assert( Mtmp.checkerboard ==Odd); tmp=src_o-Mtmp; assert( tmp.checkerboard ==Odd); - _Matrix.MpcDag(tmp,src_o); assert(src_o.checkerboard ==Odd); + + // get the right MpcDag + _HermOpEO.MpcDag(tmp,src_o); assert(src_o.checkerboard ==Odd); ////////////////////////////////////////////////////////////// // Call the red-black solver ////////////////////////////////////////////////////////////// - HermitianCheckerBoardedOperator _HermOpEO(_Matrix); std::cout << "SchurRedBlack solver calling the MpcDagMp solver" < HermOp(Ddwf); + MdagMLinearOperator HermOp(Ddwf); ConjugateGradient CG(1.0e-8,10000); CG(HermOp,src,result); @@ -149,7 +149,7 @@ void TestCGprec(What & Ddwf, pickCheckerboard(Odd,src_o,src); result_o=zero; - HermitianCheckerBoardedOperator HermOpEO(Ddwf); + SchurDiagMooeeOperator HermOpEO(Ddwf); ConjugateGradient CG(1.0e-8,10000); CG(HermOpEO,src_o,result_o); } @@ -167,6 +167,6 @@ void TestCGschur(What & Ddwf, LatticeFermion result(FGrid); result=zero; ConjugateGradient CG(1.0e-8,10000); - SchurRedBlackSolve SchurSolver(CG); + SchurRedBlackDiagMooeeSolve SchurSolver(CG); SchurSolver(Ddwf,src,result); } diff --git a/tests/Test_cayley_even_odd.cc b/tests/Test_cayley_even_odd.cc index df28981b..4b5630f7 100644 --- a/tests/Test_cayley_even_odd.cc +++ b/tests/Test_cayley_even_odd.cc @@ -214,7 +214,7 @@ void TestWhat(What & Ddwf, std::cout << "norm diff "<< norm2(err)<< std::endl; std::cout<<"=============================================================="< HermOpEO(Ddwf); + HermOpEO.MpcDagMpc(chi_e,dchi_e,t1,t2); + HermOpEO.MpcDagMpc(chi_o,dchi_o,t1,t2); - Ddwf.MpcDagMpc(phi_e,dphi_e,t1,t2); - Ddwf.MpcDagMpc(phi_o,dphi_o,t1,t2); + HermOpEO.MpcDagMpc(phi_e,dphi_e,t1,t2); + HermOpEO.MpcDagMpc(phi_o,dphi_o,t1,t2); pDce = innerProduct(phi_e,dchi_e); pDco = innerProduct(phi_o,dchi_o); diff --git a/tests/Test_contfrac_cg.cc b/tests/Test_contfrac_cg.cc index b91faae2..5851f5e6 100644 --- a/tests/Test_contfrac_cg.cc +++ b/tests/Test_contfrac_cg.cc @@ -121,7 +121,7 @@ void TestCGunprec(What & Ddwf, LatticeFermion src (FGrid); random(*RNG5,src); LatticeFermion result(FGrid); result=zero; - HermitianOperator HermOp(Ddwf); + MdagMLinearOperator HermOp(Ddwf); ConjugateGradient CG(1.0e-8,10000); CG(HermOp,src,result); @@ -140,7 +140,7 @@ void TestCGprec(What & Ddwf, pickCheckerboard(Odd,src_o,src); result_o=zero; - HermitianCheckerBoardedOperator HermOpEO(Ddwf); + SchurDiagMooeeOperator HermOpEO(Ddwf); ConjugateGradient CG(1.0e-8,10000); CG(HermOpEO,src_o,result_o); } @@ -158,6 +158,6 @@ void TestCGschur(What & Ddwf, LatticeFermion result(FGrid); result=zero; ConjugateGradient CG(1.0e-8,10000); - SchurRedBlackSolve SchurSolver(CG); + SchurRedBlackDiagMooeeSolve SchurSolver(CG); SchurSolver(Ddwf,src,result); } diff --git a/tests/Test_contfrac_even_odd.cc b/tests/Test_contfrac_even_odd.cc index d3f5fb1b..a246b904 100644 --- a/tests/Test_contfrac_even_odd.cc +++ b/tests/Test_contfrac_even_odd.cc @@ -211,11 +211,12 @@ void TestWhat(What & Ddwf, pickCheckerboard(Odd ,phi_o,phi); RealD t1,t2; - Ddwf.MpcDagMpc(chi_e,dchi_e,t1,t2); - Ddwf.MpcDagMpc(chi_o,dchi_o,t1,t2); - - Ddwf.MpcDagMpc(phi_e,dphi_e,t1,t2); - Ddwf.MpcDagMpc(phi_o,dphi_o,t1,t2); + SchurDiagMooeeOperator HermOpEO(Ddwf); + HermOpEO.MpcDagMpc(chi_e,dchi_e,t1,t2); + HermOpEO.MpcDagMpc(chi_o,dchi_o,t1,t2); + + HermOpEO.MpcDagMpc(phi_e,dphi_e,t1,t2); + HermOpEO.MpcDagMpc(phi_o,dphi_o,t1,t2); pDce = innerProduct(phi_e,dchi_e); pDco = innerProduct(phi_o,dchi_o); diff --git a/tests/Test_dwf_cg_prec.cc b/tests/Test_dwf_cg_prec.cc index 0cf86a19..32a0c3ae 100644 --- a/tests/Test_dwf_cg_prec.cc +++ b/tests/Test_dwf_cg_prec.cc @@ -50,7 +50,7 @@ int main (int argc, char ** argv) pickCheckerboard(Odd,src_o,src); result_o=zero; - HermitianCheckerBoardedOperator HermOpEO(Ddwf); + SchurDiagMooeeOperator HermOpEO(Ddwf); ConjugateGradient CG(1.0e-8,10000); CG(HermOpEO,src_o,result_o); diff --git a/tests/Test_dwf_cg_schur.cc b/tests/Test_dwf_cg_schur.cc index aac4d3fd..d080045b 100644 --- a/tests/Test_dwf_cg_schur.cc +++ b/tests/Test_dwf_cg_schur.cc @@ -46,7 +46,7 @@ int main (int argc, char ** argv) DomainWallFermion Ddwf(Umu,*FGrid,*FrbGrid,*UGrid,*UrbGrid,mass,M5); ConjugateGradient CG(1.0e-8,10000); - SchurRedBlackSolve SchurSolver(CG); + SchurRedBlackDiagMooeeSolve SchurSolver(CG); SchurSolver(Ddwf,src,result); Grid_finalize(); diff --git a/tests/Test_dwf_cg_unprec.cc b/tests/Test_dwf_cg_unprec.cc index 5c9e7ad3..7a13be43 100644 --- a/tests/Test_dwf_cg_unprec.cc +++ b/tests/Test_dwf_cg_unprec.cc @@ -45,7 +45,7 @@ int main (int argc, char ** argv) RealD M5=1.8; DomainWallFermion Ddwf(Umu,*FGrid,*FrbGrid,*UGrid,*UrbGrid,mass,M5); - HermitianOperator HermOp(Ddwf); + MdagMLinearOperator HermOp(Ddwf); ConjugateGradient CG(1.0e-8,10000); CG(HermOp,src,result); diff --git a/tests/Test_dwf_even_odd.cc b/tests/Test_dwf_even_odd.cc index ac47bbf9..82b39f68 100644 --- a/tests/Test_dwf_even_odd.cc +++ b/tests/Test_dwf_even_odd.cc @@ -186,11 +186,13 @@ int main (int argc, char ** argv) pickCheckerboard(Odd ,phi_o,phi); RealD t1,t2; - Ddwf.MpcDagMpc(chi_e,dchi_e,t1,t2); - Ddwf.MpcDagMpc(chi_o,dchi_o,t1,t2); - Ddwf.MpcDagMpc(phi_e,dphi_e,t1,t2); - Ddwf.MpcDagMpc(phi_o,dphi_o,t1,t2); + SchurDiagMooeeOperator HermOpEO(Ddwf); + HermOpEO.MpcDagMpc(chi_e,dchi_e,t1,t2); + HermOpEO.MpcDagMpc(chi_o,dchi_o,t1,t2); + + HermOpEO.MpcDagMpc(phi_e,dphi_e,t1,t2); + HermOpEO.MpcDagMpc(phi_o,dphi_o,t1,t2); pDce = innerProduct(phi_e,dchi_e); pDco = innerProduct(phi_o,dchi_o); diff --git a/tests/Test_wilson_cg_prec.cc b/tests/Test_wilson_cg_prec.cc index e376349c..dd8f2821 100644 --- a/tests/Test_wilson_cg_prec.cc +++ b/tests/Test_wilson_cg_prec.cc @@ -53,7 +53,7 @@ int main (int argc, char ** argv) pickCheckerboard(Odd,src_o,src); result_o=zero; - HermitianCheckerBoardedOperator HermOpEO(Dw); + SchurDiagMooeeOperator HermOpEO(Dw); ConjugateGradient CG(1.0e-8,10000); CG(HermOpEO,src_o,result_o); diff --git a/tests/Test_wilson_cg_schur.cc b/tests/Test_wilson_cg_schur.cc index 28db1d4b..56abcd4a 100644 --- a/tests/Test_wilson_cg_schur.cc +++ b/tests/Test_wilson_cg_schur.cc @@ -40,7 +40,7 @@ int main (int argc, char ** argv) WilsonFermion Dw(Umu,Grid,RBGrid,mass); ConjugateGradient CG(1.0e-8,10000); - SchurRedBlackSolve SchurSolver(CG); + SchurRedBlackDiagMooeeSolve SchurSolver(CG); SchurSolver(Dw,src,result); diff --git a/tests/Test_wilson_cg_unprec.cc b/tests/Test_wilson_cg_unprec.cc index 905dfde5..c9f2856b 100644 --- a/tests/Test_wilson_cg_unprec.cc +++ b/tests/Test_wilson_cg_unprec.cc @@ -49,7 +49,7 @@ int main (int argc, char ** argv) RealD mass=0.5; WilsonFermion Dw(Umu,Grid,RBGrid,mass); - HermitianOperator HermOp(Dw); + MdagMLinearOperator HermOp(Dw); ConjugateGradient CG(1.0e-8,10000); CG(HermOp,src,result); diff --git a/tests/Test_wilson_even_odd.cc b/tests/Test_wilson_even_odd.cc index 3ebc4709..16019c9c 100644 --- a/tests/Test_wilson_even_odd.cc +++ b/tests/Test_wilson_even_odd.cc @@ -177,11 +177,12 @@ int main (int argc, char ** argv) pickCheckerboard(Odd ,phi_o,phi); RealD t1,t2; - Dw.MpcDagMpc(chi_e,dchi_e,t1,t2); - Dw.MpcDagMpc(chi_o,dchi_o,t1,t2); + SchurDiagMooeeOperator HermOpEO(Dw); + HermOpEO.MpcDagMpc(chi_e,dchi_e,t1,t2); + HermOpEO.MpcDagMpc(chi_o,dchi_o,t1,t2); - Dw.MpcDagMpc(phi_e,dphi_e,t1,t2); - Dw.MpcDagMpc(phi_o,dphi_o,t1,t2); + HermOpEO.MpcDagMpc(phi_e,dphi_e,t1,t2); + HermOpEO.MpcDagMpc(phi_o,dphi_o,t1,t2); pDce = innerProduct(phi_e,dchi_e); pDco = innerProduct(phi_o,dchi_o);