mirror of
				https://github.com/paboyle/Grid.git
				synced 2025-11-04 05:54:32 +00:00 
			
		
		
		
	Updates now schur red black solver working
This commit is contained in:
		@@ -10,6 +10,7 @@ namespace Grid {
 | 
			
		||||
  /////////////////////////////////////////////////////////////////////////////////////////////
 | 
			
		||||
    template<class Field> class SparseMatrixBase {
 | 
			
		||||
    public:
 | 
			
		||||
      GridBase *_grid;
 | 
			
		||||
      // Full checkerboar operations
 | 
			
		||||
      virtual RealD M    (const Field &in, Field &out)=0;
 | 
			
		||||
      virtual RealD Mdag (const Field &in, Field &out)=0;
 | 
			
		||||
@@ -18,6 +19,7 @@ namespace Grid {
 | 
			
		||||
	ni=M(in,tmp);
 | 
			
		||||
	no=Mdag(tmp,out);
 | 
			
		||||
      }
 | 
			
		||||
      SparseMatrixBase(GridBase *grid) : _grid(grid) {};
 | 
			
		||||
    };
 | 
			
		||||
 | 
			
		||||
  /////////////////////////////////////////////////////////////////////////////////////////////
 | 
			
		||||
@@ -25,7 +27,7 @@ namespace Grid {
 | 
			
		||||
  /////////////////////////////////////////////////////////////////////////////////////////////
 | 
			
		||||
    template<class Field> class CheckerBoardedSparseMatrixBase : public SparseMatrixBase<Field> {
 | 
			
		||||
    public:
 | 
			
		||||
      
 | 
			
		||||
      GridBase *_cbgrid;
 | 
			
		||||
      // half checkerboard operaions
 | 
			
		||||
      virtual  void Meooe    (const Field &in, Field &out)=0;
 | 
			
		||||
      virtual  void Mooee    (const Field &in, Field &out)=0;
 | 
			
		||||
@@ -44,9 +46,7 @@ namespace Grid {
 | 
			
		||||
	Meooe(out,tmp);
 | 
			
		||||
 | 
			
		||||
	Mooee(in,out);
 | 
			
		||||
	out=out-tmp; // axpy_norm
 | 
			
		||||
	RealD n=norm2(out);
 | 
			
		||||
	return n;
 | 
			
		||||
	return axpy_norm(out,-1.0,tmp,out);
 | 
			
		||||
      }
 | 
			
		||||
      virtual  RealD MpcDag   (const Field &in, Field &out){
 | 
			
		||||
	Field tmp(in._grid);
 | 
			
		||||
@@ -56,15 +56,15 @@ namespace Grid {
 | 
			
		||||
	MeooeDag(out,tmp);
 | 
			
		||||
 | 
			
		||||
	MooeeDag(in,out);
 | 
			
		||||
	out=out-tmp; // axpy_norm
 | 
			
		||||
	RealD n=norm2(out);
 | 
			
		||||
	return n;
 | 
			
		||||
	return axpy_norm(out,-1.0,tmp,out);
 | 
			
		||||
      }
 | 
			
		||||
      virtual void MpcDagMpc(const Field &in, Field &out,RealD ni,RealD no) {
 | 
			
		||||
      virtual void MpcDagMpc(const Field &in, Field &out,RealD &ni,RealD &no) {
 | 
			
		||||
	Field tmp(in._grid);
 | 
			
		||||
	ni=Mpc(in,tmp);
 | 
			
		||||
	no=Mpc(tmp,out);
 | 
			
		||||
	no=MpcDag(tmp,out);
 | 
			
		||||
	//	std::cout<<"MpcDagMpc "<<ni<<" "<<no<<std::endl;
 | 
			
		||||
      }
 | 
			
		||||
      CheckerBoardedSparseMatrixBase(GridBase *grid,GridBase *cbgrid) : SparseMatrixBase<Field>(grid), _cbgrid(cbgrid) {};
 | 
			
		||||
    };
 | 
			
		||||
 | 
			
		||||
}
 | 
			
		||||
 
 | 
			
		||||
@@ -9,17 +9,21 @@ namespace Grid {
 | 
			
		||||
    /////////////////////////////////////////////////////////////
 | 
			
		||||
 | 
			
		||||
  template<class Field> 
 | 
			
		||||
    class ConjugateGradient :  public OperatorFunction<Field> {
 | 
			
		||||
    class ConjugateGradient : public HermitianOperatorFunction<Field> {
 | 
			
		||||
public:                                                
 | 
			
		||||
    RealD   Tolerance;
 | 
			
		||||
    Integer MaxIterations;
 | 
			
		||||
 | 
			
		||||
    int verbose;
 | 
			
		||||
    ConjugateGradient(RealD tol,Integer maxit) : Tolerance(tol), MaxIterations(maxit) { 
 | 
			
		||||
      verbose=0;
 | 
			
		||||
    };
 | 
			
		||||
 | 
			
		||||
    void operator() (LinearOperatorBase<Field> &Linop,const Field &src, Field &psi) {assert(0);};
 | 
			
		||||
 | 
			
		||||
    void operator() (HermitianOperatorBase<Field> &Linop,const Field &src, Field &psi){
 | 
			
		||||
 | 
			
		||||
      psi.checkerboard = src.checkerboard;
 | 
			
		||||
      conformable(psi,src);
 | 
			
		||||
 | 
			
		||||
      RealD cp,c,a,d,b,ssq,qq,b_pred;
 | 
			
		||||
      
 | 
			
		||||
      Field   p(src);
 | 
			
		||||
@@ -38,12 +42,14 @@ public:
 | 
			
		||||
      cp =a;
 | 
			
		||||
      ssq=norm2(src);
 | 
			
		||||
 | 
			
		||||
      std::cout <<std::setprecision(4)<< "ConjugateGradient: guess "<<guess<<std::endl;
 | 
			
		||||
      std::cout <<std::setprecision(4)<< "ConjugateGradient:   src "<<ssq  <<std::endl;
 | 
			
		||||
      std::cout <<std::setprecision(4)<< "ConjugateGradient:    mp "<<d    <<std::endl;
 | 
			
		||||
      std::cout <<std::setprecision(4)<< "ConjugateGradient:   mmp "<<b    <<std::endl;
 | 
			
		||||
      std::cout <<std::setprecision(4)<< "ConjugateGradient:     r "<<cp   <<std::endl;
 | 
			
		||||
      std::cout <<std::setprecision(4)<< "ConjugateGradient:     p "<<a    <<std::endl;
 | 
			
		||||
      if ( verbose ) {
 | 
			
		||||
	std::cout <<std::setprecision(4)<< "ConjugateGradient: guess "<<guess<<std::endl;
 | 
			
		||||
	std::cout <<std::setprecision(4)<< "ConjugateGradient:   src "<<ssq  <<std::endl;
 | 
			
		||||
	std::cout <<std::setprecision(4)<< "ConjugateGradient:    mp "<<d    <<std::endl;
 | 
			
		||||
	std::cout <<std::setprecision(4)<< "ConjugateGradient:   mmp "<<b    <<std::endl;
 | 
			
		||||
	std::cout <<std::setprecision(4)<< "ConjugateGradient:  cp,r "<<cp   <<std::endl;
 | 
			
		||||
	std::cout <<std::setprecision(4)<< "ConjugateGradient:     p "<<a    <<std::endl;
 | 
			
		||||
      }
 | 
			
		||||
 | 
			
		||||
      RealD rsq =  Tolerance* Tolerance*ssq;
 | 
			
		||||
      
 | 
			
		||||
@@ -61,13 +67,15 @@ public:
 | 
			
		||||
	
 | 
			
		||||
	Linop.OpAndNorm(p,mmp,d,qq);
 | 
			
		||||
 | 
			
		||||
	//	std::cout <<std::setprecision(4)<< "ConjugateGradient:  d,qq "<<d<< " "<<qq <<std::endl;
 | 
			
		||||
	RealD    qqck = norm2(mmp);
 | 
			
		||||
	ComplexD dck  = innerProduct(p,mmp);
 | 
			
		||||
	//	if (verbose) std::cout <<std::setprecision(4)<< "ConjugateGradient:  d,qq "<<d<< " "<<qq <<" qqcheck "<< qqck<< " dck "<< dck<<std::endl;
 | 
			
		||||
      
 | 
			
		||||
	a      = c/d;
 | 
			
		||||
	b_pred = a*(a*qq-d)/c;
 | 
			
		||||
 | 
			
		||||
	//	std::cout <<std::setprecision(4)<< "ConjugateGradient:  a,bp "<<a<< " "<<b_pred <<std::endl;
 | 
			
		||||
 | 
			
		||||
	//	if (verbose) std::cout <<std::setprecision(4)<< "ConjugateGradient:  a,bp "<<a<< " "<<b_pred <<std::endl;
 | 
			
		||||
	cp = axpy_norm(r,-a,mmp,r);
 | 
			
		||||
	b = cp/c;
 | 
			
		||||
	//	std::cout <<std::setprecision(4)<< "ConjugateGradient:  cp,b "<<cp<< " "<<b <<std::endl;
 | 
			
		||||
@@ -76,7 +84,16 @@ public:
 | 
			
		||||
	psi= a*p+psi;
 | 
			
		||||
	p  = p*b+r;
 | 
			
		||||
	  
 | 
			
		||||
	std::cout<<"ConjugateGradient: Iteration " <<k<<" residual "<<cp<< " target"<< rsq<<std::endl;
 | 
			
		||||
	if (verbose) std::cout<<"ConjugateGradient: Iteration " <<k<<" residual "<<cp<< " target"<< rsq<<std::endl;
 | 
			
		||||
 | 
			
		||||
	// Hack
 | 
			
		||||
	if (0) {
 | 
			
		||||
	  Field   tt(src);
 | 
			
		||||
      	  Linop.Op(psi,mmp);
 | 
			
		||||
	  tt=mmp-src;
 | 
			
		||||
	  RealD resnorm = norm2(tt);
 | 
			
		||||
	  std::cout<<"ConjugateGradient: Iteration " <<k<<" true residual "<<resnorm << " computed " << cp <<std::endl;
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	// Stopping condition
 | 
			
		||||
	if ( cp <= rsq ) { 
 | 
			
		||||
 
 | 
			
		||||
@@ -1,6 +1,7 @@
 | 
			
		||||
#ifndef GRID_SCHUR_RED_BLACK_H
 | 
			
		||||
#define GRID_SCHUR_RED_BLACK_H
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
  /*
 | 
			
		||||
   * Red black Schur decomposition
 | 
			
		||||
   *
 | 
			
		||||
@@ -25,53 +26,50 @@
 | 
			
		||||
   *     M psi = eta
 | 
			
		||||
   ***********************
 | 
			
		||||
   *Odd
 | 
			
		||||
   * i)   (D_oo)^{\dag} D_oo psi_o = (D_oo)^\dag L^{-1}  eta_o
 | 
			
		||||
   *                        eta_o' = D_oo (eta_o - Moe Mee^{-1} eta_e)
 | 
			
		||||
   * i)   (D_oo)^{\dag} D_oo psi_o = (D_oo)^dag L^{-1}  eta_o
 | 
			
		||||
   *                        eta_o' = (D_oo)^dag (eta_o - Moe Mee^{-1} eta_e)
 | 
			
		||||
   *Even
 | 
			
		||||
   * ii)  Mee psi_e + Meo psi_o = src_e
 | 
			
		||||
   *
 | 
			
		||||
   *   => sol_e = M_ee^-1 * ( src_e - Meo sol_o )...
 | 
			
		||||
   *
 | 
			
		||||
   */
 | 
			
		||||
 | 
			
		||||
namespace Grid {
 | 
			
		||||
 | 
			
		||||
  ///////////////////////////////////////////////////////////////////////////////////////////////////////
 | 
			
		||||
  // Take a matrix and form a Red Black solver calling a Herm solver
 | 
			
		||||
  // Use of RB info prevents making SchurRedBlackSolve conform to standard interface
 | 
			
		||||
  ///////////////////////////////////////////////////////////////////////////////////////////////////////
 | 
			
		||||
  template<class Field> class SchurRedBlackSolve : public OperatorFunction<Field>{
 | 
			
		||||
  template<class Field> class SchurRedBlackSolve {
 | 
			
		||||
  private:
 | 
			
		||||
    SparseMatrixBase<Field> & _Matrix;
 | 
			
		||||
    OperatorFunction<Field> & _HermitianRBSolver;
 | 
			
		||||
    HermitianOperatorFunction<Field> & _HermitianRBSolver;
 | 
			
		||||
    int CBfactorise;
 | 
			
		||||
  public:
 | 
			
		||||
 | 
			
		||||
    /////////////////////////////////////////////////////
 | 
			
		||||
    // Wrap the usual normal equations Schur trick
 | 
			
		||||
    /////////////////////////////////////////////////////
 | 
			
		||||
  SchurRedBlackSolve(SparseMatrixBase<Field> &Matrix, OperatorFunction<Field> &HermitianRBSolver) 
 | 
			
		||||
    :  _Matrix(Matrix), _HermitianRBSolver(HermitianRBSolver) { 
 | 
			
		||||
  SchurRedBlackSolve(HermitianOperatorFunction<Field> &HermitianRBSolver)  :
 | 
			
		||||
     _HermitianRBSolver(HermitianRBSolver) 
 | 
			
		||||
    { 
 | 
			
		||||
      CBfactorise=0;
 | 
			
		||||
    };
 | 
			
		||||
 | 
			
		||||
    void operator() (const Field &in, Field &out){
 | 
			
		||||
    template<class Matrix>
 | 
			
		||||
      void operator() (Matrix & _Matrix,const Field &in, Field &out){
 | 
			
		||||
 | 
			
		||||
      // FIXME CGdiagonalMee not implemented virtual function
 | 
			
		||||
      // FIXME need to make eo grid from full grid.
 | 
			
		||||
      // FIXME use CBfactorise to control schur decomp
 | 
			
		||||
      const int Even=0;
 | 
			
		||||
      const int Odd =1;
 | 
			
		||||
      GridBase *grid = _Matrix._cbgrid;
 | 
			
		||||
      GridBase *fgrid= _Matrix._grid;
 | 
			
		||||
 
 | 
			
		||||
      // Make a cartesianRedBlack from full Grid
 | 
			
		||||
      GridRedBlackCartesian grid(in._grid);
 | 
			
		||||
 
 | 
			
		||||
      Field src_e(&grid);
 | 
			
		||||
      Field src_o(&grid);
 | 
			
		||||
      Field sol_e(&grid);
 | 
			
		||||
      Field sol_o(&grid);
 | 
			
		||||
      Field   tmp(&grid);
 | 
			
		||||
      Field  Mtmp(&grid);
 | 
			
		||||
      Field src_e(grid);
 | 
			
		||||
      Field src_o(grid);
 | 
			
		||||
      Field sol_e(grid);
 | 
			
		||||
      Field sol_o(grid);
 | 
			
		||||
      Field   tmp(grid);
 | 
			
		||||
      Field  Mtmp(grid);
 | 
			
		||||
      Field resid(fgrid);
 | 
			
		||||
 | 
			
		||||
      pickCheckerboard(Even,src_e,in);
 | 
			
		||||
      pickCheckerboard(Odd ,src_o,in);
 | 
			
		||||
@@ -79,26 +77,35 @@ namespace Grid {
 | 
			
		||||
      /////////////////////////////////////////////////////
 | 
			
		||||
      // src_o = Mdag * (source_o - Moe MeeInv source_e)
 | 
			
		||||
      /////////////////////////////////////////////////////
 | 
			
		||||
      _Matrix.MooeeInv(src_e,tmp);     //    MooeeInv(source[Even],tmp,DaggerNo,Even);
 | 
			
		||||
      _Matrix.Meooe   (tmp,Mtmp);      //    Meo     (tmp,src,Odd,DaggerNo);
 | 
			
		||||
      tmp=src_o-Mtmp;                  //    axpy    (tmp,src,source[Odd],-1.0);
 | 
			
		||||
      _Matrix.MpcDag(tmp,src_o);       //    Mprec(tmp,src,Mtmp,DaggerYes);  
 | 
			
		||||
      _Matrix.MooeeInv(src_e,tmp);     assert(  tmp.checkerboard ==Even);
 | 
			
		||||
      _Matrix.Meooe   (tmp,Mtmp);      assert( Mtmp.checkerboard ==Odd);     
 | 
			
		||||
      tmp=src_o-Mtmp;                  assert(  tmp.checkerboard ==Odd);     
 | 
			
		||||
      _Matrix.MpcDag(tmp,src_o);       assert(src_o.checkerboard ==Odd);       
 | 
			
		||||
 | 
			
		||||
      //////////////////////////////////////////////////////////////
 | 
			
		||||
      // Call the red-black solver
 | 
			
		||||
      //////////////////////////////////////////////////////////////
 | 
			
		||||
      _HermitianRBSolver(src_o,sol_o); //  CGNE_prec_MdagM(solution[Odd],src);
 | 
			
		||||
      HermitianCheckerBoardedOperator<Matrix,Field> _HermOpEO(_Matrix);
 | 
			
		||||
      std::cout << "SchurRedBlack solver calling the MpcDagMp solver" <<std::endl;
 | 
			
		||||
      _HermitianRBSolver(_HermOpEO,src_o,sol_o);  assert(sol_o.checkerboard==Odd);
 | 
			
		||||
 | 
			
		||||
      ///////////////////////////////////////////////////
 | 
			
		||||
      // sol_e = M_ee^-1 * ( src_e - Meo sol_o )...
 | 
			
		||||
      ///////////////////////////////////////////////////
 | 
			
		||||
      _Matrix.Meooe(sol_o,tmp);        // Meo(solution[Odd],tmp,Even,DaggerNo);
 | 
			
		||||
      src_e = src_e-tmp;               // axpy(src,tmp,source[Even],-1.0);
 | 
			
		||||
      _Matrix.MooeeInv(src_e,sol_e);   // MooeeInv(src,solution[Even],DaggerNo,Even);
 | 
			
		||||
      _Matrix.Meooe(sol_o,tmp);        assert(  tmp.checkerboard   ==Even);
 | 
			
		||||
      src_e = src_e-tmp;               assert(  src_e.checkerboard ==Even);
 | 
			
		||||
      _Matrix.MooeeInv(src_e,sol_e);   assert(  sol_e.checkerboard ==Even);
 | 
			
		||||
     
 | 
			
		||||
      setCheckerboard(out,sol_e);
 | 
			
		||||
      setCheckerboard(out,sol_o);
 | 
			
		||||
      setCheckerboard(out,sol_e); assert(  sol_e.checkerboard ==Even);
 | 
			
		||||
      setCheckerboard(out,sol_o); assert(  sol_o.checkerboard ==Odd );
 | 
			
		||||
 | 
			
		||||
      // Verify the unprec residual
 | 
			
		||||
      _Matrix.M(out,resid); 
 | 
			
		||||
      resid = resid-in;
 | 
			
		||||
      RealD ns = norm2(in);
 | 
			
		||||
      RealD nr = norm2(resid);
 | 
			
		||||
 | 
			
		||||
      std::cout << "SchurRedBlack solver true unprec resid "<< sqrt(nr/ns) <<" nr "<< nr <<" ns "<<ns << std::endl;
 | 
			
		||||
    }     
 | 
			
		||||
  };
 | 
			
		||||
 | 
			
		||||
 
 | 
			
		||||
		Reference in New Issue
	
	Block a user