mirror of
				https://github.com/paboyle/Grid.git
				synced 2025-11-04 05:54:32 +00:00 
			
		
		
		
	Rework the linop support to get different forms of red black schur solver
Moo on diag, or MooInv Moe MeeInv Meo
This commit is contained in:
		@@ -18,84 +18,144 @@ namespace Grid {
 | 
			
		||||
    public:
 | 
			
		||||
      virtual void Op     (const Field &in, Field &out) = 0; // Abstract base
 | 
			
		||||
      virtual void AdjOp  (const Field &in, Field &out) = 0; // Abstract base
 | 
			
		||||
      virtual void HermOpAndNorm(const Field &in, Field &out,double &n1,double &n2)=0;
 | 
			
		||||
    };
 | 
			
		||||
 | 
			
		||||
  /////////////////////////////////////////////////////////////////////////////////////////////
 | 
			
		||||
  // Hermitian operators are self adjoint and only require Op to be defined, so refine the base
 | 
			
		||||
  /////////////////////////////////////////////////////////////////////////////////////////////
 | 
			
		||||
    template<class Field> class HermitianOperatorBase : public LinearOperatorBase<Field> {
 | 
			
		||||
    public:
 | 
			
		||||
      virtual void OpAndNorm(const Field &in, Field &out,double &n1,double &n2)=0;
 | 
			
		||||
      void AdjOp(const Field &in, Field &out) {
 | 
			
		||||
	Op(in,out);
 | 
			
		||||
      };
 | 
			
		||||
      void Op(const Field &in, Field &out) {
 | 
			
		||||
	double n1,n2;
 | 
			
		||||
	OpAndNorm(in,out,n1,n2);
 | 
			
		||||
      };
 | 
			
		||||
    };
 | 
			
		||||
 | 
			
		||||
  /////////////////////////////////////////////////////////////////////////////////////////////
 | 
			
		||||
  // Whereas non hermitian takes a generic sparse matrix (e.g. lattice action)
 | 
			
		||||
  // conforming to sparse matrix interface and builds the full checkerboard non-herm operator
 | 
			
		||||
  // Op and AdjOp distinct.
 | 
			
		||||
  // By sharing the class for Sparse Matrix across multiple operator wrappers, we can share code
 | 
			
		||||
  // between RB and non-RB variants. Sparse matrix is like the fermion action def, and then
 | 
			
		||||
  // the wrappers implement the specialisation of "Op" and "AdjOp" to the cases minimising
 | 
			
		||||
  // replication of code.
 | 
			
		||||
  //
 | 
			
		||||
  // I'm not entirely happy with implementation; to share the Schur code between herm and non-herm
 | 
			
		||||
  // while still having a "OpAndNorm" in the abstract base I had to implement it in both cases
 | 
			
		||||
  // with an assert trap in the non-herm. This isn't right; there must be a better C++ way to
 | 
			
		||||
  // do it, but I fear it required multiple inheritance and mixed in abstract base classes
 | 
			
		||||
  /////////////////////////////////////////////////////////////////////////////////////////////
 | 
			
		||||
 | 
			
		||||
    ////////////////////////////////////////////////////////////////////
 | 
			
		||||
    // Construct herm op from non-herm matrix
 | 
			
		||||
    ////////////////////////////////////////////////////////////////////
 | 
			
		||||
    template<class Matrix,class Field>
 | 
			
		||||
    class NonHermitianOperator : public LinearOperatorBase<Field> {
 | 
			
		||||
    class MdagMLinearOperator : public LinearOperatorBase<Field> {
 | 
			
		||||
      Matrix &_Mat;
 | 
			
		||||
    public:
 | 
			
		||||
      NonHermitianOperator(Matrix &Mat): _Mat(Mat){};
 | 
			
		||||
    MdagMLinearOperator(Matrix &Mat): _Mat(Mat){};
 | 
			
		||||
      void Op     (const Field &in, Field &out){
 | 
			
		||||
	_Mat.M(in,out);
 | 
			
		||||
      }
 | 
			
		||||
      void AdjOp     (const Field &in, Field &out){
 | 
			
		||||
	_Mat.Mdag(in,out);
 | 
			
		||||
      }
 | 
			
		||||
      void HermOpAndNorm(const Field &in, Field &out,double &n1,double &n2){
 | 
			
		||||
	_Mat.MdagM(in,out,n1,n2);
 | 
			
		||||
      }
 | 
			
		||||
    };
 | 
			
		||||
    
 | 
			
		||||
    ////////////////////////////////////////////////////////////////////////////////////
 | 
			
		||||
    // Redblack Non hermitian wrapper
 | 
			
		||||
    ////////////////////////////////////////////////////////////////////////////////////
 | 
			
		||||
 | 
			
		||||
    ////////////////////////////////////////////////////////////////////
 | 
			
		||||
    // Wrap an already herm matrix
 | 
			
		||||
    ////////////////////////////////////////////////////////////////////
 | 
			
		||||
    template<class Matrix,class Field>
 | 
			
		||||
    class NonHermitianCheckerBoardedOperator : public LinearOperatorBase<Field> {
 | 
			
		||||
    class HermitianLinearOperator : public LinearOperatorBase<Field> {
 | 
			
		||||
      Matrix &_Mat;
 | 
			
		||||
    public:
 | 
			
		||||
      NonHermitianCheckerBoardedOperator(Matrix &Mat): _Mat(Mat){};
 | 
			
		||||
    HermitianLinearOperator(Matrix &Mat): _Mat(Mat){};
 | 
			
		||||
      void Op     (const Field &in, Field &out){
 | 
			
		||||
	_Mat.Mpc(in,out);
 | 
			
		||||
	_Mat.M(in,out);
 | 
			
		||||
      }
 | 
			
		||||
      void AdjOp     (const Field &in, Field &out){ //
 | 
			
		||||
	_Mat.MpcDag(in,out);
 | 
			
		||||
      void AdjOp     (const Field &in, Field &out){
 | 
			
		||||
	_Mat.M(in,out);
 | 
			
		||||
      }
 | 
			
		||||
      void HermOpAndNorm(const Field &in, Field &out,double &n1,double &n2){
 | 
			
		||||
	ComplexD dot;
 | 
			
		||||
 | 
			
		||||
	_Mat.M(in,out);
 | 
			
		||||
	
 | 
			
		||||
	dot= innerProduct(in,out);
 | 
			
		||||
	n1=real(dot);
 | 
			
		||||
 | 
			
		||||
	dot = innerProduct(out,out);
 | 
			
		||||
	n2=real(dot);
 | 
			
		||||
      }
 | 
			
		||||
    };
 | 
			
		||||
 | 
			
		||||
    ////////////////////////////////////////////////////////////////////////////////////
 | 
			
		||||
    // Hermitian wrapper
 | 
			
		||||
    ////////////////////////////////////////////////////////////////////////////////////
 | 
			
		||||
    template<class Matrix,class Field>
 | 
			
		||||
    class HermitianOperator : public HermitianOperatorBase<Field> {
 | 
			
		||||
      Matrix &_Mat;
 | 
			
		||||
    //////////////////////////////////////////////////////////
 | 
			
		||||
    // Even Odd Schur decomp operators; there are several
 | 
			
		||||
    // ways to introduce the even odd checkerboarding
 | 
			
		||||
    //////////////////////////////////////////////////////////
 | 
			
		||||
 | 
			
		||||
    template<class Field>
 | 
			
		||||
      class SchurOperatorBase :  public LinearOperatorBase<Field> {
 | 
			
		||||
    public:
 | 
			
		||||
      HermitianOperator(Matrix &Mat): _Mat(Mat) {};
 | 
			
		||||
      void OpAndNorm(const Field &in, Field &out,double &n1,double &n2){
 | 
			
		||||
	return _Mat.MdagM(in,out,n1,n2);
 | 
			
		||||
      virtual  RealD Mpc      (const Field &in, Field &out) =0;
 | 
			
		||||
      virtual  RealD MpcDag   (const Field &in, Field &out) =0;
 | 
			
		||||
      virtual void MpcDagMpc(const Field &in, Field &out,RealD &ni,RealD &no) {
 | 
			
		||||
	Field tmp(in._grid);
 | 
			
		||||
	ni=Mpc(in,tmp);
 | 
			
		||||
	no=MpcDag(tmp,out);
 | 
			
		||||
      }
 | 
			
		||||
      void HermOpAndNorm(const Field &in, Field &out,RealD &n1,RealD &n2){
 | 
			
		||||
	MpcDagMpc(in,out,n1,n2);
 | 
			
		||||
      }
 | 
			
		||||
      void Op     (const Field &in, Field &out){
 | 
			
		||||
	Mpc(in,out);
 | 
			
		||||
      }
 | 
			
		||||
      void AdjOp     (const Field &in, Field &out){ 
 | 
			
		||||
	MpcDag(in,out);
 | 
			
		||||
      }
 | 
			
		||||
    };
 | 
			
		||||
 | 
			
		||||
    ////////////////////////////////////////////////////////////////////////////////////
 | 
			
		||||
    // Hermitian CheckerBoarded wrapper
 | 
			
		||||
    ////////////////////////////////////////////////////////////////////////////////////
 | 
			
		||||
    template<class Matrix,class Field>
 | 
			
		||||
    class HermitianCheckerBoardedOperator : public HermitianOperatorBase<Field> {
 | 
			
		||||
      class SchurDiagMooeeOperator :  public SchurOperatorBase<Field> {
 | 
			
		||||
      Matrix &_Mat;
 | 
			
		||||
    public:
 | 
			
		||||
      HermitianCheckerBoardedOperator(Matrix &Mat): _Mat(Mat) {};
 | 
			
		||||
      void OpAndNorm(const Field &in, Field &out,RealD &n1,RealD &n2){
 | 
			
		||||
	_Mat.MpcDagMpc(in,out,n1,n2);
 | 
			
		||||
      SchurDiagMooeeOperator (Matrix &Mat): _Mat(Mat){};
 | 
			
		||||
      virtual  RealD Mpc      (const Field &in, Field &out) {
 | 
			
		||||
	Field tmp(in._grid);
 | 
			
		||||
 | 
			
		||||
	_Mat.Meooe(in,tmp);
 | 
			
		||||
	_Mat.MooeeInv(tmp,out);
 | 
			
		||||
	_Mat.Meooe(out,tmp);
 | 
			
		||||
 | 
			
		||||
	_Mat.Mooee(in,out);
 | 
			
		||||
	return axpy_norm(out,-1.0,tmp,out);
 | 
			
		||||
      }
 | 
			
		||||
      virtual  RealD MpcDag   (const Field &in, Field &out){
 | 
			
		||||
	Field tmp(in._grid);
 | 
			
		||||
 | 
			
		||||
	_Mat.MeooeDag(in,tmp);
 | 
			
		||||
	_Mat.MooeeInvDag(tmp,out);
 | 
			
		||||
	_Mat.MeooeDag(out,tmp);
 | 
			
		||||
 | 
			
		||||
	_Mat.MooeeDag(in,out);
 | 
			
		||||
	return axpy_norm(out,-1.0,tmp,out);
 | 
			
		||||
      }
 | 
			
		||||
    };
 | 
			
		||||
    template<class Matrix,class Field>
 | 
			
		||||
      class SchurDiagOneOperator :  public SchurOperatorBase<Field> {
 | 
			
		||||
      Matrix &_Mat;
 | 
			
		||||
    public:
 | 
			
		||||
      SchurDiagOneOperator (Matrix &Mat): _Mat(Mat){};
 | 
			
		||||
 | 
			
		||||
      virtual  RealD Mpc      (const Field &in, Field &out) {
 | 
			
		||||
	Field tmp(in._grid);
 | 
			
		||||
 | 
			
		||||
	_Mat.Meooe(in,tmp);
 | 
			
		||||
	_Mat.MooeeInv(tmp,out);
 | 
			
		||||
	_Mat.Meooe(out,tmp);
 | 
			
		||||
	_Mat.MooeeInv(tmp,out);
 | 
			
		||||
 | 
			
		||||
	return axpy_norm(out,-1.0,tmp,in);
 | 
			
		||||
      }
 | 
			
		||||
      virtual  RealD MpcDag   (const Field &in, Field &out){
 | 
			
		||||
	Field tmp(in._grid);
 | 
			
		||||
 | 
			
		||||
	_Mat.MooeeInvDag(in,out);
 | 
			
		||||
	_Mat.MeooeDag(out,tmp);
 | 
			
		||||
	_Mat.MooeeInvDag(tmp,out);
 | 
			
		||||
	_Mat.MeooeDag(out,tmp);
 | 
			
		||||
 | 
			
		||||
	return axpy_norm(out,-1.0,tmp,in);
 | 
			
		||||
      }
 | 
			
		||||
    };
 | 
			
		||||
 | 
			
		||||
@@ -106,10 +166,6 @@ namespace Grid {
 | 
			
		||||
    public:
 | 
			
		||||
      virtual void operator() (LinearOperatorBase<Field> &Linop, const Field &in, Field &out) = 0;
 | 
			
		||||
    };
 | 
			
		||||
    template<class Field> class HermitianOperatorFunction {
 | 
			
		||||
    public:
 | 
			
		||||
      virtual void operator() (HermitianOperatorBase<Field> &Linop, const Field &in, Field &out) = 0;
 | 
			
		||||
    };
 | 
			
		||||
 | 
			
		||||
    // FIXME : To think about
 | 
			
		||||
 | 
			
		||||
 
 | 
			
		||||
@@ -36,32 +36,6 @@ namespace Grid {
 | 
			
		||||
      virtual  void MooeeDag    (const Field &in, Field &out)=0;
 | 
			
		||||
      virtual  void MooeeInvDag (const Field &in, Field &out)=0;
 | 
			
		||||
 | 
			
		||||
      // Schur decomp operators
 | 
			
		||||
      virtual  RealD Mpc      (const Field &in, Field &out) {
 | 
			
		||||
	Field tmp(in._grid);
 | 
			
		||||
 | 
			
		||||
	Meooe(in,tmp);
 | 
			
		||||
	MooeeInv(tmp,out);
 | 
			
		||||
	Meooe(out,tmp);
 | 
			
		||||
 | 
			
		||||
	Mooee(in,out);
 | 
			
		||||
	return axpy_norm(out,-1.0,tmp,out);
 | 
			
		||||
      }
 | 
			
		||||
      virtual  RealD MpcDag   (const Field &in, Field &out){
 | 
			
		||||
	Field tmp(in._grid);
 | 
			
		||||
 | 
			
		||||
	MeooeDag(in,tmp);
 | 
			
		||||
	MooeeInvDag(tmp,out);
 | 
			
		||||
	MeooeDag(out,tmp);
 | 
			
		||||
 | 
			
		||||
	MooeeDag(in,out);
 | 
			
		||||
	return axpy_norm(out,-1.0,tmp,out);
 | 
			
		||||
      }
 | 
			
		||||
      virtual void MpcDagMpc(const Field &in, Field &out,RealD &ni,RealD &no) {
 | 
			
		||||
	Field tmp(in._grid);
 | 
			
		||||
	ni=Mpc(in,tmp);
 | 
			
		||||
	no=MpcDag(tmp,out);
 | 
			
		||||
      }
 | 
			
		||||
    };
 | 
			
		||||
 | 
			
		||||
}
 | 
			
		||||
 
 | 
			
		||||
@@ -9,17 +9,17 @@ namespace Grid {
 | 
			
		||||
    /////////////////////////////////////////////////////////////
 | 
			
		||||
 | 
			
		||||
  template<class Field> 
 | 
			
		||||
    class ConjugateGradient : public HermitianOperatorFunction<Field> {
 | 
			
		||||
    class ConjugateGradient : public OperatorFunction<Field> {
 | 
			
		||||
public:                                                
 | 
			
		||||
    RealD   Tolerance;
 | 
			
		||||
    Integer MaxIterations;
 | 
			
		||||
    int verbose;
 | 
			
		||||
    ConjugateGradient(RealD tol,Integer maxit) : Tolerance(tol), MaxIterations(maxit) { 
 | 
			
		||||
      verbose=0;
 | 
			
		||||
      verbose=1;
 | 
			
		||||
    };
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
    void operator() (HermitianOperatorBase<Field> &Linop,const Field &src, Field &psi){
 | 
			
		||||
    void operator() (LinearOperatorBase<Field> &Linop,const Field &src, Field &psi){
 | 
			
		||||
 | 
			
		||||
      psi.checkerboard = src.checkerboard;
 | 
			
		||||
      conformable(psi,src);
 | 
			
		||||
@@ -33,7 +33,7 @@ public:
 | 
			
		||||
      //Initial residual computation & set up
 | 
			
		||||
      RealD guess = norm2(psi);
 | 
			
		||||
      
 | 
			
		||||
      Linop.OpAndNorm(psi,mmp,d,b);
 | 
			
		||||
      Linop.HermOpAndNorm(psi,mmp,d,b);
 | 
			
		||||
      
 | 
			
		||||
      r= src-mmp;
 | 
			
		||||
      p= r;
 | 
			
		||||
@@ -65,7 +65,7 @@ public:
 | 
			
		||||
	
 | 
			
		||||
	c=cp;
 | 
			
		||||
	
 | 
			
		||||
	Linop.OpAndNorm(p,mmp,d,qq);
 | 
			
		||||
	Linop.HermOpAndNorm(p,mmp,d,qq);
 | 
			
		||||
 | 
			
		||||
	RealD    qqck = norm2(mmp);
 | 
			
		||||
	ComplexD dck  = innerProduct(p,mmp);
 | 
			
		||||
@@ -86,19 +86,10 @@ public:
 | 
			
		||||
	  
 | 
			
		||||
	if (verbose) std::cout<<"ConjugateGradient: Iteration " <<k<<" residual "<<cp<< " target"<< rsq<<std::endl;
 | 
			
		||||
 | 
			
		||||
	// Hack
 | 
			
		||||
	if (0) {
 | 
			
		||||
	  Field   tt(src);
 | 
			
		||||
      	  Linop.Op(psi,mmp);
 | 
			
		||||
	  tt=mmp-src;
 | 
			
		||||
	  RealD resnorm = norm2(tt);
 | 
			
		||||
	  std::cout<<"ConjugateGradient: Iteration " <<k<<" true residual "<<resnorm << " computed " << cp <<std::endl;
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	// Stopping condition
 | 
			
		||||
	if ( cp <= rsq ) { 
 | 
			
		||||
	  
 | 
			
		||||
	  Linop.Op(psi,mmp);
 | 
			
		||||
	  Linop.HermOpAndNorm(psi,mmp,d,qq);
 | 
			
		||||
	  p=mmp-src;
 | 
			
		||||
	  
 | 
			
		||||
	  RealD mmpnorm = sqrt(norm2(mmp));
 | 
			
		||||
 
 | 
			
		||||
@@ -40,16 +40,16 @@ namespace Grid {
 | 
			
		||||
  // Take a matrix and form a Red Black solver calling a Herm solver
 | 
			
		||||
  // Use of RB info prevents making SchurRedBlackSolve conform to standard interface
 | 
			
		||||
  ///////////////////////////////////////////////////////////////////////////////////////////////////////
 | 
			
		||||
  template<class Field> class SchurRedBlackSolve {
 | 
			
		||||
  template<class Field> class SchurRedBlackDiagMooeeSolve {
 | 
			
		||||
  private:
 | 
			
		||||
    HermitianOperatorFunction<Field> & _HermitianRBSolver;
 | 
			
		||||
    OperatorFunction<Field> & _HermitianRBSolver;
 | 
			
		||||
    int CBfactorise;
 | 
			
		||||
  public:
 | 
			
		||||
 | 
			
		||||
    /////////////////////////////////////////////////////
 | 
			
		||||
    // Wrap the usual normal equations Schur trick
 | 
			
		||||
    /////////////////////////////////////////////////////
 | 
			
		||||
  SchurRedBlackSolve(HermitianOperatorFunction<Field> &HermitianRBSolver)  :
 | 
			
		||||
  SchurRedBlackDiagMooeeSolve(OperatorFunction<Field> &HermitianRBSolver)  :
 | 
			
		||||
     _HermitianRBSolver(HermitianRBSolver) 
 | 
			
		||||
    { 
 | 
			
		||||
      CBfactorise=0;
 | 
			
		||||
@@ -62,6 +62,8 @@ namespace Grid {
 | 
			
		||||
      // FIXME use CBfactorise to control schur decomp
 | 
			
		||||
      GridBase *grid = _Matrix.RedBlackGrid();
 | 
			
		||||
      GridBase *fgrid= _Matrix.Grid();
 | 
			
		||||
 | 
			
		||||
      SchurDiagMooeeOperator<Matrix,Field> _HermOpEO(_Matrix);
 | 
			
		||||
 
 | 
			
		||||
      Field src_e(grid);
 | 
			
		||||
      Field src_o(grid);
 | 
			
		||||
@@ -80,12 +82,13 @@ namespace Grid {
 | 
			
		||||
      _Matrix.MooeeInv(src_e,tmp);     assert(  tmp.checkerboard ==Even);
 | 
			
		||||
      _Matrix.Meooe   (tmp,Mtmp);      assert( Mtmp.checkerboard ==Odd);     
 | 
			
		||||
      tmp=src_o-Mtmp;                  assert(  tmp.checkerboard ==Odd);     
 | 
			
		||||
      _Matrix.MpcDag(tmp,src_o);       assert(src_o.checkerboard ==Odd);       
 | 
			
		||||
 | 
			
		||||
      // get the right MpcDag
 | 
			
		||||
      _HermOpEO.MpcDag(tmp,src_o);     assert(src_o.checkerboard ==Odd);       
 | 
			
		||||
 | 
			
		||||
      //////////////////////////////////////////////////////////////
 | 
			
		||||
      // Call the red-black solver
 | 
			
		||||
      //////////////////////////////////////////////////////////////
 | 
			
		||||
      HermitianCheckerBoardedOperator<Matrix,Field> _HermOpEO(_Matrix);
 | 
			
		||||
      std::cout << "SchurRedBlack solver calling the MpcDagMp solver" <<std::endl;
 | 
			
		||||
      _HermitianRBSolver(_HermOpEO,src_o,sol_o);  assert(sol_o.checkerboard==Odd);
 | 
			
		||||
 | 
			
		||||
@@ -105,7 +108,7 @@ namespace Grid {
 | 
			
		||||
      RealD ns = norm2(in);
 | 
			
		||||
      RealD nr = norm2(resid);
 | 
			
		||||
 | 
			
		||||
      std::cout << "SchurRedBlack solver true unprec resid "<< sqrt(nr/ns) <<" nr "<< nr <<" ns "<<ns << std::endl;
 | 
			
		||||
      std::cout << "SchurRedBlackDiagMooee solver true unprec resid "<< sqrt(nr/ns) <<" nr "<< nr <<" ns "<<ns << std::endl;
 | 
			
		||||
    }     
 | 
			
		||||
  };
 | 
			
		||||
 | 
			
		||||
 
 | 
			
		||||
		Reference in New Issue
	
	Block a user