mirror of
				https://github.com/paboyle/Grid.git
				synced 2025-10-30 19:44:32 +00:00 
			
		
		
		
	Updates now schur red black solver working
This commit is contained in:
		| @@ -10,6 +10,7 @@ namespace Grid { | |||||||
|   ///////////////////////////////////////////////////////////////////////////////////////////// |   ///////////////////////////////////////////////////////////////////////////////////////////// | ||||||
|     template<class Field> class SparseMatrixBase { |     template<class Field> class SparseMatrixBase { | ||||||
|     public: |     public: | ||||||
|  |       GridBase *_grid; | ||||||
|       // Full checkerboar operations |       // Full checkerboar operations | ||||||
|       virtual RealD M    (const Field &in, Field &out)=0; |       virtual RealD M    (const Field &in, Field &out)=0; | ||||||
|       virtual RealD Mdag (const Field &in, Field &out)=0; |       virtual RealD Mdag (const Field &in, Field &out)=0; | ||||||
| @@ -18,6 +19,7 @@ namespace Grid { | |||||||
| 	ni=M(in,tmp); | 	ni=M(in,tmp); | ||||||
| 	no=Mdag(tmp,out); | 	no=Mdag(tmp,out); | ||||||
|       } |       } | ||||||
|  |       SparseMatrixBase(GridBase *grid) : _grid(grid) {}; | ||||||
|     }; |     }; | ||||||
|  |  | ||||||
|   ///////////////////////////////////////////////////////////////////////////////////////////// |   ///////////////////////////////////////////////////////////////////////////////////////////// | ||||||
| @@ -25,7 +27,7 @@ namespace Grid { | |||||||
|   ///////////////////////////////////////////////////////////////////////////////////////////// |   ///////////////////////////////////////////////////////////////////////////////////////////// | ||||||
|     template<class Field> class CheckerBoardedSparseMatrixBase : public SparseMatrixBase<Field> { |     template<class Field> class CheckerBoardedSparseMatrixBase : public SparseMatrixBase<Field> { | ||||||
|     public: |     public: | ||||||
|        |       GridBase *_cbgrid; | ||||||
|       // half checkerboard operaions |       // half checkerboard operaions | ||||||
|       virtual  void Meooe    (const Field &in, Field &out)=0; |       virtual  void Meooe    (const Field &in, Field &out)=0; | ||||||
|       virtual  void Mooee    (const Field &in, Field &out)=0; |       virtual  void Mooee    (const Field &in, Field &out)=0; | ||||||
| @@ -44,9 +46,7 @@ namespace Grid { | |||||||
| 	Meooe(out,tmp); | 	Meooe(out,tmp); | ||||||
|  |  | ||||||
| 	Mooee(in,out); | 	Mooee(in,out); | ||||||
| 	out=out-tmp; // axpy_norm | 	return axpy_norm(out,-1.0,tmp,out); | ||||||
| 	RealD n=norm2(out); |  | ||||||
| 	return n; |  | ||||||
|       } |       } | ||||||
|       virtual  RealD MpcDag   (const Field &in, Field &out){ |       virtual  RealD MpcDag   (const Field &in, Field &out){ | ||||||
| 	Field tmp(in._grid); | 	Field tmp(in._grid); | ||||||
| @@ -56,15 +56,15 @@ namespace Grid { | |||||||
| 	MeooeDag(out,tmp); | 	MeooeDag(out,tmp); | ||||||
|  |  | ||||||
| 	MooeeDag(in,out); | 	MooeeDag(in,out); | ||||||
| 	out=out-tmp; // axpy_norm | 	return axpy_norm(out,-1.0,tmp,out); | ||||||
| 	RealD n=norm2(out); |  | ||||||
| 	return n; |  | ||||||
|       } |       } | ||||||
|       virtual void MpcDagMpc(const Field &in, Field &out,RealD ni,RealD no) { |       virtual void MpcDagMpc(const Field &in, Field &out,RealD &ni,RealD &no) { | ||||||
| 	Field tmp(in._grid); | 	Field tmp(in._grid); | ||||||
| 	ni=Mpc(in,tmp); | 	ni=Mpc(in,tmp); | ||||||
| 	no=Mpc(tmp,out); | 	no=MpcDag(tmp,out); | ||||||
|  | 	//	std::cout<<"MpcDagMpc "<<ni<<" "<<no<<std::endl; | ||||||
|       } |       } | ||||||
|  |       CheckerBoardedSparseMatrixBase(GridBase *grid,GridBase *cbgrid) : SparseMatrixBase<Field>(grid), _cbgrid(cbgrid) {}; | ||||||
|     }; |     }; | ||||||
|  |  | ||||||
| } | } | ||||||
|   | |||||||
| @@ -9,17 +9,21 @@ namespace Grid { | |||||||
|     ///////////////////////////////////////////////////////////// |     ///////////////////////////////////////////////////////////// | ||||||
|  |  | ||||||
|   template<class Field>  |   template<class Field>  | ||||||
|     class ConjugateGradient :  public OperatorFunction<Field> { |     class ConjugateGradient : public HermitianOperatorFunction<Field> { | ||||||
| public:                                                 | public:                                                 | ||||||
|     RealD   Tolerance; |     RealD   Tolerance; | ||||||
|     Integer MaxIterations; |     Integer MaxIterations; | ||||||
|  |     int verbose; | ||||||
|     ConjugateGradient(RealD tol,Integer maxit) : Tolerance(tol), MaxIterations(maxit) {  |     ConjugateGradient(RealD tol,Integer maxit) : Tolerance(tol), MaxIterations(maxit) {  | ||||||
|  |       verbose=0; | ||||||
|     }; |     }; | ||||||
|  |  | ||||||
|     void operator() (LinearOperatorBase<Field> &Linop,const Field &src, Field &psi) {assert(0);}; |  | ||||||
|     void operator() (HermitianOperatorBase<Field> &Linop,const Field &src, Field &psi){ |     void operator() (HermitianOperatorBase<Field> &Linop,const Field &src, Field &psi){ | ||||||
|  |  | ||||||
|  |       psi.checkerboard = src.checkerboard; | ||||||
|  |       conformable(psi,src); | ||||||
|  |  | ||||||
|       RealD cp,c,a,d,b,ssq,qq,b_pred; |       RealD cp,c,a,d,b,ssq,qq,b_pred; | ||||||
|        |        | ||||||
|       Field   p(src); |       Field   p(src); | ||||||
| @@ -38,12 +42,14 @@ public: | |||||||
|       cp =a; |       cp =a; | ||||||
|       ssq=norm2(src); |       ssq=norm2(src); | ||||||
|  |  | ||||||
|  |       if ( verbose ) { | ||||||
| 	std::cout <<std::setprecision(4)<< "ConjugateGradient: guess "<<guess<<std::endl; | 	std::cout <<std::setprecision(4)<< "ConjugateGradient: guess "<<guess<<std::endl; | ||||||
| 	std::cout <<std::setprecision(4)<< "ConjugateGradient:   src "<<ssq  <<std::endl; | 	std::cout <<std::setprecision(4)<< "ConjugateGradient:   src "<<ssq  <<std::endl; | ||||||
| 	std::cout <<std::setprecision(4)<< "ConjugateGradient:    mp "<<d    <<std::endl; | 	std::cout <<std::setprecision(4)<< "ConjugateGradient:    mp "<<d    <<std::endl; | ||||||
| 	std::cout <<std::setprecision(4)<< "ConjugateGradient:   mmp "<<b    <<std::endl; | 	std::cout <<std::setprecision(4)<< "ConjugateGradient:   mmp "<<b    <<std::endl; | ||||||
|       std::cout <<std::setprecision(4)<< "ConjugateGradient:     r "<<cp   <<std::endl; | 	std::cout <<std::setprecision(4)<< "ConjugateGradient:  cp,r "<<cp   <<std::endl; | ||||||
| 	std::cout <<std::setprecision(4)<< "ConjugateGradient:     p "<<a    <<std::endl; | 	std::cout <<std::setprecision(4)<< "ConjugateGradient:     p "<<a    <<std::endl; | ||||||
|  |       } | ||||||
|  |  | ||||||
|       RealD rsq =  Tolerance* Tolerance*ssq; |       RealD rsq =  Tolerance* Tolerance*ssq; | ||||||
|        |        | ||||||
| @@ -61,13 +67,15 @@ public: | |||||||
| 	 | 	 | ||||||
| 	Linop.OpAndNorm(p,mmp,d,qq); | 	Linop.OpAndNorm(p,mmp,d,qq); | ||||||
|  |  | ||||||
| 	//	std::cout <<std::setprecision(4)<< "ConjugateGradient:  d,qq "<<d<< " "<<qq <<std::endl; | 	RealD    qqck = norm2(mmp); | ||||||
|  | 	ComplexD dck  = innerProduct(p,mmp); | ||||||
|  | 	//	if (verbose) std::cout <<std::setprecision(4)<< "ConjugateGradient:  d,qq "<<d<< " "<<qq <<" qqcheck "<< qqck<< " dck "<< dck<<std::endl; | ||||||
|        |        | ||||||
| 	a      = c/d; | 	a      = c/d; | ||||||
| 	b_pred = a*(a*qq-d)/c; | 	b_pred = a*(a*qq-d)/c; | ||||||
|  |  | ||||||
| 	//	std::cout <<std::setprecision(4)<< "ConjugateGradient:  a,bp "<<a<< " "<<b_pred <<std::endl; |  | ||||||
|  |  | ||||||
|  | 	//	if (verbose) std::cout <<std::setprecision(4)<< "ConjugateGradient:  a,bp "<<a<< " "<<b_pred <<std::endl; | ||||||
| 	cp = axpy_norm(r,-a,mmp,r); | 	cp = axpy_norm(r,-a,mmp,r); | ||||||
| 	b = cp/c; | 	b = cp/c; | ||||||
| 	//	std::cout <<std::setprecision(4)<< "ConjugateGradient:  cp,b "<<cp<< " "<<b <<std::endl; | 	//	std::cout <<std::setprecision(4)<< "ConjugateGradient:  cp,b "<<cp<< " "<<b <<std::endl; | ||||||
| @@ -76,7 +84,16 @@ public: | |||||||
| 	psi= a*p+psi; | 	psi= a*p+psi; | ||||||
| 	p  = p*b+r; | 	p  = p*b+r; | ||||||
| 	   | 	   | ||||||
| 	std::cout<<"ConjugateGradient: Iteration " <<k<<" residual "<<cp<< " target"<< rsq<<std::endl; | 	if (verbose) std::cout<<"ConjugateGradient: Iteration " <<k<<" residual "<<cp<< " target"<< rsq<<std::endl; | ||||||
|  |  | ||||||
|  | 	// Hack | ||||||
|  | 	if (0) { | ||||||
|  | 	  Field   tt(src); | ||||||
|  |       	  Linop.Op(psi,mmp); | ||||||
|  | 	  tt=mmp-src; | ||||||
|  | 	  RealD resnorm = norm2(tt); | ||||||
|  | 	  std::cout<<"ConjugateGradient: Iteration " <<k<<" true residual "<<resnorm << " computed " << cp <<std::endl; | ||||||
|  | 	} | ||||||
|  |  | ||||||
| 	// Stopping condition | 	// Stopping condition | ||||||
| 	if ( cp <= rsq ) {  | 	if ( cp <= rsq ) {  | ||||||
|   | |||||||
| @@ -1,6 +1,7 @@ | |||||||
| #ifndef GRID_SCHUR_RED_BLACK_H | #ifndef GRID_SCHUR_RED_BLACK_H | ||||||
| #define GRID_SCHUR_RED_BLACK_H | #define GRID_SCHUR_RED_BLACK_H | ||||||
|  |  | ||||||
|  |  | ||||||
|   /* |   /* | ||||||
|    * Red black Schur decomposition |    * Red black Schur decomposition | ||||||
|    * |    * | ||||||
| @@ -25,53 +26,50 @@ | |||||||
|    *     M psi = eta |    *     M psi = eta | ||||||
|    *********************** |    *********************** | ||||||
|    *Odd |    *Odd | ||||||
|    * i)   (D_oo)^{\dag} D_oo psi_o = (D_oo)^\dag L^{-1}  eta_o |    * i)   (D_oo)^{\dag} D_oo psi_o = (D_oo)^dag L^{-1}  eta_o | ||||||
|    *                        eta_o' = D_oo (eta_o - Moe Mee^{-1} eta_e) |    *                        eta_o' = (D_oo)^dag (eta_o - Moe Mee^{-1} eta_e) | ||||||
|    *Even |    *Even | ||||||
|    * ii)  Mee psi_e + Meo psi_o = src_e |    * ii)  Mee psi_e + Meo psi_o = src_e | ||||||
|    * |    * | ||||||
|    *   => sol_e = M_ee^-1 * ( src_e - Meo sol_o )... |    *   => sol_e = M_ee^-1 * ( src_e - Meo sol_o )... | ||||||
|    * |    * | ||||||
|    */ |    */ | ||||||
|  |  | ||||||
| namespace Grid { | namespace Grid { | ||||||
|  |  | ||||||
|   /////////////////////////////////////////////////////////////////////////////////////////////////////// |   /////////////////////////////////////////////////////////////////////////////////////////////////////// | ||||||
|   // Take a matrix and form a Red Black solver calling a Herm solver |   // Take a matrix and form a Red Black solver calling a Herm solver | ||||||
|   // Use of RB info prevents making SchurRedBlackSolve conform to standard interface |   // Use of RB info prevents making SchurRedBlackSolve conform to standard interface | ||||||
|   /////////////////////////////////////////////////////////////////////////////////////////////////////// |   /////////////////////////////////////////////////////////////////////////////////////////////////////// | ||||||
|   template<class Field> class SchurRedBlackSolve : public OperatorFunction<Field>{ |   template<class Field> class SchurRedBlackSolve { | ||||||
|   private: |   private: | ||||||
|     SparseMatrixBase<Field> & _Matrix; |     HermitianOperatorFunction<Field> & _HermitianRBSolver; | ||||||
|     OperatorFunction<Field> & _HermitianRBSolver; |  | ||||||
|     int CBfactorise; |     int CBfactorise; | ||||||
|   public: |   public: | ||||||
|  |  | ||||||
|     ///////////////////////////////////////////////////// |     ///////////////////////////////////////////////////// | ||||||
|     // Wrap the usual normal equations Schur trick |     // Wrap the usual normal equations Schur trick | ||||||
|     ///////////////////////////////////////////////////// |     ///////////////////////////////////////////////////// | ||||||
|   SchurRedBlackSolve(SparseMatrixBase<Field> &Matrix, OperatorFunction<Field> &HermitianRBSolver)  |   SchurRedBlackSolve(HermitianOperatorFunction<Field> &HermitianRBSolver)  : | ||||||
|     :  _Matrix(Matrix), _HermitianRBSolver(HermitianRBSolver) {  |      _HermitianRBSolver(HermitianRBSolver)  | ||||||
|  |     {  | ||||||
|       CBfactorise=0; |       CBfactorise=0; | ||||||
|     }; |     }; | ||||||
|  |  | ||||||
|     void operator() (const Field &in, Field &out){ |     template<class Matrix> | ||||||
|  |       void operator() (Matrix & _Matrix,const Field &in, Field &out){ | ||||||
|  |  | ||||||
|       // FIXME CGdiagonalMee not implemented virtual function |       // FIXME CGdiagonalMee not implemented virtual function | ||||||
|       // FIXME need to make eo grid from full grid. |  | ||||||
|       // FIXME use CBfactorise to control schur decomp |       // FIXME use CBfactorise to control schur decomp | ||||||
|       const int Even=0; |       GridBase *grid = _Matrix._cbgrid; | ||||||
|       const int Odd =1; |       GridBase *fgrid= _Matrix._grid; | ||||||
|   |   | ||||||
|       // Make a cartesianRedBlack from full Grid |       Field src_e(grid); | ||||||
|       GridRedBlackCartesian grid(in._grid); |       Field src_o(grid); | ||||||
|   |       Field sol_e(grid); | ||||||
|       Field src_e(&grid); |       Field sol_o(grid); | ||||||
|       Field src_o(&grid); |       Field   tmp(grid); | ||||||
|       Field sol_e(&grid); |       Field  Mtmp(grid); | ||||||
|       Field sol_o(&grid); |       Field resid(fgrid); | ||||||
|       Field   tmp(&grid); |  | ||||||
|       Field  Mtmp(&grid); |  | ||||||
|  |  | ||||||
|       pickCheckerboard(Even,src_e,in); |       pickCheckerboard(Even,src_e,in); | ||||||
|       pickCheckerboard(Odd ,src_o,in); |       pickCheckerboard(Odd ,src_o,in); | ||||||
| @@ -79,26 +77,35 @@ namespace Grid { | |||||||
|       ///////////////////////////////////////////////////// |       ///////////////////////////////////////////////////// | ||||||
|       // src_o = Mdag * (source_o - Moe MeeInv source_e) |       // src_o = Mdag * (source_o - Moe MeeInv source_e) | ||||||
|       ///////////////////////////////////////////////////// |       ///////////////////////////////////////////////////// | ||||||
|       _Matrix.MooeeInv(src_e,tmp);     //    MooeeInv(source[Even],tmp,DaggerNo,Even); |       _Matrix.MooeeInv(src_e,tmp);     assert(  tmp.checkerboard ==Even); | ||||||
|       _Matrix.Meooe   (tmp,Mtmp);      //    Meo     (tmp,src,Odd,DaggerNo); |       _Matrix.Meooe   (tmp,Mtmp);      assert( Mtmp.checkerboard ==Odd);      | ||||||
|       tmp=src_o-Mtmp;                  //    axpy    (tmp,src,source[Odd],-1.0); |       tmp=src_o-Mtmp;                  assert(  tmp.checkerboard ==Odd);      | ||||||
|       _Matrix.MpcDag(tmp,src_o);       //    Mprec(tmp,src,Mtmp,DaggerYes);   |       _Matrix.MpcDag(tmp,src_o);       assert(src_o.checkerboard ==Odd);        | ||||||
|  |  | ||||||
|       ////////////////////////////////////////////////////////////// |       ////////////////////////////////////////////////////////////// | ||||||
|       // Call the red-black solver |       // Call the red-black solver | ||||||
|       ////////////////////////////////////////////////////////////// |       ////////////////////////////////////////////////////////////// | ||||||
|       _HermitianRBSolver(src_o,sol_o); //  CGNE_prec_MdagM(solution[Odd],src); |       HermitianCheckerBoardedOperator<Matrix,Field> _HermOpEO(_Matrix); | ||||||
|  |       std::cout << "SchurRedBlack solver calling the MpcDagMp solver" <<std::endl; | ||||||
|  |       _HermitianRBSolver(_HermOpEO,src_o,sol_o);  assert(sol_o.checkerboard==Odd); | ||||||
|  |  | ||||||
|       /////////////////////////////////////////////////// |       /////////////////////////////////////////////////// | ||||||
|       // sol_e = M_ee^-1 * ( src_e - Meo sol_o )... |       // sol_e = M_ee^-1 * ( src_e - Meo sol_o )... | ||||||
|       /////////////////////////////////////////////////// |       /////////////////////////////////////////////////// | ||||||
|       _Matrix.Meooe(sol_o,tmp);        // Meo(solution[Odd],tmp,Even,DaggerNo); |       _Matrix.Meooe(sol_o,tmp);        assert(  tmp.checkerboard   ==Even); | ||||||
|       src_e = src_e-tmp;               // axpy(src,tmp,source[Even],-1.0); |       src_e = src_e-tmp;               assert(  src_e.checkerboard ==Even); | ||||||
|       _Matrix.MooeeInv(src_e,sol_e);   // MooeeInv(src,solution[Even],DaggerNo,Even); |       _Matrix.MooeeInv(src_e,sol_e);   assert(  sol_e.checkerboard ==Even); | ||||||
|       |       | ||||||
|       setCheckerboard(out,sol_e); |       setCheckerboard(out,sol_e); assert(  sol_e.checkerboard ==Even); | ||||||
|       setCheckerboard(out,sol_o); |       setCheckerboard(out,sol_o); assert(  sol_o.checkerboard ==Odd ); | ||||||
|  |  | ||||||
|  |       // Verify the unprec residual | ||||||
|  |       _Matrix.M(out,resid);  | ||||||
|  |       resid = resid-in; | ||||||
|  |       RealD ns = norm2(in); | ||||||
|  |       RealD nr = norm2(resid); | ||||||
|  |  | ||||||
|  |       std::cout << "SchurRedBlack solver true unprec resid "<< sqrt(nr/ns) <<" nr "<< nr <<" ns "<<ns << std::endl; | ||||||
|     }      |     }      | ||||||
|   }; |   }; | ||||||
|  |  | ||||||
|   | |||||||
		Reference in New Issue
	
	Block a user