mirror of
https://github.com/paboyle/Grid.git
synced 2024-11-10 07:55:35 +00:00
Almost there to coalesced ET
This commit is contained in:
parent
47b89d2739
commit
3448b7387c
@ -42,9 +42,25 @@ NAMESPACE_BEGIN(Grid);
|
||||
////////////////////////////////////////////////////
|
||||
// Predicated where support
|
||||
////////////////////////////////////////////////////
|
||||
#ifdef GRID_SIMT
|
||||
template <class iobj, class vobj, class robj>
|
||||
accelerator_inline vobj predicatedWhere(const iobj &predicate, const vobj &iftrue,
|
||||
const robj &iffalse) {
|
||||
accelerator_inline vobj predicatedWhere(const iobj &predicate,
|
||||
const vobj &iftrue,
|
||||
const robj &iffalse)
|
||||
{
|
||||
// should drop to sccalar in SIMT
|
||||
// typename std::remove_const<vobj>::type ret;
|
||||
// Integer mask = TensorRemove(predicate);
|
||||
// ret = iffalse;
|
||||
// if (TensorRemove(mask)) ret=iftrue;
|
||||
return iftrue;
|
||||
}
|
||||
#else
|
||||
template <class iobj, class vobj, class robj>
|
||||
accelerator_inline vobj predicatedWhere(const iobj &predicate,
|
||||
const vobj &iftrue,
|
||||
const robj &iffalse)
|
||||
{
|
||||
typename std::remove_const<vobj>::type ret;
|
||||
|
||||
typedef typename vobj::scalar_object scalar_object;
|
||||
@ -68,6 +84,7 @@ accelerator_inline vobj predicatedWhere(const iobj &predicate, const vobj &iftru
|
||||
merge(ret, falsevals);
|
||||
return ret;
|
||||
}
|
||||
#endif
|
||||
|
||||
/////////////////////////////////////////////////////
|
||||
//Specialization of getVectorType for lattices
|
||||
@ -86,13 +103,45 @@ sobj eval(const uint64_t ss, const sobj &arg)
|
||||
{
|
||||
return arg;
|
||||
}
|
||||
|
||||
template <class lobj> accelerator_inline
|
||||
const lobj & eval(const uint64_t ss, const LatticeView<lobj> &arg)
|
||||
{
|
||||
return arg[ss];
|
||||
}
|
||||
|
||||
template<class sobj,
|
||||
typename std::enable_if<!is_lattice<sobj>::value&&!is_lattice_expr<sobj>::value,sobj>::type * = nullptr>
|
||||
accelerator_inline
|
||||
sobj coalescedEval(const uint64_t ss, const sobj &arg)
|
||||
{
|
||||
return arg;
|
||||
}
|
||||
#ifdef GRID_SIMT
|
||||
#warning device
|
||||
template <class lobj> accelerator_inline
|
||||
typename lobj::scalar_object coalescedEval(const uint64_t ss, const LatticeView<lobj> &arg)
|
||||
{
|
||||
auto ret = arg(ss);
|
||||
return ret;
|
||||
}
|
||||
#else
|
||||
#warning host
|
||||
template <class lobj> accelerator_inline
|
||||
lobj coalescedEval(const uint64_t ss, const LatticeView<lobj> &arg)
|
||||
{
|
||||
// return coalescedRead(arg[ss]);
|
||||
return arg[ss];
|
||||
}
|
||||
#endif
|
||||
/*
|
||||
template <class lobj> accelerator_inline
|
||||
typename lobj::scalar_type coalescedEval(const uint64_t ss, const Lattice<lobj> &arg)
|
||||
{
|
||||
assert(0);
|
||||
return coalescedRead(arg[ss]);
|
||||
}
|
||||
*/
|
||||
|
||||
// What needs this?
|
||||
// Cannot be legal on accelerator
|
||||
// Comparison must convert
|
||||
@ -115,18 +164,14 @@ auto eval(const uint64_t ss, const LatticeUnaryExpression<Op, T1> &expr)
|
||||
{
|
||||
return expr.op.func( eval(ss, expr.arg1) );
|
||||
}
|
||||
///////////////////////
|
||||
// eval two operands
|
||||
///////////////////////
|
||||
template <typename Op, typename T1, typename T2> accelerator_inline
|
||||
auto eval(const uint64_t ss, const LatticeBinaryExpression<Op, T1, T2> &expr)
|
||||
-> decltype(expr.op.func( eval(ss,expr.arg1),eval(ss,expr.arg2)))
|
||||
{
|
||||
return expr.op.func( eval(ss,expr.arg1), eval(ss,expr.arg2) );
|
||||
}
|
||||
///////////////////////
|
||||
// eval three operands
|
||||
///////////////////////
|
||||
template <typename Op, typename T1, typename T2, typename T3> accelerator_inline
|
||||
auto eval(const uint64_t ss, const LatticeTrinaryExpression<Op, T1, T2, T3> &expr)
|
||||
-> decltype(expr.op.func(eval(ss, expr.arg1), eval(ss, expr.arg2), eval(ss, expr.arg3)))
|
||||
@ -134,6 +179,34 @@ auto eval(const uint64_t ss, const LatticeTrinaryExpression<Op, T1, T2, T3> &exp
|
||||
return expr.op.func(eval(ss, expr.arg1), eval(ss, expr.arg2), eval(ss, expr.arg3));
|
||||
}
|
||||
|
||||
///////////////////////////////////////////////////
|
||||
// handle nodes in syntax tree- eval one operand coalesced
|
||||
///////////////////////////////////////////////////
|
||||
template <typename Op, typename T1> accelerator_inline
|
||||
auto coalescedEval(const uint64_t ss, const LatticeUnaryExpression<Op, T1> &expr)
|
||||
-> decltype(expr.op.func( coalescedEval(ss, expr.arg1)))
|
||||
{
|
||||
return expr.op.func( coalescedEval(ss, expr.arg1) );
|
||||
}
|
||||
// eval two operands
|
||||
template <typename Op, typename T1, typename T2> accelerator_inline
|
||||
auto coalescedEval(const uint64_t ss, const LatticeBinaryExpression<Op, T1, T2> &expr)
|
||||
-> decltype(expr.op.func( coalescedEval(ss,expr.arg1),coalescedEval(ss,expr.arg2)))
|
||||
{
|
||||
return expr.op.func( coalescedEval(ss,expr.arg1), coalescedEval(ss,expr.arg2) );
|
||||
}
|
||||
// eval three operands
|
||||
template <typename Op, typename T1, typename T2, typename T3> accelerator_inline
|
||||
auto coalescedEval(const uint64_t ss, const LatticeTrinaryExpression<Op, T1, T2, T3> &expr)
|
||||
-> decltype(expr.op.func(coalescedEval(ss, expr.arg1),
|
||||
coalescedEval(ss, expr.arg2),
|
||||
coalescedEval(ss, expr.arg3)))
|
||||
{
|
||||
return expr.op.func(coalescedEval(ss, expr.arg1),
|
||||
coalescedEval(ss, expr.arg2),
|
||||
coalescedEval(ss, expr.arg3));
|
||||
}
|
||||
|
||||
//////////////////////////////////////////////////////////////////////////
|
||||
// Obtain the grid from an expression, ensuring conformable. This must follow a
|
||||
// tree recursion; must retain grid pointer in the LatticeView class which sucks
|
||||
@ -275,7 +348,7 @@ inline void ExpressionViewClose(LatticeTrinaryExpression<Op, T1, T2, T3> &expr)
|
||||
#define GridUnopClass(name, ret) \
|
||||
template <class arg> \
|
||||
struct name { \
|
||||
static auto accelerator_inline func(const arg a) -> decltype(ret) { return ret; } \
|
||||
template<class _arg> static auto accelerator_inline func(const _arg a) -> decltype(ret) { return ret; } \
|
||||
};
|
||||
|
||||
GridUnopClass(UnarySub, -a);
|
||||
@ -308,8 +381,9 @@ GridUnopClass(UnaryExp, exp(a));
|
||||
#define GridBinOpClass(name, combination) \
|
||||
template <class left, class right> \
|
||||
struct name { \
|
||||
template <class _left, class _right> \
|
||||
static auto accelerator_inline \
|
||||
func(const left &lhs, const right &rhs) \
|
||||
func(const _left &lhs, const _right &rhs) \
|
||||
-> decltype(combination) const \
|
||||
{ \
|
||||
return combination; \
|
||||
@ -331,8 +405,9 @@ GridBinOpClass(BinaryOrOr, lhs || rhs);
|
||||
#define GridTrinOpClass(name, combination) \
|
||||
template <class predicate, class left, class right> \
|
||||
struct name { \
|
||||
template <class _predicate,class _left, class _right> \
|
||||
static auto accelerator_inline \
|
||||
func(const predicate &pred, const left &lhs, const right &rhs) \
|
||||
func(const _predicate &pred, const _left &lhs, const _right &rhs) \
|
||||
-> decltype(combination) const \
|
||||
{ \
|
||||
return combination; \
|
||||
@ -340,9 +415,10 @@ GridBinOpClass(BinaryOrOr, lhs || rhs);
|
||||
};
|
||||
|
||||
GridTrinOpClass(TrinaryWhere,
|
||||
(predicatedWhere<predicate,
|
||||
typename std::remove_reference<left>::type,
|
||||
typename std::remove_reference<right>::type>(pred, lhs,rhs)));
|
||||
(predicatedWhere<
|
||||
typename std::remove_reference<_predicate>::type,
|
||||
typename std::remove_reference<_left>::type,
|
||||
typename std::remove_reference<_right>::type>(pred, lhs,rhs)));
|
||||
|
||||
////////////////////////////////////////////
|
||||
// Operator syntactical glue
|
||||
|
@ -124,8 +124,9 @@ public:
|
||||
ExpressionViewOpen(exprCopy);
|
||||
auto me = View(AcceleratorWriteDiscard);
|
||||
accelerator_for(ss,me.size(),1,{
|
||||
auto tmp = eval(ss,exprCopy);
|
||||
vstream(me[ss],tmp);
|
||||
auto tmp = coalescedEval(ss,exprCopy);
|
||||
coalescedWrite(me[ss],tmp);
|
||||
// me[ss]=tmp;
|
||||
});
|
||||
me.ViewClose();
|
||||
ExpressionViewClose(exprCopy);
|
||||
@ -147,8 +148,9 @@ public:
|
||||
ExpressionViewOpen(exprCopy);
|
||||
auto me = View(AcceleratorWriteDiscard);
|
||||
accelerator_for(ss,me.size(),1,{
|
||||
auto tmp = eval(ss,exprCopy);
|
||||
vstream(me[ss],tmp);
|
||||
auto tmp = coalescedEval(ss,exprCopy);
|
||||
coalescedWrite(me[ss],tmp);
|
||||
//me[ss]=tmp;
|
||||
});
|
||||
me.ViewClose();
|
||||
ExpressionViewClose(exprCopy);
|
||||
@ -169,8 +171,9 @@ public:
|
||||
ExpressionViewOpen(exprCopy);
|
||||
auto me = View(AcceleratorWriteDiscard);
|
||||
accelerator_for(ss,me.size(),1,{
|
||||
auto tmp = eval(ss,exprCopy);
|
||||
vstream(me[ss],tmp);
|
||||
auto tmp = coalescedEval(ss,exprCopy);
|
||||
coalescedWrite(me[ss],tmp);
|
||||
// me[ss]=tmp;
|
||||
});
|
||||
me.ViewClose();
|
||||
ExpressionViewClose(exprCopy);
|
||||
|
Loading…
Reference in New Issue
Block a user