1
0
mirror of https://github.com/paboyle/Grid.git synced 2025-04-09 21:50:45 +01:00

Where working

This commit is contained in:
Peter Boyle 2020-08-31 23:53:46 -04:00
parent e14a84317d
commit 7d14a3c086

View File

@ -43,17 +43,16 @@ NAMESPACE_BEGIN(Grid);
// Predicated where support // Predicated where support
//////////////////////////////////////////////////// ////////////////////////////////////////////////////
#ifdef GRID_SIMT #ifdef GRID_SIMT
// drop to scalar in SIMT; cleaner in fact
template <class iobj, class vobj, class robj> template <class iobj, class vobj, class robj>
accelerator_inline vobj predicatedWhere(const iobj &predicate, accelerator_inline vobj predicatedWhere(const iobj &predicate,
const vobj &iftrue, const vobj &iftrue,
const robj &iffalse) const robj &iffalse)
{ {
// should drop to sccalar in SIMT Integer mask = TensorRemove(predicate);
// typename std::remove_const<vobj>::type ret; typename std::remove_const<vobj>::type ret= iffalse;
// Integer mask = TensorRemove(predicate); if (mask) ret=iftrue;
// ret = iffalse; return ret;
// if (TensorRemove(mask)) ret=iftrue;
return iftrue;
} }
#else #else
template <class iobj, class vobj, class robj> template <class iobj, class vobj, class robj>
@ -98,65 +97,62 @@ struct getVectorType<Lattice<T> >{
//-- recursive evaluation of expressions; -- //-- recursive evaluation of expressions; --
// handle leaves of syntax tree // handle leaves of syntax tree
/////////////////////////////////////////////////// ///////////////////////////////////////////////////
template<class sobj> accelerator_inline template<class sobj,
typename std::enable_if<!is_lattice<sobj>::value&&!is_lattice_expr<sobj>::value,sobj>::type * = nullptr>
accelerator_inline
sobj eval(const uint64_t ss, const sobj &arg) sobj eval(const uint64_t ss, const sobj &arg)
{ {
return arg; return arg;
} }
template <class lobj> accelerator_inline template <class lobj> accelerator_inline
const lobj & eval(const uint64_t ss, const LatticeView<lobj> &arg) auto eval(const uint64_t ss, const LatticeView<lobj> &arg) -> decltype(arg(ss))
{ {
return arg[ss]; return arg(ss);
} }
template<class sobj, ////////////////////////////////////////////
typename std::enable_if<!is_lattice<sobj>::value&&!is_lattice_expr<sobj>::value,sobj>::type * = nullptr> //-- recursive evaluation of expressions; --
accelerator_inline // whole vector return, used only for expression return type inference
sobj coalescedEval(const uint64_t ss, const sobj &arg) ///////////////////////////////////////////////////
template<class sobj> accelerator_inline
sobj vecEval(const uint64_t ss, const sobj &arg)
{ {
return arg; return arg;
} }
#ifdef GRID_SIMT
#warning device
template <class lobj> accelerator_inline template <class lobj> accelerator_inline
typename lobj::scalar_object coalescedEval(const uint64_t ss, const LatticeView<lobj> &arg) const lobj & vecEval(const uint64_t ss, const LatticeView<lobj> &arg)
{ {
auto ret = arg(ss);
return ret;
}
#else
#warning host
template <class lobj> accelerator_inline
lobj coalescedEval(const uint64_t ss, const LatticeView<lobj> &arg)
{
// return coalescedRead(arg[ss]);
return arg[ss]; return arg[ss];
} }
#endif
/*
template <class lobj> accelerator_inline
typename lobj::scalar_type coalescedEval(const uint64_t ss, const Lattice<lobj> &arg)
{
assert(0);
return coalescedRead(arg[ss]);
}
*/
// What needs this?
// Cannot be legal on accelerator
// Comparison must convert
#if 1
template <class lobj> accelerator_inline
const lobj & eval(const uint64_t ss, const Lattice<lobj> &arg)
{
assert(0);
auto view = arg.View(AcceleratorRead);
return view[ss];
}
#endif
/////////////////////////////////////////////////// ///////////////////////////////////////////////////
// handle nodes in syntax tree- eval one operand // handle nodes in syntax tree- eval one operand
// vecEval needed (but never called as all expressions offloaded) to infer the return type
// in SIMT contexts of closure.
///////////////////////////////////////////////////
template <typename Op, typename T1> accelerator_inline
auto vecEval(const uint64_t ss, const LatticeUnaryExpression<Op, T1> &expr)
-> decltype(expr.op.func( vecEval(ss, expr.arg1)))
{
return expr.op.func( vecEval(ss, expr.arg1) );
}
// vecEval two operands
template <typename Op, typename T1, typename T2> accelerator_inline
auto vecEval(const uint64_t ss, const LatticeBinaryExpression<Op, T1, T2> &expr)
-> decltype(expr.op.func( vecEval(ss,expr.arg1),vecEval(ss,expr.arg2)))
{
return expr.op.func( vecEval(ss,expr.arg1), vecEval(ss,expr.arg2) );
}
// vecEval three operands
template <typename Op, typename T1, typename T2, typename T3> accelerator_inline
auto vecEval(const uint64_t ss, const LatticeTrinaryExpression<Op, T1, T2, T3> &expr)
-> decltype(expr.op.func(vecEval(ss, expr.arg1), vecEval(ss, expr.arg2), vecEval(ss, expr.arg3)))
{
return expr.op.func(vecEval(ss, expr.arg1), vecEval(ss, expr.arg2), vecEval(ss, expr.arg3));
}
///////////////////////////////////////////////////
// handle nodes in syntax tree- eval one operand coalesced
/////////////////////////////////////////////////// ///////////////////////////////////////////////////
template <typename Op, typename T1> accelerator_inline template <typename Op, typename T1> accelerator_inline
auto eval(const uint64_t ss, const LatticeUnaryExpression<Op, T1> &expr) auto eval(const uint64_t ss, const LatticeUnaryExpression<Op, T1> &expr)
@ -174,37 +170,31 @@ auto eval(const uint64_t ss, const LatticeBinaryExpression<Op, T1, T2> &expr)
// eval three operands // eval three operands
template <typename Op, typename T1, typename T2, typename T3> accelerator_inline template <typename Op, typename T1, typename T2, typename T3> accelerator_inline
auto eval(const uint64_t ss, const LatticeTrinaryExpression<Op, T1, T2, T3> &expr) auto eval(const uint64_t ss, const LatticeTrinaryExpression<Op, T1, T2, T3> &expr)
-> decltype(expr.op.func(eval(ss, expr.arg1), eval(ss, expr.arg2), eval(ss, expr.arg3))) -> decltype(expr.op.func(eval(ss, expr.arg1),
eval(ss, expr.arg2),
eval(ss, expr.arg3)))
{ {
return expr.op.func(eval(ss, expr.arg1), eval(ss, expr.arg2), eval(ss, expr.arg3)); #ifdef GRID_SIMT
} // Handles Nsimd (vInteger) != Nsimd(ComplexD)
typedef decltype(vecEval(ss, expr.arg2)) rvobj;
typedef typename std::remove_reference<rvobj>::type vobj;
/////////////////////////////////////////////////// const int Nsimd = vobj::vector_type::Nsimd();
// handle nodes in syntax tree- eval one operand coalesced
/////////////////////////////////////////////////// auto vpred = vecEval(ss,expr.arg1);
template <typename Op, typename T1> accelerator_inline
auto coalescedEval(const uint64_t ss, const LatticeUnaryExpression<Op, T1> &expr) ExtractBuffer<Integer> mask(Nsimd);
-> decltype(expr.op.func( coalescedEval(ss, expr.arg1))) extract<vInteger, Integer>(TensorRemove(vpred), mask);
{
return expr.op.func( coalescedEval(ss, expr.arg1) ); int s = acceleratorSIMTlane(Nsimd);
} return expr.op.func(mask[s],
// eval two operands eval(ss, expr.arg2),
template <typename Op, typename T1, typename T2> accelerator_inline eval(ss, expr.arg3));
auto coalescedEval(const uint64_t ss, const LatticeBinaryExpression<Op, T1, T2> &expr) #else
-> decltype(expr.op.func( coalescedEval(ss,expr.arg1),coalescedEval(ss,expr.arg2))) return expr.op.func(eval(ss, expr.arg1),
{ eval(ss, expr.arg2),
return expr.op.func( coalescedEval(ss,expr.arg1), coalescedEval(ss,expr.arg2) ); eval(ss, expr.arg3));
} #endif
// eval three operands
template <typename Op, typename T1, typename T2, typename T3> accelerator_inline
auto coalescedEval(const uint64_t ss, const LatticeTrinaryExpression<Op, T1, T2, T3> &expr)
-> decltype(expr.op.func(coalescedEval(ss, expr.arg1),
coalescedEval(ss, expr.arg2),
coalescedEval(ss, expr.arg3)))
{
return expr.op.func(coalescedEval(ss, expr.arg1),
coalescedEval(ss, expr.arg2),
coalescedEval(ss, expr.arg3));
} }
////////////////////////////////////////////////////////////////////////// //////////////////////////////////////////////////////////////////////////
@ -302,7 +292,7 @@ template <typename Op, typename T1, typename T2> inline
void ExpressionViewOpen(LatticeBinaryExpression<Op, T1, T2> &expr) void ExpressionViewOpen(LatticeBinaryExpression<Op, T1, T2> &expr)
{ {
ExpressionViewOpen(expr.arg1); // recurse AST ExpressionViewOpen(expr.arg1); // recurse AST
ExpressionViewOpen(expr.arg2); // recurse AST ExpressionViewOpen(expr.arg2); // rrecurse AST
} }
template <typename Op, typename T1, typename T2, typename T3> template <typename Op, typename T1, typename T2, typename T3>
inline void ExpressionViewOpen(LatticeTrinaryExpression<Op, T1, T2, T3> &expr) inline void ExpressionViewOpen(LatticeTrinaryExpression<Op, T1, T2, T3> &expr)
@ -346,7 +336,6 @@ inline void ExpressionViewClose(LatticeTrinaryExpression<Op, T1, T2, T3> &expr)
// Unary operators and funcs // Unary operators and funcs
//////////////////////////////////////////// ////////////////////////////////////////////
#define GridUnopClass(name, ret) \ #define GridUnopClass(name, ret) \
template <class arg> \
struct name { \ struct name { \
template<class _arg> static auto accelerator_inline func(const _arg a) -> decltype(ret) { return ret; } \ template<class _arg> static auto accelerator_inline func(const _arg a) -> decltype(ret) { return ret; } \
}; };
@ -359,8 +348,6 @@ GridUnopClass(UnaryTrace, trace(a));
GridUnopClass(UnaryTranspose, transpose(a)); GridUnopClass(UnaryTranspose, transpose(a));
GridUnopClass(UnaryTa, Ta(a)); GridUnopClass(UnaryTa, Ta(a));
GridUnopClass(UnaryProjectOnGroup, ProjectOnGroup(a)); GridUnopClass(UnaryProjectOnGroup, ProjectOnGroup(a));
GridUnopClass(UnaryReal, real(a));
GridUnopClass(UnaryImag, imag(a));
GridUnopClass(UnaryToReal, toReal(a)); GridUnopClass(UnaryToReal, toReal(a));
GridUnopClass(UnaryToComplex, toComplex(a)); GridUnopClass(UnaryToComplex, toComplex(a));
GridUnopClass(UnaryTimesI, timesI(a)); GridUnopClass(UnaryTimesI, timesI(a));
@ -379,7 +366,6 @@ GridUnopClass(UnaryExp, exp(a));
// Binary operators // Binary operators
//////////////////////////////////////////// ////////////////////////////////////////////
#define GridBinOpClass(name, combination) \ #define GridBinOpClass(name, combination) \
template <class left, class right> \
struct name { \ struct name { \
template <class _left, class _right> \ template <class _left, class _right> \
static auto accelerator_inline \ static auto accelerator_inline \
@ -403,7 +389,6 @@ GridBinOpClass(BinaryOrOr, lhs || rhs);
// Trinary conditional op // Trinary conditional op
//////////////////////////////////////////////////// ////////////////////////////////////////////////////
#define GridTrinOpClass(name, combination) \ #define GridTrinOpClass(name, combination) \
template <class predicate, class left, class right> \
struct name { \ struct name { \
template <class _predicate,class _left, class _right> \ template <class _predicate,class _left, class _right> \
static auto accelerator_inline \ static auto accelerator_inline \
@ -423,10 +408,9 @@ GridTrinOpClass(TrinaryWhere,
//////////////////////////////////////////// ////////////////////////////////////////////
// Operator syntactical glue // Operator syntactical glue
//////////////////////////////////////////// ////////////////////////////////////////////
#define GRID_UNOP(name) name
#define GRID_UNOP(name) name<decltype(eval(0, arg))> #define GRID_BINOP(name) name
#define GRID_BINOP(name) name<decltype(eval(0, lhs)), decltype(eval(0, rhs))> #define GRID_TRINOP(name) name
#define GRID_TRINOP(name) name<decltype(eval(0, pred)), decltype(eval(0, lhs)), decltype(eval(0, rhs))>
#define GRID_DEF_UNOP(op, name) \ #define GRID_DEF_UNOP(op, name) \
template <typename T1, typename std::enable_if<is_lattice<T1>::value||is_lattice_expr<T1>::value,T1>::type * = nullptr> \ template <typename T1, typename std::enable_if<is_lattice<T1>::value||is_lattice_expr<T1>::value,T1>::type * = nullptr> \
@ -478,8 +462,6 @@ GRID_DEF_UNOP(trace, UnaryTrace);
GRID_DEF_UNOP(transpose, UnaryTranspose); GRID_DEF_UNOP(transpose, UnaryTranspose);
GRID_DEF_UNOP(Ta, UnaryTa); GRID_DEF_UNOP(Ta, UnaryTa);
GRID_DEF_UNOP(ProjectOnGroup, UnaryProjectOnGroup); GRID_DEF_UNOP(ProjectOnGroup, UnaryProjectOnGroup);
GRID_DEF_UNOP(real, UnaryReal);
GRID_DEF_UNOP(imag, UnaryImag);
GRID_DEF_UNOP(toReal, UnaryToReal); GRID_DEF_UNOP(toReal, UnaryToReal);
GRID_DEF_UNOP(toComplex, UnaryToComplex); GRID_DEF_UNOP(toComplex, UnaryToComplex);
GRID_DEF_UNOP(timesI, UnaryTimesI); GRID_DEF_UNOP(timesI, UnaryTimesI);
@ -512,29 +494,36 @@ GRID_DEF_TRINOP(where, TrinaryWhere);
///////////////////////////////////////////////////////////// /////////////////////////////////////////////////////////////
template <class Op, class T1> template <class Op, class T1>
auto closure(const LatticeUnaryExpression<Op, T1> &expr) auto closure(const LatticeUnaryExpression<Op, T1> &expr)
-> Lattice<decltype(expr.op.func(eval(0, expr.arg1)))> -> Lattice<decltype(expr.op.func(vecEval(0, expr.arg1)))>
{ {
Lattice<decltype(expr.op.func(eval(0, expr.arg1)))> ret(expr); Lattice<decltype(expr.op.func(vecEval(0, expr.arg1)))> ret(expr);
return ret; return ret;
} }
template <class Op, class T1, class T2> template <class Op, class T1, class T2>
auto closure(const LatticeBinaryExpression<Op, T1, T2> &expr) auto closure(const LatticeBinaryExpression<Op, T1, T2> &expr)
-> Lattice<decltype(expr.op.func(eval(0, expr.arg1),eval(0, expr.arg2)))> -> Lattice<decltype(expr.op.func(vecEval(0, expr.arg1),vecEval(0, expr.arg2)))>
{ {
Lattice<decltype(expr.op.func(eval(0, expr.arg1),eval(0, expr.arg2)))> ret(expr); Lattice<decltype(expr.op.func(vecEval(0, expr.arg1),vecEval(0, expr.arg2)))> ret(expr);
return ret; return ret;
} }
template <class Op, class T1, class T2, class T3> template <class Op, class T1, class T2, class T3>
auto closure(const LatticeTrinaryExpression<Op, T1, T2, T3> &expr) auto closure(const LatticeTrinaryExpression<Op, T1, T2, T3> &expr)
-> Lattice<decltype(expr.op.func(eval(0, expr.arg1), -> Lattice<decltype(expr.op.func(vecEval(0, expr.arg1),
eval(0, expr.arg2), vecEval(0, expr.arg2),
eval(0, expr.arg3)))> vecEval(0, expr.arg3)))>
{ {
Lattice<decltype(expr.op.func(eval(0, expr.arg1), Lattice<decltype(expr.op.func(vecEval(0, expr.arg1),
eval(0, expr.arg2), vecEval(0, expr.arg2),
eval(0, expr.arg3)))> ret(expr); vecEval(0, expr.arg3)))> ret(expr);
return ret; return ret;
} }
#define EXPRESSION_CLOSURE(function) \
template<class Expression,typename std::enable_if<is_lattice_expr<Expression>::value,void>::type * = nullptr> \
auto function(Expression &expr) -> decltype(function(closure(expr))) \
{ \
return function(closure(expr)); \
}
#undef GRID_UNOP #undef GRID_UNOP
#undef GRID_BINOP #undef GRID_BINOP