1
0
mirror of https://github.com/paboyle/Grid.git synced 2025-04-09 21:50:45 +01:00

Almost there to coalesced ET

This commit is contained in:
Peter Boyle 2020-08-26 17:04:49 -04:00
parent 47b89d2739
commit 3448b7387c
2 changed files with 98 additions and 19 deletions

View File

@ -42,9 +42,25 @@ NAMESPACE_BEGIN(Grid);
//////////////////////////////////////////////////// ////////////////////////////////////////////////////
// Predicated where support // Predicated where support
//////////////////////////////////////////////////// ////////////////////////////////////////////////////
#ifdef GRID_SIMT
template <class iobj, class vobj, class robj> template <class iobj, class vobj, class robj>
accelerator_inline vobj predicatedWhere(const iobj &predicate, const vobj &iftrue, accelerator_inline vobj predicatedWhere(const iobj &predicate,
const robj &iffalse) { const vobj &iftrue,
const robj &iffalse)
{
// should drop to sccalar in SIMT
// typename std::remove_const<vobj>::type ret;
// Integer mask = TensorRemove(predicate);
// ret = iffalse;
// if (TensorRemove(mask)) ret=iftrue;
return iftrue;
}
#else
template <class iobj, class vobj, class robj>
accelerator_inline vobj predicatedWhere(const iobj &predicate,
const vobj &iftrue,
const robj &iffalse)
{
typename std::remove_const<vobj>::type ret; typename std::remove_const<vobj>::type ret;
typedef typename vobj::scalar_object scalar_object; typedef typename vobj::scalar_object scalar_object;
@ -68,6 +84,7 @@ accelerator_inline vobj predicatedWhere(const iobj &predicate, const vobj &iftru
merge(ret, falsevals); merge(ret, falsevals);
return ret; return ret;
} }
#endif
///////////////////////////////////////////////////// /////////////////////////////////////////////////////
//Specialization of getVectorType for lattices //Specialization of getVectorType for lattices
@ -86,13 +103,45 @@ sobj eval(const uint64_t ss, const sobj &arg)
{ {
return arg; return arg;
} }
template <class lobj> accelerator_inline template <class lobj> accelerator_inline
const lobj & eval(const uint64_t ss, const LatticeView<lobj> &arg) const lobj & eval(const uint64_t ss, const LatticeView<lobj> &arg)
{ {
return arg[ss]; return arg[ss];
} }
template<class sobj,
typename std::enable_if<!is_lattice<sobj>::value&&!is_lattice_expr<sobj>::value,sobj>::type * = nullptr>
accelerator_inline
sobj coalescedEval(const uint64_t ss, const sobj &arg)
{
return arg;
}
#ifdef GRID_SIMT
#warning device
template <class lobj> accelerator_inline
typename lobj::scalar_object coalescedEval(const uint64_t ss, const LatticeView<lobj> &arg)
{
auto ret = arg(ss);
return ret;
}
#else
#warning host
template <class lobj> accelerator_inline
lobj coalescedEval(const uint64_t ss, const LatticeView<lobj> &arg)
{
// return coalescedRead(arg[ss]);
return arg[ss];
}
#endif
/*
template <class lobj> accelerator_inline
typename lobj::scalar_type coalescedEval(const uint64_t ss, const Lattice<lobj> &arg)
{
assert(0);
return coalescedRead(arg[ss]);
}
*/
// What needs this? // What needs this?
// Cannot be legal on accelerator // Cannot be legal on accelerator
// Comparison must convert // Comparison must convert
@ -115,18 +164,14 @@ auto eval(const uint64_t ss, const LatticeUnaryExpression<Op, T1> &expr)
{ {
return expr.op.func( eval(ss, expr.arg1) ); return expr.op.func( eval(ss, expr.arg1) );
} }
///////////////////////
// eval two operands // eval two operands
///////////////////////
template <typename Op, typename T1, typename T2> accelerator_inline template <typename Op, typename T1, typename T2> accelerator_inline
auto eval(const uint64_t ss, const LatticeBinaryExpression<Op, T1, T2> &expr) auto eval(const uint64_t ss, const LatticeBinaryExpression<Op, T1, T2> &expr)
-> decltype(expr.op.func( eval(ss,expr.arg1),eval(ss,expr.arg2))) -> decltype(expr.op.func( eval(ss,expr.arg1),eval(ss,expr.arg2)))
{ {
return expr.op.func( eval(ss,expr.arg1), eval(ss,expr.arg2) ); return expr.op.func( eval(ss,expr.arg1), eval(ss,expr.arg2) );
} }
///////////////////////
// eval three operands // eval three operands
///////////////////////
template <typename Op, typename T1, typename T2, typename T3> accelerator_inline template <typename Op, typename T1, typename T2, typename T3> accelerator_inline
auto eval(const uint64_t ss, const LatticeTrinaryExpression<Op, T1, T2, T3> &expr) auto eval(const uint64_t ss, const LatticeTrinaryExpression<Op, T1, T2, T3> &expr)
-> decltype(expr.op.func(eval(ss, expr.arg1), eval(ss, expr.arg2), eval(ss, expr.arg3))) -> decltype(expr.op.func(eval(ss, expr.arg1), eval(ss, expr.arg2), eval(ss, expr.arg3)))
@ -134,6 +179,34 @@ auto eval(const uint64_t ss, const LatticeTrinaryExpression<Op, T1, T2, T3> &exp
return expr.op.func(eval(ss, expr.arg1), eval(ss, expr.arg2), eval(ss, expr.arg3)); return expr.op.func(eval(ss, expr.arg1), eval(ss, expr.arg2), eval(ss, expr.arg3));
} }
///////////////////////////////////////////////////
// handle nodes in syntax tree- eval one operand coalesced
///////////////////////////////////////////////////
template <typename Op, typename T1> accelerator_inline
auto coalescedEval(const uint64_t ss, const LatticeUnaryExpression<Op, T1> &expr)
-> decltype(expr.op.func( coalescedEval(ss, expr.arg1)))
{
return expr.op.func( coalescedEval(ss, expr.arg1) );
}
// eval two operands
template <typename Op, typename T1, typename T2> accelerator_inline
auto coalescedEval(const uint64_t ss, const LatticeBinaryExpression<Op, T1, T2> &expr)
-> decltype(expr.op.func( coalescedEval(ss,expr.arg1),coalescedEval(ss,expr.arg2)))
{
return expr.op.func( coalescedEval(ss,expr.arg1), coalescedEval(ss,expr.arg2) );
}
// eval three operands
template <typename Op, typename T1, typename T2, typename T3> accelerator_inline
auto coalescedEval(const uint64_t ss, const LatticeTrinaryExpression<Op, T1, T2, T3> &expr)
-> decltype(expr.op.func(coalescedEval(ss, expr.arg1),
coalescedEval(ss, expr.arg2),
coalescedEval(ss, expr.arg3)))
{
return expr.op.func(coalescedEval(ss, expr.arg1),
coalescedEval(ss, expr.arg2),
coalescedEval(ss, expr.arg3));
}
////////////////////////////////////////////////////////////////////////// //////////////////////////////////////////////////////////////////////////
// Obtain the grid from an expression, ensuring conformable. This must follow a // Obtain the grid from an expression, ensuring conformable. This must follow a
// tree recursion; must retain grid pointer in the LatticeView class which sucks // tree recursion; must retain grid pointer in the LatticeView class which sucks
@ -275,7 +348,7 @@ inline void ExpressionViewClose(LatticeTrinaryExpression<Op, T1, T2, T3> &expr)
#define GridUnopClass(name, ret) \ #define GridUnopClass(name, ret) \
template <class arg> \ template <class arg> \
struct name { \ struct name { \
static auto accelerator_inline func(const arg a) -> decltype(ret) { return ret; } \ template<class _arg> static auto accelerator_inline func(const _arg a) -> decltype(ret) { return ret; } \
}; };
GridUnopClass(UnarySub, -a); GridUnopClass(UnarySub, -a);
@ -308,8 +381,9 @@ GridUnopClass(UnaryExp, exp(a));
#define GridBinOpClass(name, combination) \ #define GridBinOpClass(name, combination) \
template <class left, class right> \ template <class left, class right> \
struct name { \ struct name { \
template <class _left, class _right> \
static auto accelerator_inline \ static auto accelerator_inline \
func(const left &lhs, const right &rhs) \ func(const _left &lhs, const _right &rhs) \
-> decltype(combination) const \ -> decltype(combination) const \
{ \ { \
return combination; \ return combination; \
@ -331,8 +405,9 @@ GridBinOpClass(BinaryOrOr, lhs || rhs);
#define GridTrinOpClass(name, combination) \ #define GridTrinOpClass(name, combination) \
template <class predicate, class left, class right> \ template <class predicate, class left, class right> \
struct name { \ struct name { \
template <class _predicate,class _left, class _right> \
static auto accelerator_inline \ static auto accelerator_inline \
func(const predicate &pred, const left &lhs, const right &rhs) \ func(const _predicate &pred, const _left &lhs, const _right &rhs) \
-> decltype(combination) const \ -> decltype(combination) const \
{ \ { \
return combination; \ return combination; \
@ -340,9 +415,10 @@ GridBinOpClass(BinaryOrOr, lhs || rhs);
}; };
GridTrinOpClass(TrinaryWhere, GridTrinOpClass(TrinaryWhere,
(predicatedWhere<predicate, (predicatedWhere<
typename std::remove_reference<left>::type, typename std::remove_reference<_predicate>::type,
typename std::remove_reference<right>::type>(pred, lhs,rhs))); typename std::remove_reference<_left>::type,
typename std::remove_reference<_right>::type>(pred, lhs,rhs)));
//////////////////////////////////////////// ////////////////////////////////////////////
// Operator syntactical glue // Operator syntactical glue

View File

@ -124,8 +124,9 @@ public:
ExpressionViewOpen(exprCopy); ExpressionViewOpen(exprCopy);
auto me = View(AcceleratorWriteDiscard); auto me = View(AcceleratorWriteDiscard);
accelerator_for(ss,me.size(),1,{ accelerator_for(ss,me.size(),1,{
auto tmp = eval(ss,exprCopy); auto tmp = coalescedEval(ss,exprCopy);
vstream(me[ss],tmp); coalescedWrite(me[ss],tmp);
// me[ss]=tmp;
}); });
me.ViewClose(); me.ViewClose();
ExpressionViewClose(exprCopy); ExpressionViewClose(exprCopy);
@ -147,8 +148,9 @@ public:
ExpressionViewOpen(exprCopy); ExpressionViewOpen(exprCopy);
auto me = View(AcceleratorWriteDiscard); auto me = View(AcceleratorWriteDiscard);
accelerator_for(ss,me.size(),1,{ accelerator_for(ss,me.size(),1,{
auto tmp = eval(ss,exprCopy); auto tmp = coalescedEval(ss,exprCopy);
vstream(me[ss],tmp); coalescedWrite(me[ss],tmp);
//me[ss]=tmp;
}); });
me.ViewClose(); me.ViewClose();
ExpressionViewClose(exprCopy); ExpressionViewClose(exprCopy);
@ -169,8 +171,9 @@ public:
ExpressionViewOpen(exprCopy); ExpressionViewOpen(exprCopy);
auto me = View(AcceleratorWriteDiscard); auto me = View(AcceleratorWriteDiscard);
accelerator_for(ss,me.size(),1,{ accelerator_for(ss,me.size(),1,{
auto tmp = eval(ss,exprCopy); auto tmp = coalescedEval(ss,exprCopy);
vstream(me[ss],tmp); coalescedWrite(me[ss],tmp);
// me[ss]=tmp;
}); });
me.ViewClose(); me.ViewClose();
ExpressionViewClose(exprCopy); ExpressionViewClose(exprCopy);