1
0
mirror of https://github.com/paboyle/Grid.git synced 2024-11-10 07:55:35 +00:00

Updates now schur red black solver working

This commit is contained in:
Peter Boyle 2015-05-25 13:43:58 +01:00
parent ac99832d21
commit 624c0ac3ef
3 changed files with 82 additions and 58 deletions

View File

@ -10,6 +10,7 @@ namespace Grid {
///////////////////////////////////////////////////////////////////////////////////////////// /////////////////////////////////////////////////////////////////////////////////////////////
template<class Field> class SparseMatrixBase { template<class Field> class SparseMatrixBase {
public: public:
GridBase *_grid;
// Full checkerboar operations // Full checkerboar operations
virtual RealD M (const Field &in, Field &out)=0; virtual RealD M (const Field &in, Field &out)=0;
virtual RealD Mdag (const Field &in, Field &out)=0; virtual RealD Mdag (const Field &in, Field &out)=0;
@ -18,6 +19,7 @@ namespace Grid {
ni=M(in,tmp); ni=M(in,tmp);
no=Mdag(tmp,out); no=Mdag(tmp,out);
} }
SparseMatrixBase(GridBase *grid) : _grid(grid) {};
}; };
///////////////////////////////////////////////////////////////////////////////////////////// /////////////////////////////////////////////////////////////////////////////////////////////
@ -25,7 +27,7 @@ namespace Grid {
///////////////////////////////////////////////////////////////////////////////////////////// /////////////////////////////////////////////////////////////////////////////////////////////
template<class Field> class CheckerBoardedSparseMatrixBase : public SparseMatrixBase<Field> { template<class Field> class CheckerBoardedSparseMatrixBase : public SparseMatrixBase<Field> {
public: public:
GridBase *_cbgrid;
// half checkerboard operaions // half checkerboard operaions
virtual void Meooe (const Field &in, Field &out)=0; virtual void Meooe (const Field &in, Field &out)=0;
virtual void Mooee (const Field &in, Field &out)=0; virtual void Mooee (const Field &in, Field &out)=0;
@ -44,9 +46,7 @@ namespace Grid {
Meooe(out,tmp); Meooe(out,tmp);
Mooee(in,out); Mooee(in,out);
out=out-tmp; // axpy_norm return axpy_norm(out,-1.0,tmp,out);
RealD n=norm2(out);
return n;
} }
virtual RealD MpcDag (const Field &in, Field &out){ virtual RealD MpcDag (const Field &in, Field &out){
Field tmp(in._grid); Field tmp(in._grid);
@ -56,15 +56,15 @@ namespace Grid {
MeooeDag(out,tmp); MeooeDag(out,tmp);
MooeeDag(in,out); MooeeDag(in,out);
out=out-tmp; // axpy_norm return axpy_norm(out,-1.0,tmp,out);
RealD n=norm2(out);
return n;
} }
virtual void MpcDagMpc(const Field &in, Field &out,RealD ni,RealD no) { virtual void MpcDagMpc(const Field &in, Field &out,RealD &ni,RealD &no) {
Field tmp(in._grid); Field tmp(in._grid);
ni=Mpc(in,tmp); ni=Mpc(in,tmp);
no=Mpc(tmp,out); no=MpcDag(tmp,out);
// std::cout<<"MpcDagMpc "<<ni<<" "<<no<<std::endl;
} }
CheckerBoardedSparseMatrixBase(GridBase *grid,GridBase *cbgrid) : SparseMatrixBase<Field>(grid), _cbgrid(cbgrid) {};
}; };
} }

View File

