1
0
mirror of https://github.com/paboyle/Grid.git synced 2025-06-12 20:27:06 +01:00

Merge branch 'develop' into feature/hmc_generalise

This commit is contained in:
Guido Cossu
2017-05-01 12:13:56 +01:00
69 changed files with 3971 additions and 3179 deletions

View File

@ -368,8 +368,8 @@ namespace Optimization {
b0 = _mm256_extractf128_si256(b,0);
a1 = _mm256_extractf128_si256(a,1);
b1 = _mm256_extractf128_si256(b,1);
a0 = _mm_mul_epi32(a0,b0);
a1 = _mm_mul_epi32(a1,b1);
a0 = _mm_mullo_epi32(a0,b0);
a1 = _mm_mullo_epi32(a1,b1);
return _mm256_set_m128i(a1,a0);
#endif
#if defined (AVX2)
@ -461,7 +461,52 @@ namespace Optimization {
return in;
};
};
#define USE_FP16
struct PrecisionChange {
static inline __m256i StoH (__m256 a,__m256 b) {
__m256i h;
#ifdef USE_FP16
__m128i ha = _mm256_cvtps_ph(a,0);
__m128i hb = _mm256_cvtps_ph(b,0);
h =(__m256i) _mm256_castps128_ps256((__m128)ha);
h =(__m256i) _mm256_insertf128_ps((__m256)h,(__m128)hb,1);
#else
assert(0);
#endif
return h;
}
static inline void HtoS (__m256i h,__m256 &sa,__m256 &sb) {
#ifdef USE_FP16
sa = _mm256_cvtph_ps((__m128i)_mm256_extractf128_ps((__m256)h,0));
sb = _mm256_cvtph_ps((__m128i)_mm256_extractf128_ps((__m256)h,1));
#else
assert(0);
#endif
}
static inline __m256 DtoS (__m256d a,__m256d b) {
__m128 sa = _mm256_cvtpd_ps(a);
__m128 sb = _mm256_cvtpd_ps(b);
__m256 s = _mm256_castps128_ps256(sa);
s = _mm256_insertf128_ps(s,sb,1);
return s;
}
static inline void StoD (__m256 s,__m256d &a,__m256d &b) {
a = _mm256_cvtps_pd(_mm256_extractf128_ps(s,0));
b = _mm256_cvtps_pd(_mm256_extractf128_ps(s,1));
}
static inline __m256i DtoH (__m256d a,__m256d b,__m256d c,__m256d d) {
__m256 sa,sb;
sa = DtoS(a,b);
sb = DtoS(c,d);
return StoH(sa,sb);
}
static inline void HtoD (__m256i h,__m256d &a,__m256d &b,__m256d &c,__m256d &d) {
__m256 sa,sb;
HtoS(h,sa,sb);
StoD(sa,a,b);
StoD(sb,c,d);
}
};
struct Exchange{
// 3210 ordering
static inline void Exchange0(__m256 &out1,__m256 &out2,__m256 in1,__m256 in2){
@ -666,6 +711,7 @@ namespace Optimization {
//////////////////////////////////////////////////////////////////////////////////////
// Here assign types
typedef __m256i SIMD_Htype; // Single precision type
typedef __m256 SIMD_Ftype; // Single precision type
typedef __m256d SIMD_Dtype; // Double precision type
typedef __m256i SIMD_Itype; // Integer type

View File

@ -235,11 +235,9 @@ namespace Optimization {
inline void mac(__m512 &a, __m512 b, __m512 c){
a= _mm512_fmadd_ps( b, c, a);
}
inline void mac(__m512d &a, __m512d b, __m512d c){
a= _mm512_fmadd_pd( b, c, a);
}
// Real float
inline __m512 operator()(__m512 a, __m512 b){
return _mm512_mul_ps(a,b);
@ -342,7 +340,52 @@ namespace Optimization {
};
};
#define USE_FP16
struct PrecisionChange {
static inline __m512i StoH (__m512 a,__m512 b) {
__m512i h;
#ifdef USE_FP16
__m256i ha = _mm512_cvtps_ph(a,0);
__m256i hb = _mm512_cvtps_ph(b,0);
h =(__m512i) _mm512_castps256_ps512((__m256)ha);
h =(__m512i) _mm512_insertf64x4((__m512d)h,(__m256d)hb,1);
#else
assert(0);
#endif
return h;
}
static inline void HtoS (__m512i h,__m512 &sa,__m512 &sb) {
#ifdef USE_FP16
sa = _mm512_cvtph_ps((__m256i)_mm512_extractf64x4_pd((__m512d)h,0));
sb = _mm512_cvtph_ps((__m256i)_mm512_extractf64x4_pd((__m512d)h,1));
#else
assert(0);
#endif
}
static inline __m512 DtoS (__m512d a,__m512d b) {
__m256 sa = _mm512_cvtpd_ps(a);
__m256 sb = _mm512_cvtpd_ps(b);
__m512 s = _mm512_castps256_ps512(sa);
s =(__m512) _mm512_insertf64x4((__m512d)s,(__m256d)sb,1);
return s;
}
static inline void StoD (__m512 s,__m512d &a,__m512d &b) {
a = _mm512_cvtps_pd((__m256)_mm512_extractf64x4_pd((__m512d)s,0));
b = _mm512_cvtps_pd((__m256)_mm512_extractf64x4_pd((__m512d)s,1));
}
static inline __m512i DtoH (__m512d a,__m512d b,__m512d c,__m512d d) {
__m512 sa,sb;
sa = DtoS(a,b);
sb = DtoS(c,d);
return StoH(sa,sb);
}
static inline void HtoD (__m512i h,__m512d &a,__m512d &b,__m512d &c,__m512d &d) {
__m512 sa,sb;
HtoS(h,sa,sb);
StoD(sa,a,b);
StoD(sb,c,d);
}
};
// On extracting face: Ah Al , Bh Bl -> Ah Bh, Al Bl
// On merging buffers: Ah,Bh , Al Bl -> Ah Al, Bh, Bl
// The operation is its own inverse
@ -539,7 +582,9 @@ namespace Optimization {
//////////////////////////////////////////////////////////////////////////////////////
// Here assign types
typedef __m512 SIMD_Ftype; // Single precision type
typedef __m512i SIMD_Htype; // Single precision type
typedef __m512 SIMD_Ftype; // Single precision type
typedef __m512d SIMD_Dtype; // Double precision type
typedef __m512i SIMD_Itype; // Integer type

View File

@ -279,6 +279,101 @@ namespace Optimization {
#undef timesi
struct PrecisionChange {
static inline vech StoH (const vecf &a,const vecf &b) {
#ifdef USE_FP16
vech ret;
vech *ha = (vech *)&a;
vech *hb = (vech *)&b;
const int nf = W<float>::r;
// VECTOR_FOR(i, nf,1){ ret.v[i] = ( (uint16_t *) &a.v[i])[1] ; }
// VECTOR_FOR(i, nf,1){ ret.v[i+nf] = ( (uint16_t *) &b.v[i])[1] ; }
VECTOR_FOR(i, nf,1){ ret.v[i] = ha->v[2*i+1]; }
VECTOR_FOR(i, nf,1){ ret.v[i+nf] = hb->v[2*i+1]; }
#else
assert(0);
#endif
return ret;
}
static inline void HtoS (vech h,vecf &sa,vecf &sb) {
#ifdef USE_FP16
const int nf = W<float>::r;
const int nh = W<uint16_t>::r;
vech *ha = (vech *)&sa;
vech *hb = (vech *)&sb;
VECTOR_FOR(i, nf, 1){ sb.v[i]= sa.v[i] = 0; }
// VECTOR_FOR(i, nf, 1){ ( (uint16_t *) (&sa.v[i]))[1] = h.v[i];}
// VECTOR_FOR(i, nf, 1){ ( (uint16_t *) (&sb.v[i]))[1] = h.v[i+nf];}
VECTOR_FOR(i, nf, 1){ ha->v[2*i+1]=h.v[i]; }
VECTOR_FOR(i, nf, 1){ hb->v[2*i+1]=h.v[i+nf]; }
#else
assert(0);
#endif
}
static inline vecf DtoS (vecd a,vecd b) {
const int nd = W<double>::r;
const int nf = W<float>::r;
vecf ret;
VECTOR_FOR(i, nd,1){ ret.v[i] = a.v[i] ; }
VECTOR_FOR(i, nd,1){ ret.v[i+nd] = b.v[i] ; }
return ret;
}
static inline void StoD (vecf s,vecd &a,vecd &b) {
const int nd = W<double>::r;
VECTOR_FOR(i, nd,1){ a.v[i] = s.v[i] ; }
VECTOR_FOR(i, nd,1){ b.v[i] = s.v[i+nd] ; }
}
static inline vech DtoH (vecd a,vecd b,vecd c,vecd d) {
vecf sa,sb;
sa = DtoS(a,b);
sb = DtoS(c,d);
return StoH(sa,sb);
}
static inline void HtoD (vech h,vecd &a,vecd &b,vecd &c,vecd &d) {
vecf sa,sb;
HtoS(h,sa,sb);
StoD(sa,a,b);
StoD(sb,c,d);
}
};
//////////////////////////////////////////////
// Exchange support
struct Exchange{
template <typename T,int n>
static inline void ExchangeN(vec<T> &out1,vec<T> &out2,vec<T> &in1,vec<T> &in2){
const int w = W<T>::r;
unsigned int mask = w >> (n + 1);
// std::cout << " Exchange "<<n<<" nsimd "<<w<<" mask 0x" <<std::hex<<mask<<std::dec<<std::endl;
VECTOR_FOR(i, w, 1) {
int j1 = i&(~mask);
if ( (i&mask) == 0 ) { out1.v[i]=in1.v[j1];}
else { out1.v[i]=in2.v[j1];}
int j2 = i|mask;
if ( (i&mask) == 0 ) { out2.v[i]=in1.v[j2];}
else { out2.v[i]=in2.v[j2];}
}
}
template <typename T>
static inline void Exchange0(vec<T> &out1,vec<T> &out2,vec<T> &in1,vec<T> &in2){
ExchangeN<T,0>(out1,out2,in1,in2);
};
template <typename T>
static inline void Exchange1(vec<T> &out1,vec<T> &out2,vec<T> &in1,vec<T> &in2){
ExchangeN<T,1>(out1,out2,in1,in2);
};
template <typename T>
static inline void Exchange2(vec<T> &out1,vec<T> &out2,vec<T> &in1,vec<T> &in2){
ExchangeN<T,2>(out1,out2,in1,in2);
};
template <typename T>
static inline void Exchange3(vec<T> &out1,vec<T> &out2,vec<T> &in1,vec<T> &in2){
ExchangeN<T,3>(out1,out2,in1,in2);
};
};
//////////////////////////////////////////////
// Some Template specialization
#define perm(a, b, n, w)\
@ -403,6 +498,7 @@ namespace Optimization {
//////////////////////////////////////////////////////////////////////////////////////
// Here assign types
typedef Optimization::vech SIMD_Htype; // Reduced precision type
typedef Optimization::vecf SIMD_Ftype; // Single precision type
typedef Optimization::vecd SIMD_Dtype; // Double precision type
typedef Optimization::veci SIMD_Itype; // Integer type

View File

@ -66,6 +66,10 @@ namespace Optimization {
template <> struct W<Integer> {
constexpr static unsigned int r = GEN_SIMD_WIDTH/4u;
};
template <> struct W<uint16_t> {
constexpr static unsigned int c = GEN_SIMD_WIDTH/4u;
constexpr static unsigned int r = GEN_SIMD_WIDTH/2u;
};
// SIMD vector types
template <typename T>
@ -73,8 +77,9 @@ namespace Optimization {
alignas(GEN_SIMD_WIDTH) T v[W<T>::r];
};
typedef vec<float> vecf;
typedef vec<double> vecd;
typedef vec<Integer> veci;
typedef vec<float> vecf;
typedef vec<double> vecd;
typedef vec<uint16_t> vech; // half precision comms
typedef vec<Integer> veci;
}}

View File

@ -33,6 +33,14 @@
#include "Grid_generic_types.h" // Definitions for simulated integer SIMD.
namespace Grid {
#ifdef QPX
#include <spi/include/kernel/location.h>
#include <spi/include/l1p/types.h>
#include <hwi/include/bqc/l1p_mmio.h>
#include <hwi/include/bqc/A2_inlines.h>
#endif
namespace Optimization {
typedef struct
{
@ -125,7 +133,6 @@ namespace Optimization {
f[2] = a.v2;
f[3] = a.v3;
}
//Double
inline void operator()(double *d, vector4double a){
vec_st(a, 0, d);

View File

@ -328,6 +328,140 @@ namespace Optimization {
};
};
#define _my_alignr_epi32(a,b,n) _mm_alignr_epi8(a,b,(n*4)%16)
#define _my_alignr_epi64(a,b,n) _mm_alignr_epi8(a,b,(n*8)%16)
#ifdef SFW_FP16
struct Grid_half {
Grid_half(){}
Grid_half(uint16_t raw) : x(raw) {}
uint16_t x;
};
union FP32 {
unsigned int u;
float f;
};
// PAB - Lifted and adapted from Eigen, which is GPL V2
inline float sfw_half_to_float(Grid_half h) {
const FP32 magic = { 113 << 23 };
const unsigned int shifted_exp = 0x7c00 << 13; // exponent mask after shift
FP32 o;
o.u = (h.x & 0x7fff) << 13; // exponent/mantissa bits
unsigned int exp = shifted_exp & o.u; // just the exponent
o.u += (127 - 15) << 23; // exponent adjust
// handle exponent special cases
if (exp == shifted_exp) { // Inf/NaN?
o.u += (128 - 16) << 23; // extra exp adjust
} else if (exp == 0) { // Zero/Denormal?
o.u += 1 << 23; // extra exp adjust
o.f -= magic.f; // renormalize
}
o.u |= (h.x & 0x8000) << 16; // sign bit
return o.f;
}
inline Grid_half sfw_float_to_half(float ff) {
FP32 f; f.f = ff;
const FP32 f32infty = { 255 << 23 };
const FP32 f16max = { (127 + 16) << 23 };
const FP32 denorm_magic = { ((127 - 15) + (23 - 10) + 1) << 23 };
unsigned int sign_mask = 0x80000000u;
Grid_half o;
o.x = static_cast<unsigned short>(0x0u);
unsigned int sign = f.u & sign_mask;
f.u ^= sign;
// NOTE all the integer compares in this function can be safely
// compiled into signed compares since all operands are below
// 0x80000000. Important if you want fast straight SSE2 code
// (since there's no unsigned PCMPGTD).
if (f.u >= f16max.u) { // result is Inf or NaN (all exponent bits set)
o.x = (f.u > f32infty.u) ? 0x7e00 : 0x7c00; // NaN->qNaN and Inf->Inf
} else { // (De)normalized number or zero
if (f.u < (113 << 23)) { // resulting FP16 is subnormal or zero
// use a magic value to align our 10 mantissa bits at the bottom of
// the float. as long as FP addition is round-to-nearest-even this
// just works.
f.f += denorm_magic.f;
// and one integer subtract of the bias later, we have our final float!
o.x = static_cast<unsigned short>(f.u - denorm_magic.u);
} else {
unsigned int mant_odd = (f.u >> 13) & 1; // resulting mantissa is odd
// update exponent, rounding bias part 1
f.u += ((unsigned int)(15 - 127) << 23) + 0xfff;
// rounding bias part 2
f.u += mant_odd;
// take the bits!
o.x = static_cast<unsigned short>(f.u >> 13);
}
}
o.x |= static_cast<unsigned short>(sign >> 16);
return o;
}
static inline __m128i Grid_mm_cvtps_ph(__m128 f,int discard) {
__m128i ret=(__m128i)_mm_setzero_ps();
float *fp = (float *)&f;
Grid_half *hp = (Grid_half *)&ret;
hp[0] = sfw_float_to_half(fp[0]);
hp[1] = sfw_float_to_half(fp[1]);
hp[2] = sfw_float_to_half(fp[2]);
hp[3] = sfw_float_to_half(fp[3]);
return ret;
}
static inline __m128 Grid_mm_cvtph_ps(__m128i h,int discard) {
__m128 ret=_mm_setzero_ps();
float *fp = (float *)&ret;
Grid_half *hp = (Grid_half *)&h;
fp[0] = sfw_half_to_float(hp[0]);
fp[1] = sfw_half_to_float(hp[1]);
fp[2] = sfw_half_to_float(hp[2]);
fp[3] = sfw_half_to_float(hp[3]);
return ret;
}
#else
#define Grid_mm_cvtps_ph _mm_cvtps_ph
#define Grid_mm_cvtph_ps _mm_cvtph_ps
#endif
struct PrecisionChange {
static inline __m128i StoH (__m128 a,__m128 b) {
__m128i ha = Grid_mm_cvtps_ph(a,0);
__m128i hb = Grid_mm_cvtps_ph(b,0);
__m128i h =(__m128i) _mm_shuffle_ps((__m128)ha,(__m128)hb,_MM_SELECT_FOUR_FOUR(1,0,1,0));
return h;
}
static inline void HtoS (__m128i h,__m128 &sa,__m128 &sb) {
sa = Grid_mm_cvtph_ps(h,0);
h = (__m128i)_my_alignr_epi32((__m128i)h,(__m128i)h,2);
sb = Grid_mm_cvtph_ps(h,0);
}
static inline __m128 DtoS (__m128d a,__m128d b) {
__m128 sa = _mm_cvtpd_ps(a);
__m128 sb = _mm_cvtpd_ps(b);
__m128 s = _mm_shuffle_ps(sa,sb,_MM_SELECT_FOUR_FOUR(1,0,1,0));
return s;
}
static inline void StoD (__m128 s,__m128d &a,__m128d &b) {
a = _mm_cvtps_pd(s);
s = (__m128)_my_alignr_epi32((__m128i)s,(__m128i)s,2);
b = _mm_cvtps_pd(s);
}
static inline __m128i DtoH (__m128d a,__m128d b,__m128d c,__m128d d) {
__m128 sa,sb;
sa = DtoS(a,b);
sb = DtoS(c,d);
return StoH(sa,sb);
}
static inline void HtoD (__m128i h,__m128d &a,__m128d &b,__m128d &c,__m128d &d) {
__m128 sa,sb;
HtoS(h,sa,sb);
StoD(sa,a,b);
StoD(sb,c,d);
}
};
struct Exchange{
// 3210 ordering
static inline void Exchange0(__m128 &out1,__m128 &out2,__m128 in1,__m128 in2){
@ -335,8 +469,10 @@ namespace Optimization {
out2= _mm_shuffle_ps(in1,in2,_MM_SELECT_FOUR_FOUR(3,2,3,2));
};
static inline void Exchange1(__m128 &out1,__m128 &out2,__m128 in1,__m128 in2){
out1= _mm_shuffle_ps(in1,in2,_MM_SELECT_FOUR_FOUR(2,0,2,0));
out2= _mm_shuffle_ps(in1,in2,_MM_SELECT_FOUR_FOUR(3,1,3,1));
out1= _mm_shuffle_ps(in1,in2,_MM_SELECT_FOUR_FOUR(2,0,2,0)); /*ACEG*/
out2= _mm_shuffle_ps(in1,in2,_MM_SELECT_FOUR_FOUR(3,1,3,1)); /*BDFH*/
out1= _mm_shuffle_ps(out1,out1,_MM_SELECT_FOUR_FOUR(3,1,2,0)); /*AECG*/
out2= _mm_shuffle_ps(out2,out2,_MM_SELECT_FOUR_FOUR(3,1,2,0)); /*AECG*/
};
static inline void Exchange2(__m128 &out1,__m128 &out2,__m128 in1,__m128 in2){
assert(0);
@ -383,14 +519,9 @@ namespace Optimization {
default: assert(0);
}
}
#ifndef _mm_alignr_epi64
#define _mm_alignr_epi32(a,b,n) _mm_alignr_epi8(a,b,(n*4)%16)
#define _mm_alignr_epi64(a,b,n) _mm_alignr_epi8(a,b,(n*8)%16)
#endif
template<int n> static inline __m128 tRotate(__m128 in){ return (__m128)_mm_alignr_epi32((__m128i)in,(__m128i)in,n); };
template<int n> static inline __m128d tRotate(__m128d in){ return (__m128d)_mm_alignr_epi64((__m128i)in,(__m128i)in,n); };
template<int n> static inline __m128 tRotate(__m128 in){ return (__m128)_my_alignr_epi32((__m128i)in,(__m128i)in,n); };
template<int n> static inline __m128d tRotate(__m128d in){ return (__m128d)_my_alignr_epi64((__m128i)in,(__m128i)in,n); };
};
//////////////////////////////////////////////
@ -450,7 +581,8 @@ namespace Optimization {
//////////////////////////////////////////////////////////////////////////////////////
// Here assign types
typedef __m128 SIMD_Ftype; // Single precision type
typedef __m128i SIMD_Htype; // Single precision type
typedef __m128 SIMD_Ftype; // Single precision type
typedef __m128d SIMD_Dtype; // Double precision type
typedef __m128i SIMD_Itype; // Integer type

View File

@ -2,7 +2,7 @@
Grid physics library, www.github.com/paboyle/Grid
Source file: ./lib/simd/Grid_vector_types.h
Source file: ./lib/simd/Grid_vector_type.h
Copyright (C) 2015
@ -53,12 +53,14 @@ directory
#if defined IMCI
#include "Grid_imci.h"
#endif
#if defined QPX
#include "Grid_qpx.h"
#endif
#ifdef NEONv8
#include "Grid_neon.h"
#endif
#if defined QPX
#include "Grid_qpx.h"
#endif
#include "l1p.h"
namespace Grid {
@ -74,12 +76,14 @@ struct RealPart<std::complex<T> > {
typedef T type;
};
#include <type_traits>
//////////////////////////////////////
// demote a vector to real type
//////////////////////////////////////
// type alias used to simplify the syntax of std::enable_if
template <typename T> using Invoke = typename T::type;
template <typename Condition, typename ReturnType> using EnableIf = Invoke<std::enable_if<Condition::value, ReturnType> >;
template <typename Condition, typename ReturnType> using EnableIf = Invoke<std::enable_if<Condition::value, ReturnType> >;
template <typename Condition, typename ReturnType> using NotEnableIf = Invoke<std::enable_if<!Condition::value, ReturnType> >;
////////////////////////////////////////////////////////
@ -88,13 +92,15 @@ template <typename T> struct is_complex : public std::false_type {};
template <> struct is_complex<std::complex<double> > : public std::true_type {};
template <> struct is_complex<std::complex<float> > : public std::true_type {};
template <typename T> using IfReal = Invoke<std::enable_if<std::is_floating_point<T>::value, int> >;
template <typename T> using IfComplex = Invoke<std::enable_if<is_complex<T>::value, int> >;
template <typename T> using IfInteger = Invoke<std::enable_if<std::is_integral<T>::value, int> >;
template <typename T> using IfReal = Invoke<std::enable_if<std::is_floating_point<T>::value, int> >;
template <typename T> using IfComplex = Invoke<std::enable_if<is_complex<T>::value, int> >;
template <typename T> using IfInteger = Invoke<std::enable_if<std::is_integral<T>::value, int> >;
template <typename T1,typename T2> using IfSame = Invoke<std::enable_if<std::is_same<T1,T2>::value, int> >;
template <typename T> using IfNotReal = Invoke<std::enable_if<!std::is_floating_point<T>::value, int> >;
template <typename T> using IfNotComplex = Invoke<std::enable_if<!is_complex<T>::value, int> >;
template <typename T> using IfNotInteger = Invoke<std::enable_if<!std::is_integral<T>::value, int> >;
template <typename T> using IfNotReal = Invoke<std::enable_if<!std::is_floating_point<T>::value, int> >;
template <typename T> using IfNotComplex = Invoke<std::enable_if<!is_complex<T>::value, int> >;
template <typename T> using IfNotInteger = Invoke<std::enable_if<!std::is_integral<T>::value, int> >;
template <typename T1,typename T2> using IfNotSame = Invoke<std::enable_if<!std::is_same<T1,T2>::value, int> >;
////////////////////////////////////////////////////////
// Define the operation templates functors
@ -358,16 +364,12 @@ class Grid_simd {
{
if (n==3) {
Optimization::Exchange::Exchange3(out1.v,out2.v,in1.v,in2.v);
// std::cout << " Exchange3 "<< out1<<" "<< out2<<" <- " << in1 << " "<<in2<<std::endl;
} else if(n==2) {
Optimization::Exchange::Exchange2(out1.v,out2.v,in1.v,in2.v);
// std::cout << " Exchange2 "<< out1<<" "<< out2<<" <- " << in1 << " "<<in2<<std::endl;
} else if(n==1) {
Optimization::Exchange::Exchange1(out1.v,out2.v,in1.v,in2.v);
// std::cout << " Exchange1 "<< out1<<" "<< out2<<" <- " << in1 << " "<<in2<<std::endl;
} else if(n==0) {
Optimization::Exchange::Exchange0(out1.v,out2.v,in1.v,in2.v);
// std::cout << " Exchange0 "<< out1<<" "<< out2<<" <- " << in1 << " "<<in2<<std::endl;
}
}
@ -428,7 +430,6 @@ template <class S, class V, IfNotComplex<S> = 0>
inline Grid_simd<S, V> rotate(Grid_simd<S, V> b, int nrot) {
nrot = nrot % Grid_simd<S, V>::Nsimd();
Grid_simd<S, V> ret;
// std::cout << "Rotate Real by "<<nrot<<std::endl;
ret.v = Optimization::Rotate::rotate(b.v, nrot);
return ret;
}
@ -436,7 +437,6 @@ template <class S, class V, IfComplex<S> = 0>
inline Grid_simd<S, V> rotate(Grid_simd<S, V> b, int nrot) {
nrot = nrot % Grid_simd<S, V>::Nsimd();
Grid_simd<S, V> ret;
// std::cout << "Rotate Complex by "<<nrot<<std::endl;
ret.v = Optimization::Rotate::rotate(b.v, 2 * nrot);
return ret;
}
@ -444,14 +444,12 @@ template <class S, class V, IfNotComplex<S> =0>
inline void rotate( Grid_simd<S,V> &ret,Grid_simd<S,V> b,int nrot)
{
nrot = nrot % Grid_simd<S,V>::Nsimd();
// std::cout << "Rotate Real by "<<nrot<<std::endl;
ret.v = Optimization::Rotate::rotate(b.v,nrot);
}
template <class S, class V, IfComplex<S> =0>
inline void rotate(Grid_simd<S,V> &ret,Grid_simd<S,V> b,int nrot)
{
nrot = nrot % Grid_simd<S,V>::Nsimd();
// std::cout << "Rotate Complex by "<<nrot<<std::endl;
ret.v = Optimization::Rotate::rotate(b.v,2*nrot);
}
@ -711,7 +709,6 @@ inline Grid_simd<S, V> innerProduct(const Grid_simd<S, V> &l,
const Grid_simd<S, V> &r) {
return conjugate(l) * r;
}
template <class S, class V>
inline Grid_simd<S, V> outerProduct(const Grid_simd<S, V> &l,
const Grid_simd<S, V> &r) {
@ -771,6 +768,67 @@ typedef Grid_simd<std::complex<float>, SIMD_Ftype> vComplexF;
typedef Grid_simd<std::complex<double>, SIMD_Dtype> vComplexD;
typedef Grid_simd<Integer, SIMD_Itype> vInteger;
// Half precision; no arithmetic support
typedef Grid_simd<uint16_t, SIMD_Htype> vRealH;
typedef Grid_simd<std::complex<uint16_t>, SIMD_Htype> vComplexH;
inline void precisionChange(vRealF *out,vRealD *in,int nvec)
{
assert((nvec&0x1)==0);
for(int m=0;m*2<nvec;m++){
int n=m*2;
out[m].v=Optimization::PrecisionChange::DtoS(in[n].v,in[n+1].v);
}
}
inline void precisionChange(vRealH *out,vRealD *in,int nvec)
{
assert((nvec&0x3)==0);
for(int m=0;m*4<nvec;m++){
int n=m*4;
out[m].v=Optimization::PrecisionChange::DtoH(in[n].v,in[n+1].v,in[n+2].v,in[n+3].v);
}
}
inline void precisionChange(vRealH *out,vRealF *in,int nvec)
{
assert((nvec&0x1)==0);
for(int m=0;m*2<nvec;m++){
int n=m*2;
out[m].v=Optimization::PrecisionChange::StoH(in[n].v,in[n+1].v);
}
}
inline void precisionChange(vRealD *out,vRealF *in,int nvec)
{
assert((nvec&0x1)==0);
for(int m=0;m*2<nvec;m++){
int n=m*2;
Optimization::PrecisionChange::StoD(in[m].v,out[n].v,out[n+1].v);
}
}
inline void precisionChange(vRealD *out,vRealH *in,int nvec)
{
assert((nvec&0x3)==0);
for(int m=0;m*4<nvec;m++){
int n=m*4;
Optimization::PrecisionChange::HtoD(in[m].v,out[n].v,out[n+1].v,out[n+2].v,out[n+3].v);
}
}
inline void precisionChange(vRealF *out,vRealH *in,int nvec)
{
assert((nvec&0x1)==0);
for(int m=0;m*2<nvec;m++){
int n=m*2;
Optimization::PrecisionChange::HtoS(in[m].v,out[n].v,out[n+1].v);
}
}
inline void precisionChange(vComplexF *out,vComplexD *in,int nvec){ precisionChange((vRealF *)out,(vRealD *)in,nvec);}
inline void precisionChange(vComplexH *out,vComplexD *in,int nvec){ precisionChange((vRealH *)out,(vRealD *)in,nvec);}
inline void precisionChange(vComplexH *out,vComplexF *in,int nvec){ precisionChange((vRealH *)out,(vRealF *)in,nvec);}
inline void precisionChange(vComplexD *out,vComplexF *in,int nvec){ precisionChange((vRealD *)out,(vRealF *)in,nvec);}
inline void precisionChange(vComplexD *out,vComplexH *in,int nvec){ precisionChange((vRealD *)out,(vRealH *)in,nvec);}
inline void precisionChange(vComplexF *out,vComplexH *in,int nvec){ precisionChange((vRealF *)out,(vRealH *)in,nvec);}
// Check our vector types are of an appropriate size.
#if defined QPX
static_assert(2*sizeof(SIMD_Ftype) == sizeof(SIMD_Dtype), "SIMD vector lengths incorrect");

37
lib/simd/l1p.h Normal file
View File

@ -0,0 +1,37 @@
#pragma once
namespace Grid {
// L1p optimisation
inline void bgq_l1p_optimisation(int mode)
{
#ifdef QPX
#undef L1P_CFG_PF_USR
#define L1P_CFG_PF_USR (0x3fde8000108ll) /* (64 bit reg, 23 bits wide, user/unpriv) */
uint64_t cfg_pf_usr;
if ( mode ) {
cfg_pf_usr =
L1P_CFG_PF_USR_ifetch_depth(0)
| L1P_CFG_PF_USR_ifetch_max_footprint(1)
| L1P_CFG_PF_USR_pf_stream_est_on_dcbt
| L1P_CFG_PF_USR_pf_stream_establish_enable
| L1P_CFG_PF_USR_pf_stream_optimistic
| L1P_CFG_PF_USR_pf_adaptive_throttle(0xF) ;
// if ( sizeof(Float) == sizeof(double) ) {
cfg_pf_usr |= L1P_CFG_PF_USR_dfetch_depth(2)| L1P_CFG_PF_USR_dfetch_max_footprint(3) ;
// } else {
// cfg_pf_usr |= L1P_CFG_PF_USR_dfetch_depth(1)| L1P_CFG_PF_USR_dfetch_max_footprint(2) ;
// }
} else {
cfg_pf_usr = L1P_CFG_PF_USR_dfetch_depth(1)
| L1P_CFG_PF_USR_dfetch_max_footprint(2)
| L1P_CFG_PF_USR_ifetch_depth(0)
| L1P_CFG_PF_USR_ifetch_max_footprint(1)
| L1P_CFG_PF_USR_pf_stream_est_on_dcbt
| L1P_CFG_PF_USR_pf_stream_establish_enable
| L1P_CFG_PF_USR_pf_stream_optimistic
| L1P_CFG_PF_USR_pf_stream_prefetch_enable;
}
*((uint64_t *)L1P_CFG_PF_USR) = cfg_pf_usr;
#endif
}
}