1
0
mirror of https://github.com/paboyle/Grid.git synced 2025-04-25 13:15:55 +01:00

Small simplification of FermionOperatorImpl towards GPU but not there yet

This commit is contained in:
paboyle 2018-02-01 22:41:54 +00:00
parent 79b50feacf
commit 8ae77d3706

View File

@ -203,7 +203,7 @@ public:
bool overlapCommsCompute(void) { return Params.overlapCommsCompute; }; bool overlapCommsCompute(void) { return Params.overlapCommsCompute; };
inline void multLink(SiteHalfSpinor &phi, accelerator_inline void multLink(SiteHalfSpinor &phi,
const SiteDoubledGaugeField &U, const SiteDoubledGaugeField &U,
const SiteHalfSpinor &chi, const SiteHalfSpinor &chi,
int mu, int mu,
@ -212,7 +212,7 @@ public:
mult(&phi(), &U(mu), &chi()); mult(&phi(), &U(mu), &chi());
} }
inline void multLinkProp(SitePropagator &phi, accelerator_inline void multLinkProp(SitePropagator &phi,
const SiteDoubledGaugeField &U, const SiteDoubledGaugeField &U,
const SitePropagator &chi, const SitePropagator &chi,
int mu) { int mu) {
@ -220,7 +220,7 @@ public:
} }
template <class ref> template <class ref>
inline void loadLinkElement(Simd &reg, ref &memory) { accelerator_inline void loadLinkElement(Simd &reg, ref &memory) {
reg = memory; reg = memory;
} }
@ -331,11 +331,11 @@ public:
bool overlapCommsCompute(void) { return false; }; bool overlapCommsCompute(void) { return false; };
template <class ref> template <class ref>
inline void loadLinkElement(Simd &reg, ref &memory) { accelerator_inline void loadLinkElement(Simd &reg, ref &memory) {
vsplat(reg, memory); vsplat(reg, memory);
} }
inline void multLink(SiteHalfSpinor &phi, const SiteDoubledGaugeField &U, accelerator_inline void multLink(SiteHalfSpinor &phi, const SiteDoubledGaugeField &U,
const SiteHalfSpinor &chi, int mu, StencilEntry *SE, const SiteHalfSpinor &chi, int mu, StencilEntry *SE,
StencilImpl &St) { StencilImpl &St) {
SiteGaugeLink UU; SiteGaugeLink UU;
@ -347,7 +347,7 @@ public:
mult(&phi(), &UU(), &chi()); mult(&phi(), &UU(), &chi());
} }
inline void multLinkProp(SitePropagator &phi, accelerator_inline void multLinkProp(SitePropagator &phi,
const SiteDoubledGaugeField &U, const SiteDoubledGaugeField &U,
const SitePropagator &chi, const SitePropagator &chi,
int mu) { int mu) {
@ -486,24 +486,26 @@ public:
// provide the multiply by link that is differentiated between Gparity (with // provide the multiply by link that is differentiated between Gparity (with
// flavour index) and non-Gparity // flavour index) and non-Gparity
inline void multLink(SiteHalfSpinor &phi, const SiteDoubledGaugeField &U, accelerator_inline void multLink(SiteHalfSpinor &phi, const SiteDoubledGaugeField &U,
const SiteHalfSpinor &chi, int mu, StencilEntry *SE, const SiteHalfSpinor &chi, int mu, StencilEntry *SE,
StencilImpl &St) { StencilImpl &St) {
typedef SiteHalfSpinor vobj; typedef SiteHalfSpinor vobj;
typedef typename SiteHalfSpinor::scalar_object sobj; typedef typename SiteHalfSpinor::scalar_object sobj;
typedef typename SiteHalfSpinor::vector_type vector_type;
vobj vtmp; vobj vtmp;
sobj stmp; sobj stmp;
const int Nsimd =vector_type::Nsimd();
// const int Nsimd = grid->Nsimd();
GridBase *grid= St.Grid(); GridBase *grid= St.Grid();
const int Nsimd = grid->Nsimd();
int direction = St._directions[mu]; int direction = St._directions[mu];
int distance = St._distances[mu]; int distance = St._distances[mu];
int ptype = St._permute_type[mu]; int ptype = St._permute_type[mu];
int sl = St.Grid()->_simd_layout[direction]; int sl = grid->_simd_layout[direction];
// Fixme X.Y.Z.T hardcode in stencil // Fixme X.Y.Z.T hardcode in stencil
int mmu = mu % Nd; int mmu = mu % Nd;
@ -556,14 +558,14 @@ public:
} }
// Fixme: Gparity prop * link // Fixme: Gparity prop * link
inline void multLinkProp(SitePropagator &phi, const SiteDoubledGaugeField &U, accelerator_inline void multLinkProp(SitePropagator &phi, const SiteDoubledGaugeField &U,
const SitePropagator &chi, int mu) const SitePropagator &chi, int mu)
{ {
assert(0); assert(0);
} }
template <class ref> template <class ref>
inline void loadLinkElement(Simd &reg, ref &memory) { accelerator_inline void loadLinkElement(Simd &reg, ref &memory) {
reg = memory; reg = memory;
} }
@ -695,13 +697,13 @@ public:
StaggeredImpl(const ImplParams &p = ImplParams()) : Params(p){}; StaggeredImpl(const ImplParams &p = ImplParams()) : Params(p){};
inline void multLink(SiteSpinor &phi, accelerator_inline void multLink(SiteSpinor &phi,
const SiteDoubledGaugeField &U, const SiteDoubledGaugeField &U,
const SiteSpinor &chi, const SiteSpinor &chi,
int mu){ int mu){
mult(&phi(), &U(mu), &chi()); mult(&phi(), &U(mu), &chi());
} }
inline void multLinkAdd(SiteSpinor &phi, accelerator_inline void multLinkAdd(SiteSpinor &phi,
const SiteDoubledGaugeField &U, const SiteDoubledGaugeField &U,
const SiteSpinor &chi, const SiteSpinor &chi,
int mu){ int mu){
@ -709,7 +711,7 @@ public:
} }
template <class ref> template <class ref>
inline void loadLinkElement(Simd &reg, ref &memory) { accelerator_inline void loadLinkElement(Simd &reg, ref &memory) {
reg = memory; reg = memory;
} }
@ -832,11 +834,11 @@ public:
StaggeredVec5dImpl(const ImplParams &p = ImplParams()) : Params(p){}; StaggeredVec5dImpl(const ImplParams &p = ImplParams()) : Params(p){};
template <class ref> template <class ref>
inline void loadLinkElement(Simd &reg, ref &memory) { accelerator_inline void loadLinkElement(Simd &reg, ref &memory) {
vsplat(reg, memory); vsplat(reg, memory);
} }
inline void multLink(SiteHalfSpinor &phi, const SiteDoubledGaugeField &U, accelerator_inline void multLink(SiteHalfSpinor &phi, const SiteDoubledGaugeField &U,
const SiteHalfSpinor &chi, int mu) { const SiteHalfSpinor &chi, int mu) {
SiteGaugeLink UU; SiteGaugeLink UU;
for (int i = 0; i < Dimension; i++) { for (int i = 0; i < Dimension; i++) {
@ -846,7 +848,7 @@ public:
} }
mult(&phi(), &UU(), &chi()); mult(&phi(), &UU(), &chi());
} }
inline void multLinkAdd(SiteHalfSpinor &phi, const SiteDoubledGaugeField &U, accelerator_inline void multLinkAdd(SiteHalfSpinor &phi, const SiteDoubledGaugeField &U,
const SiteHalfSpinor &chi, int mu) { const SiteHalfSpinor &chi, int mu) {
SiteGaugeLink UU; SiteGaugeLink UU;
for (int i = 0; i < Dimension; i++) { for (int i = 0; i < Dimension; i++) {