mirror of
https://github.com/paboyle/Grid.git
synced 2025-06-14 05:07:05 +01:00
Added a new, much faster implementation of precision change that uses (optionally) a precomputed workspace containing pointer offsets that is device resident, such that all lattice copying occurs only on the device and no host<->device transfer is required, other than the pointer table. It also avoids the need to unpack and repack the fields using explicit lane copying. When this new precisionChange is called without a workspace, one will be computed on-the-fly; however it is still considerably faster than the original implementation. In the special case of using double2 and when the Grids are the same, calls to the new precisionChange will automatically use precisionChangeFast, such that there is a single API call for all precision changes. Reliable update and mixed-prec multishift have been modified to precompute precision change workspaces Renamed the original precisionChange as precisionChangeOrig Fixed incorrect pointer offset bug in copyLane Added a test and a benchmark for precisionChange Added a test for reliable update CG
278 lines
10 KiB
C++
278 lines
10 KiB
C++
/*************************************************************************************
|
|
|
|
Grid physics library, www.github.com/paboyle/Grid
|
|
|
|
Source file: ./lib/algorithms/iterative/ConjugateGradientReliableUpdate.h
|
|
|
|
Copyright (C) 2015
|
|
|
|
Author: Christopher Kelly <ckelly@phys.columbia.edu>
|
|
|
|
This program is free software; you can redistribute it and/or modify
|
|
it under the terms of the GNU General Public License as published by
|
|
the Free Software Foundation; either version 2 of the License, or
|
|
(at your option) any later version.
|
|
|
|
This program is distributed in the hope that it will be useful,
|
|
but WITHOUT ANY WARRANTY; without even the implied warranty of
|
|
MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
|
|
GNU General Public License for more details.
|
|
|
|
You should have received a copy of the GNU General Public License along
|
|
with this program; if not, write to the Free Software Foundation, Inc.,
|
|
51 Franklin Street, Fifth Floor, Boston, MA 02110-1301 USA.
|
|
|
|
See the full license in the file "LICENSE" in the top level distribution directory
|
|
*************************************************************************************/
|
|
/* END LEGAL */
|
|
#ifndef GRID_CONJUGATE_GRADIENT_RELIABLE_UPDATE_H
|
|
#define GRID_CONJUGATE_GRADIENT_RELIABLE_UPDATE_H
|
|
|
|
NAMESPACE_BEGIN(Grid);
|
|
|
|
template<class FieldD,class FieldF,
|
|
typename std::enable_if< getPrecision<FieldD>::value == 2, int>::type = 0,
|
|
typename std::enable_if< getPrecision<FieldF>::value == 1, int>::type = 0>
|
|
class ConjugateGradientReliableUpdate : public LinearFunction<FieldD> {
|
|
public:
|
|
bool ErrorOnNoConverge; // throw an assert when the CG fails to converge.
|
|
// Defaults true.
|
|
RealD Tolerance;
|
|
Integer MaxIterations;
|
|
Integer IterationsToComplete; //Number of iterations the CG took to finish. Filled in upon completion
|
|
Integer ReliableUpdatesPerformed;
|
|
|
|
bool DoFinalCleanup; //Final DP cleanup, defaults to true
|
|
Integer IterationsToCleanup; //Final DP cleanup step iterations
|
|
|
|
LinearOperatorBase<FieldF> &Linop_f;
|
|
LinearOperatorBase<FieldD> &Linop_d;
|
|
GridBase* SinglePrecGrid;
|
|
RealD Delta; //reliable update parameter. A reliable update is performed when the residual drops by a factor of Delta relative to its value at the last update
|
|
|
|
//Optional ability to switch to a different linear operator once the tolerance reaches a certain point. Useful for single/half -> single/single
|
|
LinearOperatorBase<FieldF> *Linop_fallback;
|
|
RealD fallback_transition_tol;
|
|
|
|
|
|
ConjugateGradientReliableUpdate(RealD tol, Integer maxit, RealD _delta, GridBase* _sp_grid, LinearOperatorBase<FieldF> &_Linop_f, LinearOperatorBase<FieldD> &_Linop_d, bool err_on_no_conv = true)
|
|
: Tolerance(tol),
|
|
MaxIterations(maxit),
|
|
Delta(_delta),
|
|
Linop_f(_Linop_f),
|
|
Linop_d(_Linop_d),
|
|
SinglePrecGrid(_sp_grid),
|
|
ErrorOnNoConverge(err_on_no_conv),
|
|
DoFinalCleanup(true),
|
|
Linop_fallback(NULL)
|
|
{
|
|
assert(Delta > 0. && Delta < 1. && "Expect 0 < Delta < 1");
|
|
};
|
|
|
|
void setFallbackLinop(LinearOperatorBase<FieldF> &_Linop_fallback, const RealD _fallback_transition_tol){
|
|
Linop_fallback = &_Linop_fallback;
|
|
fallback_transition_tol = _fallback_transition_tol;
|
|
}
|
|
|
|
void operator()(const FieldD &src, FieldD &psi) {
|
|
GRID_TRACE("ConjugateGradientReliableUpdate");
|
|
LinearOperatorBase<FieldF> *Linop_f_use = &Linop_f;
|
|
bool using_fallback = false;
|
|
|
|
psi.Checkerboard() = src.Checkerboard();
|
|
conformable(psi, src);
|
|
|
|
RealD cp, c, a, d, b, ssq, qq, b_pred;
|
|
|
|
FieldD p(src);
|
|
FieldD mmp(src);
|
|
FieldD r(src);
|
|
|
|
// Initial residual computation & set up
|
|
RealD guess = norm2(psi);
|
|
assert(std::isnan(guess) == 0);
|
|
|
|
Linop_d.HermOpAndNorm(psi, mmp, d, b);
|
|
|
|
r = src - mmp;
|
|
p = r;
|
|
|
|
a = norm2(p);
|
|
cp = a;
|
|
ssq = norm2(src);
|
|
|
|
std::cout << GridLogIterative << std::setprecision(4) << "ConjugateGradientReliableUpdate: guess " << guess << std::endl;
|
|
std::cout << GridLogIterative << std::setprecision(4) << "ConjugateGradientReliableUpdate: src " << ssq << std::endl;
|
|
std::cout << GridLogIterative << std::setprecision(4) << "ConjugateGradientReliableUpdate: mp " << d << std::endl;
|
|
std::cout << GridLogIterative << std::setprecision(4) << "ConjugateGradientReliableUpdate: mmp " << b << std::endl;
|
|
std::cout << GridLogIterative << std::setprecision(4) << "ConjugateGradientReliableUpdate: cp,r " << cp << std::endl;
|
|
std::cout << GridLogIterative << std::setprecision(4) << "ConjugateGradientReliableUpdate: p " << a << std::endl;
|
|
|
|
RealD rsq = Tolerance * Tolerance * ssq;
|
|
|
|
// Check if guess is really REALLY good :)
|
|
if (cp <= rsq) {
|
|
std::cout << GridLogMessage << "ConjugateGradientReliableUpdate guess was REALLY good\n";
|
|
std::cout << GridLogMessage << "\tComputed residual " << std::sqrt(cp / ssq)<<std::endl;
|
|
return;
|
|
}
|
|
|
|
//Single prec initialization
|
|
precisionChangeWorkspace pc_wk_sp_to_dp(src.Grid(), SinglePrecGrid);
|
|
precisionChangeWorkspace pc_wk_dp_to_sp(SinglePrecGrid, src.Grid());
|
|
|
|
FieldF r_f(SinglePrecGrid);
|
|
r_f.Checkerboard() = r.Checkerboard();
|
|
precisionChange(r_f, r, pc_wk_dp_to_sp);
|
|
|
|
FieldF psi_f(r_f);
|
|
psi_f = Zero();
|
|
|
|
FieldF p_f(r_f);
|
|
FieldF mmp_f(r_f);
|
|
|
|
RealD MaxResidSinceLastRelUp = cp; //initial residual
|
|
|
|
std::cout << GridLogIterative << std::setprecision(4)
|
|
<< "ConjugateGradient: k=0 residual " << cp << " target " << rsq << std::endl;
|
|
|
|
GridStopWatch LinalgTimer;
|
|
GridStopWatch MatrixTimer;
|
|
GridStopWatch SolverTimer;
|
|
GridStopWatch PrecChangeTimer;
|
|
|
|
SolverTimer.Start();
|
|
int k = 0;
|
|
int l = 0;
|
|
|
|
for (k = 1; k <= MaxIterations; k++) {
|
|
c = cp;
|
|
|
|
MatrixTimer.Start();
|
|
Linop_f_use->HermOpAndNorm(p_f, mmp_f, d, qq);
|
|
MatrixTimer.Stop();
|
|
|
|
LinalgTimer.Start();
|
|
|
|
a = c / d;
|
|
b_pred = a * (a * qq - d) / c;
|
|
|
|
cp = axpy_norm(r_f, -a, mmp_f, r_f);
|
|
b = cp / c;
|
|
|
|
// Fuse these loops ; should be really easy
|
|
psi_f = a * p_f + psi_f;
|
|
//p_f = p_f * b + r_f;
|
|
|
|
LinalgTimer.Stop();
|
|
|
|
std::cout << GridLogIterative << "ConjugateGradientReliableUpdate: Iteration " << k
|
|
<< " residual " << cp << " target " << rsq << std::endl;
|
|
std::cout << GridLogDebug << "a = "<< a << " b_pred = "<< b_pred << " b = "<< b << std::endl;
|
|
std::cout << GridLogDebug << "qq = "<< qq << " d = "<< d << " c = "<< c << std::endl;
|
|
|
|
if(cp > MaxResidSinceLastRelUp){
|
|
std::cout << GridLogIterative << "ConjugateGradientReliableUpdate: updating MaxResidSinceLastRelUp : " << MaxResidSinceLastRelUp << " -> " << cp << std::endl;
|
|
MaxResidSinceLastRelUp = cp;
|
|
}
|
|
|
|
// Stopping condition
|
|
if (cp <= rsq) {
|
|
//Although not written in the paper, I assume that I have to add on the final solution
|
|
PrecChangeTimer.Start();
|
|
precisionChange(mmp, psi_f, pc_wk_sp_to_dp);
|
|
PrecChangeTimer.Stop();
|
|
psi = psi + mmp;
|
|
|
|
|
|
SolverTimer.Stop();
|
|
Linop_d.HermOpAndNorm(psi, mmp, d, qq);
|
|
p = mmp - src;
|
|
|
|
RealD srcnorm = std::sqrt(norm2(src));
|
|
RealD resnorm = std::sqrt(norm2(p));
|
|
RealD true_residual = resnorm / srcnorm;
|
|
|
|
std::cout << GridLogMessage << "ConjugateGradientReliableUpdate Converged on iteration " << k << " after " << l << " reliable updates" << std::endl;
|
|
std::cout << GridLogMessage << "\tComputed residual " << std::sqrt(cp / ssq)<<std::endl;
|
|
std::cout << GridLogMessage << "\tTrue residual " << true_residual<<std::endl;
|
|
std::cout << GridLogMessage << "\tTarget " << Tolerance << std::endl;
|
|
|
|
std::cout << GridLogMessage << "Time breakdown "<<std::endl;
|
|
std::cout << GridLogMessage << "\tElapsed " << SolverTimer.Elapsed() <<std::endl;
|
|
std::cout << GridLogMessage << "\tMatrix " << MatrixTimer.Elapsed() <<std::endl;
|
|
std::cout << GridLogMessage << "\tLinalg " << LinalgTimer.Elapsed() <<std::endl;
|
|
std::cout << GridLogMessage << "\tPrecChange " << PrecChangeTimer.Elapsed() <<std::endl;
|
|
std::cout << GridLogMessage << "\tPrecChange avg time " << PrecChangeTimer.Elapsed()/(2*l+1) <<std::endl;
|
|
|
|
|
|
IterationsToComplete = k;
|
|
ReliableUpdatesPerformed = l;
|
|
|
|
if(DoFinalCleanup){
|
|
//Do a final CG to cleanup
|
|
std::cout << GridLogMessage << "ConjugateGradientReliableUpdate performing final cleanup.\n";
|
|
ConjugateGradient<FieldD> CG(Tolerance,MaxIterations);
|
|
CG.ErrorOnNoConverge = ErrorOnNoConverge;
|
|
CG(Linop_d,src,psi);
|
|
IterationsToCleanup = CG.IterationsToComplete;
|
|
}
|
|
else if (ErrorOnNoConverge) assert(true_residual / Tolerance < 10000.0);
|
|
|
|
std::cout << GridLogMessage << "ConjugateGradientReliableUpdate complete.\n";
|
|
return;
|
|
}
|
|
else if(cp < Delta * MaxResidSinceLastRelUp) { //reliable update
|
|
std::cout << GridLogMessage << "ConjugateGradientReliableUpdate "
|
|
<< cp << "(residual) < " << Delta << "(Delta) * " << MaxResidSinceLastRelUp << "(MaxResidSinceLastRelUp) on iteration " << k << " : performing reliable update\n";
|
|
PrecChangeTimer.Start();
|
|
precisionChange(mmp, psi_f, pc_wk_sp_to_dp);
|
|
PrecChangeTimer.Stop();
|
|
psi = psi + mmp;
|
|
|
|
MatrixTimer.Start();
|
|
Linop_d.HermOpAndNorm(psi, mmp, d, qq);
|
|
MatrixTimer.Stop();
|
|
|
|
r = src - mmp;
|
|
|
|
psi_f = Zero();
|
|
PrecChangeTimer.Start();
|
|
precisionChange(r_f, r, pc_wk_dp_to_sp);
|
|
PrecChangeTimer.Stop();
|
|
cp = norm2(r);
|
|
MaxResidSinceLastRelUp = cp;
|
|
|
|
b = cp/c;
|
|
|
|
std::cout << GridLogMessage << "ConjugateGradientReliableUpdate new residual " << cp << std::endl;
|
|
|
|
l = l+1;
|
|
}
|
|
|
|
p_f = p_f * b + r_f; //update search vector after reliable update appears to help convergence
|
|
|
|
if(!using_fallback && Linop_fallback != NULL && cp < fallback_transition_tol){
|
|
std::cout << GridLogMessage << "ConjugateGradientReliableUpdate switching to fallback linear operator on iteration " << k << " at residual " << cp << std::endl;
|
|
Linop_f_use = Linop_fallback;
|
|
using_fallback = true;
|
|
}
|
|
|
|
|
|
}
|
|
std::cout << GridLogMessage << "ConjugateGradientReliableUpdate did NOT converge"
|
|
<< std::endl;
|
|
|
|
if (ErrorOnNoConverge) assert(0);
|
|
IterationsToComplete = k;
|
|
ReliableUpdatesPerformed = l;
|
|
}
|
|
};
|
|
|
|
|
|
NAMESPACE_END(Grid);
|
|
|
|
|
|
|
|
#endif
|