mirror of
https://github.com/paboyle/Grid.git
synced 2025-04-11 14:40:46 +01:00
Offload the linear combinations in CG
This commit is contained in:
parent
9efcc535bc
commit
ab063f33c0
@ -391,12 +391,12 @@ template<class Matrix,class Field> using SchurStagOperator = SchurStaggeredOpera
|
|||||||
template<class Field> class OperatorFunction {
|
template<class Field> class OperatorFunction {
|
||||||
public:
|
public:
|
||||||
virtual void operator() (LinearOperatorBase<Field> &Linop, const Field &in, Field &out) = 0;
|
virtual void operator() (LinearOperatorBase<Field> &Linop, const Field &in, Field &out) = 0;
|
||||||
virtual void operator() (LinearOperatorBase<Field> &Linop, const std::vector<Field> &in,std::vector<Field> &out) {
|
virtual void operator() (LinearOperatorBase<Field> &Linop, const std::vector<Field> &in,std::vector<Field> &out) {
|
||||||
assert(in.size()==out.size());
|
assert(in.size()==out.size());
|
||||||
for(int k=0;k<in.size();k++){
|
for(int k=0;k<in.size();k++){
|
||||||
(*this)(Linop,in[k],out[k]);
|
(*this)(Linop,in[k],out[k]);
|
||||||
}
|
}
|
||||||
};
|
};
|
||||||
};
|
};
|
||||||
|
|
||||||
template<class Field> class LinearFunction {
|
template<class Field> class LinearFunction {
|
||||||
|
@ -47,6 +47,8 @@ struct ChebyParams : Serializable {
|
|||||||
template<class Field>
|
template<class Field>
|
||||||
class Chebyshev : public OperatorFunction<Field> {
|
class Chebyshev : public OperatorFunction<Field> {
|
||||||
private:
|
private:
|
||||||
|
using OperatorFunction<Field>::operator();
|
||||||
|
|
||||||
std::vector<RealD> Coeffs;
|
std::vector<RealD> Coeffs;
|
||||||
int order;
|
int order;
|
||||||
RealD hi;
|
RealD hi;
|
||||||
|
@ -41,6 +41,9 @@ NAMESPACE_BEGIN(Grid);
|
|||||||
template <class Field>
|
template <class Field>
|
||||||
class ConjugateGradient : public OperatorFunction<Field> {
|
class ConjugateGradient : public OperatorFunction<Field> {
|
||||||
public:
|
public:
|
||||||
|
|
||||||
|
using OperatorFunction<Field>::operator();
|
||||||
|
|
||||||
bool ErrorOnNoConverge; // throw an assert when the CG fails to converge.
|
bool ErrorOnNoConverge; // throw an assert when the CG fails to converge.
|
||||||
// Defaults true.
|
// Defaults true.
|
||||||
RealD Tolerance;
|
RealD Tolerance;
|
||||||
@ -58,7 +61,8 @@ public:
|
|||||||
|
|
||||||
conformable(psi, src);
|
conformable(psi, src);
|
||||||
|
|
||||||
RealD cp, c, a, d, b, ssq, qq, b_pred;
|
RealD cp, c, a, d, b, ssq, qq;
|
||||||
|
//RealD b_pred;
|
||||||
|
|
||||||
Field p(src);
|
Field p(src);
|
||||||
Field mmp(src);
|
Field mmp(src);
|
||||||
@ -128,10 +132,10 @@ public:
|
|||||||
auto psi_v = psi.View();
|
auto psi_v = psi.View();
|
||||||
auto p_v = p.View();
|
auto p_v = p.View();
|
||||||
auto r_v = r.View();
|
auto r_v = r.View();
|
||||||
parallel_for(int ss=0;ss<src.Grid()->oSites();ss++){
|
accelerator_loop(ss,p_v,{
|
||||||
vstream(psi_v[ss], a * p_v[ss] + psi_v[ss]);
|
vstream(psi_v[ss], a * p_v[ss] + psi_v[ss]);
|
||||||
vstream(p_v [ss], b * p_v[ss] + r_v[ss]);
|
vstream(p_v [ss], b * p_v[ss] + r_v[ss]);
|
||||||
}
|
});
|
||||||
LinearCombTimer.Stop();
|
LinearCombTimer.Stop();
|
||||||
LinalgTimer.Stop();
|
LinalgTimer.Stop();
|
||||||
|
|
||||||
|
Loading…
x
Reference in New Issue
Block a user