mirror of
https://github.com/paboyle/Grid.git
synced 2025-10-24 17:54:47 +01:00
Updates now schur red black solver working
This commit is contained in:
@@ -10,6 +10,7 @@ namespace Grid {
|
|||||||
/////////////////////////////////////////////////////////////////////////////////////////////
|
/////////////////////////////////////////////////////////////////////////////////////////////
|
||||||
template<class Field> class SparseMatrixBase {
|
template<class Field> class SparseMatrixBase {
|
||||||
public:
|
public:
|
||||||
|
GridBase *_grid;
|
||||||
// Full checkerboar operations
|
// Full checkerboar operations
|
||||||
virtual RealD M (const Field &in, Field &out)=0;
|
virtual RealD M (const Field &in, Field &out)=0;
|
||||||
virtual RealD Mdag (const Field &in, Field &out)=0;
|
virtual RealD Mdag (const Field &in, Field &out)=0;
|
||||||
@@ -18,6 +19,7 @@ namespace Grid {
|
|||||||
ni=M(in,tmp);
|
ni=M(in,tmp);
|
||||||
no=Mdag(tmp,out);
|
no=Mdag(tmp,out);
|
||||||
}
|
}
|
||||||
|
SparseMatrixBase(GridBase *grid) : _grid(grid) {};
|
||||||
};
|
};
|
||||||
|
|
||||||
/////////////////////////////////////////////////////////////////////////////////////////////
|
/////////////////////////////////////////////////////////////////////////////////////////////
|
||||||
@@ -25,7 +27,7 @@ namespace Grid {
|
|||||||
/////////////////////////////////////////////////////////////////////////////////////////////
|
/////////////////////////////////////////////////////////////////////////////////////////////
|
||||||
template<class Field> class CheckerBoardedSparseMatrixBase : public SparseMatrixBase<Field> {
|
template<class Field> class CheckerBoardedSparseMatrixBase : public SparseMatrixBase<Field> {
|
||||||
public:
|
public:
|
||||||
|
GridBase *_cbgrid;
|
||||||
// half checkerboard operaions
|
// half checkerboard operaions
|
||||||
virtual void Meooe (const Field &in, Field &out)=0;
|
virtual void Meooe (const Field &in, Field &out)=0;
|
||||||
virtual void Mooee (const Field &in, Field &out)=0;
|
virtual void Mooee (const Field &in, Field &out)=0;
|
||||||
@@ -44,9 +46,7 @@ namespace Grid {
|
|||||||
Meooe(out,tmp);
|
Meooe(out,tmp);
|
||||||
|
|
||||||
Mooee(in,out);
|
Mooee(in,out);
|
||||||
out=out-tmp; // axpy_norm
|
return axpy_norm(out,-1.0,tmp,out);
|
||||||
RealD n=norm2(out);
|
|
||||||
return n;
|
|
||||||
}
|
}
|
||||||
virtual RealD MpcDag (const Field &in, Field &out){
|
virtual RealD MpcDag (const Field &in, Field &out){
|
||||||
Field tmp(in._grid);
|
Field tmp(in._grid);
|
||||||
@@ -56,15 +56,15 @@ namespace Grid {
|
|||||||
MeooeDag(out,tmp);
|
MeooeDag(out,tmp);
|
||||||
|
|
||||||
MooeeDag(in,out);
|
MooeeDag(in,out);
|
||||||
out=out-tmp; // axpy_norm
|
return axpy_norm(out,-1.0,tmp,out);
|
||||||
RealD n=norm2(out);
|
|
||||||
return n;
|
|
||||||
}
|
}
|
||||||
virtual void MpcDagMpc(const Field &in, Field &out,RealD ni,RealD no) {
|
virtual void MpcDagMpc(const Field &in, Field &out,RealD &ni,RealD &no) {
|
||||||
Field tmp(in._grid);
|
Field tmp(in._grid);
|
||||||
ni=Mpc(in,tmp);
|
ni=Mpc(in,tmp);
|
||||||
no=Mpc(tmp,out);
|
no=MpcDag(tmp,out);
|
||||||
|
// std::cout<<"MpcDagMpc "<<ni<<" "<<no<<std::endl;
|
||||||
}
|
}
|
||||||
|
CheckerBoardedSparseMatrixBase(GridBase *grid,GridBase *cbgrid) : SparseMatrixBase<Field>(grid), _cbgrid(cbgrid) {};
|
||||||
};
|
};
|
||||||
|
|
||||||
}
|
}
|
||||||
|
@@ -9,17 +9,21 @@ namespace Grid {
|
|||||||
/////////////////////////////////////////////////////////////
|
/////////////////////////////////////////////////////////////
|
||||||
|
|
||||||
template<class Field>
|
template<class Field>
|
||||||
class ConjugateGradient : public OperatorFunction<Field> {
|
class ConjugateGradient : public HermitianOperatorFunction<Field> {
|
||||||
public:
|
public:
|
||||||
RealD Tolerance;
|
RealD Tolerance;
|
||||||
Integer MaxIterations;
|
Integer MaxIterations;
|
||||||
|
int verbose;
|
||||||
ConjugateGradient(RealD tol,Integer maxit) : Tolerance(tol), MaxIterations(maxit) {
|
ConjugateGradient(RealD tol,Integer maxit) : Tolerance(tol), MaxIterations(maxit) {
|
||||||
|
verbose=0;
|
||||||
};
|
};
|
||||||
|
|
||||||
void operator() (LinearOperatorBase<Field> &Linop,const Field &src, Field &psi) {assert(0);};
|
|
||||||
void operator() (HermitianOperatorBase<Field> &Linop,const Field &src, Field &psi){
|
void operator() (HermitianOperatorBase<Field> &Linop,const Field &src, Field &psi){
|
||||||
|
|
||||||
|
psi.checkerboard = src.checkerboard;
|
||||||
|
conformable(psi,src);
|
||||||
|
|
||||||
RealD cp,c,a,d,b,ssq,qq,b_pred;
|
RealD cp,c,a,d,b,ssq,qq,b_pred;
|
||||||
|
|
||||||
Field p(src);
|
Field p(src);
|
||||||
@@ -37,14 +41,16 @@ public:
|
|||||||
a =norm2(p);
|
a =norm2(p);
|
||||||
cp =a;
|
cp =a;
|
||||||
ssq=norm2(src);
|
ssq=norm2(src);
|
||||||
|
|
||||||
std::cout <<std::setprecision(4)<< "ConjugateGradient: guess "<<guess<<std::endl;
|
if ( verbose ) {
|
||||||
std::cout <<std::setprecision(4)<< "ConjugateGradient: src "<<ssq <<std::endl;
|
std::cout <<std::setprecision(4)<< "ConjugateGradient: guess "<<guess<<std::endl;
|
||||||
std::cout <<std::setprecision(4)<< "ConjugateGradient: mp "<<d <<std::endl;
|
std::cout <<std::setprecision(4)<< "ConjugateGradient: src "<<ssq <<std::endl;
|
||||||
std::cout <<std::setprecision(4)<< "ConjugateGradient: mmp "<<b <<std::endl;
|
std::cout <<std::setprecision(4)<< "ConjugateGradient: mp "<<d <<std::endl;
|
||||||
std::cout <<std::setprecision(4)<< "ConjugateGradient: r "<<cp <<std::endl;
|
std::cout <<std::setprecision(4)<< "ConjugateGradient: mmp "<<b <<std::endl;
|
||||||
std::cout <<std::setprecision(4)<< "ConjugateGradient: p "<<a <<std::endl;
|
std::cout <<std::setprecision(4)<< "ConjugateGradient: cp,r "<<cp <<std::endl;
|
||||||
|
std::cout <<std::setprecision(4)<< "ConjugateGradient: p "<<a <<std::endl;
|
||||||
|
}
|
||||||
|
|
||||||
RealD rsq = Tolerance* Tolerance*ssq;
|
RealD rsq = Tolerance* Tolerance*ssq;
|
||||||
|
|
||||||
//Check if guess is really REALLY good :)
|
//Check if guess is really REALLY good :)
|
||||||
@@ -60,14 +66,16 @@ public:
|
|||||||
c=cp;
|
c=cp;
|
||||||
|
|
||||||
Linop.OpAndNorm(p,mmp,d,qq);
|
Linop.OpAndNorm(p,mmp,d,qq);
|
||||||
|
|
||||||
// std::cout <<std::setprecision(4)<< "ConjugateGradient: d,qq "<<d<< " "<<qq <<std::endl;
|
RealD qqck = norm2(mmp);
|
||||||
|
ComplexD dck = innerProduct(p,mmp);
|
||||||
|
// if (verbose) std::cout <<std::setprecision(4)<< "ConjugateGradient: d,qq "<<d<< " "<<qq <<" qqcheck "<< qqck<< " dck "<< dck<<std::endl;
|
||||||
|
|
||||||
a = c/d;
|
a = c/d;
|
||||||
b_pred = a*(a*qq-d)/c;
|
b_pred = a*(a*qq-d)/c;
|
||||||
|
|
||||||
// std::cout <<std::setprecision(4)<< "ConjugateGradient: a,bp "<<a<< " "<<b_pred <<std::endl;
|
|
||||||
|
|
||||||
|
|
||||||
|
// if (verbose) std::cout <<std::setprecision(4)<< "ConjugateGradient: a,bp "<<a<< " "<<b_pred <<std::endl;
|
||||||
cp = axpy_norm(r,-a,mmp,r);
|
cp = axpy_norm(r,-a,mmp,r);
|
||||||
b = cp/c;
|
b = cp/c;
|
||||||
// std::cout <<std::setprecision(4)<< "ConjugateGradient: cp,b "<<cp<< " "<<b <<std::endl;
|
// std::cout <<std::setprecision(4)<< "ConjugateGradient: cp,b "<<cp<< " "<<b <<std::endl;
|
||||||
@@ -76,7 +84,16 @@ public:
|
|||||||
psi= a*p+psi;
|
psi= a*p+psi;
|
||||||
p = p*b+r;
|
p = p*b+r;
|
||||||
|
|
||||||
std::cout<<"ConjugateGradient: Iteration " <<k<<" residual "<<cp<< " target"<< rsq<<std::endl;
|
if (verbose) std::cout<<"ConjugateGradient: Iteration " <<k<<" residual "<<cp<< " target"<< rsq<<std::endl;
|
||||||
|
|
||||||
|
// Hack
|
||||||
|
if (0) {
|
||||||
|
Field tt(src);
|
||||||
|
Linop.Op(psi,mmp);
|
||||||
|
tt=mmp-src;
|
||||||
|
RealD resnorm = norm2(tt);
|
||||||
|
std::cout<<"ConjugateGradient: Iteration " <<k<<" true residual "<<resnorm << " computed " << cp <<std::endl;
|
||||||
|
}
|
||||||
|
|
||||||
// Stopping condition
|
// Stopping condition
|
||||||
if ( cp <= rsq ) {
|
if ( cp <= rsq ) {
|
||||||
|
@@ -1,6 +1,7 @@
|
|||||||
#ifndef GRID_SCHUR_RED_BLACK_H
|
#ifndef GRID_SCHUR_RED_BLACK_H
|
||||||
#define GRID_SCHUR_RED_BLACK_H
|
#define GRID_SCHUR_RED_BLACK_H
|
||||||
|
|
||||||
|
|
||||||
/*
|
/*
|
||||||
* Red black Schur decomposition
|
* Red black Schur decomposition
|
||||||
*
|
*
|
||||||
@@ -25,80 +26,86 @@
|
|||||||
* M psi = eta
|
* M psi = eta
|
||||||
***********************
|
***********************
|
||||||
*Odd
|
*Odd
|
||||||
* i) (D_oo)^{\dag} D_oo psi_o = (D_oo)^\dag L^{-1} eta_o
|
* i) (D_oo)^{\dag} D_oo psi_o = (D_oo)^dag L^{-1} eta_o
|
||||||
* eta_o' = D_oo (eta_o - Moe Mee^{-1} eta_e)
|
* eta_o' = (D_oo)^dag (eta_o - Moe Mee^{-1} eta_e)
|
||||||
*Even
|
*Even
|
||||||
* ii) Mee psi_e + Meo psi_o = src_e
|
* ii) Mee psi_e + Meo psi_o = src_e
|
||||||
*
|
*
|
||||||
* => sol_e = M_ee^-1 * ( src_e - Meo sol_o )...
|
* => sol_e = M_ee^-1 * ( src_e - Meo sol_o )...
|
||||||
*
|
*
|
||||||
*/
|
*/
|
||||||
|
|
||||||
namespace Grid {
|
namespace Grid {
|
||||||
|
|
||||||
///////////////////////////////////////////////////////////////////////////////////////////////////////
|
///////////////////////////////////////////////////////////////////////////////////////////////////////
|
||||||
// Take a matrix and form a Red Black solver calling a Herm solver
|
// Take a matrix and form a Red Black solver calling a Herm solver
|
||||||
// Use of RB info prevents making SchurRedBlackSolve conform to standard interface
|
// Use of RB info prevents making SchurRedBlackSolve conform to standard interface
|
||||||
///////////////////////////////////////////////////////////////////////////////////////////////////////
|
///////////////////////////////////////////////////////////////////////////////////////////////////////
|
||||||
template<class Field> class SchurRedBlackSolve : public OperatorFunction<Field>{
|
template<class Field> class SchurRedBlackSolve {
|
||||||
private:
|
private:
|
||||||
SparseMatrixBase<Field> & _Matrix;
|
HermitianOperatorFunction<Field> & _HermitianRBSolver;
|
||||||
OperatorFunction<Field> & _HermitianRBSolver;
|
|
||||||
int CBfactorise;
|
int CBfactorise;
|
||||||
public:
|
public:
|
||||||
|
|
||||||
/////////////////////////////////////////////////////
|
/////////////////////////////////////////////////////
|
||||||
// Wrap the usual normal equations Schur trick
|
// Wrap the usual normal equations Schur trick
|
||||||
/////////////////////////////////////////////////////
|
/////////////////////////////////////////////////////
|
||||||
SchurRedBlackSolve(SparseMatrixBase<Field> &Matrix, OperatorFunction<Field> &HermitianRBSolver)
|
SchurRedBlackSolve(HermitianOperatorFunction<Field> &HermitianRBSolver) :
|
||||||
: _Matrix(Matrix), _HermitianRBSolver(HermitianRBSolver) {
|
_HermitianRBSolver(HermitianRBSolver)
|
||||||
|
{
|
||||||
CBfactorise=0;
|
CBfactorise=0;
|
||||||
};
|
};
|
||||||
|
|
||||||
void operator() (const Field &in, Field &out){
|
template<class Matrix>
|
||||||
|
void operator() (Matrix & _Matrix,const Field &in, Field &out){
|
||||||
|
|
||||||
// FIXME CGdiagonalMee not implemented virtual function
|
// FIXME CGdiagonalMee not implemented virtual function
|
||||||
// FIXME need to make eo grid from full grid.
|
|
||||||
// FIXME use CBfactorise to control schur decomp
|
// FIXME use CBfactorise to control schur decomp
|
||||||
const int Even=0;
|
GridBase *grid = _Matrix._cbgrid;
|
||||||
const int Odd =1;
|
GridBase *fgrid= _Matrix._grid;
|
||||||
|
|
||||||
// Make a cartesianRedBlack from full Grid
|
|
||||||
GridRedBlackCartesian grid(in._grid);
|
|
||||||
|
|
||||||
Field src_e(&grid);
|
Field src_e(grid);
|
||||||
Field src_o(&grid);
|
Field src_o(grid);
|
||||||
Field sol_e(&grid);
|
Field sol_e(grid);
|
||||||
Field sol_o(&grid);
|
Field sol_o(grid);
|
||||||
Field tmp(&grid);
|
Field tmp(grid);
|
||||||
Field Mtmp(&grid);
|
Field Mtmp(grid);
|
||||||
|
Field resid(fgrid);
|
||||||
|
|
||||||
pickCheckerboard(Even,src_e,in);
|
pickCheckerboard(Even,src_e,in);
|
||||||
pickCheckerboard(Odd ,src_o,in);
|
pickCheckerboard(Odd ,src_o,in);
|
||||||
|
|
||||||
/////////////////////////////////////////////////////
|
/////////////////////////////////////////////////////
|
||||||
// src_o = Mdag * (source_o - Moe MeeInv source_e)
|
// src_o = Mdag * (source_o - Moe MeeInv source_e)
|
||||||
/////////////////////////////////////////////////////
|
/////////////////////////////////////////////////////
|
||||||
_Matrix.MooeeInv(src_e,tmp); // MooeeInv(source[Even],tmp,DaggerNo,Even);
|
_Matrix.MooeeInv(src_e,tmp); assert( tmp.checkerboard ==Even);
|
||||||
_Matrix.Meooe (tmp,Mtmp); // Meo (tmp,src,Odd,DaggerNo);
|
_Matrix.Meooe (tmp,Mtmp); assert( Mtmp.checkerboard ==Odd);
|
||||||
tmp=src_o-Mtmp; // axpy (tmp,src,source[Odd],-1.0);
|
tmp=src_o-Mtmp; assert( tmp.checkerboard ==Odd);
|
||||||
_Matrix.MpcDag(tmp,src_o); // Mprec(tmp,src,Mtmp,DaggerYes);
|
_Matrix.MpcDag(tmp,src_o); assert(src_o.checkerboard ==Odd);
|
||||||
|
|
||||||
//////////////////////////////////////////////////////////////
|
//////////////////////////////////////////////////////////////
|
||||||
// Call the red-black solver
|
// Call the red-black solver
|
||||||
//////////////////////////////////////////////////////////////
|
//////////////////////////////////////////////////////////////
|
||||||
_HermitianRBSolver(src_o,sol_o); // CGNE_prec_MdagM(solution[Odd],src);
|
HermitianCheckerBoardedOperator<Matrix,Field> _HermOpEO(_Matrix);
|
||||||
|
std::cout << "SchurRedBlack solver calling the MpcDagMp solver" <<std::endl;
|
||||||
|
_HermitianRBSolver(_HermOpEO,src_o,sol_o); assert(sol_o.checkerboard==Odd);
|
||||||
|
|
||||||
///////////////////////////////////////////////////
|
///////////////////////////////////////////////////
|
||||||
// sol_e = M_ee^-1 * ( src_e - Meo sol_o )...
|
// sol_e = M_ee^-1 * ( src_e - Meo sol_o )...
|
||||||
///////////////////////////////////////////////////
|
///////////////////////////////////////////////////
|
||||||
_Matrix.Meooe(sol_o,tmp); // Meo(solution[Odd],tmp,Even,DaggerNo);
|
_Matrix.Meooe(sol_o,tmp); assert( tmp.checkerboard ==Even);
|
||||||
src_e = src_e-tmp; // axpy(src,tmp,source[Even],-1.0);
|
src_e = src_e-tmp; assert( src_e.checkerboard ==Even);
|
||||||
_Matrix.MooeeInv(src_e,sol_e); // MooeeInv(src,solution[Even],DaggerNo,Even);
|
_Matrix.MooeeInv(src_e,sol_e); assert( sol_e.checkerboard ==Even);
|
||||||
|
|
||||||
setCheckerboard(out,sol_e);
|
setCheckerboard(out,sol_e); assert( sol_e.checkerboard ==Even);
|
||||||
setCheckerboard(out,sol_o);
|
setCheckerboard(out,sol_o); assert( sol_o.checkerboard ==Odd );
|
||||||
|
|
||||||
|
// Verify the unprec residual
|
||||||
|
_Matrix.M(out,resid);
|
||||||
|
resid = resid-in;
|
||||||
|
RealD ns = norm2(in);
|
||||||
|
RealD nr = norm2(resid);
|
||||||
|
|
||||||
|
std::cout << "SchurRedBlack solver true unprec resid "<< sqrt(nr/ns) <<" nr "<< nr <<" ns "<<ns << std::endl;
|
||||||
}
|
}
|
||||||
};
|
};
|
||||||
|
|
||||||
|
Reference in New Issue
Block a user