/************************************************************************************* Grid physics library, www.github.com/paboyle/Grid Source file: ./lib/lattice/Lattice_ET.h Copyright (C) 2015 Author: Azusa Yamaguchi Author: Peter Boyle Author: neo Author: Christoph Lehner #include #include #include NAMESPACE_BEGIN(Grid); //////////////////////////////////////////////////// // Predicated where support //////////////////////////////////////////////////// #ifdef GRID_SIMT // drop to scalar in SIMT; cleaner in fact template accelerator_inline vobj predicatedWhere(const iobj &predicate, const vobj &iftrue, const robj &iffalse) { Integer mask = TensorRemove(predicate); typename std::remove_const::type ret= iffalse; if (mask) ret=iftrue; return ret; } #else template accelerator_inline vobj predicatedWhere(const iobj &predicate, const vobj &iftrue, const robj &iffalse) { typename std::remove_const::type ret; typedef typename vobj::scalar_object scalar_object; typedef typename vobj::scalar_type scalar_type; typedef typename vobj::vector_type vector_type; const int Nsimd = vobj::vector_type::Nsimd(); ExtractBuffer mask(Nsimd); ExtractBuffer truevals(Nsimd); ExtractBuffer falsevals(Nsimd); extract(iftrue, truevals); extract(iffalse, falsevals); extract(TensorRemove(predicate), mask); for (int s = 0; s < Nsimd; s++) { if (mask[s]) falsevals[s] = truevals[s]; } merge(ret, falsevals); return ret; } #endif ///////////////////////////////////////////////////// //Specialization of getVectorType for lattices ///////////////////////////////////////////////////// template struct getVectorType >{ typedef typename Lattice::vector_object type; }; //////////////////////////////////////////// //-- recursive evaluation of expressions; -- // handle leaves of syntax tree /////////////////////////////////////////////////// template::value&&!is_lattice_expr::value,sobj>::type * = nullptr> accelerator_inline sobj eval(const uint64_t ss, const sobj &arg) { return arg; } template accelerator_inline auto eval(const uint64_t ss, const LatticeView &arg) -> decltype(arg(ss)) { return arg(ss); } //////////////////////////////////////////// //-- recursive evaluation of expressions; -- // whole vector return, used only for expression return type inference /////////////////////////////////////////////////// template accelerator_inline sobj vecEval(const uint64_t ss, const sobj &arg) { return arg; } template accelerator_inline const lobj & vecEval(const uint64_t ss, const LatticeView &arg) { return arg[ss]; } /////////////////////////////////////////////////// // handle nodes in syntax tree- eval one operand // vecEval needed (but never called as all expressions offloaded) to infer the return type // in SIMT contexts of closure. /////////////////////////////////////////////////// template accelerator_inline auto vecEval(const uint64_t ss, const LatticeUnaryExpression &expr) -> decltype(expr.op.func( vecEval(ss, expr.arg1))) { return expr.op.func( vecEval(ss, expr.arg1) ); } // vecEval two operands template accelerator_inline auto vecEval(const uint64_t ss, const LatticeBinaryExpression &expr) -> decltype(expr.op.func( vecEval(ss,expr.arg1),vecEval(ss,expr.arg2))) { return expr.op.func( vecEval(ss,expr.arg1), vecEval(ss,expr.arg2) ); } // vecEval three operands template accelerator_inline auto vecEval(const uint64_t ss, const LatticeTrinaryExpression &expr) -> decltype(expr.op.func(vecEval(ss, expr.arg1), vecEval(ss, expr.arg2), vecEval(ss, expr.arg3))) { return expr.op.func(vecEval(ss, expr.arg1), vecEval(ss, expr.arg2), vecEval(ss, expr.arg3)); } /////////////////////////////////////////////////// // handle nodes in syntax tree- eval one operand coalesced /////////////////////////////////////////////////// template accelerator_inline auto eval(const uint64_t ss, const LatticeUnaryExpression &expr) -> decltype(expr.op.func( eval(ss, expr.arg1))) { return expr.op.func( eval(ss, expr.arg1) ); } // eval two operands template accelerator_inline auto eval(const uint64_t ss, const LatticeBinaryExpression &expr) -> decltype(expr.op.func( eval(ss,expr.arg1),eval(ss,expr.arg2))) { return expr.op.func( eval(ss,expr.arg1), eval(ss,expr.arg2) ); } // eval three operands template accelerator_inline auto eval(const uint64_t ss, const LatticeTrinaryExpression &expr) -> decltype(expr.op.func(eval(ss, expr.arg1), eval(ss, expr.arg2), eval(ss, expr.arg3))) { #ifdef GRID_SIMT // Handles Nsimd (vInteger) != Nsimd(ComplexD) typedef decltype(vecEval(ss, expr.arg2)) rvobj; typedef typename std::remove_reference::type vobj; const int Nsimd = vobj::vector_type::Nsimd(); auto vpred = vecEval(ss,expr.arg1); ExtractBuffer mask(Nsimd); extract(TensorRemove(vpred), mask); int s = acceleratorSIMTlane(Nsimd); return expr.op.func(mask[s], eval(ss, expr.arg2), eval(ss, expr.arg3)); #else return expr.op.func(eval(ss, expr.arg1), eval(ss, expr.arg2), eval(ss, expr.arg3)); #endif } ////////////////////////////////////////////////////////////////////////// // Obtain the grid from an expression, ensuring conformable. This must follow a // tree recursion; must retain grid pointer in the LatticeView class which sucks // Use a different method, and make it void *. // Perhaps a conformable method. ////////////////////////////////////////////////////////////////////////// template ::value, T1>::type * = nullptr> accelerator_inline void GridFromExpression(GridBase *&grid, const T1 &lat) // Lattice leaf { lat.Conformable(grid); } template ::value, T1>::type * = nullptr> accelerator_inline void GridFromExpression(GridBase *&grid,const T1 ¬lat) // non-lattice leaf {} template accelerator_inline void GridFromExpression(GridBase *&grid,const LatticeUnaryExpression &expr) { GridFromExpression(grid, expr.arg1); // recurse } template accelerator_inline void GridFromExpression(GridBase *&grid, const LatticeBinaryExpression &expr) { GridFromExpression(grid, expr.arg1); // recurse GridFromExpression(grid, expr.arg2); } template accelerator_inline void GridFromExpression(GridBase *&grid, const LatticeTrinaryExpression &expr) { GridFromExpression(grid, expr.arg1); // recurse GridFromExpression(grid, expr.arg2); // recurse GridFromExpression(grid, expr.arg3); // recurse } ////////////////////////////////////////////////////////////////////////// // Obtain the CB from an expression, ensuring conformable. This must follow a // tree recursion ////////////////////////////////////////////////////////////////////////// template ::value, T1>::type * = nullptr> inline void CBFromExpression(int &cb, const T1 &lat) // Lattice leaf { if ((cb == Odd) || (cb == Even)) { assert(cb == lat.Checkerboard()); } cb = lat.Checkerboard(); } template ::value, T1>::type * = nullptr> inline void CBFromExpression(int &cb, const T1 ¬lat) {} // non-lattice leaf template inline void CBFromExpression(int &cb,const LatticeUnaryExpression &expr) { CBFromExpression(cb, expr.arg1); // recurse AST } template inline void CBFromExpression(int &cb,const LatticeBinaryExpression &expr) { CBFromExpression(cb, expr.arg1); // recurse AST CBFromExpression(cb, expr.arg2); // recurse AST } template inline void CBFromExpression(int &cb, const LatticeTrinaryExpression &expr) { CBFromExpression(cb, expr.arg1); // recurse AST CBFromExpression(cb, expr.arg2); // recurse AST CBFromExpression(cb, expr.arg3); // recurse AST } ////////////////////////////////////////////////////////////////////////// // ViewOpen ////////////////////////////////////////////////////////////////////////// template ::value, T1>::type * = nullptr> inline void ExpressionViewOpen(T1 &lat) // Lattice leaf { lat.ViewOpen(AcceleratorRead); } template ::value, T1>::type * = nullptr> inline void ExpressionViewOpen(T1 ¬lat) {} template inline void ExpressionViewOpen(LatticeUnaryExpression &expr) { ExpressionViewOpen(expr.arg1); // recurse AST } template inline void ExpressionViewOpen(LatticeBinaryExpression &expr) { ExpressionViewOpen(expr.arg1); // recurse AST ExpressionViewOpen(expr.arg2); // rrecurse AST } template inline void ExpressionViewOpen(LatticeTrinaryExpression &expr) { ExpressionViewOpen(expr.arg1); // recurse AST ExpressionViewOpen(expr.arg2); // recurse AST ExpressionViewOpen(expr.arg3); // recurse AST } ////////////////////////////////////////////////////////////////////////// // ViewClose ////////////////////////////////////////////////////////////////////////// template ::value, T1>::type * = nullptr> inline void ExpressionViewClose( T1 &lat) // Lattice leaf { lat.ViewClose(); } template ::value, T1>::type * = nullptr> inline void ExpressionViewClose(T1 ¬lat) {} template inline void ExpressionViewClose(LatticeUnaryExpression &expr) { ExpressionViewClose(expr.arg1); // recurse AST } template inline void ExpressionViewClose(LatticeBinaryExpression &expr) { ExpressionViewClose(expr.arg1); // recurse AST ExpressionViewClose(expr.arg2); // recurse AST } template inline void ExpressionViewClose(LatticeTrinaryExpression &expr) { ExpressionViewClose(expr.arg1); // recurse AST ExpressionViewClose(expr.arg2); // recurse AST ExpressionViewClose(expr.arg3); // recurse AST } //////////////////////////////////////////// // Unary operators and funcs //////////////////////////////////////////// #define GridUnopClass(name, ret) \ struct name { \ template static auto accelerator_inline func(const _arg a) -> decltype(ret) { return ret; } \ }; GridUnopClass(UnarySub, -a); GridUnopClass(UnaryNot, Not(a)); GridUnopClass(UnaryTrace, trace(a)); GridUnopClass(UnaryTranspose, transpose(a)); GridUnopClass(UnaryTa, Ta(a)); GridUnopClass(UnaryProjectOnGroup, ProjectOnGroup(a)); GridUnopClass(UnaryTimesI, timesI(a)); GridUnopClass(UnaryTimesMinusI, timesMinusI(a)); GridUnopClass(UnaryAbs, abs(a)); GridUnopClass(UnarySqrt, sqrt(a)); GridUnopClass(UnarySin, sin(a)); GridUnopClass(UnaryCos, cos(a)); GridUnopClass(UnaryAsin, asin(a)); GridUnopClass(UnaryAcos, acos(a)); GridUnopClass(UnaryLog, log(a)); GridUnopClass(UnaryExp, exp(a)); //////////////////////////////////////////// // Binary operators //////////////////////////////////////////// #define GridBinOpClass(name, combination) \ struct name { \ template \ static auto accelerator_inline \ func(const _left &lhs, const _right &rhs) \ -> decltype(combination) const \ { \ return combination; \ } \ }; GridBinOpClass(BinaryAdd, lhs + rhs); GridBinOpClass(BinarySub, lhs - rhs); GridBinOpClass(BinaryMul, lhs *rhs); GridBinOpClass(BinaryDiv, lhs /rhs); GridBinOpClass(BinaryAnd, lhs &rhs); GridBinOpClass(BinaryOr, lhs | rhs); GridBinOpClass(BinaryAndAnd, lhs &&rhs); GridBinOpClass(BinaryOrOr, lhs || rhs); //////////////////////////////////////////////////// // Trinary conditional op //////////////////////////////////////////////////// #define GridTrinOpClass(name, combination) \ struct name { \ template \ static auto accelerator_inline \ func(const _predicate &pred, const _left &lhs, const _right &rhs) \ -> decltype(combination) const \ { \ return combination; \ } \ }; GridTrinOpClass(TrinaryWhere, (predicatedWhere< typename std::remove_reference<_predicate>::type, typename std::remove_reference<_left>::type, typename std::remove_reference<_right>::type>(pred, lhs,rhs))); //////////////////////////////////////////// // Operator syntactical glue //////////////////////////////////////////// #define GRID_UNOP(name) name #define GRID_BINOP(name) name #define GRID_TRINOP(name) name #define GRID_DEF_UNOP(op, name) \ template ::value||is_lattice_expr::value,T1>::type * = nullptr> \ inline auto op(const T1 &arg) ->decltype(LatticeUnaryExpression(GRID_UNOP(name)(), arg)) \ { \ return LatticeUnaryExpression(GRID_UNOP(name)(), arg); \ } #define GRID_BINOP_LEFT(op, name) \ template ::value||is_lattice_expr::value,T1>::type * = nullptr> \ inline auto op(const T1 &lhs, const T2 &rhs) \ ->decltype(LatticeBinaryExpression(GRID_BINOP(name)(),lhs,rhs)) \ { \ return LatticeBinaryExpression(GRID_BINOP(name)(),lhs,rhs);\ } #define GRID_BINOP_RIGHT(op, name) \ template ::value&&!is_lattice_expr::value,T1>::type * = nullptr, \ typename std::enable_if< is_lattice::value|| is_lattice_expr::value,T2>::type * = nullptr> \ inline auto op(const T1 &lhs, const T2 &rhs) \ ->decltype(LatticeBinaryExpression(GRID_BINOP(name)(),lhs, rhs)) \ { \ return LatticeBinaryExpression(GRID_BINOP(name)(),lhs, rhs); \ } #define GRID_DEF_BINOP(op, name) \ GRID_BINOP_LEFT(op, name); \ GRID_BINOP_RIGHT(op, name); #define GRID_DEF_TRINOP(op, name) \ template \ inline auto op(const T1 &pred, const T2 &lhs, const T3 &rhs) \ ->decltype(LatticeTrinaryExpression(GRID_TRINOP(name)(),pred, lhs, rhs)) \ { \ return LatticeTrinaryExpression(GRID_TRINOP(name)(),pred, lhs, rhs); \ } //////////////////////// // Operator definitions //////////////////////// GRID_DEF_UNOP(operator-, UnarySub); GRID_DEF_UNOP(Not, UnaryNot); GRID_DEF_UNOP(operator!, UnaryNot); //GRID_DEF_UNOP(adj, UnaryAdj); //GRID_DEF_UNOP(conjugate, UnaryConj); GRID_DEF_UNOP(trace, UnaryTrace); GRID_DEF_UNOP(transpose, UnaryTranspose); GRID_DEF_UNOP(Ta, UnaryTa); GRID_DEF_UNOP(ProjectOnGroup, UnaryProjectOnGroup); GRID_DEF_UNOP(timesI, UnaryTimesI); GRID_DEF_UNOP(timesMinusI, UnaryTimesMinusI); GRID_DEF_UNOP(abs, UnaryAbs); // abs overloaded in cmath C++98; DON'T do the // abs-fabs-dabs-labs thing GRID_DEF_UNOP(sqrt, UnarySqrt); GRID_DEF_UNOP(sin, UnarySin); GRID_DEF_UNOP(cos, UnaryCos); GRID_DEF_UNOP(asin, UnaryAsin); GRID_DEF_UNOP(acos, UnaryAcos); GRID_DEF_UNOP(log, UnaryLog); GRID_DEF_UNOP(exp, UnaryExp); GRID_DEF_BINOP(operator+, BinaryAdd); GRID_DEF_BINOP(operator-, BinarySub); GRID_DEF_BINOP(operator*, BinaryMul); GRID_DEF_BINOP(operator/, BinaryDiv); GRID_DEF_BINOP(operator&, BinaryAnd); GRID_DEF_BINOP(operator|, BinaryOr); GRID_DEF_BINOP(operator&&, BinaryAndAnd); GRID_DEF_BINOP(operator||, BinaryOrOr); GRID_DEF_TRINOP(where, TrinaryWhere); ///////////////////////////////////////////////////////////// // Closure convenience to force expression to evaluate ///////////////////////////////////////////////////////////// template auto closure(const LatticeUnaryExpression &expr) -> Lattice::type > { Lattice::type > ret(expr); return ret; } template auto closure(const LatticeBinaryExpression &expr) -> Lattice::type > { Lattice::type > ret(expr); return ret; } template auto closure(const LatticeTrinaryExpression &expr) -> Lattice::type > { Lattice::type > ret(expr); return ret; } #define EXPRESSION_CLOSURE(function) \ template::value,void>::type * = nullptr> \ auto function(Expression &expr) -> decltype(function(closure(expr))) \ { \ return function(closure(expr)); \ } #undef GRID_UNOP #undef GRID_BINOP #undef GRID_TRINOP #undef GRID_DEF_UNOP #undef GRID_DEF_BINOP #undef GRID_DEF_TRINOP NAMESPACE_END(Grid); #endif