diff --git a/benchmarks/Grid_wilson.cc b/benchmarks/Grid_wilson.cc index 32255b3e..3b0d04bc 100644 --- a/benchmarks/Grid_wilson.cc +++ b/benchmarks/Grid_wilson.cc @@ -31,11 +31,9 @@ int main (int argc, char ** argv) std::cout << "Grid is setup to use "< seeds({1,2,3,4}); - GridParallelRNG pRNG(&Grid); - // std::vector seeds({1,2,3,4}); - // pRNG.SeedFixedIntegers(seeds); - pRNG.SeedRandomDevice(); + pRNG.SeedFixedIntegers(seeds); + // pRNG.SeedRandomDevice(); LatticeFermion src (&Grid); random(pRNG,src); LatticeFermion result(&Grid); result=zero; @@ -55,8 +53,10 @@ int main (int argc, char ** argv) Complex cone(1.0,0.0); for(int nn=0;nn(Umu,U[nn],nn); } @@ -85,7 +85,7 @@ int main (int argc, char ** argv) WilsonMatrix Dw(Umu,Grid,RBGrid,mass); std::cout << "Calling Dw"< +inline void Gpermute0(vsimd &y,const vsimd &b) { + union { + fvec f; + decltype(vsimd::v) v; + } conv; + conv.v = b.v; +#ifdef SSE4 + conv.f = _mm_shuffle_ps(conv.f,conv.f,_MM_SHUFFLE(1,0,3,2)); +#endif +#if defined(AVX1)||defined(AVX2) + conv.f = _mm256_permute2f128_ps(conv.f,conv.f,0x01); +#endif +#ifdef AVX512 + conv.f = _mm512_permute4f128_ps(conv.f,(_MM_PERM_ENUM)_MM_SHUFFLE(1,0,3,2)); +#endif + y.v=conv.v; +}; +template +inline void Gpermute1(vsimd &y,const vsimd &b) { + union { + fvec f; + decltype(vsimd::v) v; + } conv; + conv.v = b.v; +#ifdef SSE4 + conv.f = _mm_shuffle_ps(conv.f,conv.f,_MM_SHUFFLE(2,3,0,1)); +#endif +#if defined(AVX1)||defined(AVX2) + conv.f = _mm256_shuffle_ps(conv.f,conv.f,_MM_SHUFFLE(1,0,3,2)); +#endif +#ifdef AVX512 + conv.f = _mm512_permute4f128_ps(conv.f,(_MM_PERM_ENUM)_MM_SHUFFLE(2,3,0,1)); +#endif + y.v=conv.v; +}; +template +inline void Gpermute2(vsimd &y,const vsimd &b) { + union { + fvec f; + decltype(vsimd::v) v; + } conv; + conv.v = b.v; +#ifdef SSE4 +#endif +#if defined(AVX1)||defined(AVX2) + conv.f = _mm256_shuffle_ps(conv.f,conv.f,_MM_SHUFFLE(2,3,0,1)); +#endif +#ifdef AVX512 + conv.f = _mm512_swizzle_ps(conv.f,_MM_SWIZ_REG_BADC); +#endif + y.v=conv.v; + +}; +template +inline void Gpermute3(vsimd &y,const vsimd &b) { + union { + fvec f; + decltype(vsimd::v) v; + } conv; + conv.v = b.v; +#ifdef AVX512 + conv.f = _mm512_swizzle_ps(conv.f,_MM_SWIZ_REG_CDAB); +#endif + y.v=conv.v; + +}; + template inline void Gpermute(vsimd &y,const vsimd &b,int perm){ union { @@ -170,36 +238,12 @@ inline void Gpermute(vsimd &y,const vsimd &b,int perm){ } conv; conv.v = b.v; switch (perm){ -#if defined(AVX1)||defined(AVX2) - // 8x32 bits=>3 permutes - case 2: - conv.f = _mm256_shuffle_ps(conv.f,conv.f,_MM_SHUFFLE(2,3,0,1)); - break; - case 1: conv.f = _mm256_shuffle_ps(conv.f,conv.f,_MM_SHUFFLE(1,0,3,2)); break; - case 0: conv.f = _mm256_permute2f128_ps(conv.f,conv.f,0x01); break; -#endif -#ifdef SSE4 - case 1: conv.f = _mm_shuffle_ps(conv.f,conv.f,_MM_SHUFFLE(2,3,0,1)); break; - case 0: conv.f = _mm_shuffle_ps(conv.f,conv.f,_MM_SHUFFLE(1,0,3,2));break; -#endif -#ifdef AVX512 - // 16 floats=> permutes - // Permute 0 every abcd efgh ijkl mnop -> badc fehg jilk nmpo - // Permute 1 every abcd efgh ijkl mnop -> cdab ghef jkij opmn - // Permute 2 every abcd efgh ijkl mnop -> efgh abcd mnop ijkl - // Permute 3 every abcd efgh ijkl mnop -> ijkl mnop abcd efgh - case 3: conv.f = _mm512_swizzle_ps(conv.f,_MM_SWIZ_REG_CDAB); break; - case 2: conv.f = _mm512_swizzle_ps(conv.f,_MM_SWIZ_REG_BADC); break; - case 1: conv.f = _mm512_permute4f128_ps(conv.f,(_MM_PERM_ENUM)_MM_SHUFFLE(2,3,0,1)); break; - case 0: conv.f = _mm512_permute4f128_ps(conv.f,(_MM_PERM_ENUM)_MM_SHUFFLE(1,0,3,2)); break; -#endif -#ifdef QPX -#error not implemented -#endif + case 3: Gpermute3(y,b); break; + case 2: Gpermute2(y,b); break; + case 1: Gpermute1(y,b); break; + case 0: Gpermute0(y,b); break; default: assert(0); break; } - y.v=conv.v; - }; }; diff --git a/lib/Makefile.am b/lib/Makefile.am index 82459763..6bb5e187 100644 --- a/lib/Makefile.am +++ b/lib/Makefile.am @@ -18,6 +18,8 @@ libGrid_a_SOURCES = \ Grid_init.cc \ stencil/Grid_stencil_common.cc \ qcd/Grid_qcd_dirac.cc \ + qcd/Grid_qcd_dhop.cc \ + qcd/Grid_qcd_dhop_hand.cc \ qcd/Grid_qcd_wilson_dop.cc \ algorithms/approx/Zolotarev.cc \ algorithms/approx/Remez.cc \ diff --git a/lib/lattice/Grid_lattice_base.h b/lib/lattice/Grid_lattice_base.h index 1d3b1efb..4a6d3180 100644 --- a/lib/lattice/Grid_lattice_base.h +++ b/lib/lattice/Grid_lattice_base.h @@ -47,6 +47,11 @@ class LatticeTrinaryExpression :public std::pair >, publ LatticeTrinaryExpression(const std::pair > &arg): std::pair >(arg) {}; }; +void inline conformable(GridBase *lhs,GridBase *rhs) +{ + assert(lhs == rhs); +} + template class Lattice : public LatticeBase { @@ -60,7 +65,8 @@ public: typedef typename vobj::scalar_type scalar_type; typedef typename vobj::vector_type vector_type; typedef vobj vector_object; - + + //////////////////////////////////////////////////////////////////////////////// // Expression Template closure support //////////////////////////////////////////////////////////////////////////////// @@ -276,17 +282,15 @@ PARALLEL_FOR_LOOP } -#include +#include #define GRID_LATTICE_EXPRESSION_TEMPLATES #ifdef GRID_LATTICE_EXPRESSION_TEMPLATES #include #else #include #endif - #include - #include #include #include diff --git a/lib/lattice/Grid_lattice_conformable.h b/lib/lattice/Grid_lattice_conformable.h index faa8c7a7..a77e57af 100644 --- a/lib/lattice/Grid_lattice_conformable.h +++ b/lib/lattice/Grid_lattice_conformable.h @@ -3,16 +3,11 @@ namespace Grid { - template - void conformable(const Lattice &lhs,const Lattice &rhs) + template void conformable(const Lattice &lhs,const Lattice &rhs) { assert(lhs._grid == rhs._grid); assert(lhs.checkerboard == rhs.checkerboard); } - void inline conformable(const GridBase *lhs,GridBase *rhs) - { - assert(lhs == rhs); - } } #endif diff --git a/lib/qcd/Grid_qcd_wilson_dop.cc b/lib/qcd/Grid_qcd_wilson_dop.cc index 318e18df..9a3f5f6a 100644 --- a/lib/qcd/Grid_qcd_wilson_dop.cc +++ b/lib/qcd/Grid_qcd_wilson_dop.cc @@ -1,4 +1,3 @@ - #include namespace Grid { @@ -7,15 +6,7 @@ namespace QCD { const std::vector WilsonMatrix::directions ({0,1,2,3, 0, 1, 2, 3}); const std::vector WilsonMatrix::displacements({1,1,1,1,-1,-1,-1,-1}); - // Should be in header? -const int WilsonMatrix::Xp = 0; -const int WilsonMatrix::Yp = 1; -const int WilsonMatrix::Zp = 2; -const int WilsonMatrix::Tp = 3; -const int WilsonMatrix::Xm = 4; -const int WilsonMatrix::Ym = 5; -const int WilsonMatrix::Zm = 6; -const int WilsonMatrix::Tm = 7; + int WilsonMatrix::HandOptDslash; class WilsonCompressor { public: @@ -39,28 +30,28 @@ const int WilsonMatrix::Tm = 7; mudag=(mu+Nd)%(2*Nd); } switch(mudag) { - case WilsonMatrix::Xp: + case Xp: spProjXp(ret,in); break; - case WilsonMatrix::Yp: + case Yp: spProjYp(ret,in); break; - case WilsonMatrix::Zp: + case Zp: spProjZp(ret,in); break; - case WilsonMatrix::Tp: + case Tp: spProjTp(ret,in); break; - case WilsonMatrix::Xm: + case Xm: spProjXm(ret,in); break; - case WilsonMatrix::Ym: + case Ym: spProjYm(ret,in); break; - case WilsonMatrix::Zm: + case Zm: spProjZm(ret,in); break; - case WilsonMatrix::Tm: + case Tm: spProjTm(ret,in); break; default: @@ -157,316 +148,36 @@ void WilsonMatrix::MooeeInvDag(const LatticeFermion &in, LatticeFermion &out) MooeeInv(in,out); } -void WilsonMatrix::DhopSite(CartesianStencil &st,LatticeDoubledGaugeField &U, - std::vector > &buf, - int ss,const LatticeFermion &in, LatticeFermion &out) -{ - vHalfSpinColourVector tmp; - vHalfSpinColourVector chi; - vSpinColourVector result; - vHalfSpinColourVector Uchi; - int offset,local,perm, ptype; - - //#define VERBOSE( A) if ( ss<10 ) { std::cout << "site " < > &buf, - int ss,const LatticeFermion &in, LatticeFermion &out) -{ - vHalfSpinColourVector tmp; - vHalfSpinColourVector chi; - vSpinColourVector result; - vHalfSpinColourVector Uchi; - int offset,local,perm, ptype; - - // Xp - offset = st._offsets [Xm][ss]; - local = st._is_local[Xm][ss]; - perm = st._permute[Xm][ss]; - - ptype = st._permute_type[Xm]; - if ( local && perm ) { - spProjXp(tmp,in._odata[offset]); - permute(chi,tmp,ptype); - } else if ( local ) { - spProjXp(chi,in._odata[offset]); - } else { - chi=buf[offset]; - } - mult(&Uchi(),&U._odata[ss](Xm),&chi()); - spReconXp(result,Uchi); - - // Yp - offset = st._offsets [Ym][ss]; - local = st._is_local[Ym][ss]; - perm = st._permute[Ym][ss]; - ptype = st._permute_type[Ym]; - if ( local && perm ) { - spProjYp(tmp,in._odata[offset]); - permute(chi,tmp,ptype); - } else if ( local ) { - spProjYp(chi,in._odata[offset]); - } else { - chi=buf[offset]; - } - mult(&Uchi(),&U._odata[ss](Ym),&chi()); - accumReconYp(result,Uchi); - - // Zp - offset = st._offsets [Zm][ss]; - local = st._is_local[Zm][ss]; - perm = st._permute[Zm][ss]; - ptype = st._permute_type[Zm]; - if ( local && perm ) { - spProjZp(tmp,in._odata[offset]); - permute(chi,tmp,ptype); - } else if ( local ) { - spProjZp(chi,in._odata[offset]); - } else { - chi=buf[offset]; - } - mult(&Uchi(),&U._odata[ss](Zm),&chi()); - accumReconZp(result,Uchi); - - // Tp - offset = st._offsets [Tm][ss]; - local = st._is_local[Tm][ss]; - perm = st._permute[Tm][ss]; - ptype = st._permute_type[Tm]; - if ( local && perm ) { - spProjTp(tmp,in._odata[offset]); - permute(chi,tmp,ptype); - } else if ( local ) { - spProjTp(chi,in._odata[offset]); - } else { - chi=buf[offset]; - } - mult(&Uchi(),&U._odata[ss](Tm),&chi()); - accumReconTp(result,Uchi); - - // Xm - offset = st._offsets [Xp][ss]; - local = st._is_local[Xp][ss]; - perm = st._permute[Xp][ss]; - ptype = st._permute_type[Xp]; - - if ( local && perm ) - { - spProjXm(tmp,in._odata[offset]); - permute(chi,tmp,ptype); - } else if ( local ) { - spProjXm(chi,in._odata[offset]); - } else { - chi=buf[offset]; - } - mult(&Uchi(),&U._odata[ss](Xp),&chi()); - accumReconXm(result,Uchi); - - // Ym - offset = st._offsets [Yp][ss]; - local = st._is_local[Yp][ss]; - perm = st._permute[Yp][ss]; - ptype = st._permute_type[Yp]; - - if ( local && perm ) { - spProjYm(tmp,in._odata[offset]); - permute(chi,tmp,ptype); - } else if ( local ) { - spProjYm(chi,in._odata[offset]); - } else { - chi=buf[offset]; - } - mult(&Uchi(),&U._odata[ss](Yp),&chi()); - accumReconYm(result,Uchi); - - // Zm - offset = st._offsets [Zp][ss]; - local = st._is_local[Zp][ss]; - perm = st._permute[Zp][ss]; - ptype = st._permute_type[Zp]; - if ( local && perm ) { - spProjZm(tmp,in._odata[offset]); - permute(chi,tmp,ptype); - } else if ( local ) { - spProjZm(chi,in._odata[offset]); - } else { - chi=buf[offset]; - } - mult(&Uchi(),&U._odata[ss](Zp),&chi()); - accumReconZm(result,Uchi); - - // Tm - offset = st._offsets [Tp][ss]; - local = st._is_local[Tp][ss]; - perm = st._permute[Tp][ss]; - ptype = st._permute_type[Tp]; - if ( local && perm ) { - spProjTm(tmp,in._odata[offset]); - permute(chi,tmp,ptype); - } else if ( local ) { - spProjTm(chi,in._odata[offset]); - } else { - chi=buf[offset]; - } - mult(&Uchi(),&U._odata[ss](Tp),&chi()); - accumReconTm(result,Uchi); - - vstream(out._odata[ss],result); -} - void WilsonMatrix::DhopInternal(CartesianStencil & st,LatticeDoubledGaugeField & U, const LatticeFermion &in, LatticeFermion &out,int dag) { assert((dag==DaggerNo) ||(dag==DaggerYes)); WilsonCompressor compressor(dag); - st.HaloExchange(in,comm_buf,compressor); if ( dag == DaggerYes ) { + if( HandOptDslash ) { PARALLEL_FOR_LOOP - for(int sss=0;sssoSites();sss++){ - DhopSiteDag(st,U,comm_buf,sss,in,out); + for(int sss=0;sssoSites();sss++){ + DiracOptHand::DhopSiteDag(st,U,comm_buf,sss,in,out); + } + } else { +PARALLEL_FOR_LOOP + for(int sss=0;sssoSites();sss++){ + DiracOpt::DhopSiteDag(st,U,comm_buf,sss,in,out); + } } } else { + if( HandOptDslash ) { PARALLEL_FOR_LOOP - for(int sss=0;sssoSites();sss++){ - DhopSite(st,U,comm_buf,sss,in,out); + for(int sss=0;sssoSites();sss++){ + DiracOptHand::DhopSite(st,U,comm_buf,sss,in,out); + } + } else { +PARALLEL_FOR_LOOP + for(int sss=0;sssoSites();sss++){ + DiracOpt::DhopSite(st,U,comm_buf,sss,in,out); + } } } } diff --git a/lib/qcd/Grid_qcd_wilson_dop.h b/lib/qcd/Grid_qcd_wilson_dop.h index 96b29cd0..87418603 100644 --- a/lib/qcd/Grid_qcd_wilson_dop.h +++ b/lib/qcd/Grid_qcd_wilson_dop.h @@ -6,10 +6,22 @@ namespace Grid { namespace QCD { + // Should be in header? + const int Xp = 0; + const int Yp = 1; + const int Zp = 2; + const int Tp = 3; + const int Xm = 4; + const int Ym = 5; + const int Zm = 6; + const int Tm = 7; + class WilsonMatrix : public CheckerBoardedSparseMatrixBase { //NB r=1; public: + static int HandOptDslash; + double mass; // GridBase * grid; // Inherited // GridBase * cbgrid; @@ -56,14 +68,6 @@ namespace Grid { void DhopEO(const LatticeFermion &in, LatticeFermion &out,int dag); void DhopInternal(CartesianStencil & st,LatticeDoubledGaugeField &U, const LatticeFermion &in, LatticeFermion &out,int dag); - // These ones will need to be package intelligently. WilsonType base class - // for use by DWF etc.. - void DhopSite(CartesianStencil &st,LatticeDoubledGaugeField &U, - std::vector > &buf, - int ss,const LatticeFermion &in, LatticeFermion &out); - void DhopSiteDag(CartesianStencil &st,LatticeDoubledGaugeField &U, - std::vector > &buf, - int ss,const LatticeFermion &in, LatticeFermion &out); typedef iScalar > matrix; @@ -71,6 +75,31 @@ namespace Grid { }; + class DiracOpt { + public: + // These ones will need to be package intelligently. WilsonType base class + // for use by DWF etc.. + static void DhopSite(CartesianStencil &st,LatticeDoubledGaugeField &U, + std::vector > &buf, + int ss,const LatticeFermion &in, LatticeFermion &out); + static void DhopSiteDag(CartesianStencil &st,LatticeDoubledGaugeField &U, + std::vector > &buf, + int ss,const LatticeFermion &in, LatticeFermion &out); + + }; + class DiracOptHand { + public: + // These ones will need to be package intelligently. WilsonType base class + // for use by DWF etc.. + static void DhopSite(CartesianStencil &st,LatticeDoubledGaugeField &U, + std::vector > &buf, + int ss,const LatticeFermion &in, LatticeFermion &out); + static void DhopSiteDag(CartesianStencil &st,LatticeDoubledGaugeField &U, + std::vector > &buf, + int ss,const LatticeFermion &in, LatticeFermion &out); + + }; + } } #endif