1
0
mirror of https://github.com/paboyle/Grid.git synced 2024-11-14 17:55:38 +00:00
Grid/lib/algorithms/iterative/GeneralisedMinimalResidual.h

267 lines
8.2 KiB
C
Raw Normal View History

2017-07-17 20:02:10 +01:00
/*************************************************************************************
Grid physics library, www.github.com/paboyle/Grid
2017-11-07 09:22:41 +00:00
Source file: ./lib/algorithms/iterative/GeneralisedMinimalResidual.h
2017-07-17 20:02:10 +01:00
Copyright (C) 2015
2017-11-07 09:22:41 +00:00
Author: Daniel Richtmann <daniel.richtmann@ur.de>
2017-07-17 20:02:10 +01:00
This program is free software; you can redistribute it and/or modify
it under the terms of the GNU General Public License as published by
the Free Software Foundation; either version 2 of the License, or
(at your option) any later version.
This program is distributed in the hope that it will be useful,
but WITHOUT ANY WARRANTY; without even the implied warranty of
MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
GNU General Public License for more details.
You should have received a copy of the GNU General Public License along
with this program; if not, write to the Free Software Foundation, Inc.,
51 Franklin Street, Fifth Floor, Boston, MA 02110-1301 USA.
See the full license in the file "LICENSE" in the top level distribution
directory
*************************************************************************************/
/* END LEGAL */
#ifndef GRID_GENERALISED_MINIMAL_RESIDUAL_H
#define GRID_GENERALISED_MINIMAL_RESIDUAL_H
// from Y. Saad - Iterative Methods for Sparse Linear Systems, PP 172
// Compute r0 = b Ax0 , β := ||r0||2 , and v1 := r0 /β
// For j = 1, 2, ..., m Do:
// Compute wj := Avj
// For i = 1, ..., j Do:
// hij := (wj , vi)
// wj := wj hij vi
// EndDo
// hj+1,j = ||wj||2 . If hj+1,j = 0 set m := j and go to HERE
// vj+1 = wj /hj+1,j
// EndDo
// Define the (m + 1) × m Hessenberg matrix H̄m = {hij}1≤i≤m+1,1≤j≤m. [HERE]
// Compute ym the minimizer of ||βe1 H̄m y||2 and xm = x0 + Vm ym.
///////////////////////////////////////////////////////////////////////////////////////////////////////
// want to solve Ax = b -> A = LinOp, psi = x, b = src
2017-07-18 16:57:13 +01:00
namespace Grid {
template<class Field>
class GeneralisedMinimalResidual : public OperatorFunction<Field> {
public:
bool ErrorOnNoConverge; // Throw an assert when GMRES fails to converge,
// defaults to True.
2017-07-18 16:57:13 +01:00
RealD Tolerance;
2017-07-18 16:57:13 +01:00
Integer MaxIterations;
Integer RestartLength;
Integer IterationCount; // Number of iterations the GMRES took to finish,
// filled in upon completion
2017-11-07 09:22:41 +00:00
GridStopWatch MatrixTimer;
GridStopWatch LinalgTimer;
GridStopWatch QrTimer;
GridStopWatch CompSolutionTimer;
Eigen::MatrixXcd H;
std::vector<std::complex<double>> y;
std::vector<std::complex<double>> gamma;
std::vector<std::complex<double>> c;
std::vector<std::complex<double>> s;
2017-07-18 16:57:13 +01:00
GeneralisedMinimalResidual(RealD tol,
Integer maxit,
Integer restart_length,
2017-07-18 16:57:13 +01:00
bool err_on_no_conv = true)
: Tolerance(tol)
, MaxIterations(maxit)
, RestartLength(restart_length)
, ErrorOnNoConverge(err_on_no_conv)
, H(Eigen::MatrixXcd::Zero(RestartLength, RestartLength + 1)) // sizes taken from DD-αAMG code base
, y(RestartLength + 1, 0.)
, gamma(RestartLength + 1, 0.)
, c(RestartLength + 1, 0.)
, s(RestartLength + 1, 0.) {};
2017-07-18 16:57:13 +01:00
void operator()(LinearOperatorBase<Field> &LinOp, const Field &src, Field &psi) {
2017-11-06 16:05:25 +00:00
psi.checkerboard = src.checkerboard;
2017-11-07 14:00:08 +00:00
conformable(psi, src);
2017-11-06 16:05:25 +00:00
RealD guess = norm2(psi);
assert(std::isnan(guess) == 0);
RealD cp;
RealD ssq = norm2(src);
RealD rsq = Tolerance * Tolerance * ssq;
2017-11-06 16:05:25 +00:00
Field r(src._grid);
std::cout << std::setprecision(4) << std::scientific << std::endl;
std::cout << GridLogIterative << "GeneralisedMinimalResidual: guess " << guess << std::endl;
std::cout << GridLogIterative << "GeneralisedMinimalResidual: src " << ssq << std::endl;
2017-11-08 12:23:41 +00:00
2017-11-07 09:22:41 +00:00
MatrixTimer.Reset();
2017-11-06 16:05:25 +00:00
LinalgTimer.Reset();
QrTimer.Reset();
CompSolutionTimer.Reset();
2017-11-06 16:05:25 +00:00
GridStopWatch SolverTimer;
SolverTimer.Start();
IterationCount = 0;
2017-11-06 16:05:25 +00:00
for (int k=0; k<MaxIterations; k++) {
cp = outerLoopBody(LinOp, src, psi, rsq);
2017-11-06 16:05:25 +00:00
// Stopping condition
if (cp <= rsq) {
2017-11-06 16:05:25 +00:00
SolverTimer.Stop();
2017-11-07 09:22:41 +00:00
LinOp.Op(psi,r);
2017-11-06 16:12:23 +00:00
axpy(r,-1.0,src,r);
2017-11-06 16:05:25 +00:00
RealD srcnorm = sqrt(ssq);
RealD resnorm = sqrt(norm2(r));
RealD true_residual = resnorm / srcnorm;
std::cout << GridLogMessage << "GeneralisedMinimalResidual: Converged on iteration " << IterationCount << std::endl;
std::cout << GridLogMessage << "\tComputed residual " << sqrt(cp / ssq) << std::endl;
std::cout << GridLogMessage << "\tTrue residual " << true_residual << std::endl;
std::cout << GridLogMessage << "\tTarget " << Tolerance << std::endl;
2017-11-06 16:05:25 +00:00
std::cout << GridLogMessage << "GeneralisedMinimalResidual Time breakdown" << std::endl;
2017-11-06 16:05:25 +00:00
std::cout << GridLogMessage << "\tElapsed " << SolverTimer.Elapsed() << std::endl;
2017-11-07 09:22:41 +00:00
std::cout << GridLogMessage << "\tMatrix " << MatrixTimer.Elapsed() << std::endl;
2017-11-06 16:05:25 +00:00
std::cout << GridLogMessage << "\tLinalg " << LinalgTimer.Elapsed() << std::endl;
std::cout << GridLogMessage << "\tQR " << QrTimer.Elapsed() << std::endl;
std::cout << GridLogMessage << "\tCompSol " << CompSolutionTimer.Elapsed() << std::endl;
2017-11-06 16:05:25 +00:00
return;
}
}
std::cout << GridLogMessage << "GeneralisedMinimalResidual did NOT converge" << std::endl;
2017-11-06 16:05:25 +00:00
if (ErrorOnNoConverge)
assert(0);
}
RealD outerLoopBody(LinearOperatorBase<Field> &LinOp, const Field &src, Field &psi, RealD rsq) {
2017-11-07 09:22:41 +00:00
RealD cp = 0;
2017-11-06 17:09:48 +00:00
Field w(src._grid);
Field r(src._grid);
std::vector<Field> v(RestartLength + 1, src._grid);
2017-11-06 17:09:48 +00:00
MatrixTimer.Start();
LinOp.Op(psi, w);
2017-11-06 17:09:48 +00:00
MatrixTimer.Stop();
LinalgTimer.Start();
r = src - w;
gamma[0] = sqrt(norm2(r));
2017-11-06 17:09:48 +00:00
v[0] = (1. / gamma[0]) * r;
LinalgTimer.Stop();
for (int i=0; i<RestartLength; i++) {
IterationCount++;
2017-11-06 17:09:48 +00:00
arnoldiStep(LinOp, v, w, i);
2017-11-06 17:09:48 +00:00
qrUpdate(i);
2017-11-06 17:09:48 +00:00
cp = std::norm(gamma[i+1]);
2017-11-06 17:09:48 +00:00
std::cout << GridLogIterative << "GeneralisedMinimalResidual: Iteration " << IterationCount
<< " residual " << cp << " target " << rsq << std::endl;
2017-11-06 17:09:48 +00:00
if ((i == RestartLength - 1) || (cp <= rsq)) {
2017-11-06 17:09:48 +00:00
computeSolution(v, psi, i);
2017-11-06 17:09:48 +00:00
return cp;
}
}
assert(0); // Never reached
return cp;
2017-11-06 17:09:48 +00:00
}
2017-11-07 09:22:41 +00:00
void arnoldiStep(LinearOperatorBase<Field> &LinOp, std::vector<Field> &v, Field &w, int iter) {
2017-11-06 17:09:48 +00:00
MatrixTimer.Start();
2017-11-07 09:22:41 +00:00
LinOp.Op(v[iter], w);
2017-11-06 17:09:48 +00:00
MatrixTimer.Stop();
LinalgTimer.Start();
for (int i = 0; i <= iter; ++i) {
H(iter, i) = innerProduct(v[i], w);
w = w - H(iter, i) * v[i];
2017-11-06 17:09:48 +00:00
}
H(iter, iter + 1) = sqrt(norm2(w));
v[iter + 1] = (1. / H(iter, iter + 1)) * w;
2017-11-06 17:09:48 +00:00
LinalgTimer.Stop();
}
void qrUpdate(int iter) {
QrTimer.Start();
for (int i = 0; i < iter ; ++i) {
auto tmp = -s[i] * H(iter, i) + c[i] * H(iter, i + 1);
H(iter, i) = std::conj(c[i]) * H(iter, i) + std::conj(s[i]) * H(iter, i + 1);
H(iter, i + 1) = tmp;
2017-11-06 17:09:48 +00:00
}
// Compute new Givens Rotation
ComplexD nu = sqrt(std::norm(H(iter, iter)) + std::norm(H(iter, iter + 1)));
c[iter] = H(iter, iter) / nu;
s[iter] = H(iter, iter + 1) / nu;
2017-11-06 17:09:48 +00:00
// Apply new Givens rotation
2017-11-06 17:09:48 +00:00
H(iter, iter) = nu;
H(iter, iter + 1) = 0.;
2017-11-06 17:09:48 +00:00
gamma[iter + 1] = -s[iter] * gamma[iter];
gamma[iter] = std::conj(c[iter]) * gamma[iter];
QrTimer.Stop();
2017-11-06 16:05:25 +00:00
}
2017-11-08 13:22:38 +00:00
void computeSolution(std::vector<Field> const &v, Field &psi, int iter) {
2017-11-06 16:05:25 +00:00
CompSolutionTimer.Start();
for (int i = iter; i >= 0; i--) {
2017-07-18 16:57:13 +01:00
y[i] = gamma[i];
for (int k = i + 1; k <= iter; k++)
2017-11-08 13:23:55 +00:00
y[i] = y[i] - H(k, i) * y[k];
y[i] = y[i] / H(i, i);
2017-07-17 20:02:10 +01:00
}
// TODO: Use axpys or similar for these
// TODO: Fix the condition
if (true) {
for (int i = 0; i <= iter; i++)
psi = psi + v[i] * y[i];
}
else {
psi = y[0] * v[0];
for (int i = 1; i <= iter; i++)
psi = psi + v[i] * y[i];
}
CompSolutionTimer.Stop();
2017-07-18 16:57:13 +01:00
}
};
2017-07-17 20:02:10 +01:00
}
2017-07-18 16:57:13 +01:00
#endif