diff --git a/Grid/lattice/Lattice_ET.h b/Grid/lattice/Lattice_ET.h index 24ec812c..c43844f8 100644 --- a/Grid/lattice/Lattice_ET.h +++ b/Grid/lattice/Lattice_ET.h @@ -43,17 +43,16 @@ NAMESPACE_BEGIN(Grid); // Predicated where support //////////////////////////////////////////////////// #ifdef GRID_SIMT +// drop to scalar in SIMT; cleaner in fact template accelerator_inline vobj predicatedWhere(const iobj &predicate, const vobj &iftrue, const robj &iffalse) { - // should drop to sccalar in SIMT - // typename std::remove_const::type ret; - // Integer mask = TensorRemove(predicate); - // ret = iffalse; - // if (TensorRemove(mask)) ret=iftrue; - return iftrue; + Integer mask = TensorRemove(predicate); + typename std::remove_const::type ret= iffalse; + if (mask) ret=iftrue; + return ret; } #else template @@ -98,65 +97,62 @@ struct getVectorType >{ //-- recursive evaluation of expressions; -- // handle leaves of syntax tree /////////////////////////////////////////////////// -template accelerator_inline +template::value&&!is_lattice_expr::value,sobj>::type * = nullptr> +accelerator_inline sobj eval(const uint64_t ss, const sobj &arg) { return arg; } template accelerator_inline -const lobj & eval(const uint64_t ss, const LatticeView &arg) +auto eval(const uint64_t ss, const LatticeView &arg) -> decltype(arg(ss)) { - return arg[ss]; + return arg(ss); } -template::value&&!is_lattice_expr::value,sobj>::type * = nullptr> -accelerator_inline -sobj coalescedEval(const uint64_t ss, const sobj &arg) +//////////////////////////////////////////// +//-- recursive evaluation of expressions; -- +// whole vector return, used only for expression return type inference +/////////////////////////////////////////////////// +template accelerator_inline +sobj vecEval(const uint64_t ss, const sobj &arg) { return arg; } -#ifdef GRID_SIMT -#warning device template accelerator_inline -typename lobj::scalar_object coalescedEval(const uint64_t ss, const LatticeView &arg) +const lobj & vecEval(const uint64_t ss, const LatticeView &arg) { - auto ret = arg(ss); - return ret; -} -#else -#warning host -template accelerator_inline -lobj coalescedEval(const uint64_t ss, const LatticeView &arg) -{ - // return coalescedRead(arg[ss]); return arg[ss]; } -#endif -/* -template accelerator_inline -typename lobj::scalar_type coalescedEval(const uint64_t ss, const Lattice &arg) -{ - assert(0); - return coalescedRead(arg[ss]); -} -*/ - -// What needs this? -// Cannot be legal on accelerator -// Comparison must convert -#if 1 -template accelerator_inline -const lobj & eval(const uint64_t ss, const Lattice &arg) -{ - assert(0); - auto view = arg.View(AcceleratorRead); - return view[ss]; -} -#endif /////////////////////////////////////////////////// // handle nodes in syntax tree- eval one operand +// vecEval needed (but never called as all expressions offloaded) to infer the return type +// in SIMT contexts of closure. +/////////////////////////////////////////////////// +template accelerator_inline +auto vecEval(const uint64_t ss, const LatticeUnaryExpression &expr) + -> decltype(expr.op.func( vecEval(ss, expr.arg1))) +{ + return expr.op.func( vecEval(ss, expr.arg1) ); +} +// vecEval two operands +template accelerator_inline +auto vecEval(const uint64_t ss, const LatticeBinaryExpression &expr) + -> decltype(expr.op.func( vecEval(ss,expr.arg1),vecEval(ss,expr.arg2))) +{ + return expr.op.func( vecEval(ss,expr.arg1), vecEval(ss,expr.arg2) ); +} +// vecEval three operands +template accelerator_inline +auto vecEval(const uint64_t ss, const LatticeTrinaryExpression &expr) + -> decltype(expr.op.func(vecEval(ss, expr.arg1), vecEval(ss, expr.arg2), vecEval(ss, expr.arg3))) +{ + return expr.op.func(vecEval(ss, expr.arg1), vecEval(ss, expr.arg2), vecEval(ss, expr.arg3)); +} + +/////////////////////////////////////////////////// +// handle nodes in syntax tree- eval one operand coalesced /////////////////////////////////////////////////// template accelerator_inline auto eval(const uint64_t ss, const LatticeUnaryExpression &expr) @@ -174,37 +170,31 @@ auto eval(const uint64_t ss, const LatticeBinaryExpression &expr) // eval three operands template accelerator_inline auto eval(const uint64_t ss, const LatticeTrinaryExpression &expr) - -> decltype(expr.op.func(eval(ss, expr.arg1), eval(ss, expr.arg2), eval(ss, expr.arg3))) + -> decltype(expr.op.func(eval(ss, expr.arg1), + eval(ss, expr.arg2), + eval(ss, expr.arg3))) { - return expr.op.func(eval(ss, expr.arg1), eval(ss, expr.arg2), eval(ss, expr.arg3)); -} +#ifdef GRID_SIMT + // Handles Nsimd (vInteger) != Nsimd(ComplexD) + typedef decltype(vecEval(ss, expr.arg2)) rvobj; + typedef typename std::remove_reference::type vobj; -/////////////////////////////////////////////////// -// handle nodes in syntax tree- eval one operand coalesced -/////////////////////////////////////////////////// -template accelerator_inline -auto coalescedEval(const uint64_t ss, const LatticeUnaryExpression &expr) - -> decltype(expr.op.func( coalescedEval(ss, expr.arg1))) -{ - return expr.op.func( coalescedEval(ss, expr.arg1) ); -} -// eval two operands -template accelerator_inline -auto coalescedEval(const uint64_t ss, const LatticeBinaryExpression &expr) - -> decltype(expr.op.func( coalescedEval(ss,expr.arg1),coalescedEval(ss,expr.arg2))) -{ - return expr.op.func( coalescedEval(ss,expr.arg1), coalescedEval(ss,expr.arg2) ); -} -// eval three operands -template accelerator_inline -auto coalescedEval(const uint64_t ss, const LatticeTrinaryExpression &expr) - -> decltype(expr.op.func(coalescedEval(ss, expr.arg1), - coalescedEval(ss, expr.arg2), - coalescedEval(ss, expr.arg3))) -{ - return expr.op.func(coalescedEval(ss, expr.arg1), - coalescedEval(ss, expr.arg2), - coalescedEval(ss, expr.arg3)); + const int Nsimd = vobj::vector_type::Nsimd(); + + auto vpred = vecEval(ss,expr.arg1); + + ExtractBuffer mask(Nsimd); + extract(TensorRemove(vpred), mask); + + int s = acceleratorSIMTlane(Nsimd); + return expr.op.func(mask[s], + eval(ss, expr.arg2), + eval(ss, expr.arg3)); +#else + return expr.op.func(eval(ss, expr.arg1), + eval(ss, expr.arg2), + eval(ss, expr.arg3)); +#endif } ////////////////////////////////////////////////////////////////////////// @@ -302,7 +292,7 @@ template inline void ExpressionViewOpen(LatticeBinaryExpression &expr) { ExpressionViewOpen(expr.arg1); // recurse AST - ExpressionViewOpen(expr.arg2); // recurse AST + ExpressionViewOpen(expr.arg2); // rrecurse AST } template inline void ExpressionViewOpen(LatticeTrinaryExpression &expr) @@ -346,7 +336,6 @@ inline void ExpressionViewClose(LatticeTrinaryExpression &expr) // Unary operators and funcs //////////////////////////////////////////// #define GridUnopClass(name, ret) \ - template \ struct name { \ template static auto accelerator_inline func(const _arg a) -> decltype(ret) { return ret; } \ }; @@ -359,8 +348,6 @@ GridUnopClass(UnaryTrace, trace(a)); GridUnopClass(UnaryTranspose, transpose(a)); GridUnopClass(UnaryTa, Ta(a)); GridUnopClass(UnaryProjectOnGroup, ProjectOnGroup(a)); -GridUnopClass(UnaryReal, real(a)); -GridUnopClass(UnaryImag, imag(a)); GridUnopClass(UnaryToReal, toReal(a)); GridUnopClass(UnaryToComplex, toComplex(a)); GridUnopClass(UnaryTimesI, timesI(a)); @@ -379,7 +366,6 @@ GridUnopClass(UnaryExp, exp(a)); // Binary operators //////////////////////////////////////////// #define GridBinOpClass(name, combination) \ - template \ struct name { \ template \ static auto accelerator_inline \ @@ -403,7 +389,6 @@ GridBinOpClass(BinaryOrOr, lhs || rhs); // Trinary conditional op //////////////////////////////////////////////////// #define GridTrinOpClass(name, combination) \ - template \ struct name { \ template \ static auto accelerator_inline \ @@ -423,10 +408,9 @@ GridTrinOpClass(TrinaryWhere, //////////////////////////////////////////// // Operator syntactical glue //////////////////////////////////////////// - -#define GRID_UNOP(name) name -#define GRID_BINOP(name) name -#define GRID_TRINOP(name) name +#define GRID_UNOP(name) name +#define GRID_BINOP(name) name +#define GRID_TRINOP(name) name #define GRID_DEF_UNOP(op, name) \ template ::value||is_lattice_expr::value,T1>::type * = nullptr> \ @@ -478,8 +462,6 @@ GRID_DEF_UNOP(trace, UnaryTrace); GRID_DEF_UNOP(transpose, UnaryTranspose); GRID_DEF_UNOP(Ta, UnaryTa); GRID_DEF_UNOP(ProjectOnGroup, UnaryProjectOnGroup); -GRID_DEF_UNOP(real, UnaryReal); -GRID_DEF_UNOP(imag, UnaryImag); GRID_DEF_UNOP(toReal, UnaryToReal); GRID_DEF_UNOP(toComplex, UnaryToComplex); GRID_DEF_UNOP(timesI, UnaryTimesI); @@ -512,29 +494,36 @@ GRID_DEF_TRINOP(where, TrinaryWhere); ///////////////////////////////////////////////////////////// template auto closure(const LatticeUnaryExpression &expr) - -> Lattice + -> Lattice { - Lattice ret(expr); + Lattice ret(expr); return ret; } template auto closure(const LatticeBinaryExpression &expr) - -> Lattice + -> Lattice { - Lattice ret(expr); + Lattice ret(expr); return ret; } template auto closure(const LatticeTrinaryExpression &expr) - -> Lattice + -> Lattice { - Lattice ret(expr); + Lattice ret(expr); return ret; } +#define EXPRESSION_CLOSURE(function) \ + template::value,void>::type * = nullptr> \ + auto function(Expression &expr) -> decltype(function(closure(expr))) \ + { \ + return function(closure(expr)); \ + } + #undef GRID_UNOP #undef GRID_BINOP