From e3acb36de6b23039ce26c3bb75a07062520b44c7 Mon Sep 17 00:00:00 2001 From: Peter Boyle Date: Sun, 10 May 2015 15:22:31 +0100 Subject: [PATCH] Bringing expression templates for faster vector loops --- lib/Grid_lattice.h | 105 +++++++++++++++++++++++++++++++++++++++++++-- 1 file changed, 101 insertions(+), 4 deletions(-) diff --git a/lib/Grid_lattice.h b/lib/Grid_lattice.h index 75a80ef8..861a9e24 100644 --- a/lib/Grid_lattice.h +++ b/lib/Grid_lattice.h @@ -20,18 +20,106 @@ namespace Grid { extern int GridCshiftPermuteMap[4][16]; +//////////////////////////////////////////////// +// Basic expressions used in Expression Template +//////////////////////////////////////////////// + +class LatticeBase {}; +class LatticeExpressionBase {}; + +template +class LatticeUnaryExpression : public std::pair > , public LatticeExpressionBase { + public: + LatticeUnaryExpression(const std::pair > &arg): std::pair >(arg) {}; +}; + +template +class LatticeBinaryExpression : public std::pair > , public LatticeExpressionBase { + public: + LatticeBinaryExpression(const std::pair > &arg): std::pair >(arg) {}; +}; + +template +class LatticeTrinaryExpression :public std::pair >, public LatticeExpressionBase { + public: + LatticeTrinaryExpression(const std::pair > &arg): std::pair >(arg) {}; +}; + template -class Lattice +class Lattice : public LatticeBase { public: + GridBase *_grid; int checkerboard; std::vector > _odata; - //std::valarray _odata; -public: +public: typedef typename vobj::scalar_type scalar_type; typedef typename vobj::vector_type vector_type; + typedef vobj vector_object; + + //////////////////////////////////////////////////////////////////////////////// + // Expression Template closure support + //////////////////////////////////////////////////////////////////////////////// + template inline Lattice & operator=(const LatticeUnaryExpression &expr) + { +#pragma omp parallel for + for(int ss=0;ss<_grid->oSites();ss++){ + vobj tmp= eval(ss,expr); + vstream(_odata[ss] ,tmp); + } + return *this; + } + template inline Lattice & operator=(const LatticeBinaryExpression &expr) + { +#pragma omp parallel for + for(int ss=0;ss<_grid->oSites();ss++){ + vobj tmp= eval(ss,expr); + vstream(_odata[ss] ,tmp); + } + return *this; + } + template inline Lattice & operator=(const LatticeTrinaryExpression &expr) + { +#pragma omp parallel for + for(int ss=0;ss<_grid->oSites();ss++){ + vobj tmp= eval(ss,expr); + vstream(_odata[ss] ,tmp); + } + return *this; + } + //GridFromExpression is tricky to do + template + Lattice(const LatticeUnaryExpression & expr): _grid(nullptr){ + GridFromExpression(_grid,expr); + assert(_grid!=nullptr); + _odata.resize(_grid->oSites()); +#pragma omp parallel for + for(int ss=0;ss<_grid->oSites();ss++){ + _odata[ss] = eval(ss,expr); + } + }; + template + Lattice(const LatticeBinaryExpression & expr): _grid(nullptr){ + GridFromExpression(_grid,expr); + assert(_grid!=nullptr); + _odata.resize(_grid->oSites()); +#pragma omp parallel for + for(int ss=0;ss<_grid->oSites();ss++){ + _odata[ss] = eval(ss,expr); + } + }; + template + Lattice(const LatticeTrinaryExpression & expr): _grid(nullptr){ + GridFromExpression(_grid,expr); + assert(_grid!=nullptr); + _odata.resize(_grid->oSites()); +#pragma omp parallel for + for(int ss=0;ss<_grid->oSites();ss++){ + _odata[ss] = eval(ss,expr); + } + }; ////////////////////////////////////////////////////////////////// // Constructor requires "grid" passed. @@ -54,7 +142,7 @@ public: template inline Lattice & operator = (const Lattice & r){ conformable(*this,r); std::cout<<"Lattice operator ="<oSites();ss++){ this->_odata[ss]=r._odata[ss]; } @@ -88,8 +176,17 @@ public: }; // class Lattice } +#define GRID_LATTICE_EXPRESSION_TEMPLATES + #include + +#ifdef GRID_LATTICE_EXPRESSION_TEMPLATES +#include +#else +#include +#endif #include + #include #include #include