1
0
mirror of https://github.com/paboyle/Grid.git synced 2025-04-24 12:45:56 +01:00

Schur solver for Mdag

This commit is contained in:
Peter Boyle 2021-05-06 23:21:15 +02:00
parent f8d7d23893
commit 07d1030660

View File

@ -40,7 +40,7 @@ Author: Peter Boyle <paboyle@ph.ed.ac.uk>
* (-MoeMee^{-1} 1 ) * (-MoeMee^{-1} 1 )
* L^{dag} = ( 1 Mee^{-dag} Moe^{dag} ) * L^{dag} = ( 1 Mee^{-dag} Moe^{dag} )
* ( 0 1 ) * ( 0 1 )
* L^{-d} = ( 1 -Mee^{-dag} Moe^{dag} ) * L^{-dag}= ( 1 -Mee^{-dag} Moe^{dag} )
* ( 0 1 ) * ( 0 1 )
* *
* U^-1 = (1 -Mee^{-1} Meo) * U^-1 = (1 -Mee^{-1} Meo)
@ -82,7 +82,8 @@ Author: Peter Boyle <paboyle@ph.ed.ac.uk>
* c) M_oo^-dag Doo^{dag} Doo Moo^-1 phi_0 = M_oo^-dag (D_oo)^dag L^{-1} eta_o * c) M_oo^-dag Doo^{dag} Doo Moo^-1 phi_0 = M_oo^-dag (D_oo)^dag L^{-1} eta_o
* eta_o' = M_oo^-dag (D_oo)^dag (eta_o - Moe Mee^{-1} eta_e) * eta_o' = M_oo^-dag (D_oo)^dag (eta_o - Moe Mee^{-1} eta_e)
* psi_o = M_oo^-1 phi_o * psi_o = M_oo^-1 phi_o
* TODO: Deflation *
*
*/ */
namespace Grid { namespace Grid {
@ -97,6 +98,7 @@ namespace Grid {
protected: protected:
typedef CheckerBoardedSparseMatrixBase<Field> Matrix; typedef CheckerBoardedSparseMatrixBase<Field> Matrix;
OperatorFunction<Field> & _HermitianRBSolver; OperatorFunction<Field> & _HermitianRBSolver;
int CBfactorise; int CBfactorise;
bool subGuess; bool subGuess;
bool useSolnAsInitGuess; // if true user-supplied solution vector is used as initial guess for solver bool useSolnAsInitGuess; // if true user-supplied solution vector is used as initial guess for solver
@ -190,12 +192,19 @@ namespace Grid {
// Check unprec residual if possible // Check unprec residual if possible
///////////////////////////////////////////////// /////////////////////////////////////////////////
if ( ! subGuess ) { if ( ! subGuess ) {
_Matrix.M(out[b],resid);
if ( this->adjoint() ) _Matrix.Mdag(out[b],resid);
else _Matrix.M(out[b],resid);
resid = resid-in[b]; resid = resid-in[b];
RealD ns = norm2(in[b]); RealD ns = norm2(in[b]);
RealD nr = norm2(resid); RealD nr = norm2(resid);
std::cout<<GridLogMessage<< "SchurRedBlackBase solver true unprec resid["<<b<<"] "<<std::sqrt(nr/ns) << std::endl; std::cout<<GridLogMessage<< "SchurRedBlackBase adjoint "<< this->adjoint() << std::endl;
if ( this->adjoint() )
std::cout<<GridLogMessage<< "SchurRedBlackBase adjoint solver true unprec resid["<<b<<"] "<<std::sqrt(nr/ns) << std::endl;
else
std::cout<<GridLogMessage<< "SchurRedBlackBase solver true unprec resid["<<b<<"] "<<std::sqrt(nr/ns) << std::endl;
} else { } else {
std::cout<<GridLogMessage<< "SchurRedBlackBase Guess subtracted after solve["<<b<<"] " << std::endl; std::cout<<GridLogMessage<< "SchurRedBlackBase Guess subtracted after solve["<<b<<"] " << std::endl;
} }
@ -249,12 +258,21 @@ namespace Grid {
// Verify the unprec residual // Verify the unprec residual
if ( ! subGuess ) { if ( ! subGuess ) {
_Matrix.M(out,resid);
std::cout<<GridLogMessage<< "SchurRedBlackBase adjoint "<< this->adjoint() << std::endl;
if ( this->adjoint() ) _Matrix.Mdag(out,resid);
else _Matrix.M(out,resid);
resid = resid-in; resid = resid-in;
RealD ns = norm2(in); RealD ns = norm2(in);
RealD nr = norm2(resid); RealD nr = norm2(resid);
std::cout<<GridLogMessage << "SchurRedBlackBase solver true unprec resid "<< std::sqrt(nr/ns) << std::endl; if ( this->adjoint() )
std::cout<<GridLogMessage<< "SchurRedBlackBase adjoint solver true unprec resid "<<std::sqrt(nr/ns) << std::endl;
else
std::cout<<GridLogMessage<< "SchurRedBlackBase solver true unprec resid "<<std::sqrt(nr/ns) << std::endl;
} else { } else {
std::cout << GridLogMessage << "SchurRedBlackBase Guess subtracted after solve." << std::endl; std::cout << GridLogMessage << "SchurRedBlackBase Guess subtracted after solve." << std::endl;
} }
@ -263,6 +281,7 @@ namespace Grid {
///////////////////////////////////////////////////////////// /////////////////////////////////////////////////////////////
// Override in derived. // Override in derived.
///////////////////////////////////////////////////////////// /////////////////////////////////////////////////////////////
virtual bool adjoint(void) { return false; }
virtual void RedBlackSource (Matrix & _Matrix,const Field &src, Field &src_e,Field &src_o) =0; virtual void RedBlackSource (Matrix & _Matrix,const Field &src, Field &src_e,Field &src_o) =0;
virtual void RedBlackSolution(Matrix & _Matrix,const Field &sol_o, const Field &src_e,Field &sol) =0; virtual void RedBlackSolution(Matrix & _Matrix,const Field &sol_o, const Field &src_e,Field &sol) =0;
virtual void RedBlackSolve (Matrix & _Matrix,const Field &src_o, Field &sol_o) =0; virtual void RedBlackSolve (Matrix & _Matrix,const Field &src_o, Field &sol_o) =0;
@ -616,6 +635,127 @@ namespace Grid {
this->_HermitianRBSolver(_OpEO, src_o, sol_o); this->_HermitianRBSolver(_OpEO, src_o, sol_o);
} }
}; };
/*
* Red black Schur decomposition
*
* M = (Mee Meo) = (1 0 ) (Mee 0 ) (1 Mee^{-1} Meo)
* (Moe Moo) (Moe Mee^-1 1 ) (0 Moo-Moe Mee^-1 Meo) (0 1 )
* = L D U
*
* L^-1 = (1 0 )
* (-MoeMee^{-1} 1 )
* L^{dag} = ( 1 Mee^{-dag} Moe^{dag} )
* ( 0 1 )
*
* U^-1 = (1 -Mee^{-1} Meo)
* (0 1 )
* U^{dag} = ( 1 0)
* (Meo^dag Mee^{-dag} 1)
* U^{-dag} = ( 1 0)
* (-Meo^dag Mee^{-dag} 1)
*
*
***********************
* M^dag psi = eta
***********************
*
* Really for Mobius: (Wilson - easier to just use gamma 5 hermiticity)
*
* Mdag psi = Udag Ddag Ldag psi = eta
*
* U^{-dag} = ( 1 0)
* (-Meo^dag Mee^{-dag} 1)
*
*
* i) D^dag phi = (U^{-dag} eta)
* eta'_e = eta_e
* eta'_o = (eta_o - Meo^dag Mee^{-dag} eta_e)
*
* phi_o = D_oo^-dag eta'_o = D_oo^-dag (eta_o - Meo^dag Mee^{-dag} eta_e)
*
* phi_e = D_ee^-dag eta'_e = D_ee^-dag eta_e
*
* Solve:
*
* D_oo D_oo^dag phi_o = D_oo (eta_o - Meo^dag Mee^{-dag} eta_e)
*
* ii)
* phi = L^dag psi => psi = L^-dag phi.
*
* L^{-dag} = ( 1 -Mee^{-dag} Moe^{dag} )
* ( 0 1 )
*
* => sol_e = M_ee^-dag * ( src_e - Moe^dag phi_o )...
* => sol_o = phi_o
*/
///////////////////////////////////////////////////////////////////////////////////////////////////////
// Site diagonal has Mooee on it, but solve the Adjoint system
///////////////////////////////////////////////////////////////////////////////////////////////////////
template<class Field> class SchurRedBlackDiagMooeeDagSolve : public SchurRedBlackBase<Field> {
public:
typedef CheckerBoardedSparseMatrixBase<Field> Matrix;
virtual bool adjoint(void) { return true; }
SchurRedBlackDiagMooeeDagSolve(OperatorFunction<Field> &HermitianRBSolver,
const bool initSubGuess = false,
const bool _solnAsInitGuess = false)
: SchurRedBlackBase<Field> (HermitianRBSolver,initSubGuess,_solnAsInitGuess) {};
//////////////////////////////////////////////////////
// Override RedBlack specialisation
//////////////////////////////////////////////////////
virtual void RedBlackSource(Matrix & _Matrix,const Field &src, Field &src_e,Field &src_o)
{
GridBase *grid = _Matrix.RedBlackGrid();
GridBase *fgrid= _Matrix.Grid();
Field tmp(grid);
Field Mtmp(grid);
pickCheckerboard(Even,src_e,src);
pickCheckerboard(Odd ,src_o,src);
/////////////////////////////////////////////////////
// src_o = (source_o - Moe^dag MeeInvDag source_e)
/////////////////////////////////////////////////////
_Matrix.MooeeInvDag(src_e,tmp); assert( tmp.Checkerboard() ==Even);
_Matrix.MeooeDag (tmp,Mtmp); assert( Mtmp.Checkerboard() ==Odd);
tmp=src_o-Mtmp; assert( tmp.Checkerboard() ==Odd);
// get the right Mpc
SchurDiagMooeeOperator<Matrix,Field> _HermOpEO(_Matrix);
_HermOpEO.Mpc(tmp,src_o); assert(src_o.Checkerboard() ==Odd);
}
virtual void RedBlackSolve (Matrix & _Matrix,const Field &src_o, Field &sol_o)
{
SchurDiagMooeeDagOperator<Matrix,Field> _HermOpEO(_Matrix);
this->_HermitianRBSolver(_HermOpEO,src_o,sol_o);
};
virtual void RedBlackSolve (Matrix & _Matrix,const std::vector<Field> &src_o, std::vector<Field> &sol_o)
{
SchurDiagMooeeDagOperator<Matrix,Field> _HermOpEO(_Matrix);
this->_HermitianRBSolver(_HermOpEO,src_o,sol_o);
}
virtual void RedBlackSolution(Matrix & _Matrix,const Field &sol_o, const Field &src_e,Field &sol)
{
GridBase *grid = _Matrix.RedBlackGrid();
GridBase *fgrid= _Matrix.Grid();
Field sol_e(grid);
Field tmp(grid);
///////////////////////////////////////////////////
// sol_e = M_ee^-dag * ( src_e - Moe^dag phi_o )...
// sol_o = phi_o
///////////////////////////////////////////////////
_Matrix.MeooeDag(sol_o,tmp); assert(tmp.Checkerboard()==Even);
tmp = src_e-tmp; assert(tmp.Checkerboard()==Even);
_Matrix.MooeeInvDag(tmp,sol_e); assert(sol_e.Checkerboard()==Even);
setCheckerboard(sol,sol_e); assert( sol_e.Checkerboard() ==Even);
setCheckerboard(sol,sol_o); assert( sol_o.Checkerboard() ==Odd );
}
};
} }
#endif #endif