mirror of
				https://github.com/paboyle/Grid.git
				synced 2025-10-30 19:44:32 +00:00 
			
		
		
		
	Almost there to coalesced ET
This commit is contained in:
		| @@ -42,9 +42,25 @@ NAMESPACE_BEGIN(Grid); | |||||||
| //////////////////////////////////////////////////// | //////////////////////////////////////////////////// | ||||||
| // Predicated where support | // Predicated where support | ||||||
| //////////////////////////////////////////////////// | //////////////////////////////////////////////////// | ||||||
|  | #ifdef GRID_SIMT | ||||||
| template <class iobj, class vobj, class robj> | template <class iobj, class vobj, class robj> | ||||||
| accelerator_inline vobj predicatedWhere(const iobj &predicate, const vobj &iftrue, | accelerator_inline vobj predicatedWhere(const iobj &predicate,  | ||||||
|                             const robj &iffalse) { | 					const vobj &iftrue,  | ||||||
|  | 					const robj &iffalse)  | ||||||
|  | { | ||||||
|  |   // should drop to sccalar in SIMT | ||||||
|  |   //  typename std::remove_const<vobj>::type ret; | ||||||
|  |   //  Integer mask = TensorRemove(predicate); | ||||||
|  |   //  ret = iffalse; | ||||||
|  |   //  if (TensorRemove(mask)) ret=iftrue; | ||||||
|  |   return iftrue; | ||||||
|  | } | ||||||
|  | #else | ||||||
|  | template <class iobj, class vobj, class robj> | ||||||
|  | accelerator_inline vobj predicatedWhere(const iobj &predicate,  | ||||||
|  | 					const vobj &iftrue,  | ||||||
|  | 					const robj &iffalse)  | ||||||
|  | { | ||||||
|   typename std::remove_const<vobj>::type ret; |   typename std::remove_const<vobj>::type ret; | ||||||
|  |  | ||||||
|   typedef typename vobj::scalar_object scalar_object; |   typedef typename vobj::scalar_object scalar_object; | ||||||
| @@ -68,6 +84,7 @@ accelerator_inline vobj predicatedWhere(const iobj &predicate, const vobj &iftru | |||||||
|   merge(ret, falsevals); |   merge(ret, falsevals); | ||||||
|   return ret; |   return ret; | ||||||
| } | } | ||||||
|  | #endif | ||||||
|  |  | ||||||
| ///////////////////////////////////////////////////// | ///////////////////////////////////////////////////// | ||||||
| //Specialization of getVectorType for lattices | //Specialization of getVectorType for lattices | ||||||
| @@ -86,13 +103,45 @@ sobj eval(const uint64_t ss, const sobj &arg) | |||||||
| { | { | ||||||
|   return arg; |   return arg; | ||||||
| } | } | ||||||
|  |  | ||||||
| template <class lobj> accelerator_inline  | template <class lobj> accelerator_inline  | ||||||
| const lobj & eval(const uint64_t ss, const LatticeView<lobj> &arg)  | const lobj & eval(const uint64_t ss, const LatticeView<lobj> &arg)  | ||||||
| { | { | ||||||
|   return arg[ss]; |   return arg[ss]; | ||||||
| } | } | ||||||
|  |  | ||||||
|  | template<class sobj, | ||||||
|  |   typename std::enable_if<!is_lattice<sobj>::value&&!is_lattice_expr<sobj>::value,sobj>::type * = nullptr>  | ||||||
|  | accelerator_inline  | ||||||
|  | sobj coalescedEval(const uint64_t ss, const sobj &arg) | ||||||
|  | { | ||||||
|  |   return arg; | ||||||
|  | } | ||||||
|  | #ifdef GRID_SIMT | ||||||
|  | #warning device | ||||||
|  | template <class lobj> accelerator_inline  | ||||||
|  | typename lobj::scalar_object coalescedEval(const uint64_t ss, const LatticeView<lobj> &arg)  | ||||||
|  | { | ||||||
|  |   auto ret = arg(ss); | ||||||
|  |   return ret; | ||||||
|  | } | ||||||
|  | #else | ||||||
|  | #warning host | ||||||
|  | template <class lobj> accelerator_inline  | ||||||
|  | lobj coalescedEval(const uint64_t ss, const LatticeView<lobj> &arg)  | ||||||
|  | { | ||||||
|  |   //   return coalescedRead(arg[ss]); | ||||||
|  |   return arg[ss]; | ||||||
|  | } | ||||||
|  | #endif | ||||||
|  | /* | ||||||
|  | template <class lobj> accelerator_inline  | ||||||
|  | typename lobj::scalar_type coalescedEval(const uint64_t ss, const Lattice<lobj> &arg)  | ||||||
|  | { | ||||||
|  |   assert(0); | ||||||
|  |   return coalescedRead(arg[ss]); | ||||||
|  | } | ||||||
|  | */ | ||||||
|  |  | ||||||
| // What needs this? | // What needs this? | ||||||
| // Cannot be legal on accelerator | // Cannot be legal on accelerator | ||||||
| // Comparison must convert | // Comparison must convert | ||||||
| @@ -115,18 +164,14 @@ auto eval(const uint64_t ss, const LatticeUnaryExpression<Op, T1> &expr) | |||||||
| { | { | ||||||
|   return expr.op.func( eval(ss, expr.arg1) ); |   return expr.op.func( eval(ss, expr.arg1) ); | ||||||
| } | } | ||||||
| /////////////////////// |  | ||||||
| // eval two operands | // eval two operands | ||||||
| /////////////////////// |  | ||||||
| template <typename Op, typename T1, typename T2> accelerator_inline | template <typename Op, typename T1, typename T2> accelerator_inline | ||||||
| auto eval(const uint64_t ss, const LatticeBinaryExpression<Op, T1, T2> &expr)   | auto eval(const uint64_t ss, const LatticeBinaryExpression<Op, T1, T2> &expr)   | ||||||
|   -> decltype(expr.op.func( eval(ss,expr.arg1),eval(ss,expr.arg2))) |   -> decltype(expr.op.func( eval(ss,expr.arg1),eval(ss,expr.arg2))) | ||||||
| { | { | ||||||
|   return expr.op.func( eval(ss,expr.arg1), eval(ss,expr.arg2) ); |   return expr.op.func( eval(ss,expr.arg1), eval(ss,expr.arg2) ); | ||||||
| } | } | ||||||
| /////////////////////// |  | ||||||
| // eval three operands | // eval three operands | ||||||
| /////////////////////// |  | ||||||
| template <typename Op, typename T1, typename T2, typename T3> accelerator_inline | template <typename Op, typename T1, typename T2, typename T3> accelerator_inline | ||||||
| auto eval(const uint64_t ss, const LatticeTrinaryExpression<Op, T1, T2, T3> &expr)   | auto eval(const uint64_t ss, const LatticeTrinaryExpression<Op, T1, T2, T3> &expr)   | ||||||
|   -> decltype(expr.op.func(eval(ss, expr.arg1), eval(ss, expr.arg2), eval(ss, expr.arg3))) |   -> decltype(expr.op.func(eval(ss, expr.arg1), eval(ss, expr.arg2), eval(ss, expr.arg3))) | ||||||
| @@ -134,6 +179,34 @@ auto eval(const uint64_t ss, const LatticeTrinaryExpression<Op, T1, T2, T3> &exp | |||||||
|   return expr.op.func(eval(ss, expr.arg1), eval(ss, expr.arg2), eval(ss, expr.arg3)); |   return expr.op.func(eval(ss, expr.arg1), eval(ss, expr.arg2), eval(ss, expr.arg3)); | ||||||
| } | } | ||||||
|  |  | ||||||
|  | /////////////////////////////////////////////////// | ||||||
|  | // handle nodes in syntax tree- eval one operand coalesced | ||||||
|  | /////////////////////////////////////////////////// | ||||||
|  | template <typename Op, typename T1> accelerator_inline  | ||||||
|  | auto coalescedEval(const uint64_t ss, const LatticeUnaryExpression<Op, T1> &expr)   | ||||||
|  |   -> decltype(expr.op.func( coalescedEval(ss, expr.arg1))) | ||||||
|  | { | ||||||
|  |   return expr.op.func( coalescedEval(ss, expr.arg1) ); | ||||||
|  | } | ||||||
|  | // eval two operands | ||||||
|  | template <typename Op, typename T1, typename T2> accelerator_inline | ||||||
|  | auto coalescedEval(const uint64_t ss, const LatticeBinaryExpression<Op, T1, T2> &expr)   | ||||||
|  |   -> decltype(expr.op.func( coalescedEval(ss,expr.arg1),coalescedEval(ss,expr.arg2))) | ||||||
|  | { | ||||||
|  |   return expr.op.func( coalescedEval(ss,expr.arg1), coalescedEval(ss,expr.arg2) ); | ||||||
|  | } | ||||||
|  | // eval three operands | ||||||
|  | template <typename Op, typename T1, typename T2, typename T3> accelerator_inline | ||||||
|  | auto coalescedEval(const uint64_t ss, const LatticeTrinaryExpression<Op, T1, T2, T3> &expr)   | ||||||
|  |   -> decltype(expr.op.func(coalescedEval(ss, expr.arg1),  | ||||||
|  | 			   coalescedEval(ss, expr.arg2),  | ||||||
|  | 			   coalescedEval(ss, expr.arg3))) | ||||||
|  | { | ||||||
|  |   return expr.op.func(coalescedEval(ss, expr.arg1),  | ||||||
|  | 		      coalescedEval(ss, expr.arg2),  | ||||||
|  | 		      coalescedEval(ss, expr.arg3)); | ||||||
|  | } | ||||||
|  |  | ||||||
| ////////////////////////////////////////////////////////////////////////// | ////////////////////////////////////////////////////////////////////////// | ||||||
| // Obtain the grid from an expression, ensuring conformable. This must follow a | // Obtain the grid from an expression, ensuring conformable. This must follow a | ||||||
| // tree recursion; must retain grid pointer in the LatticeView class which sucks | // tree recursion; must retain grid pointer in the LatticeView class which sucks | ||||||
| @@ -275,7 +348,7 @@ inline void ExpressionViewClose(LatticeTrinaryExpression<Op, T1, T2, T3> &expr) | |||||||
| #define GridUnopClass(name, ret)					\ | #define GridUnopClass(name, ret)					\ | ||||||
|   template <class arg>							\ |   template <class arg>							\ | ||||||
|   struct name {								\ |   struct name {								\ | ||||||
|     static auto accelerator_inline func(const arg a) -> decltype(ret) { return ret; } \ |     template<class _arg> static auto accelerator_inline func(const _arg a) -> decltype(ret) { return ret; } \ | ||||||
|   }; |   }; | ||||||
|  |  | ||||||
| GridUnopClass(UnarySub, -a); | GridUnopClass(UnarySub, -a); | ||||||
| @@ -308,8 +381,9 @@ GridUnopClass(UnaryExp, exp(a)); | |||||||
| #define GridBinOpClass(name, combination)			\ | #define GridBinOpClass(name, combination)			\ | ||||||
|   template <class left, class right>				\ |   template <class left, class right>				\ | ||||||
|   struct name {							\ |   struct name {							\ | ||||||
|  |     template <class _left, class _right>			\ | ||||||
|     static auto accelerator_inline				\ |     static auto accelerator_inline				\ | ||||||
|     func(const left &lhs, const right &rhs)			\ |     func(const _left &lhs, const _right &rhs)			\ | ||||||
|       -> decltype(combination) const				\ |       -> decltype(combination) const				\ | ||||||
|     {								\ |     {								\ | ||||||
|       return combination;					\ |       return combination;					\ | ||||||
| @@ -331,8 +405,9 @@ GridBinOpClass(BinaryOrOr, lhs || rhs); | |||||||
| #define GridTrinOpClass(name, combination)				\ | #define GridTrinOpClass(name, combination)				\ | ||||||
|   template <class predicate, class left, class right>			\ |   template <class predicate, class left, class right>			\ | ||||||
|   struct name {								\ |   struct name {								\ | ||||||
|  |     template <class _predicate,class _left, class _right>		\ | ||||||
|     static auto accelerator_inline					\ |     static auto accelerator_inline					\ | ||||||
|     func(const predicate &pred, const left &lhs, const right &rhs)	\ |     func(const _predicate &pred, const _left &lhs, const _right &rhs)	\ | ||||||
|       -> decltype(combination) const					\ |       -> decltype(combination) const					\ | ||||||
|     {									\ |     {									\ | ||||||
|       return combination;						\ |       return combination;						\ | ||||||
| @@ -340,9 +415,10 @@ GridBinOpClass(BinaryOrOr, lhs || rhs); | |||||||
|   }; |   }; | ||||||
|  |  | ||||||
| GridTrinOpClass(TrinaryWhere, | GridTrinOpClass(TrinaryWhere, | ||||||
| 		(predicatedWhere<predicate,  | 		(predicatedWhere< | ||||||
| 		 typename std::remove_reference<left>::type, | 		 typename std::remove_reference<_predicate>::type,  | ||||||
| 		 typename std::remove_reference<right>::type>(pred, lhs,rhs))); | 		 typename std::remove_reference<_left>::type, | ||||||
|  | 		 typename std::remove_reference<_right>::type>(pred, lhs,rhs))); | ||||||
|  |  | ||||||
| //////////////////////////////////////////// | //////////////////////////////////////////// | ||||||
| // Operator syntactical glue | // Operator syntactical glue | ||||||
|   | |||||||
| @@ -124,8 +124,9 @@ public: | |||||||
|     ExpressionViewOpen(exprCopy); |     ExpressionViewOpen(exprCopy); | ||||||
|     auto me  = View(AcceleratorWriteDiscard); |     auto me  = View(AcceleratorWriteDiscard); | ||||||
|     accelerator_for(ss,me.size(),1,{ |     accelerator_for(ss,me.size(),1,{ | ||||||
|       auto tmp = eval(ss,exprCopy); |       auto tmp = coalescedEval(ss,exprCopy); | ||||||
|       vstream(me[ss],tmp); |       coalescedWrite(me[ss],tmp); | ||||||
|  |       //      me[ss]=tmp; | ||||||
|     }); |     }); | ||||||
|     me.ViewClose(); |     me.ViewClose(); | ||||||
|     ExpressionViewClose(exprCopy); |     ExpressionViewClose(exprCopy); | ||||||
| @@ -147,8 +148,9 @@ public: | |||||||
|     ExpressionViewOpen(exprCopy); |     ExpressionViewOpen(exprCopy); | ||||||
|     auto me  = View(AcceleratorWriteDiscard); |     auto me  = View(AcceleratorWriteDiscard); | ||||||
|     accelerator_for(ss,me.size(),1,{ |     accelerator_for(ss,me.size(),1,{ | ||||||
|       auto tmp = eval(ss,exprCopy); |       auto tmp = coalescedEval(ss,exprCopy); | ||||||
|       vstream(me[ss],tmp); |       coalescedWrite(me[ss],tmp); | ||||||
|  |       //me[ss]=tmp; | ||||||
|     }); |     }); | ||||||
|     me.ViewClose(); |     me.ViewClose(); | ||||||
|     ExpressionViewClose(exprCopy); |     ExpressionViewClose(exprCopy); | ||||||
| @@ -169,8 +171,9 @@ public: | |||||||
|     ExpressionViewOpen(exprCopy); |     ExpressionViewOpen(exprCopy); | ||||||
|     auto me  = View(AcceleratorWriteDiscard); |     auto me  = View(AcceleratorWriteDiscard); | ||||||
|     accelerator_for(ss,me.size(),1,{ |     accelerator_for(ss,me.size(),1,{ | ||||||
|       auto tmp = eval(ss,exprCopy); |       auto tmp = coalescedEval(ss,exprCopy); | ||||||
|       vstream(me[ss],tmp); |       coalescedWrite(me[ss],tmp); | ||||||
|  |       //      me[ss]=tmp; | ||||||
|     }); |     }); | ||||||
|     me.ViewClose(); |     me.ViewClose(); | ||||||
|     ExpressionViewClose(exprCopy); |     ExpressionViewClose(exprCopy); | ||||||
|   | |||||||
		Reference in New Issue
	
	Block a user