@ -9,17 +9,21 @@ namespace Grid {
///////////////////////////////////////////////////////////// /////////////////////////////////////////////////////////////
template<class Field> template<class Field>
class ConjugateGradient : public OperatorFunction<Field> { class ConjugateGradient : public HermitianOperatorFunction<Field> {
public: public:
RealD Tolerance; RealD Tolerance;
Integer MaxIterations; Integer MaxIterations;
int verbose;
ConjugateGradient(RealD tol,Integer maxit) : Tolerance(tol), MaxIterations(maxit) { ConjugateGradient(RealD tol,Integer maxit) : Tolerance(tol), MaxIterations(maxit) {
verbose=0;
}; };
void operator() (LinearOperatorBase<Field> &Linop,const Field &src, Field &psi) {assert(0);};
void operator() (HermitianOperatorBase<Field> &Linop,const Field &src, Field &psi){ void operator() (HermitianOperatorBase<Field> &Linop,const Field &src, Field &psi){
psi.checkerboard = src.checkerboard;
conformable(psi,src);
RealD cp,c,a,d,b,ssq,qq,b_pred; RealD cp,c,a,d,b,ssq,qq,b_pred;
Field p(src); Field p(src);
@ -37,14 +41,16 @@ public:
a =norm2(p); a =norm2(p);
cp =a; cp =a;
ssq=norm2(src); ssq=norm2(src);
std::cout <<std::setprecision(4)<< "ConjugateGradient: guess "<<guess<<std::endl; if ( verbose ) {
std::cout <<std::setprecision(4)<< "ConjugateGradient: src "<<ssq <<std::endl; std::cout <<std::setprecision(4)<< "ConjugateGradient: guess "<<guess<<std::endl;
std::cout <<std::setprecision(4)<< "ConjugateGradient: mp "<<d <<std::endl; std::cout <<std::setprecision(4)<< "ConjugateGradient: src "<<ssq <<std::endl;
std::cout <<std::setprecision(4)<< "ConjugateGradient: mmp "<<b <<std::endl; std::cout <<std::setprecision(4)<< "ConjugateGradient: mp "<<d <<std::endl;
std::cout <<std::setprecision(4)<< "ConjugateGradient: r "<<cp <<std::endl; std::cout <<std::setprecision(4)<< "ConjugateGradient: mmp "<<b <<std::endl;
std::cout <<std::setprecision(4)<< "ConjugateGradient: p "<<a <<std::endl; std::cout <<std::setprecision(4)<< "ConjugateGradient: cp,r "<<cp <<std::endl;
std::cout <<std::setprecision(4)<< "ConjugateGradient: p "<<a <<std::endl;
}
RealD rsq = Tolerance* Tolerance*ssq; RealD rsq = Tolerance* Tolerance*ssq;
//Check if guess is really REALLY good :) //Check if guess is really REALLY good :)
@ -60,14 +66,16 @@ public:
c=cp; c=cp;
Linop.OpAndNorm(p,mmp,d,qq); Linop.OpAndNorm(p,mmp,d,qq);
// std::cout <<std::setprecision(4)<< "ConjugateGradient: d,qq "<<d<< " "<<qq <<std::endl; RealD qqck = norm2(mmp);
ComplexD dck = innerProduct(p,mmp);
// if (verbose) std::cout <<std::setprecision(4)<< "ConjugateGradient: d,qq "<<d<< " "<<qq <<" qqcheck "<< qqck<< " dck "<< dck<<std::endl;
a = c/d; a = c/d;
b_pred = a*(a*qq-d)/c; b_pred = a*(a*qq-d)/c;
// std::cout <<std::setprecision(4)<< "ConjugateGradient: a,bp "<<a<< " "<<b_pred <<std::endl;
// if (verbose) std::cout <<std::setprecision(4)<< "ConjugateGradient: a,bp "<<a<< " "<<b_pred <<std::endl;
cp = axpy_norm(r,-a,mmp,r); cp = axpy_norm(r,-a,mmp,r);
b = cp/c; b = cp/c;
// std::cout <<std::setprecision(4)<< "ConjugateGradient: cp,b "<<cp<< " "<<b <<std::endl; // std::cout <<std::setprecision(4)<< "ConjugateGradient: cp,b "<<cp<< " "<<b <<std::endl;
@ -76,7 +84,16 @@ public:
psi= a*p+psi; psi= a*p+psi;
p = p*b+r; p = p*b+r;
std::cout<<"ConjugateGradient: Iteration " <<k<<" residual "<<cp<< " target"<< rsq<<std::endl; if (verbose) std::cout<<"ConjugateGradient: Iteration " <<k<<" residual "<<cp<< " target"<< rsq<<std::endl;
// Hack
if (0) {
Field tt(src);
Linop.Op(psi,mmp);
tt=mmp-src;
RealD resnorm = norm2(tt);
std::cout<<"ConjugateGradient: Iteration " <<k<<" true residual "<<resnorm << " computed " << cp <<std::endl;
}
// Stopping condition // Stopping condition
if ( cp <= rsq ) { if ( cp <= rsq ) {

View File

@ -1,6 +1,7 @@
#ifndef GRID_SCHUR_RED_BLACK_H #ifndef GRID_SCHUR_RED_BLACK_H
#define GRID_SCHUR_RED_BLACK_H #define GRID_SCHUR_RED_BLACK_H
/* /*
* Red black Schur decomposition * Red black Schur decomposition
* *
@ -25,80 +26,86 @@
* M psi = eta * M psi = eta
*********************** ***********************
*Odd *Odd
* i) (D_oo)^{\dag} D_oo psi_o = (D_oo)^\dag L^{-1} eta_o * i) (D_oo)^{\dag} D_oo psi_o = (D_oo)^dag L^{-1} eta_o
* eta_o' = D_oo (eta_o - Moe Mee^{-1} eta_e) * eta_o' = (D_oo)^dag (eta_o - Moe Mee^{-1} eta_e)
*Even *Even
* ii) Mee psi_e + Meo psi_o = src_e * ii) Mee psi_e + Meo psi_o = src_e
* *
* => sol_e = M_ee^-1 * ( src_e - Meo sol_o )... * => sol_e = M_ee^-1 * ( src_e - Meo sol_o )...
* *
*/ */
namespace Grid { namespace Grid {
/////////////////////////////////////////////////////////////////////////////////////////////////////// ///////////////////////////////////////////////////////////////////////////////////////////////////////
// Take a matrix and form a Red Black solver calling a Herm solver // Take a matrix and form a Red Black solver calling a Herm solver
// Use of RB info prevents making SchurRedBlackSolve conform to standard interface // Use of RB info prevents making SchurRedBlackSolve conform to standard interface
/////////////////////////////////////////////////////////////////////////////////////////////////////// ///////////////////////////////////////////////////////////////////////////////////////////////////////
template<class Field> class SchurRedBlackSolve : public OperatorFunction<Field>{ template<class Field> class SchurRedBlackSolve {
private: private:
SparseMatrixBase<Field> & _Matrix; HermitianOperatorFunction<Field> & _HermitianRBSolver;
OperatorFunction<Field> & _HermitianRBSolver;
int CBfactorise; int CBfactorise;
public: public:
///////////////////////////////////////////////////// /////////////////////////////////////////////////////
// Wrap the usual normal equations Schur trick // Wrap the usual normal equations Schur trick
///////////////////////////////////////////////////// /////////////////////////////////////////////////////
SchurRedBlackSolve(SparseMatrixBase<Field> &Matrix, OperatorFunction<Field> &HermitianRBSolver) SchurRedBlackSolve(HermitianOperatorFunction<Field> &HermitianRBSolver) :
: _Matrix(Matrix), _HermitianRBSolver(HermitianRBSolver) { _HermitianRBSolver(HermitianRBSolver)
{
CBfactorise=0; CBfactorise=0;
}; };
void operator() (const Field &in, Field &out){ template<class Matrix>
void operator() (Matrix & _Matrix,const Field &in, Field &out){
// FIXME CGdiagonalMee not implemented virtual function // FIXME CGdiagonalMee not implemented virtual function
// FIXME need to make eo grid from full grid.
// FIXME use CBfactorise to control schur decomp // FIXME use CBfactorise to control schur decomp
const int Even=0; GridBase *grid = _Matrix._cbgrid;
const int Odd =1; GridBase *fgrid= _Matrix._grid;
// Make a cartesianRedBlack from full Grid
GridRedBlackCartesian grid(in._grid);
Field src_e(&grid); Field src_e(grid);
Field src_o(&grid); Field src_o(grid);
Field sol_e(&grid); Field sol_e(grid);
Field sol_o(&grid); Field sol_o(grid);
Field tmp(&grid); Field tmp(grid);
Field Mtmp(&grid); Field Mtmp(grid);
Field resid(fgrid);
pickCheckerboard(Even,src_e,in); pickCheckerboard(Even,src_e,in);
pickCheckerboard(Odd ,src_o,in); pickCheckerboard(Odd ,src_o,in);
///////////////////////////////////////////////////// /////////////////////////////////////////////////////
// src_o = Mdag * (source_o - Moe MeeInv source_e) // src_o = Mdag * (source_o - Moe MeeInv source_e)
///////////////////////////////////////////////////// /////////////////////////////////////////////////////
_Matrix.MooeeInv(src_e,tmp); // MooeeInv(source[Even],tmp,DaggerNo,Even); _Matrix.MooeeInv(src_e,tmp); assert( tmp.checkerboard ==Even);
_Matrix.Meooe (tmp,Mtmp); // Meo (tmp,src,Odd,DaggerNo); _Matrix.Meooe (tmp,Mtmp); assert( Mtmp.checkerboard ==Odd);
tmp=src_o-Mtmp; // axpy (tmp,src,source[Odd],-1.0); tmp=src_o-Mtmp; assert( tmp.checkerboard ==Odd);
_Matrix.MpcDag(tmp,src_o); // Mprec(tmp,src,Mtmp,DaggerYes); _Matrix.MpcDag(tmp,src_o); assert(src_o.checkerboard ==Odd);
////////////////////////////////////////////////////////////// //////////////////////////////////////////////////////////////
// Call the red-black solver // Call the red-black solver
////////////////////////////////////////////////////////////// //////////////////////////////////////////////////////////////
_HermitianRBSolver(src_o,sol_o); // CGNE_prec_MdagM(solution[Odd],src); HermitianCheckerBoardedOperator<Matrix,Field> _HermOpEO(_Matrix);
std::cout << "SchurRedBlack solver calling the MpcDagMp solver" <<std::endl;
_HermitianRBSolver(_HermOpEO,src_o,sol_o); assert(sol_o.checkerboard==Odd);
/////////////////////////////////////////////////// ///////////////////////////////////////////////////
// sol_e = M_ee^-1 * ( src_e - Meo sol_o )... // sol_e = M_ee^-1 * ( src_e - Meo sol_o )...
/////////////////////////////////////////////////// ///////////////////////////////////////////////////
_Matrix.Meooe(sol_o,tmp); // Meo(solution[Odd],tmp,Even,DaggerNo); _Matrix.Meooe(sol_o,tmp); assert( tmp.checkerboard ==Even);
src_e = src_e-tmp; // axpy(src,tmp,source[Even],-1.0); src_e = src_e-tmp; assert( src_e.checkerboard ==Even);
_Matrix.MooeeInv(src_e,sol_e); // MooeeInv(src,solution[Even],DaggerNo,Even); _Matrix.MooeeInv(src_e,sol_e); assert( sol_e.checkerboard ==Even);
setCheckerboard(out,sol_e); setCheckerboard(out,sol_e); assert( sol_e.checkerboard ==Even);
setCheckerboard(out,sol_o); setCheckerboard(out,sol_o); assert( sol_o.checkerboard ==Odd );
// Verify the unprec residual
_Matrix.M(out,resid);
resid = resid-in;
RealD ns = norm2(in);
RealD nr = norm2(resid);
std::cout << "SchurRedBlack solver true unprec resid "<< sqrt(nr/ns) <<" nr "<< nr <<" ns "<<ns << std::endl;
} }
}; };