mirror of
https://github.com/paboyle/Grid.git
synced 2025-04-07 04:35:56 +01:00
Offload the linear combinations in CG
This commit is contained in:
parent
9efcc535bc
commit
ab063f33c0
@ -391,12 +391,12 @@ template<class Matrix,class Field> using SchurStagOperator = SchurStaggeredOpera
|
||||
template<class Field> class OperatorFunction {
|
||||
public:
|
||||
virtual void operator() (LinearOperatorBase<Field> &Linop, const Field &in, Field &out) = 0;
|
||||
virtual void operator() (LinearOperatorBase<Field> &Linop, const std::vector<Field> &in,std::vector<Field> &out) {
|
||||
assert(in.size()==out.size());
|
||||
for(int k=0;k<in.size();k++){
|
||||
(*this)(Linop,in[k],out[k]);
|
||||
}
|
||||
};
|
||||
virtual void operator() (LinearOperatorBase<Field> &Linop, const std::vector<Field> &in,std::vector<Field> &out) {
|
||||
assert(in.size()==out.size());
|
||||
for(int k=0;k<in.size();k++){
|
||||
(*this)(Linop,in[k],out[k]);
|
||||
}
|
||||
};
|
||||
};
|
||||
|
||||
template<class Field> class LinearFunction {
|
||||
|
@ -47,6 +47,8 @@ struct ChebyParams : Serializable {
|
||||
template<class Field>
|
||||
class Chebyshev : public OperatorFunction<Field> {
|
||||
private:
|
||||
using OperatorFunction<Field>::operator();
|
||||
|
||||
std::vector<RealD> Coeffs;
|
||||
int order;
|
||||
RealD hi;
|
||||
|
@ -41,6 +41,9 @@ NAMESPACE_BEGIN(Grid);
|
||||
template <class Field>
|
||||
class ConjugateGradient : public OperatorFunction<Field> {
|
||||
public:
|
||||
|
||||
using OperatorFunction<Field>::operator();
|
||||
|
||||
bool ErrorOnNoConverge; // throw an assert when the CG fails to converge.
|
||||
// Defaults true.
|
||||
RealD Tolerance;
|
||||
@ -58,7 +61,8 @@ public:
|
||||
|
||||
conformable(psi, src);
|
||||
|
||||
RealD cp, c, a, d, b, ssq, qq, b_pred;
|
||||
RealD cp, c, a, d, b, ssq, qq;
|
||||
//RealD b_pred;
|
||||
|
||||
Field p(src);
|
||||
Field mmp(src);
|
||||
@ -128,10 +132,10 @@ public:
|
||||
auto psi_v = psi.View();
|
||||
auto p_v = p.View();
|
||||
auto r_v = r.View();
|
||||
parallel_for(int ss=0;ss<src.Grid()->oSites();ss++){
|
||||
accelerator_loop(ss,p_v,{
|
||||
vstream(psi_v[ss], a * p_v[ss] + psi_v[ss]);
|
||||
vstream(p_v [ss], b * p_v[ss] + r_v[ss]);
|
||||
}
|
||||
});
|
||||
LinearCombTimer.Stop();
|
||||
LinalgTimer.Stop();
|
||||
|
||||
|
Loading…
x
Reference in New Issue
Block a user