mirror of
https://github.com/paboyle/Grid.git
synced 2025-04-09 21:50:45 +01:00
Where working
This commit is contained in:
parent
e14a84317d
commit
7d14a3c086
@ -43,17 +43,16 @@ NAMESPACE_BEGIN(Grid);
|
|||||||
// Predicated where support
|
// Predicated where support
|
||||||
////////////////////////////////////////////////////
|
////////////////////////////////////////////////////
|
||||||
#ifdef GRID_SIMT
|
#ifdef GRID_SIMT
|
||||||
|
// drop to scalar in SIMT; cleaner in fact
|
||||||
template <class iobj, class vobj, class robj>
|
template <class iobj, class vobj, class robj>
|
||||||
accelerator_inline vobj predicatedWhere(const iobj &predicate,
|
accelerator_inline vobj predicatedWhere(const iobj &predicate,
|
||||||
const vobj &iftrue,
|
const vobj &iftrue,
|
||||||
const robj &iffalse)
|
const robj &iffalse)
|
||||||
{
|
{
|
||||||
// should drop to sccalar in SIMT
|
Integer mask = TensorRemove(predicate);
|
||||||
// typename std::remove_const<vobj>::type ret;
|
typename std::remove_const<vobj>::type ret= iffalse;
|
||||||
// Integer mask = TensorRemove(predicate);
|
if (mask) ret=iftrue;
|
||||||
// ret = iffalse;
|
return ret;
|
||||||
// if (TensorRemove(mask)) ret=iftrue;
|
|
||||||
return iftrue;
|
|
||||||
}
|
}
|
||||||
#else
|
#else
|
||||||
template <class iobj, class vobj, class robj>
|
template <class iobj, class vobj, class robj>
|
||||||
@ -98,65 +97,62 @@ struct getVectorType<Lattice<T> >{
|
|||||||
//-- recursive evaluation of expressions; --
|
//-- recursive evaluation of expressions; --
|
||||||
// handle leaves of syntax tree
|
// handle leaves of syntax tree
|
||||||
///////////////////////////////////////////////////
|
///////////////////////////////////////////////////
|
||||||
template<class sobj> accelerator_inline
|
template<class sobj,
|
||||||
|
typename std::enable_if<!is_lattice<sobj>::value&&!is_lattice_expr<sobj>::value,sobj>::type * = nullptr>
|
||||||
|
accelerator_inline
|
||||||
sobj eval(const uint64_t ss, const sobj &arg)
|
sobj eval(const uint64_t ss, const sobj &arg)
|
||||||
{
|
{
|
||||||
return arg;
|
return arg;
|
||||||
}
|
}
|
||||||
template <class lobj> accelerator_inline
|
template <class lobj> accelerator_inline
|
||||||
const lobj & eval(const uint64_t ss, const LatticeView<lobj> &arg)
|
auto eval(const uint64_t ss, const LatticeView<lobj> &arg) -> decltype(arg(ss))
|
||||||
{
|
{
|
||||||
return arg[ss];
|
return arg(ss);
|
||||||
}
|
}
|
||||||
|
|
||||||
template<class sobj,
|
////////////////////////////////////////////
|
||||||
typename std::enable_if<!is_lattice<sobj>::value&&!is_lattice_expr<sobj>::value,sobj>::type * = nullptr>
|
//-- recursive evaluation of expressions; --
|
||||||
accelerator_inline
|
// whole vector return, used only for expression return type inference
|
||||||
sobj coalescedEval(const uint64_t ss, const sobj &arg)
|
///////////////////////////////////////////////////
|
||||||
|
template<class sobj> accelerator_inline
|
||||||
|
sobj vecEval(const uint64_t ss, const sobj &arg)
|
||||||
{
|
{
|
||||||
return arg;
|
return arg;
|
||||||
}
|
}
|
||||||
#ifdef GRID_SIMT
|
|
||||||
#warning device
|
|
||||||
template <class lobj> accelerator_inline
|
template <class lobj> accelerator_inline
|
||||||
typename lobj::scalar_object coalescedEval(const uint64_t ss, const LatticeView<lobj> &arg)
|
const lobj & vecEval(const uint64_t ss, const LatticeView<lobj> &arg)
|
||||||
{
|
{
|
||||||
auto ret = arg(ss);
|
|
||||||
return ret;
|
|
||||||
}
|
|
||||||
#else
|
|
||||||
#warning host
|
|
||||||
template <class lobj> accelerator_inline
|
|
||||||
lobj coalescedEval(const uint64_t ss, const LatticeView<lobj> &arg)
|
|
||||||
{
|
|
||||||
// return coalescedRead(arg[ss]);
|
|
||||||
return arg[ss];
|
return arg[ss];
|
||||||
}
|
}
|
||||||
#endif
|
|
||||||
/*
|
|
||||||
template <class lobj> accelerator_inline
|
|
||||||
typename lobj::scalar_type coalescedEval(const uint64_t ss, const Lattice<lobj> &arg)
|
|
||||||
{
|
|
||||||
assert(0);
|
|
||||||
return coalescedRead(arg[ss]);
|
|
||||||
}
|
|
||||||
*/
|
|
||||||
|
|
||||||
// What needs this?
|
|
||||||
// Cannot be legal on accelerator
|
|
||||||
// Comparison must convert
|
|
||||||
#if 1
|
|
||||||
template <class lobj> accelerator_inline
|
|
||||||
const lobj & eval(const uint64_t ss, const Lattice<lobj> &arg)
|
|
||||||
{
|
|
||||||
assert(0);
|
|
||||||
auto view = arg.View(AcceleratorRead);
|
|
||||||
return view[ss];
|
|
||||||
}
|
|
||||||
#endif
|
|
||||||
|
|
||||||
///////////////////////////////////////////////////
|
///////////////////////////////////////////////////
|
||||||
// handle nodes in syntax tree- eval one operand
|
// handle nodes in syntax tree- eval one operand
|
||||||
|
// vecEval needed (but never called as all expressions offloaded) to infer the return type
|
||||||
|
// in SIMT contexts of closure.
|
||||||
|
///////////////////////////////////////////////////
|
||||||
|
template <typename Op, typename T1> accelerator_inline
|
||||||
|
auto vecEval(const uint64_t ss, const LatticeUnaryExpression<Op, T1> &expr)
|
||||||
|
-> decltype(expr.op.func( vecEval(ss, expr.arg1)))
|
||||||
|
{
|
||||||
|
return expr.op.func( vecEval(ss, expr.arg1) );
|
||||||
|
}
|
||||||
|
// vecEval two operands
|
||||||
|
template <typename Op, typename T1, typename T2> accelerator_inline
|
||||||
|
auto vecEval(const uint64_t ss, const LatticeBinaryExpression<Op, T1, T2> &expr)
|
||||||
|
-> decltype(expr.op.func( vecEval(ss,expr.arg1),vecEval(ss,expr.arg2)))
|
||||||
|
{
|
||||||
|
return expr.op.func( vecEval(ss,expr.arg1), vecEval(ss,expr.arg2) );
|
||||||
|
}
|
||||||
|
// vecEval three operands
|
||||||
|
template <typename Op, typename T1, typename T2, typename T3> accelerator_inline
|
||||||
|
auto vecEval(const uint64_t ss, const LatticeTrinaryExpression<Op, T1, T2, T3> &expr)
|
||||||
|
-> decltype(expr.op.func(vecEval(ss, expr.arg1), vecEval(ss, expr.arg2), vecEval(ss, expr.arg3)))
|
||||||
|
{
|
||||||
|
return expr.op.func(vecEval(ss, expr.arg1), vecEval(ss, expr.arg2), vecEval(ss, expr.arg3));
|
||||||
|
}
|
||||||
|
|
||||||
|
///////////////////////////////////////////////////
|
||||||
|
// handle nodes in syntax tree- eval one operand coalesced
|
||||||
///////////////////////////////////////////////////
|
///////////////////////////////////////////////////
|
||||||
template <typename Op, typename T1> accelerator_inline
|
template <typename Op, typename T1> accelerator_inline
|
||||||
auto eval(const uint64_t ss, const LatticeUnaryExpression<Op, T1> &expr)
|
auto eval(const uint64_t ss, const LatticeUnaryExpression<Op, T1> &expr)
|
||||||
@ -174,37 +170,31 @@ auto eval(const uint64_t ss, const LatticeBinaryExpression<Op, T1, T2> &expr)
|
|||||||
// eval three operands
|
// eval three operands
|
||||||
template <typename Op, typename T1, typename T2, typename T3> accelerator_inline
|
template <typename Op, typename T1, typename T2, typename T3> accelerator_inline
|
||||||
auto eval(const uint64_t ss, const LatticeTrinaryExpression<Op, T1, T2, T3> &expr)
|
auto eval(const uint64_t ss, const LatticeTrinaryExpression<Op, T1, T2, T3> &expr)
|
||||||
-> decltype(expr.op.func(eval(ss, expr.arg1), eval(ss, expr.arg2), eval(ss, expr.arg3)))
|
-> decltype(expr.op.func(eval(ss, expr.arg1),
|
||||||
|
eval(ss, expr.arg2),
|
||||||
|
eval(ss, expr.arg3)))
|
||||||
{
|
{
|
||||||
return expr.op.func(eval(ss, expr.arg1), eval(ss, expr.arg2), eval(ss, expr.arg3));
|
#ifdef GRID_SIMT
|
||||||
}
|
// Handles Nsimd (vInteger) != Nsimd(ComplexD)
|
||||||
|
typedef decltype(vecEval(ss, expr.arg2)) rvobj;
|
||||||
|
typedef typename std::remove_reference<rvobj>::type vobj;
|
||||||
|
|
||||||
///////////////////////////////////////////////////
|
const int Nsimd = vobj::vector_type::Nsimd();
|
||||||
// handle nodes in syntax tree- eval one operand coalesced
|
|
||||||
///////////////////////////////////////////////////
|
auto vpred = vecEval(ss,expr.arg1);
|
||||||
template <typename Op, typename T1> accelerator_inline
|
|
||||||
auto coalescedEval(const uint64_t ss, const LatticeUnaryExpression<Op, T1> &expr)
|
ExtractBuffer<Integer> mask(Nsimd);
|
||||||
-> decltype(expr.op.func( coalescedEval(ss, expr.arg1)))
|
extract<vInteger, Integer>(TensorRemove(vpred), mask);
|
||||||
{
|
|
||||||
return expr.op.func( coalescedEval(ss, expr.arg1) );
|
int s = acceleratorSIMTlane(Nsimd);
|
||||||
}
|
return expr.op.func(mask[s],
|
||||||
// eval two operands
|
eval(ss, expr.arg2),
|
||||||
template <typename Op, typename T1, typename T2> accelerator_inline
|
eval(ss, expr.arg3));
|
||||||
auto coalescedEval(const uint64_t ss, const LatticeBinaryExpression<Op, T1, T2> &expr)
|
#else
|
||||||
-> decltype(expr.op.func( coalescedEval(ss,expr.arg1),coalescedEval(ss,expr.arg2)))
|
return expr.op.func(eval(ss, expr.arg1),
|
||||||
{
|
eval(ss, expr.arg2),
|
||||||
return expr.op.func( coalescedEval(ss,expr.arg1), coalescedEval(ss,expr.arg2) );
|
eval(ss, expr.arg3));
|
||||||
}
|
#endif
|
||||||
// eval three operands
|
|
||||||
template <typename Op, typename T1, typename T2, typename T3> accelerator_inline
|
|
||||||
auto coalescedEval(const uint64_t ss, const LatticeTrinaryExpression<Op, T1, T2, T3> &expr)
|
|
||||||
-> decltype(expr.op.func(coalescedEval(ss, expr.arg1),
|
|
||||||
coalescedEval(ss, expr.arg2),
|
|
||||||
coalescedEval(ss, expr.arg3)))
|
|
||||||
{
|
|
||||||
return expr.op.func(coalescedEval(ss, expr.arg1),
|
|
||||||
coalescedEval(ss, expr.arg2),
|
|
||||||
coalescedEval(ss, expr.arg3));
|
|
||||||
}
|
}
|
||||||
|
|
||||||
//////////////////////////////////////////////////////////////////////////
|
//////////////////////////////////////////////////////////////////////////
|
||||||
@ -302,7 +292,7 @@ template <typename Op, typename T1, typename T2> inline
|
|||||||
void ExpressionViewOpen(LatticeBinaryExpression<Op, T1, T2> &expr)
|
void ExpressionViewOpen(LatticeBinaryExpression<Op, T1, T2> &expr)
|
||||||
{
|
{
|
||||||
ExpressionViewOpen(expr.arg1); // recurse AST
|
ExpressionViewOpen(expr.arg1); // recurse AST
|
||||||
ExpressionViewOpen(expr.arg2); // recurse AST
|
ExpressionViewOpen(expr.arg2); // rrecurse AST
|
||||||
}
|
}
|
||||||
template <typename Op, typename T1, typename T2, typename T3>
|
template <typename Op, typename T1, typename T2, typename T3>
|
||||||
inline void ExpressionViewOpen(LatticeTrinaryExpression<Op, T1, T2, T3> &expr)
|
inline void ExpressionViewOpen(LatticeTrinaryExpression<Op, T1, T2, T3> &expr)
|
||||||
@ -346,7 +336,6 @@ inline void ExpressionViewClose(LatticeTrinaryExpression<Op, T1, T2, T3> &expr)
|
|||||||
// Unary operators and funcs
|
// Unary operators and funcs
|
||||||
////////////////////////////////////////////
|
////////////////////////////////////////////
|
||||||
#define GridUnopClass(name, ret) \
|
#define GridUnopClass(name, ret) \
|
||||||
template <class arg> \
|
|
||||||
struct name { \
|
struct name { \
|
||||||
template<class _arg> static auto accelerator_inline func(const _arg a) -> decltype(ret) { return ret; } \
|
template<class _arg> static auto accelerator_inline func(const _arg a) -> decltype(ret) { return ret; } \
|
||||||
};
|
};
|
||||||
@ -359,8 +348,6 @@ GridUnopClass(UnaryTrace, trace(a));
|
|||||||
GridUnopClass(UnaryTranspose, transpose(a));
|
GridUnopClass(UnaryTranspose, transpose(a));
|
||||||
GridUnopClass(UnaryTa, Ta(a));
|
GridUnopClass(UnaryTa, Ta(a));
|
||||||
GridUnopClass(UnaryProjectOnGroup, ProjectOnGroup(a));
|
GridUnopClass(UnaryProjectOnGroup, ProjectOnGroup(a));
|
||||||
GridUnopClass(UnaryReal, real(a));
|
|
||||||
GridUnopClass(UnaryImag, imag(a));
|
|
||||||
GridUnopClass(UnaryToReal, toReal(a));
|
GridUnopClass(UnaryToReal, toReal(a));
|
||||||
GridUnopClass(UnaryToComplex, toComplex(a));
|
GridUnopClass(UnaryToComplex, toComplex(a));
|
||||||
GridUnopClass(UnaryTimesI, timesI(a));
|
GridUnopClass(UnaryTimesI, timesI(a));
|
||||||
@ -379,7 +366,6 @@ GridUnopClass(UnaryExp, exp(a));
|
|||||||
// Binary operators
|
// Binary operators
|
||||||
////////////////////////////////////////////
|
////////////////////////////////////////////
|
||||||
#define GridBinOpClass(name, combination) \
|
#define GridBinOpClass(name, combination) \
|
||||||
template <class left, class right> \
|
|
||||||
struct name { \
|
struct name { \
|
||||||
template <class _left, class _right> \
|
template <class _left, class _right> \
|
||||||
static auto accelerator_inline \
|
static auto accelerator_inline \
|
||||||
@ -403,7 +389,6 @@ GridBinOpClass(BinaryOrOr, lhs || rhs);
|
|||||||
// Trinary conditional op
|
// Trinary conditional op
|
||||||
////////////////////////////////////////////////////
|
////////////////////////////////////////////////////
|
||||||
#define GridTrinOpClass(name, combination) \
|
#define GridTrinOpClass(name, combination) \
|
||||||
template <class predicate, class left, class right> \
|
|
||||||
struct name { \
|
struct name { \
|
||||||
template <class _predicate,class _left, class _right> \
|
template <class _predicate,class _left, class _right> \
|
||||||
static auto accelerator_inline \
|
static auto accelerator_inline \
|
||||||
@ -423,10 +408,9 @@ GridTrinOpClass(TrinaryWhere,
|
|||||||
////////////////////////////////////////////
|
////////////////////////////////////////////
|
||||||
// Operator syntactical glue
|
// Operator syntactical glue
|
||||||
////////////////////////////////////////////
|
////////////////////////////////////////////
|
||||||
|
#define GRID_UNOP(name) name
|
||||||
#define GRID_UNOP(name) name<decltype(eval(0, arg))>
|
#define GRID_BINOP(name) name
|
||||||
#define GRID_BINOP(name) name<decltype(eval(0, lhs)), decltype(eval(0, rhs))>
|
#define GRID_TRINOP(name) name
|
||||||
#define GRID_TRINOP(name) name<decltype(eval(0, pred)), decltype(eval(0, lhs)), decltype(eval(0, rhs))>
|
|
||||||
|
|
||||||
#define GRID_DEF_UNOP(op, name) \
|
#define GRID_DEF_UNOP(op, name) \
|
||||||
template <typename T1, typename std::enable_if<is_lattice<T1>::value||is_lattice_expr<T1>::value,T1>::type * = nullptr> \
|
template <typename T1, typename std::enable_if<is_lattice<T1>::value||is_lattice_expr<T1>::value,T1>::type * = nullptr> \
|
||||||
@ -478,8 +462,6 @@ GRID_DEF_UNOP(trace, UnaryTrace);
|
|||||||
GRID_DEF_UNOP(transpose, UnaryTranspose);
|
GRID_DEF_UNOP(transpose, UnaryTranspose);
|
||||||
GRID_DEF_UNOP(Ta, UnaryTa);
|
GRID_DEF_UNOP(Ta, UnaryTa);
|
||||||
GRID_DEF_UNOP(ProjectOnGroup, UnaryProjectOnGroup);
|
GRID_DEF_UNOP(ProjectOnGroup, UnaryProjectOnGroup);
|
||||||
GRID_DEF_UNOP(real, UnaryReal);
|
|
||||||
GRID_DEF_UNOP(imag, UnaryImag);
|
|
||||||
GRID_DEF_UNOP(toReal, UnaryToReal);
|
GRID_DEF_UNOP(toReal, UnaryToReal);
|
||||||
GRID_DEF_UNOP(toComplex, UnaryToComplex);
|
GRID_DEF_UNOP(toComplex, UnaryToComplex);
|
||||||
GRID_DEF_UNOP(timesI, UnaryTimesI);
|
GRID_DEF_UNOP(timesI, UnaryTimesI);
|
||||||
@ -512,29 +494,36 @@ GRID_DEF_TRINOP(where, TrinaryWhere);
|
|||||||
/////////////////////////////////////////////////////////////
|
/////////////////////////////////////////////////////////////
|
||||||
template <class Op, class T1>
|
template <class Op, class T1>
|
||||||
auto closure(const LatticeUnaryExpression<Op, T1> &expr)
|
auto closure(const LatticeUnaryExpression<Op, T1> &expr)
|
||||||
-> Lattice<decltype(expr.op.func(eval(0, expr.arg1)))>
|
-> Lattice<decltype(expr.op.func(vecEval(0, expr.arg1)))>
|
||||||
{
|
{
|
||||||
Lattice<decltype(expr.op.func(eval(0, expr.arg1)))> ret(expr);
|
Lattice<decltype(expr.op.func(vecEval(0, expr.arg1)))> ret(expr);
|
||||||
return ret;
|
return ret;
|
||||||
}
|
}
|
||||||
template <class Op, class T1, class T2>
|
template <class Op, class T1, class T2>
|
||||||
auto closure(const LatticeBinaryExpression<Op, T1, T2> &expr)
|
auto closure(const LatticeBinaryExpression<Op, T1, T2> &expr)
|
||||||
-> Lattice<decltype(expr.op.func(eval(0, expr.arg1),eval(0, expr.arg2)))>
|
-> Lattice<decltype(expr.op.func(vecEval(0, expr.arg1),vecEval(0, expr.arg2)))>
|
||||||
{
|
{
|
||||||
Lattice<decltype(expr.op.func(eval(0, expr.arg1),eval(0, expr.arg2)))> ret(expr);
|
Lattice<decltype(expr.op.func(vecEval(0, expr.arg1),vecEval(0, expr.arg2)))> ret(expr);
|
||||||
return ret;
|
return ret;
|
||||||
}
|
}
|
||||||
template <class Op, class T1, class T2, class T3>
|
template <class Op, class T1, class T2, class T3>
|
||||||
auto closure(const LatticeTrinaryExpression<Op, T1, T2, T3> &expr)
|
auto closure(const LatticeTrinaryExpression<Op, T1, T2, T3> &expr)
|
||||||
-> Lattice<decltype(expr.op.func(eval(0, expr.arg1),
|
-> Lattice<decltype(expr.op.func(vecEval(0, expr.arg1),
|
||||||
eval(0, expr.arg2),
|
vecEval(0, expr.arg2),
|
||||||
eval(0, expr.arg3)))>
|
vecEval(0, expr.arg3)))>
|
||||||
{
|
{
|
||||||
Lattice<decltype(expr.op.func(eval(0, expr.arg1),
|
Lattice<decltype(expr.op.func(vecEval(0, expr.arg1),
|
||||||
eval(0, expr.arg2),
|
vecEval(0, expr.arg2),
|
||||||
eval(0, expr.arg3)))> ret(expr);
|
vecEval(0, expr.arg3)))> ret(expr);
|
||||||
return ret;
|
return ret;
|
||||||
}
|
}
|
||||||
|
#define EXPRESSION_CLOSURE(function) \
|
||||||
|
template<class Expression,typename std::enable_if<is_lattice_expr<Expression>::value,void>::type * = nullptr> \
|
||||||
|
auto function(Expression &expr) -> decltype(function(closure(expr))) \
|
||||||
|
{ \
|
||||||
|
return function(closure(expr)); \
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
#undef GRID_UNOP
|
#undef GRID_UNOP
|
||||||
#undef GRID_BINOP
|
#undef GRID_BINOP
|
||||||
|
Loading…
x
Reference in New Issue
Block a user