mirror of
https://github.com/paboyle/Grid.git
synced 2024-11-09 23:45:36 +00:00
Add option to record the CG polynomial
This commit is contained in:
parent
fe65fa4988
commit
3752c49ef0
@ -38,12 +38,13 @@ NAMESPACE_BEGIN(Grid);
|
|||||||
// single input vec, single output vec.
|
// single input vec, single output vec.
|
||||||
/////////////////////////////////////////////////////////////
|
/////////////////////////////////////////////////////////////
|
||||||
|
|
||||||
|
|
||||||
template <class Field>
|
template <class Field>
|
||||||
class ConjugateGradient : public OperatorFunction<Field> {
|
class ConjugateGradient : public OperatorFunction<Field> {
|
||||||
public:
|
public:
|
||||||
|
|
||||||
using OperatorFunction<Field>::operator();
|
using OperatorFunction<Field>::operator();
|
||||||
|
|
||||||
bool ErrorOnNoConverge; // throw an assert when the CG fails to converge.
|
bool ErrorOnNoConverge; // throw an assert when the CG fails to converge.
|
||||||
// Defaults true.
|
// Defaults true.
|
||||||
RealD Tolerance;
|
RealD Tolerance;
|
||||||
@ -57,10 +58,22 @@ public:
|
|||||||
ErrorOnNoConverge(err_on_no_conv)
|
ErrorOnNoConverge(err_on_no_conv)
|
||||||
{};
|
{};
|
||||||
|
|
||||||
void operator()(LinearOperatorBase<Field> &Linop, const Field &src, Field &psi) {
|
virtual void LogIteration(int k,RealD a,RealD b){
|
||||||
|
// std::cout << "ConjugageGradient::LogIteration() "<<std::endl;
|
||||||
|
};
|
||||||
|
virtual void LogBegin(void){
|
||||||
|
std::cout << "ConjugageGradient::LogBegin() "<<std::endl;
|
||||||
|
};
|
||||||
|
|
||||||
GRID_TRACE("ConjugateGradient");
|
void operator()(LinearOperatorBase<Field> &Linop, const Field &src, Field &psi) {
|
||||||
|
|
||||||
|
this->LogBegin();
|
||||||
|
|
||||||
|
GRID_TRACE("ConjugateGradient");
|
||||||
GridStopWatch PreambleTimer;
|
GridStopWatch PreambleTimer;
|
||||||
|
GridStopWatch ConstructTimer;
|
||||||
|
GridStopWatch NormTimer;
|
||||||
|
GridStopWatch AssignTimer;
|
||||||
PreambleTimer.Start();
|
PreambleTimer.Start();
|
||||||
psi.Checkerboard() = src.Checkerboard();
|
psi.Checkerboard() = src.Checkerboard();
|
||||||
|
|
||||||
@ -70,14 +83,19 @@ public:
|
|||||||
//RealD b_pred;
|
//RealD b_pred;
|
||||||
|
|
||||||
// Was doing copies
|
// Was doing copies
|
||||||
Field p(src.Grid());
|
ConstructTimer.Start();
|
||||||
|
Field p (src.Grid());
|
||||||
Field mmp(src.Grid());
|
Field mmp(src.Grid());
|
||||||
Field r(src.Grid());
|
Field r (src.Grid());
|
||||||
|
ConstructTimer.Stop();
|
||||||
|
|
||||||
// Initial residual computation & set up
|
// Initial residual computation & set up
|
||||||
|
NormTimer.Start();
|
||||||
ssq = norm2(src);
|
ssq = norm2(src);
|
||||||
RealD guess = norm2(psi);
|
RealD guess = norm2(psi);
|
||||||
|
NormTimer.Stop();
|
||||||
assert(std::isnan(guess) == 0);
|
assert(std::isnan(guess) == 0);
|
||||||
|
AssignTimer.Start();
|
||||||
if ( guess == 0.0 ) {
|
if ( guess == 0.0 ) {
|
||||||
r = src;
|
r = src;
|
||||||
p = r;
|
p = r;
|
||||||
@ -89,6 +107,7 @@ public:
|
|||||||
a = norm2(p);
|
a = norm2(p);
|
||||||
}
|
}
|
||||||
cp = a;
|
cp = a;
|
||||||
|
AssignTimer.Stop();
|
||||||
|
|
||||||
// Handle trivial case of zero src
|
// Handle trivial case of zero src
|
||||||
if (ssq == 0.){
|
if (ssq == 0.){
|
||||||
@ -164,6 +183,7 @@ public:
|
|||||||
}
|
}
|
||||||
LinearCombTimer.Stop();
|
LinearCombTimer.Stop();
|
||||||
LinalgTimer.Stop();
|
LinalgTimer.Stop();
|
||||||
|
LogIteration(k,a,b);
|
||||||
|
|
||||||
IterationTimer.Stop();
|
IterationTimer.Stop();
|
||||||
if ( (k % 500) == 0 ) {
|
if ( (k % 500) == 0 ) {
|
||||||
@ -220,6 +240,9 @@ public:
|
|||||||
<<" residual "<< std::sqrt(cp / ssq)<< std::endl;
|
<<" residual "<< std::sqrt(cp / ssq)<< std::endl;
|
||||||
SolverTimer.Stop();
|
SolverTimer.Stop();
|
||||||
std::cout << GridLogMessage << "\tPreamble " << PreambleTimer.Elapsed() <<std::endl;
|
std::cout << GridLogMessage << "\tPreamble " << PreambleTimer.Elapsed() <<std::endl;
|
||||||
|
std::cout << GridLogMessage << "\tConstruct " << ConstructTimer.Elapsed() <<std::endl;
|
||||||
|
std::cout << GridLogMessage << "\tNorm " << NormTimer.Elapsed() <<std::endl;
|
||||||
|
std::cout << GridLogMessage << "\tAssign " << AssignTimer.Elapsed() <<std::endl;
|
||||||
std::cout << GridLogMessage << "\tSolver " << SolverTimer.Elapsed() <<std::endl;
|
std::cout << GridLogMessage << "\tSolver " << SolverTimer.Elapsed() <<std::endl;
|
||||||
std::cout << GridLogMessage << "Solver breakdown "<<std::endl;
|
std::cout << GridLogMessage << "Solver breakdown "<<std::endl;
|
||||||
std::cout << GridLogMessage << "\tMatrix " << MatrixTimer.Elapsed() <<std::endl;
|
std::cout << GridLogMessage << "\tMatrix " << MatrixTimer.Elapsed() <<std::endl;
|
||||||
@ -233,5 +256,118 @@ public:
|
|||||||
|
|
||||||
}
|
}
|
||||||
};
|
};
|
||||||
|
|
||||||
|
|
||||||
|
template <class Field>
|
||||||
|
class ConjugateGradientPolynomial : public ConjugateGradient<Field> {
|
||||||
|
public:
|
||||||
|
// Optionally record the CG polynomial
|
||||||
|
std::vector<double> ak;
|
||||||
|
std::vector<double> bk;
|
||||||
|
std::vector<double> poly_p;
|
||||||
|
std::vector<double> poly_r;
|
||||||
|
std::vector<double> poly_Ap;
|
||||||
|
std::vector<double> polynomial;
|
||||||
|
|
||||||
|
public:
|
||||||
|
ConjugateGradientPolynomial(RealD tol, Integer maxit, bool err_on_no_conv = true)
|
||||||
|
: ConjugateGradient<Field>(tol,maxit,err_on_no_conv)
|
||||||
|
{ };
|
||||||
|
void PolyHermOp(LinearOperatorBase<Field> &Linop, const Field &src, Field &psi)
|
||||||
|
{
|
||||||
|
Field tmp(src.Grid());
|
||||||
|
Field AtoN(src.Grid());
|
||||||
|
AtoN = src;
|
||||||
|
psi=AtoN*polynomial[0];
|
||||||
|
for(int n=1;n<polynomial.size();n++){
|
||||||
|
tmp = AtoN;
|
||||||
|
Linop.HermOp(tmp,AtoN);
|
||||||
|
psi = psi + polynomial[n]*AtoN;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
void CGsequenceHermOp(LinearOperatorBase<Field> &Linop, const Field &src, Field &x)
|
||||||
|
{
|
||||||
|
Field Ap(src.Grid());
|
||||||
|
Field r(src.Grid());
|
||||||
|
Field p(src.Grid());
|
||||||
|
p=src;
|
||||||
|
r=src;
|
||||||
|
x=Zero();
|
||||||
|
x.Checkerboard()=src.Checkerboard();
|
||||||
|
for(int k=0;k<ak.size();k++){
|
||||||
|
x = x + ak[k]*p;
|
||||||
|
Linop.HermOp(p,Ap);
|
||||||
|
r = r - ak[k] * Ap;
|
||||||
|
p = r + bk[k] * p;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
void Solve(LinearOperatorBase<Field> &Linop, const Field &src, Field &psi)
|
||||||
|
{
|
||||||
|
psi=Zero();
|
||||||
|
this->operator ()(Linop,src,psi);
|
||||||
|
}
|
||||||
|
virtual void LogBegin(void)
|
||||||
|
{
|
||||||
|
std::cout << "ConjugageGradientPolynomial::LogBegin() "<<std::endl;
|
||||||
|
ak.resize(0);
|
||||||
|
bk.resize(0);
|
||||||
|
polynomial.resize(0);
|
||||||
|
poly_Ap.resize(0);
|
||||||
|
poly_Ap.resize(0);
|
||||||
|
poly_p.resize(1);
|
||||||
|
poly_r.resize(1);
|
||||||
|
poly_p[0]=1.0;
|
||||||
|
poly_r[0]=1.0;
|
||||||
|
};
|
||||||
|
virtual void LogIteration(int k,RealD a,RealD b)
|
||||||
|
{
|
||||||
|
// With zero guess,
|
||||||
|
// p = r = src
|
||||||
|
//
|
||||||
|
// iterate:
|
||||||
|
// x = x + a p
|
||||||
|
// r = r - a A p
|
||||||
|
// p = r + b p
|
||||||
|
//
|
||||||
|
// [0]
|
||||||
|
// r = x
|
||||||
|
// p = x
|
||||||
|
// Ap=0
|
||||||
|
//
|
||||||
|
// [1]
|
||||||
|
// Ap = A x + 0 ==> shift poly P right by 1 and add 0.
|
||||||
|
// x = x + a p ==> add polynomials term by term
|
||||||
|
// r = r - a A p ==> add polynomials term by term
|
||||||
|
// p = r + b p ==> add polynomials term by term
|
||||||
|
//
|
||||||
|
std::cout << "ConjugageGradientPolynomial::LogIteration() "<<k<<std::endl;
|
||||||
|
ak.push_back(a);
|
||||||
|
bk.push_back(b);
|
||||||
|
// Ap= right_shift(p)
|
||||||
|
poly_Ap.resize(k+1);
|
||||||
|
poly_Ap[0]=0.0;
|
||||||
|
for(int i=0;i<k;i++){
|
||||||
|
poly_Ap[i+1]=poly_p[i];
|
||||||
|
}
|
||||||
|
|
||||||
|
// x = x + a p
|
||||||
|
polynomial.resize(k);
|
||||||
|
polynomial[k-1]=0.0;
|
||||||
|
for(int i=0;i<k;i++){
|
||||||
|
polynomial[i] = polynomial[i] + a * poly_p[i];
|
||||||
|
}
|
||||||
|
|
||||||
|
// r = r - a Ap
|
||||||
|
// p = r + b p
|
||||||
|
poly_r.resize(k+1);
|
||||||
|
poly_p.resize(k+1);
|
||||||
|
poly_r[k] = poly_p[k] = 0.0;
|
||||||
|
for(int i=0;i<k+1;i++){
|
||||||
|
poly_r[i] = poly_r[i] - a * poly_Ap[i];
|
||||||
|
poly_p[i] = poly_r[i] + b * poly_p[i];
|
||||||
|
}
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
NAMESPACE_END(Grid);
|
NAMESPACE_END(Grid);
|
||||||
#endif
|
#endif
|
||||||
|
Loading…
Reference in New Issue
Block a user