1
0
mirror of https://github.com/paboyle/Grid.git synced 2025-04-04 19:25:56 +01:00

Have to make all kernel called routines static since object reference will be a host pointer on GPU

This commit is contained in:
Peter Boyle 2018-03-24 19:29:26 -04:00
parent b50f37cfb4
commit 1f70cedbab

View File

@ -140,6 +140,7 @@ public:
//////////////////////////////////////////////////////////////////////// ////////////////////////////////////////////////////////////////////////
#define INHERIT_FIMPL_TYPES(Impl)\ #define INHERIT_FIMPL_TYPES(Impl)\
typedef typename Impl::Coeff_t Coeff_t; \
typedef typename Impl::FermionField FermionField; \ typedef typename Impl::FermionField FermionField; \
typedef typename Impl::PropagatorField PropagatorField; \ typedef typename Impl::PropagatorField PropagatorField; \
typedef typename Impl::DoubledGaugeField DoubledGaugeField; \ typedef typename Impl::DoubledGaugeField DoubledGaugeField; \
@ -149,7 +150,9 @@ public:
typedef typename Impl::Compressor Compressor; \ typedef typename Impl::Compressor Compressor; \
typedef typename Impl::StencilImpl StencilImpl; \ typedef typename Impl::StencilImpl StencilImpl; \
typedef typename Impl::ImplParams ImplParams; \ typedef typename Impl::ImplParams ImplParams; \
typedef typename Impl::Coeff_t Coeff_t; typedef typename Impl::StencilImpl::View_type StencilView; \
typedef typename ViewMap<FermionField>::Type FermionFieldView; \
typedef typename ViewMap<DoubledGaugeField>::Type DoubledGaugeFieldView;
#define INHERIT_IMPL_TYPES(Base) \ #define INHERIT_IMPL_TYPES(Base) \
INHERIT_GIMPL_TYPES(Base) \ INHERIT_GIMPL_TYPES(Base) \
@ -194,33 +197,34 @@ public:
typedef WilsonCompressor<SiteHalfCommSpinor,SiteHalfSpinor, SiteSpinor> Compressor; typedef WilsonCompressor<SiteHalfCommSpinor,SiteHalfSpinor, SiteSpinor> Compressor;
typedef WilsonImplParams ImplParams; typedef WilsonImplParams ImplParams;
typedef WilsonStencil<SiteSpinor, SiteHalfSpinor> StencilImpl; typedef WilsonStencil<SiteSpinor, SiteHalfSpinor> StencilImpl;
typedef typename StencilImpl::View_type StencilView;
ImplParams Params; ImplParams Params;
WilsonImpl(const ImplParams &p = ImplParams()) : Params(p){ WilsonImpl(const ImplParams &p = ImplParams()) : Params(p){
assert(Params.boundary_phases.size() == Nd); assert(Params.boundary_phases.size() == Nd);
}; };
bool overlapCommsCompute(void) { return Params.overlapCommsCompute; }; static accelerator_inline void multLink(SiteHalfSpinor &phi,
const SiteDoubledGaugeField &U,
accelerator_inline void multLink(SiteHalfSpinor &phi, const SiteHalfSpinor &chi,
const SiteDoubledGaugeField &U, int mu,
const SiteHalfSpinor &chi, StencilEntry *SE,
int mu, StencilView &St)
StencilEntry *SE, {
typename StencilImpl::View_type &St) {
mult(&phi(), &U(mu), &chi()); mult(&phi(), &U(mu), &chi());
} }
accelerator_inline void multLinkProp(SitePropagator &phi, static accelerator_inline void multLinkProp(SitePropagator &phi,
const SiteDoubledGaugeField &U, const SiteDoubledGaugeField &U,
const SitePropagator &chi, const SitePropagator &chi,
int mu) { int mu)
{
mult(&phi(), &U(mu), &chi()); mult(&phi(), &U(mu), &chi());
} }
template <class ref> template <class ref>
accelerator_inline void loadLinkElement(Simd &reg, ref &memory) { static accelerator_inline void loadLinkElement(Simd &reg, ref &memory)
{
reg = memory; reg = memory;
} }
@ -325,21 +329,22 @@ public:
typedef WilsonCompressor<SiteHalfCommSpinor,SiteHalfSpinor, SiteSpinor> Compressor; typedef WilsonCompressor<SiteHalfCommSpinor,SiteHalfSpinor, SiteSpinor> Compressor;
typedef WilsonImplParams ImplParams; typedef WilsonImplParams ImplParams;
typedef WilsonStencil<SiteSpinor, SiteHalfSpinor> StencilImpl; typedef WilsonStencil<SiteSpinor, SiteHalfSpinor> StencilImpl;
typedef typename StencilImpl::View_type StencilView;
ImplParams Params; ImplParams Params;
DomainWallVec5dImpl(const ImplParams &p = ImplParams()) : Params(p){}; DomainWallVec5dImpl(const ImplParams &p = ImplParams()) : Params(p){};
bool overlapCommsCompute(void) { return false; };
template <class ref> template <class ref>
accelerator_inline void loadLinkElement(Simd &reg, ref &memory) { static accelerator_inline void loadLinkElement(Simd &reg, ref &memory)
{
vsplat(reg, memory); vsplat(reg, memory);
} }
accelerator_inline void multLink(SiteHalfSpinor &phi, const SiteDoubledGaugeField &U, static accelerator_inline void multLink(SiteHalfSpinor &phi, const SiteDoubledGaugeField &U,
const SiteHalfSpinor &chi, int mu, StencilEntry *SE, const SiteHalfSpinor &chi, int mu, StencilEntry *SE,
typename StencilImpl::View_type &St) { StencilView &St)
{
SiteGaugeLink UU; SiteGaugeLink UU;
for (int i = 0; i < Nrepresentation; i++) { for (int i = 0; i < Nrepresentation; i++) {
for (int j = 0; j < Nrepresentation; j++) { for (int j = 0; j < Nrepresentation; j++) {
@ -349,10 +354,10 @@ public:
mult(&phi(), &UU(), &chi()); mult(&phi(), &UU(), &chi());
} }
accelerator_inline void multLinkProp(SitePropagator &phi, static accelerator_inline void multLinkProp(SitePropagator &phi,
const SiteDoubledGaugeField &U, const SiteDoubledGaugeField &U,
const SitePropagator &chi, const SitePropagator &chi,int mu)
int mu) { {
SiteGaugeLink UU; SiteGaugeLink UU;
for (int i = 0; i < Nrepresentation; i++) { for (int i = 0; i < Nrepresentation; i++) {
for (int j = 0; j < Nrepresentation; j++) { for (int j = 0; j < Nrepresentation; j++) {
@ -477,6 +482,7 @@ public:
typedef WilsonCompressor<SiteHalfCommSpinor,SiteHalfSpinor, SiteSpinor> Compressor; typedef WilsonCompressor<SiteHalfCommSpinor,SiteHalfSpinor, SiteSpinor> Compressor;
typedef WilsonStencil<SiteSpinor, SiteHalfSpinor> StencilImpl; typedef WilsonStencil<SiteSpinor, SiteHalfSpinor> StencilImpl;
typedef typename StencilImpl::View_type StencilView;
typedef GparityWilsonImplParams ImplParams; typedef GparityWilsonImplParams ImplParams;
@ -484,13 +490,15 @@ public:
GparityWilsonImpl(const ImplParams &p = ImplParams()) : Params(p){}; GparityWilsonImpl(const ImplParams &p = ImplParams()) : Params(p){};
bool overlapCommsCompute(void) { return Params.overlapCommsCompute; };
// provide the multiply by link that is differentiated between Gparity (with // provide the multiply by link that is differentiated between Gparity (with
// flavour index) and non-Gparity // flavour index) and non-Gparity
accelerator_inline void multLink(SiteHalfSpinor &phi, const SiteDoubledGaugeField &U, static accelerator_inline void multLink(SiteHalfSpinor &phi,
const SiteHalfSpinor &chi, int mu, StencilEntry *SE, const SiteDoubledGaugeField &U,
typename StencilImpl::View_type &St) { const SiteHalfSpinor &chi,
int mu,
StencilEntry *SE,
StencilView &St)
{
typedef SiteHalfSpinor vobj; typedef SiteHalfSpinor vobj;
typedef typename SiteHalfSpinor::scalar_object sobj; typedef typename SiteHalfSpinor::scalar_object sobj;
@ -556,14 +564,17 @@ public:
} }
// Fixme: Gparity prop * link // Fixme: Gparity prop * link
accelerator_inline void multLinkProp(SitePropagator &phi, const SiteDoubledGaugeField &U, static accelerator_inline void multLinkProp(SitePropagator &phi,
const SitePropagator &chi, int mu) const SiteDoubledGaugeField &U,
const SitePropagator &chi,
int mu)
{ {
assert(0); assert(0);
} }
template <class ref> template <class ref>
accelerator_inline void loadLinkElement(Simd &reg, ref &memory) { static accelerator_inline void loadLinkElement(Simd &reg, ref &memory)
{
reg = memory; reg = memory;
} }
@ -698,26 +709,30 @@ public:
typedef SimpleCompressor<SiteSpinor> Compressor; typedef SimpleCompressor<SiteSpinor> Compressor;
typedef StaggeredImplParams ImplParams; typedef StaggeredImplParams ImplParams;
typedef CartesianStencil<SiteSpinor, SiteSpinor> StencilImpl; typedef CartesianStencil<SiteSpinor, SiteSpinor> StencilImpl;
typedef typename StencilImpl::View_type StencilView;
ImplParams Params; ImplParams Params;
StaggeredImpl(const ImplParams &p = ImplParams()) : Params(p){}; StaggeredImpl(const ImplParams &p = ImplParams()) : Params(p){};
accelerator_inline void multLink(SiteSpinor &phi, static accelerator_inline void multLink(SiteSpinor &phi,
const SiteDoubledGaugeField &U, const SiteDoubledGaugeField &U,
const SiteSpinor &chi, const SiteSpinor &chi,
int mu){ int mu)
{
mult(&phi(), &U(mu), &chi()); mult(&phi(), &U(mu), &chi());
} }
accelerator_inline void multLinkAdd(SiteSpinor &phi, static accelerator_inline void multLinkAdd(SiteSpinor &phi,
const SiteDoubledGaugeField &U, const SiteDoubledGaugeField &U,
const SiteSpinor &chi, const SiteSpinor &chi,
int mu){ int mu)
{
mac(&phi(), &U(mu), &chi()); mac(&phi(), &U(mu), &chi());
} }
template <class ref> template <class ref>
accelerator_inline void loadLinkElement(Simd &reg, ref &memory) { static accelerator_inline void loadLinkElement(Simd &reg, ref &memory)
{
reg = memory; reg = memory;
} }
@ -834,18 +849,23 @@ public:
typedef SimpleCompressor<SiteSpinor> Compressor; typedef SimpleCompressor<SiteSpinor> Compressor;
typedef StaggeredImplParams ImplParams; typedef StaggeredImplParams ImplParams;
typedef CartesianStencil<SiteSpinor, SiteSpinor> StencilImpl; typedef CartesianStencil<SiteSpinor, SiteSpinor> StencilImpl;
typedef typename StencilImpl::View_type StencilView;
ImplParams Params; ImplParams Params;
StaggeredVec5dImpl(const ImplParams &p = ImplParams()) : Params(p){}; StaggeredVec5dImpl(const ImplParams &p = ImplParams()) : Params(p){};
template <class ref> template <class ref>
accelerator_inline void loadLinkElement(Simd &reg, ref &memory) { static accelerator_inline void loadLinkElement(Simd &reg, ref &memory)
{
vsplat(reg, memory); vsplat(reg, memory);
} }
accelerator_inline void multLink(SiteHalfSpinor &phi, const SiteDoubledGaugeField &U, static accelerator_inline void multLink(SiteHalfSpinor &phi,
const SiteHalfSpinor &chi, int mu) { const SiteDoubledGaugeField &U,
const SiteHalfSpinor &chi,
int mu)
{
SiteGaugeLink UU; SiteGaugeLink UU;
for (int i = 0; i < Dimension; i++) { for (int i = 0; i < Dimension; i++) {
for (int j = 0; j < Dimension; j++) { for (int j = 0; j < Dimension; j++) {
@ -854,8 +874,11 @@ public:
} }
mult(&phi(), &UU(), &chi()); mult(&phi(), &UU(), &chi());
} }
accelerator_inline void multLinkAdd(SiteHalfSpinor &phi, const SiteDoubledGaugeField &U, static accelerator_inline void multLinkAdd(SiteHalfSpinor &phi,
const SiteHalfSpinor &chi, int mu) { const SiteDoubledGaugeField &U,
const SiteHalfSpinor &chi,
int mu)
{
SiteGaugeLink UU; SiteGaugeLink UU;
for (int i = 0; i < Dimension; i++) { for (int i = 0; i < Dimension; i++) {
for (int j = 0; j < Dimension; j++) { for (int j = 0; j < Dimension; j++) {