1
0
mirror of https://github.com/paboyle/Grid.git synced 2025-04-11 14:40:46 +01:00

Offload the linear combinations in CG

This commit is contained in:
Peter Boyle 2019-01-01 13:42:13 +00:00
parent 9efcc535bc
commit ab063f33c0
3 changed files with 15 additions and 9 deletions

View File

@ -391,12 +391,12 @@ template<class Matrix,class Field> using SchurStagOperator = SchurStaggeredOpera
template<class Field> class OperatorFunction { template<class Field> class OperatorFunction {
public: public:
virtual void operator() (LinearOperatorBase<Field> &Linop, const Field &in, Field &out) = 0; virtual void operator() (LinearOperatorBase<Field> &Linop, const Field &in, Field &out) = 0;
virtual void operator() (LinearOperatorBase<Field> &Linop, const std::vector<Field> &in,std::vector<Field> &out) { virtual void operator() (LinearOperatorBase<Field> &Linop, const std::vector<Field> &in,std::vector<Field> &out) {
assert(in.size()==out.size()); assert(in.size()==out.size());
for(int k=0;k<in.size();k++){ for(int k=0;k<in.size();k++){
(*this)(Linop,in[k],out[k]); (*this)(Linop,in[k],out[k]);
} }
}; };
}; };
template<class Field> class LinearFunction { template<class Field> class LinearFunction {

View File

@ -47,6 +47,8 @@ struct ChebyParams : Serializable {
template<class Field> template<class Field>
class Chebyshev : public OperatorFunction<Field> { class Chebyshev : public OperatorFunction<Field> {
private: private:
using OperatorFunction<Field>::operator();
std::vector<RealD> Coeffs; std::vector<RealD> Coeffs;
int order; int order;
RealD hi; RealD hi;

View File

@ -41,6 +41,9 @@ NAMESPACE_BEGIN(Grid);
template <class Field> template <class Field>
class ConjugateGradient : public OperatorFunction<Field> { class ConjugateGradient : public OperatorFunction<Field> {
public: public:
using OperatorFunction<Field>::operator();
bool ErrorOnNoConverge; // throw an assert when the CG fails to converge. bool ErrorOnNoConverge; // throw an assert when the CG fails to converge.
// Defaults true. // Defaults true.
RealD Tolerance; RealD Tolerance;
@ -58,7 +61,8 @@ public:
conformable(psi, src); conformable(psi, src);
RealD cp, c, a, d, b, ssq, qq, b_pred; RealD cp, c, a, d, b, ssq, qq;
//RealD b_pred;
Field p(src); Field p(src);
Field mmp(src); Field mmp(src);
@ -128,10 +132,10 @@ public:
auto psi_v = psi.View(); auto psi_v = psi.View();
auto p_v = p.View(); auto p_v = p.View();
auto r_v = r.View(); auto r_v = r.View();
parallel_for(int ss=0;ss<src.Grid()->oSites();ss++){ accelerator_loop(ss,p_v,{
vstream(psi_v[ss], a * p_v[ss] + psi_v[ss]); vstream(psi_v[ss], a * p_v[ss] + psi_v[ss]);
vstream(p_v [ss], b * p_v[ss] + r_v[ss]); vstream(p_v [ss], b * p_v[ss] + r_v[ss]);
} });
LinearCombTimer.Stop(); LinearCombTimer.Stop();
LinalgTimer.Stop(); LinalgTimer.Stop();