mirror of
https://github.com/paboyle/Grid.git
synced 2025-04-09 21:50:45 +01:00
MultiRHS solver test
This commit is contained in:
parent
3d99b09dba
commit
d80d802f9d
@ -81,6 +81,30 @@ static void sliceMaddMatrix (Lattice<vobj> &R,Eigen::MatrixXcd &aa,const Lattice
|
|||||||
}
|
}
|
||||||
};
|
};
|
||||||
template<class vobj>
|
template<class vobj>
|
||||||
|
static void sliceMaddVector (Lattice<vobj> &R,std::vector<RealD> &a,const Lattice<vobj> &X,const Lattice<vobj> &Y,
|
||||||
|
int Orthog,RealD scale=1.0)
|
||||||
|
{
|
||||||
|
typedef typename vobj::scalar_object sobj;
|
||||||
|
typedef typename vobj::scalar_type scalar_type;
|
||||||
|
typedef typename vobj::vector_type vector_type;
|
||||||
|
|
||||||
|
int Nblock = X._grid->GlobalDimensions()[Orthog];
|
||||||
|
|
||||||
|
GridBase *FullGrid = X._grid;
|
||||||
|
GridBase *SliceGrid = makeSubSliceGrid(FullGrid,Orthog);
|
||||||
|
|
||||||
|
Lattice<vobj> Xslice(SliceGrid);
|
||||||
|
Lattice<vobj> Rslice(SliceGrid);
|
||||||
|
// If we based this on Cshift it would work for spread out
|
||||||
|
// but it would be even slower
|
||||||
|
for(int i=0;i<Nblock;i++){
|
||||||
|
ExtractSlice(Rslice,Y,i,Orthog);
|
||||||
|
ExtractSlice(Xslice,X,i,Orthog);
|
||||||
|
Rslice = Rslice + Xslice*(scale*a[i]);
|
||||||
|
InsertSlice(Rslice,R,i,Orthog);
|
||||||
|
}
|
||||||
|
};
|
||||||
|
template<class vobj>
|
||||||
static void sliceInnerProductMatrix( Eigen::MatrixXcd &mat, const Lattice<vobj> &lhs,const Lattice<vobj> &rhs,int Orthog)
|
static void sliceInnerProductMatrix( Eigen::MatrixXcd &mat, const Lattice<vobj> &lhs,const Lattice<vobj> &rhs,int Orthog)
|
||||||
{
|
{
|
||||||
typedef typename vobj::scalar_object sobj;
|
typedef typename vobj::scalar_object sobj;
|
||||||
@ -194,6 +218,8 @@ static void sliceInnerProductMatrixOld( Eigen::MatrixXcd &mat, const Lattice<vo
|
|||||||
}
|
}
|
||||||
*/
|
*/
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
//////////////////////////////////////////////////////////////////////////
|
//////////////////////////////////////////////////////////////////////////
|
||||||
// Block conjugate gradient. Dimension zero should be the block direction
|
// Block conjugate gradient. Dimension zero should be the block direction
|
||||||
//////////////////////////////////////////////////////////////////////////
|
//////////////////////////////////////////////////////////////////////////
|
||||||
@ -333,5 +359,138 @@ void operator()(LinearOperatorBase<Field> &Linop, const Field &Src, Field &Psi)
|
|||||||
IterationsToComplete = k;
|
IterationsToComplete = k;
|
||||||
}
|
}
|
||||||
};
|
};
|
||||||
|
|
||||||
|
|
||||||
|
//////////////////////////////////////////////////////////////////////////
|
||||||
|
// multiRHS conjugate gradient. Dimension zero should be the block direction
|
||||||
|
//////////////////////////////////////////////////////////////////////////
|
||||||
|
template <class Field>
|
||||||
|
class MultiRHSConjugateGradient : public OperatorFunction<Field> {
|
||||||
|
public:
|
||||||
|
|
||||||
|
typedef typename Field::scalar_type scomplex;
|
||||||
|
|
||||||
|
const int blockDim = 0;
|
||||||
|
|
||||||
|
int Nblock;
|
||||||
|
bool ErrorOnNoConverge; // throw an assert when the CG fails to converge.
|
||||||
|
// Defaults true.
|
||||||
|
RealD Tolerance;
|
||||||
|
Integer MaxIterations;
|
||||||
|
Integer IterationsToComplete; //Number of iterations the CG took to finish. Filled in upon completion
|
||||||
|
|
||||||
|
MultiRHSConjugateGradient(RealD tol, Integer maxit, bool err_on_no_conv = true)
|
||||||
|
: Tolerance(tol),
|
||||||
|
MaxIterations(maxit),
|
||||||
|
ErrorOnNoConverge(err_on_no_conv){};
|
||||||
|
|
||||||
|
void operator()(LinearOperatorBase<Field> &Linop, const Field &Src, Field &Psi)
|
||||||
|
{
|
||||||
|
int Orthog = 0; // First dimension is block dim
|
||||||
|
Nblock = Src._grid->_fdimensions[Orthog];
|
||||||
|
std::cout<<GridLogMessage<<" MultiRHS Conjugate Gradient : Orthog "<<Orthog<<std::endl;
|
||||||
|
std::cout<<GridLogMessage<<" MultiRHS Conjugate Gradient : Nblock "<<Nblock<<std::endl;
|
||||||
|
|
||||||
|
Psi.checkerboard = Src.checkerboard;
|
||||||
|
conformable(Psi, Src);
|
||||||
|
|
||||||
|
Field P(Src);
|
||||||
|
Field AP(Src);
|
||||||
|
Field R(Src);
|
||||||
|
|
||||||
|
std::vector<ComplexD> v_pAp(Nblock);
|
||||||
|
std::vector<RealD> v_rr (Nblock);
|
||||||
|
std::vector<RealD> v_rr_inv(Nblock);
|
||||||
|
std::vector<RealD> v_alpha(Nblock);
|
||||||
|
std::vector<RealD> v_beta(Nblock);
|
||||||
|
|
||||||
|
// Initial residual computation & set up
|
||||||
|
std::vector<RealD> residuals(Nblock);
|
||||||
|
std::vector<RealD> ssq(Nblock);
|
||||||
|
|
||||||
|
sliceNorm(ssq,Src,Orthog);
|
||||||
|
RealD sssum=0;
|
||||||
|
for(int b=0;b<Nblock;b++) sssum+=ssq[b];
|
||||||
|
|
||||||
|
sliceNorm(residuals,Src,Orthog);
|
||||||
|
for(int b=0;b<Nblock;b++){ assert(std::isnan(residuals[b])==0); }
|
||||||
|
|
||||||
|
sliceNorm(residuals,Psi,Orthog);
|
||||||
|
for(int b=0;b<Nblock;b++){ assert(std::isnan(residuals[b])==0); }
|
||||||
|
|
||||||
|
// Initial search dir is guess
|
||||||
|
Linop.HermOp(Psi, AP);
|
||||||
|
|
||||||
|
R = Src - AP;
|
||||||
|
P = R;
|
||||||
|
sliceNorm(v_rr,R,Orthog);
|
||||||
|
|
||||||
|
int k;
|
||||||
|
for (k = 1; k <= MaxIterations; k++) {
|
||||||
|
|
||||||
|
RealD rrsum=0;
|
||||||
|
for(int b=0;b<Nblock;b++) rrsum+=real(v_rr[b]);
|
||||||
|
|
||||||
|
std::cout << GridLogIterative << " iteration "<<k<<" rr_sum "<<rrsum<<" ssq_sum "<< sssum
|
||||||
|
<<" / "<<std::sqrt(rrsum/sssum) <<std::endl;
|
||||||
|
|
||||||
|
Linop.HermOp(P, AP);
|
||||||
|
|
||||||
|
// Alpha
|
||||||
|
sliceInnerProductVector(v_pAp,P,AP,Orthog);
|
||||||
|
for(int b=0;b<Nblock;b++){
|
||||||
|
v_alpha[b] = v_rr[b]/real(v_pAp[b]);
|
||||||
|
}
|
||||||
|
|
||||||
|
// Psi, R update
|
||||||
|
sliceMaddVector(Psi,v_alpha, P,Psi,Orthog); // add alpha * P to psi
|
||||||
|
sliceMaddVector(R ,v_alpha,AP, R,Orthog,-1.0);// sub alpha * AP to resid
|
||||||
|
|
||||||
|
// Beta
|
||||||
|
for(int b=0;b<Nblock;b++){
|
||||||
|
v_rr_inv[b] = 1.0/v_rr[b];
|
||||||
|
}
|
||||||
|
sliceNorm(v_rr,R,Orthog);
|
||||||
|
for(int b=0;b<Nblock;b++){
|
||||||
|
v_beta[b] = v_rr_inv[b] *v_rr[b];
|
||||||
|
}
|
||||||
|
|
||||||
|
// Search update
|
||||||
|
sliceMaddVector(P,v_beta,P,R,Orthog);
|
||||||
|
|
||||||
|
/*********************
|
||||||
|
* convergence monitor
|
||||||
|
*********************
|
||||||
|
*/
|
||||||
|
RealD max_resid=0;
|
||||||
|
for(int b=0;b<Nblock;b++){
|
||||||
|
RealD rr = v_rr[b]/ssq[b];
|
||||||
|
if ( rr > max_resid ) max_resid = rr;
|
||||||
|
}
|
||||||
|
|
||||||
|
if ( max_resid < Tolerance*Tolerance ) {
|
||||||
|
std::cout << GridLogMessage<<" MultiRHS solver has converged in "
|
||||||
|
<<k<<" iterations; max residual is "<<std::sqrt(max_resid)<<std::endl;
|
||||||
|
for(int b=0;b<Nblock;b++){
|
||||||
|
std::cout << GridLogMessage<< " block "<<b<<" resid "<< std::sqrt(v_rr[b]/ssq[b])<<std::endl;
|
||||||
|
}
|
||||||
|
|
||||||
|
Linop.HermOp(Psi, AP);
|
||||||
|
AP = AP-Src;
|
||||||
|
std::cout << " MultiRHS solver true residual is " << std::sqrt(norm2(AP)/norm2(Src)) <<std::endl;
|
||||||
|
IterationsToComplete = k;
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
|
||||||
|
}
|
||||||
|
std::cout << GridLogMessage << "MultiRHSConjugateGradient did NOT converge" << std::endl;
|
||||||
|
|
||||||
|
if (ErrorOnNoConverge) assert(0);
|
||||||
|
IterationsToComplete = k;
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
}
|
}
|
||||||
#endif
|
#endif
|
||||||
|
@ -81,11 +81,16 @@ int main (int argc, char ** argv)
|
|||||||
|
|
||||||
ConjugateGradient<FermionField> CG(1.0e-8,10000);
|
ConjugateGradient<FermionField> CG(1.0e-8,10000);
|
||||||
BlockConjugateGradient<FermionField> BCG(1.0e-8,10000);
|
BlockConjugateGradient<FermionField> BCG(1.0e-8,10000);
|
||||||
|
MultiRHSConjugateGradient<FermionField> mCG(1.0e-8,10000);
|
||||||
|
|
||||||
std::cout << GridLogMessage << " Calling CG "<<std::endl;
|
std::cout << GridLogMessage << " Calling CG "<<std::endl;
|
||||||
result=zero;
|
result=zero;
|
||||||
CG(HermOp,src,result);
|
CG(HermOp,src,result);
|
||||||
|
|
||||||
|
std::cout << GridLogMessage << " Calling multiRHS CG "<<std::endl;
|
||||||
|
result=zero;
|
||||||
|
mCG(HermOp,src,result);
|
||||||
|
|
||||||
std::cout << GridLogMessage << " Calling Block CG "<<std::endl;
|
std::cout << GridLogMessage << " Calling Block CG "<<std::endl;
|
||||||
result=zero;
|
result=zero;
|
||||||
BCG(HermOp,src,result);
|
BCG(HermOp,src,result);
|
||||||
|
Loading…
x
Reference in New Issue
Block a user