1
0
mirror of https://github.com/paboyle/Grid.git synced 2025-04-09 21:50:45 +01:00

Introduce accelerator friendly expression template rewrite.

Must obtain and access lattice indexing through a view object that is safe
to copy construct in copy to GPU (without copying the lattice).
This commit is contained in:
paboyle 2018-03-04 16:03:19 +00:00
parent dad7862f91
commit 0e6197fbed
16 changed files with 470 additions and 513 deletions

View File

@ -68,149 +68,139 @@ accelerator_inline vobj predicatedWhere(const iobj &predicate, const vobj &iftru
return ret; return ret;
} }
//////////////////////////////////////////// /////////////////////////////////////////////////////
// recursive evaluation of expressions; Could
// switch to generic approach with variadics, a la
// Antonin's Lat Sim but the repack to variadic with popped
// from tuple is hideous; C++14 introduces std::make_index_sequence for this
////////////////////////////////////////////
// leaf eval of lattice ; should enable if protect using traits
template <typename T>
using is_lattice = std::is_base_of<LatticeBase, T>;
template <typename T>
using is_lattice_expr = std::is_base_of<LatticeExpressionBase, T>;
template <typename T> using is_lattice_expr = std::is_base_of<LatticeExpressionBase,T >;
//Specialization of getVectorType for lattices //Specialization of getVectorType for lattices
/////////////////////////////////////////////////////
template<typename T> template<typename T>
struct getVectorType<Lattice<T> >{ struct getVectorType<Lattice<T> >{
typedef typename Lattice<T>::vector_object type; typedef typename Lattice<T>::vector_object type;
}; };
template<class sobj> ////////////////////////////////////////////
inline sobj eval(const unsigned int ss, const sobj &arg) //-- recursive evaluation of expressions; --
// handle leaves of syntax tree
///////////////////////////////////////////////////
template<class sobj> accelerator_inline
sobj eval(const unsigned int ss, const sobj &arg)
{ {
return arg; return arg;
} }
template <class lobj>
inline const lobj &eval(const unsigned int ss, const Lattice<lobj> &arg) { template <class lobj> accelerator_inline
const lobj & eval(const unsigned int ss, const LatticeView<lobj> &arg)
{
return arg[ss]; return arg[ss];
} }
template <class lobj> accelerator_inline
// handle nodes in syntax tree const lobj & eval(const unsigned int ss, const Lattice<lobj> &arg)
template <typename Op, typename T1> {
auto inline eval( auto view = arg.View();
const unsigned int ss, return view[ss];
const LatticeUnaryExpression<Op, T1> &expr) // eval one operand
-> decltype(expr.first.func(eval(ss, std::get<0>(expr.second)))) {
return expr.first.func(eval(ss, std::get<0>(expr.second)));
} }
template <typename Op, typename T1, typename T2> ///////////////////////////////////////////////////
auto inline eval( // handle nodes in syntax tree- eval one operand
const unsigned int ss, ///////////////////////////////////////////////////
const LatticeBinaryExpression<Op, T1, T2> &expr) // eval two operands template <typename Op, typename T1> accelerator_inline
-> decltype(expr.first.func(eval(ss, std::get<0>(expr.second)), auto eval(const unsigned int ss, const LatticeUnaryExpression<Op, T1> &expr)
eval(ss, std::get<1>(expr.second)))) { -> decltype(expr.op.func( eval(ss, expr.arg1)))
return expr.first.func(eval(ss, std::get<0>(expr.second)), {
eval(ss, std::get<1>(expr.second))); return expr.op.func( eval(ss, expr.arg1) );
} }
///////////////////////
template <typename Op, typename T1, typename T2, typename T3> // eval two operands
auto inline eval(const unsigned int ss, ///////////////////////
const LatticeTrinaryExpression<Op, T1, T2, T3> template <typename Op, typename T1, typename T2> accelerator_inline
&expr) // eval three operands auto eval(const unsigned int ss, const LatticeBinaryExpression<Op, T1, T2> &expr)
-> decltype(expr.first.func(eval(ss, std::get<0>(expr.second)), -> decltype(expr.op.func( eval(ss,expr.arg1),eval(ss,expr.arg2)))
eval(ss, std::get<1>(expr.second)), {
eval(ss, std::get<2>(expr.second)))) { return expr.op.func( eval(ss,expr.arg1), eval(ss,expr.arg2) );
return expr.first.func(eval(ss, std::get<0>(expr.second)), }
eval(ss, std::get<1>(expr.second)), ///////////////////////
eval(ss, std::get<2>(expr.second))); // eval three operands
///////////////////////
template <typename Op, typename T1, typename T2, typename T3> accelerator_inline
auto eval(const unsigned int ss, const LatticeTrinaryExpression<Op, T1, T2, T3> &expr)
-> decltype(expr.op.func(eval(ss, expr.arg1), eval(ss, expr.arg2), eval(ss, expr.arg3)))
{
return expr.op.func(eval(ss, expr.arg1), eval(ss, expr.arg2), eval(ss, expr.arg3));
} }
////////////////////////////////////////////////////////////////////////// //////////////////////////////////////////////////////////////////////////
// Obtain the grid from an expression, ensuring conformable. This must follow a // Obtain the grid from an expression, ensuring conformable. This must follow a
// tree recursion // tree recursion; must retain grid pointer in the LatticeView class which sucks
// Use a different method, and make it void *.
// Perhaps a conformable method.
////////////////////////////////////////////////////////////////////////// //////////////////////////////////////////////////////////////////////////
template <class T1, template <class T1,typename std::enable_if<is_lattice<T1>::value, T1>::type * = nullptr>
typename std::enable_if<is_lattice<T1>::value, T1>::type * = nullptr>
inline void GridFromExpression(GridBase *&grid, const T1 &lat) // Lattice leaf inline void GridFromExpression(GridBase *&grid, const T1 &lat) // Lattice leaf
{ {
if (grid) { lat.Conformable(grid);
conformable(grid, lat.Grid());
}
grid = lat.Grid();
} }
template <class T1,
typename std::enable_if<!is_lattice<T1>::value, T1>::type * = nullptr> template <class T1,typename std::enable_if<!is_lattice<T1>::value, T1>::type * = nullptr>
inline void GridFromExpression(GridBase *&grid, accelerator_inline
const T1 &notlat) // non-lattice leaf void GridFromExpression(GridBase *&grid,const T1 &notlat) // non-lattice leaf
{} {}
template <typename Op, typename T1> template <typename Op, typename T1>
inline void GridFromExpression(GridBase *&grid, accelerator_inline
const LatticeUnaryExpression<Op, T1> &expr) { void GridFromExpression(GridBase *&grid,const LatticeUnaryExpression<Op, T1> &expr)
GridFromExpression(grid, std::get<0>(expr.second)); // recurse {
GridFromExpression(grid, expr.arg1); // recurse
} }
template <typename Op, typename T1, typename T2> template <typename Op, typename T1, typename T2>
inline void GridFromExpression( accelerator_inline
GridBase *&grid, const LatticeBinaryExpression<Op, T1, T2> &expr) { void GridFromExpression(GridBase *&grid, const LatticeBinaryExpression<Op, T1, T2> &expr)
GridFromExpression(grid, std::get<0>(expr.second)); // recurse {
GridFromExpression(grid, std::get<1>(expr.second)); GridFromExpression(grid, expr.arg1); // recurse
GridFromExpression(grid, expr.arg2);
} }
template <typename Op, typename T1, typename T2, typename T3> template <typename Op, typename T1, typename T2, typename T3>
inline void GridFromExpression( accelerator_inline
GridBase *&grid, const LatticeTrinaryExpression<Op, T1, T2, T3> &expr) { void GridFromExpression(GridBase *&grid, const LatticeTrinaryExpression<Op, T1, T2, T3> &expr)
GridFromExpression(grid, std::get<0>(expr.second)); // recurse {
GridFromExpression(grid, std::get<1>(expr.second)); GridFromExpression(grid, expr.arg1); // recurse
GridFromExpression(grid, std::get<2>(expr.second)); GridFromExpression(grid, expr.arg2); // recurse
GridFromExpression(grid, expr.arg3); // recurse
} }
////////////////////////////////////////////////////////////////////////// //////////////////////////////////////////////////////////////////////////
// Obtain the CB from an expression, ensuring conformable. This must follow a // Obtain the CB from an expression, ensuring conformable. This must follow a
// tree recursion // tree recursion
////////////////////////////////////////////////////////////////////////// //////////////////////////////////////////////////////////////////////////
template <class T1, template <class T1,typename std::enable_if<is_lattice<T1>::value, T1>::type * = nullptr>
typename std::enable_if<is_lattice<T1>::value, T1>::type * = nullptr>
inline void CBFromExpression(int &cb, const T1 &lat) // Lattice leaf inline void CBFromExpression(int &cb, const T1 &lat) // Lattice leaf
{ {
if ((cb == Odd) || (cb == Even)) { if ((cb == Odd) || (cb == Even)) {
assert(cb == lat.Checkerboard()); assert(cb == lat.Checkerboard());
} }
cb = lat.Checkerboard(); cb = lat.Checkerboard();
// std::cout<<GridLogMessage<<"Lattice leaf cb "<<cb<<std::endl;
} }
template <class T1, template <class T1,typename std::enable_if<!is_lattice<T1>::value, T1>::type * = nullptr>
typename std::enable_if<!is_lattice<T1>::value, T1>::type * = nullptr>
inline void CBFromExpression(int &cb, const T1 &notlat) // non-lattice leaf inline void CBFromExpression(int &cb, const T1 &notlat) // non-lattice leaf
{ {
// std::cout<<GridLogMessage<<"Non lattice leaf cb"<<cb<<std::endl;
}
template <typename Op, typename T1>
inline void CBFromExpression(int &cb,
const LatticeUnaryExpression<Op, T1> &expr) {
CBFromExpression(cb, std::get<0>(expr.second)); // recurse
// std::cout<<GridLogMessage<<"Unary node cb "<<cb<<std::endl;
} }
template <typename Op, typename T1, typename T2> template <typename Op, typename T1> inline
inline void CBFromExpression(int &cb, void CBFromExpression(int &cb,const LatticeUnaryExpression<Op, T1> &expr)
const LatticeBinaryExpression<Op, T1, T2> &expr) { {
CBFromExpression(cb, std::get<0>(expr.second)); // recurse CBFromExpression(cb, expr.arg1); // recurse AST
CBFromExpression(cb, std::get<1>(expr.second)); }
// std::cout<<GridLogMessage<<"Binary node cb "<<cb<<std::endl;
template <typename Op, typename T1, typename T2> inline
void CBFromExpression(int &cb,const LatticeBinaryExpression<Op, T1, T2> &expr)
{
CBFromExpression(cb, expr.arg1); // recurse AST
CBFromExpression(cb, expr.arg2); // recurse AST
} }
template <typename Op, typename T1, typename T2, typename T3> template <typename Op, typename T1, typename T2, typename T3>
inline void CBFromExpression( inline void CBFromExpression(int &cb, const LatticeTrinaryExpression<Op, T1, T2, T3> &expr)
int &cb, const LatticeTrinaryExpression<Op, T1, T2, T3> &expr) { {
CBFromExpression(cb, std::get<0>(expr.second)); // recurse CBFromExpression(cb, expr.arg1); // recurse AST
CBFromExpression(cb, std::get<1>(expr.second)); CBFromExpression(cb, expr.arg2); // recurse AST
CBFromExpression(cb, std::get<2>(expr.second)); CBFromExpression(cb, expr.arg3); // recurse AST
// std::cout<<GridLogMessage<<"Trinary node cb "<<cb<<std::endl;
} }
//////////////////////////////////////////// ////////////////////////////////////////////
@ -253,15 +243,16 @@ GridUnopClass(UnaryExp, exp(a));
template <class left, class right> \ template <class left, class right> \
struct name { \ struct name { \
static auto inline func(const left &lhs, const right &rhs) \ static auto inline func(const left &lhs, const right &rhs) \
-> decltype(combination) const { \ -> decltype(combination) const \
{ \
return combination; \ return combination; \
} \ } \
} };
GridBinOpClass(BinaryAdd, lhs + rhs); GridBinOpClass(BinaryAdd, lhs + rhs);
GridBinOpClass(BinarySub, lhs - rhs); GridBinOpClass(BinarySub, lhs - rhs);
GridBinOpClass(BinaryMul, lhs *rhs); GridBinOpClass(BinaryMul, lhs *rhs);
GridBinOpClass(BinaryDiv, lhs /rhs); GridBinOpClass(BinaryDiv, lhs /rhs);
GridBinOpClass(BinaryAnd, lhs &rhs); GridBinOpClass(BinaryAnd, lhs &rhs);
GridBinOpClass(BinaryOr, lhs | rhs); GridBinOpClass(BinaryOr, lhs | rhs);
GridBinOpClass(BinaryAndAnd, lhs &&rhs); GridBinOpClass(BinaryAndAnd, lhs &&rhs);
@ -273,68 +264,50 @@ GridBinOpClass(BinaryOrOr, lhs || rhs);
#define GridTrinOpClass(name, combination) \ #define GridTrinOpClass(name, combination) \
template <class predicate, class left, class right> \ template <class predicate, class left, class right> \
struct name { \ struct name { \
static auto inline func(const predicate &pred, const left &lhs, \ static auto inline func(const predicate &pred, const left &lhs, const right &rhs) \
const right &rhs) -> decltype(combination) const { \ -> decltype(combination) const \
{ \
return combination; \ return combination; \
} \ } \
} };
GridTrinOpClass( GridTrinOpClass(TrinaryWhere,
TrinaryWhere, (predicatedWhere<predicate,
(predicatedWhere<predicate, typename std::remove_reference<left>::type, typename std::remove_reference<left>::type,
typename std::remove_reference<right>::type>(pred, lhs, typename std::remove_reference<right>::type>(pred, lhs,rhs)));
rhs)));
//////////////////////////////////////////// ////////////////////////////////////////////
// Operator syntactical glue // Operator syntactical glue
//////////////////////////////////////////// ////////////////////////////////////////////
#define GRID_UNOP(name) name<decltype(eval(0, arg))> #define GRID_UNOP(name) name<decltype(eval(0, arg))>
#define GRID_BINOP(name) name<decltype(eval(0, lhs)), decltype(eval(0, rhs))> #define GRID_BINOP(name) name<decltype(eval(0, lhs)), decltype(eval(0, rhs))>
#define GRID_TRINOP(name) \ #define GRID_TRINOP(name) name<decltype(eval(0, pred)), decltype(eval(0, lhs)), decltype(eval(0, rhs))>
name<decltype(eval(0, pred)), decltype(eval(0, lhs)), decltype(eval(0, rhs))>
#define GRID_DEF_UNOP(op, name) \ #define GRID_DEF_UNOP(op, name) \
template <typename T1, \ template <typename T1, typename std::enable_if<is_lattice<T1>::value||is_lattice_expr<T1>::value,T1>::type * = nullptr> \
typename std::enable_if<is_lattice<T1>::value || \ inline auto op(const T1 &arg) ->decltype(LatticeUnaryExpression<GRID_UNOP(name),T1>(GRID_UNOP(name)(), arg)) \
is_lattice_expr<T1>::value, \ { \
T1>::type * = nullptr> \ return LatticeUnaryExpression<GRID_UNOP(name),T1>(GRID_UNOP(name)(), arg); \
inline auto op(const T1 &arg) \
->decltype(LatticeUnaryExpression<GRID_UNOP(name), const T1 &>( \
std::make_pair(GRID_UNOP(name)(), std::forward_as_tuple(arg)))) { \
return LatticeUnaryExpression<GRID_UNOP(name), const T1 &>( \
std::make_pair(GRID_UNOP(name)(), std::forward_as_tuple(arg))); \
} }
#define GRID_BINOP_LEFT(op, name) \ #define GRID_BINOP_LEFT(op, name) \
template <typename T1, typename T2, \ template <typename T1, typename T2, \
typename std::enable_if<is_lattice<T1>::value || \ typename std::enable_if<is_lattice<T1>::value||is_lattice_expr<T1>::value,T1>::type * = nullptr> \
is_lattice_expr<T1>::value, \
T1>::type * = nullptr> \
inline auto op(const T1 &lhs, const T2 &rhs) \ inline auto op(const T1 &lhs, const T2 &rhs) \
->decltype( \ ->decltype(LatticeBinaryExpression<GRID_BINOP(name),T1,T2>(GRID_BINOP(name)(),lhs,rhs)) \
LatticeBinaryExpression<GRID_BINOP(name), const T1 &, const T2 &>( \ { \
std::make_pair(GRID_BINOP(name)(), \ return LatticeBinaryExpression<GRID_BINOP(name),T1,T2>(GRID_BINOP(name)(),lhs,rhs);\
std::forward_as_tuple(lhs, rhs)))) { \
return LatticeBinaryExpression<GRID_BINOP(name), const T1 &, const T2 &>( \
std::make_pair(GRID_BINOP(name)(), std::forward_as_tuple(lhs, rhs))); \
} }
#define GRID_BINOP_RIGHT(op, name) \ #define GRID_BINOP_RIGHT(op, name) \
template <typename T1, typename T2, \ template <typename T1, typename T2, \
typename std::enable_if<!is_lattice<T1>::value && \ typename std::enable_if<!is_lattice<T1>::value&&!is_lattice_expr<T1>::value,T1>::type * = nullptr, \
!is_lattice_expr<T1>::value, \ typename std::enable_if< is_lattice<T2>::value|| is_lattice_expr<T2>::value,T2>::type * = nullptr> \
T1>::type * = nullptr, \
typename std::enable_if<is_lattice<T2>::value || \
is_lattice_expr<T2>::value, \
T2>::type * = nullptr> \
inline auto op(const T1 &lhs, const T2 &rhs) \ inline auto op(const T1 &lhs, const T2 &rhs) \
->decltype( \ ->decltype(LatticeBinaryExpression<GRID_BINOP(name),T1,T2>(GRID_BINOP(name)(),lhs, rhs)) \
LatticeBinaryExpression<GRID_BINOP(name), const T1 &, const T2 &>( \ { \
std::make_pair(GRID_BINOP(name)(), \ return LatticeBinaryExpression<GRID_BINOP(name),T1,T2>(GRID_BINOP(name)(),lhs, rhs); \
std::forward_as_tuple(lhs, rhs)))) { \
return LatticeBinaryExpression<GRID_BINOP(name), const T1 &, const T2 &>( \
std::make_pair(GRID_BINOP(name)(), std::forward_as_tuple(lhs, rhs))); \
} }
#define GRID_DEF_BINOP(op, name) \ #define GRID_DEF_BINOP(op, name) \
@ -344,18 +317,14 @@ GridTrinOpClass(
#define GRID_DEF_TRINOP(op, name) \ #define GRID_DEF_TRINOP(op, name) \
template <typename T1, typename T2, typename T3> \ template <typename T1, typename T2, typename T3> \
inline auto op(const T1 &pred, const T2 &lhs, const T3 &rhs) \ inline auto op(const T1 &pred, const T2 &lhs, const T3 &rhs) \
->decltype( \ ->decltype(LatticeTrinaryExpression<GRID_TRINOP(name),T1,T2,T3>(GRID_TRINOP(name)(),pred, lhs, rhs)) \
LatticeTrinaryExpression<GRID_TRINOP(name), const T1 &, const T2 &, \ { \
const T3 &>(std::make_pair( \ return LatticeTrinaryExpression<GRID_TRINOP(name),T1,T2,T3>(GRID_TRINOP(name)(),pred, lhs, rhs); \
GRID_TRINOP(name)(), std::forward_as_tuple(pred, lhs, rhs)))) { \
return LatticeTrinaryExpression<GRID_TRINOP(name), const T1 &, const T2 &, \
const T3 &>(std::make_pair( \
GRID_TRINOP(name)(), std::forward_as_tuple(pred, lhs, rhs))); \
} }
//////////////////////// ////////////////////////
// Operator definitions // Operator definitions
//////////////////////// ////////////////////////
GRID_DEF_UNOP(operator-, UnarySub); GRID_DEF_UNOP(operator-, UnarySub);
GRID_DEF_UNOP(Not, UnaryNot); GRID_DEF_UNOP(Not, UnaryNot);
GRID_DEF_UNOP(operator!, UnaryNot); GRID_DEF_UNOP(operator!, UnaryNot);
@ -399,29 +368,27 @@ GRID_DEF_TRINOP(where, TrinaryWhere);
///////////////////////////////////////////////////////////// /////////////////////////////////////////////////////////////
template <class Op, class T1> template <class Op, class T1>
auto closure(const LatticeUnaryExpression<Op, T1> &expr) auto closure(const LatticeUnaryExpression<Op, T1> &expr)
-> Lattice<decltype(expr.first.func(eval(0, std::get<0>(expr.second))))> { -> Lattice<decltype(expr.op.func(eval(0, expr.arg1)))>
Lattice<decltype(expr.first.func(eval(0, std::get<0>(expr.second))))> ret( {
expr); Lattice<decltype(expr.op.func(eval(0, expr.arg1)))> ret(expr);
return ret; return ret;
} }
template <class Op, class T1, class T2> template <class Op, class T1, class T2>
auto closure(const LatticeBinaryExpression<Op, T1, T2> &expr) auto closure(const LatticeBinaryExpression<Op, T1, T2> &expr)
-> Lattice<decltype(expr.first.func(eval(0, std::get<0>(expr.second)), -> Lattice<decltype(expr.op.func(eval(0, expr.arg1),eval(0, expr.arg2)))>
eval(0, std::get<1>(expr.second))))> { {
Lattice<decltype(expr.first.func(eval(0, std::get<0>(expr.second)), Lattice<decltype(expr.op.func(eval(0, expr.arg1),eval(0, expr.arg2)))> ret(expr);
eval(0, std::get<1>(expr.second))))>
ret(expr);
return ret; return ret;
} }
template <class Op, class T1, class T2, class T3> template <class Op, class T1, class T2, class T3>
auto closure(const LatticeTrinaryExpression<Op, T1, T2, T3> &expr) auto closure(const LatticeTrinaryExpression<Op, T1, T2, T3> &expr)
-> Lattice<decltype(expr.first.func(eval(0, std::get<0>(expr.second)), -> Lattice<decltype(expr.op.func(eval(0, expr.arg1),
eval(0, std::get<1>(expr.second)), eval(0, expr.arg2),
eval(0, std::get<2>(expr.second))))> { eval(0, expr.arg3)))>
Lattice<decltype(expr.first.func(eval(0, std::get<0>(expr.second)), {
eval(0, std::get<1>(expr.second)), Lattice<decltype(expr.op.func(eval(0, expr.arg1),
eval(0, std::get<2>(expr.second))))> eval(0, expr.arg2),
ret(expr); eval(0, expr.arg3)))> ret(expr);
return ret; return ret;
} }
@ -432,34 +399,7 @@ auto closure(const LatticeTrinaryExpression<Op, T1, T2, T3> &expr)
#undef GRID_DEF_UNOP #undef GRID_DEF_UNOP
#undef GRID_DEF_BINOP #undef GRID_DEF_BINOP
#undef GRID_DEF_TRINOP #undef GRID_DEF_TRINOP
NAMESPACE_END(Grid); NAMESPACE_END(Grid);
#if 0
using namespace Grid;
int main(int argc,char **argv){
Lattice<double> v1(16);
Lattice<double> v2(16);
Lattice<double> v3(16);
BinaryAdd<double,double> tmp;
LatticeBinaryExpression<BinaryAdd<double,double>,Lattice<double> &,Lattice<double> &>
expr(std::make_pair(tmp,
std::forward_as_tuple(v1,v2)));
tmp.func(eval(0,v1),eval(0,v2));
auto var = v1+v2;
std::cout<<GridLogMessage<<typeid(var).name()<<std::endl;
v3=v1+v2;
v3=v1+v2+v1*v2;
};
void testit(Lattice<double> &v1,Lattice<double> &v2,Lattice<double> &v3)
{
v3=v1+v2+v1*v2;
}
#endif
#endif #endif

View File

@ -36,17 +36,20 @@ NAMESPACE_BEGIN(Grid);
template<class obj1,class obj2,class obj3> inline template<class obj1,class obj2,class obj3> inline
void mult(Lattice<obj1> &ret,const Lattice<obj2> &lhs,const Lattice<obj3> &rhs){ void mult(Lattice<obj1> &ret,const Lattice<obj2> &lhs,const Lattice<obj3> &rhs){
ret.Checkerboard() = lhs.Checkerboard(); ret.Checkerboard() = lhs.Checkerboard();
auto ret_v = ret.View();
auto lhs_v = lhs.View();
auto rhs_v = rhs.View();
conformable(ret,rhs); conformable(ret,rhs);
conformable(lhs,rhs); conformable(lhs,rhs);
#ifdef STREAMING_STORES #ifdef STREAMING_STORES
accelerator_loop(ss,lhs,{ accelerator_loop(ss,lhs_v,{
obj1 tmp; obj1 tmp;
mult(&tmp,&lhs[ss],&rhs[ss]); mult(&tmp,&lhs_v[ss],&rhs_v[ss]);
vstream(ret[ss],tmp); vstream(ret_v[ss],tmp);
}); });
#else #else
accelerator_loop(ss,lhs,{ accelerator_loop(ss,lhs_v,{
mult(&ret[ss],&lhs[ss],&rhs[ss]); mult(&ret_v[ss],&lhs_v[ss],&rhs_v[ss]);
}); });
#endif #endif
} }
@ -56,15 +59,18 @@ void mac(Lattice<obj1> &ret,const Lattice<obj2> &lhs,const Lattice<obj3> &rhs){
ret.Checkerboard() = lhs.Checkerboard(); ret.Checkerboard() = lhs.Checkerboard();
conformable(ret,rhs); conformable(ret,rhs);
conformable(lhs,rhs); conformable(lhs,rhs);
auto ret_v = ret.View();
auto lhs_v = lhs.View();
auto rhs_v = rhs.View();
#ifdef STREAMING_STORES #ifdef STREAMING_STORES
accelerator_loop(ss,lhs,{ accelerator_loop(ss,lhs_v,{
obj1 tmp; obj1 tmp;
mac(&tmp,&lhs[ss],&rhs[ss]); mac(&tmp,&lhs_v[ss],&rhs_v[ss]);
vstream(ret[ss],tmp); vstream(ret_v[ss],tmp);
}); });
#else #else
accelerator_loop(ss,lhs,{ accelerator_loop(ss,lhs_v,{
mac(&ret[ss],&lhs[ss],&rhs[ss]); mac(&ret_v[ss],&lhs_v[ss],&rhs_v[ss]);
}); });
#endif #endif
} }
@ -74,15 +80,18 @@ void sub(Lattice<obj1> &ret,const Lattice<obj2> &lhs,const Lattice<obj3> &rhs){
ret.Checkerboard() = lhs.Checkerboard(); ret.Checkerboard() = lhs.Checkerboard();
conformable(ret,rhs); conformable(ret,rhs);
conformable(lhs,rhs); conformable(lhs,rhs);
auto ret_v = ret.View();
auto lhs_v = lhs.View();
auto rhs_v = rhs.View();
#ifdef STREAMING_STORES #ifdef STREAMING_STORES
accelerator_loop(ss,lhs,{ accelerator_loop(ss,lhs_v,{
obj1 tmp; obj1 tmp;
sub(&tmp,&lhs[ss],&rhs[ss]); sub(&tmp,&lhs_v[ss],&rhs_v[ss]);
vstream(ret[ss],tmp); vstream(ret_v[ss],tmp);
}); });
#else #else
accelerator_loop(ss,lhs,{ accelerator_loop(ss,lhs_v,{
sub(&ret[ss],&lhs[ss],&rhs[ss]); sub(&ret[ss],&lhs_v[ss],&rhs_v[ss]);
}); });
#endif #endif
} }
@ -91,15 +100,18 @@ void add(Lattice<obj1> &ret,const Lattice<obj2> &lhs,const Lattice<obj3> &rhs){
ret.Checkerboard() = lhs.Checkerboard(); ret.Checkerboard() = lhs.Checkerboard();
conformable(ret,rhs); conformable(ret,rhs);
conformable(lhs,rhs); conformable(lhs,rhs);
auto ret_v = ret.View();
auto lhs_v = lhs.View();
auto rhs_v = rhs.View();
#ifdef STREAMING_STORES #ifdef STREAMING_STORES
accelerator_loop(ss,lhs,{ accelerator_loop(ss,lhs_v,{
obj1 tmp; obj1 tmp;
add(&tmp,&lhs[ss],&rhs[ss]); add(&tmp,&lhs_v[ss],&rhs_v[ss]);
vstream(ret[ss],tmp); vstream(ret_v[ss],tmp);
}); });
#else #else
accelerator_loop(ss,lhs,{ accelerator_loop(ss,lhs_v,{
add(&ret[ss],&lhs[ss],&rhs[ss]); add(&ret_v[ss],&lhs_v[ss],&rhs_v[ss]);
}); });
#endif #endif
} }
@ -111,10 +123,12 @@ template<class obj1,class obj2,class obj3> inline
void mult(Lattice<obj1> &ret,const Lattice<obj2> &lhs,const obj3 &rhs){ void mult(Lattice<obj1> &ret,const Lattice<obj2> &lhs,const obj3 &rhs){
ret.Checkerboard() = lhs.Checkerboard(); ret.Checkerboard() = lhs.Checkerboard();
conformable(lhs,ret); conformable(lhs,ret);
accelerator_loop(ss,lhs,{ auto ret_v = ret.View();
auto lhs_v = lhs.View();
accelerator_loop(ss,lhs_v,{
obj1 tmp; obj1 tmp;
mult(&tmp,&lhs[ss],&rhs); mult(&tmp,&lhs_v[ss],&rhs);
vstream(ret[ss],tmp); vstream(ret_v[ss],tmp);
}); });
} }
@ -122,10 +136,12 @@ template<class obj1,class obj2,class obj3> inline
void mac(Lattice<obj1> &ret,const Lattice<obj2> &lhs,const obj3 &rhs){ void mac(Lattice<obj1> &ret,const Lattice<obj2> &lhs,const obj3 &rhs){
ret.Checkerboard() = lhs.Checkerboard(); ret.Checkerboard() = lhs.Checkerboard();
conformable(ret,lhs); conformable(ret,lhs);
accelerator_loop(ss,lhs,{ auto ret_v = ret.View();
auto lhs_v = lhs.View();
accelerator_loop(ss,lhs_v,{
obj1 tmp; obj1 tmp;
mac(&tmp,&lhs[ss],&rhs); mac(&tmp,&lhs_v[ss],&rhs);
vstream(ret[ss],tmp); vstream(ret_v[ss],tmp);
}); });
} }
@ -133,15 +149,17 @@ template<class obj1,class obj2,class obj3> inline
void sub(Lattice<obj1> &ret,const Lattice<obj2> &lhs,const obj3 &rhs){ void sub(Lattice<obj1> &ret,const Lattice<obj2> &lhs,const obj3 &rhs){
ret.Checkerboard() = lhs.Checkerboard(); ret.Checkerboard() = lhs.Checkerboard();
conformable(ret,lhs); conformable(ret,lhs);
auto ret_v = ret.View();
auto lhs_v = lhs.View();
#ifdef STREAMING_STORES #ifdef STREAMING_STORES
accelerator_loop(ss,lhs,{ accelerator_loop(ss,lhs_v,{
obj1 tmp; obj1 tmp;
sub(&tmp,&lhs[ss],&rhs); sub(&tmp,&lhs_v[ss],&rhs);
vstream(ret[ss],tmp); vstream(ret_v[ss],tmp);
}); });
#else #else
accelerator_loop(ss,lhs,{ accelerator_loop(ss,lhs_v,{
sub(&ret[ss],&lhs[ss],&rhs); sub(&ret_v[ss],&lhs_v[ss],&rhs);
}); });
#endif #endif
} }
@ -149,15 +167,17 @@ template<class obj1,class obj2,class obj3> inline
void add(Lattice<obj1> &ret,const Lattice<obj2> &lhs,const obj3 &rhs){ void add(Lattice<obj1> &ret,const Lattice<obj2> &lhs,const obj3 &rhs){
ret.Checkerboard() = lhs.Checkerboard(); ret.Checkerboard() = lhs.Checkerboard();
conformable(lhs,ret); conformable(lhs,ret);
auto ret_v = ret.View();
auto lhs_v = lhs.View();
#ifdef STREAMING_STORES #ifdef STREAMING_STORES
accelerator_loop(ss,lhs,{ accelerator_loop(ss,lhs_v,{
obj1 tmp; obj1 tmp;
add(&tmp,&lhs[ss],&rhs); add(&tmp,&lhs_v[ss],&rhs);
vstream(ret[ss],tmp); vstream(ret_v[ss],tmp);
}); });
#else #else
accelerator_loop(ss,lhs,{ accelerator_loop(ss,lhs_v,{
add(&ret[ss],&lhs[ss],&rhs); add(&ret_v[ss],&lhs_v[ss],&rhs);
}); });
#endif #endif
} }
@ -169,15 +189,17 @@ template<class obj1,class obj2,class obj3> inline
void mult(Lattice<obj1> &ret,const obj2 &lhs,const Lattice<obj3> &rhs){ void mult(Lattice<obj1> &ret,const obj2 &lhs,const Lattice<obj3> &rhs){
ret.Checkerboard() = rhs.Checkerboard(); ret.Checkerboard() = rhs.Checkerboard();
conformable(ret,rhs); conformable(ret,rhs);
auto ret_v = ret.View();
auto rhs_v = lhs.View();
#ifdef STREAMING_STORES #ifdef STREAMING_STORES
accelerator_loop(ss,rhs,{ accelerator_loop(ss,rhs_v,{
obj1 tmp; obj1 tmp;
mult(&tmp,&lhs,&rhs[ss]); mult(&tmp,&lhs,&rhs_v[ss]);
vstream(ret[ss],tmp); vstream(ret_v[ss],tmp);
}); });
#else #else
accelerator_loop(ss,rhs,{ accelerator_loop(ss,rhs_v,{
mult(&ret[ss],&lhs,&rhs[ss]); mult(&ret_v[ss],&lhs,&rhs_v[ss]);
}); });
#endif #endif
} }
@ -186,15 +208,17 @@ template<class obj1,class obj2,class obj3> inline
void mac(Lattice<obj1> &ret,const obj2 &lhs,const Lattice<obj3> &rhs){ void mac(Lattice<obj1> &ret,const obj2 &lhs,const Lattice<obj3> &rhs){
ret.Checkerboard() = rhs.Checkerboard(); ret.Checkerboard() = rhs.Checkerboard();
conformable(ret,rhs); conformable(ret,rhs);
auto ret_v = ret.View();
auto rhs_v = lhs.View();
#ifdef STREAMING_STORES #ifdef STREAMING_STORES
accelerator_loop(ss,rhs,{ accelerator_loop(ss,rhs_v,{
obj1 tmp; obj1 tmp;
mac(&tmp,&lhs,&rhs[ss]); mac(&tmp,&lhs,&rhs_v[ss]);
vstream(ret[ss],tmp); vstream(ret_v[ss],tmp);
}); });
#else #else
accelerator_loop(ss,rhs,{ accelerator_loop(ss,rhs_v,{
mac(&ret[ss],&lhs,&rhs[ss]); mac(&ret_v[ss],&lhs,&rhs_v[ss]);
}); });
#endif #endif
} }
@ -203,15 +227,17 @@ template<class obj1,class obj2,class obj3> inline
void sub(Lattice<obj1> &ret,const obj2 &lhs,const Lattice<obj3> &rhs){ void sub(Lattice<obj1> &ret,const obj2 &lhs,const Lattice<obj3> &rhs){
ret.Checkerboard() = rhs.Checkerboard(); ret.Checkerboard() = rhs.Checkerboard();
conformable(ret,rhs); conformable(ret,rhs);
auto ret_v = ret.View();
auto rhs_v = lhs.View();
#ifdef STREAMING_STORES #ifdef STREAMING_STORES
accelerator_loop(ss,rhs,{ accelerator_loop(ss,rhs_v,{
obj1 tmp; obj1 tmp;
sub(&tmp,&lhs,&rhs[ss]); sub(&tmp,&lhs,&rhs_v[ss]);
vstream(ret[ss],tmp); vstream(ret_v[ss],tmp);
}); });
#else #else
accelerator_loop(ss,rhs,{ accelerator_loop(ss,rhs_v,{
sub(&ret[ss],&lhs,&rhs[ss]); sub(&ret_v[ss],&lhs,&rhs_v[ss]);
}); });
#endif #endif
} }
@ -219,15 +245,17 @@ template<class obj1,class obj2,class obj3> inline
void add(Lattice<obj1> &ret,const obj2 &lhs,const Lattice<obj3> &rhs){ void add(Lattice<obj1> &ret,const obj2 &lhs,const Lattice<obj3> &rhs){
ret.Checkerboard() = rhs.Checkerboard(); ret.Checkerboard() = rhs.Checkerboard();
conformable(ret,rhs); conformable(ret,rhs);
auto ret_v = ret.View();
auto rhs_v = lhs.View();
#ifdef STREAMING_STORES #ifdef STREAMING_STORES
accelerator_loop(ss,rhs,{ accelerator_loop(ss,rhs_v,{
obj1 tmp; obj1 tmp;
add(&tmp,&lhs,&rhs[ss]); add(&tmp,&lhs,&rhs_v[ss]);
vstream(ret[ss],tmp); vstream(ret_v[ss],tmp);
}); });
#else #else
accelerator_loop(ss,rhs,{ accelerator_loop(ss,rhs_v,{
add(&ret[ss],&lhs,&rhs[ss]); add(&ret_v[ss],&lhs,&rhs_v[ss]);
}); });
#endif #endif
} }
@ -237,14 +265,17 @@ void axpy(Lattice<vobj> &ret,sobj a,const Lattice<vobj> &x,const Lattice<vobj> &
ret.Checkerboard() = x.Checkerboard(); ret.Checkerboard() = x.Checkerboard();
conformable(ret,x); conformable(ret,x);
conformable(x,y); conformable(x,y);
auto ret_v = ret.View();
auto x_v = x.View();
auto y_v = y.View();
#ifdef STREAMING_STORES #ifdef STREAMING_STORES
accelerator_loop(ss,x,{ accelerator_loop(ss,x_v,{
vobj tmp = a*x[ss]+y[ss]; vobj tmp = a*x_v[ss]+y_v[ss];
vstream(ret[ss],tmp); vstream(ret_v[ss],tmp);
}); });
#else #else
accelerator_loop(ss,x,{ accelerator_loop(ss,x_v,{
ret[ss]=a*x[ss]+y[ss]; ret_v[ss]=a*x_v[ss]+y_v[ss];
}); });
#endif #endif
} }
@ -253,14 +284,17 @@ void axpby(Lattice<vobj> &ret,sobj a,sobj b,const Lattice<vobj> &x,const Lattice
ret.Checkerboard() = x.Checkerboard(); ret.Checkerboard() = x.Checkerboard();
conformable(ret,x); conformable(ret,x);
conformable(x,y); conformable(x,y);
auto ret_v = ret.View();
auto x_v = x.View();
auto y_v = y.View();
#ifdef STREAMING_STORES #ifdef STREAMING_STORES
accelerator_loop(ss,x,{ accelerator_loop(ss,x_v,{
vobj tmp = a*x[ss]+b*y[ss]; vobj tmp = a*x_v[ss]+b*y_v[ss];
vstream(ret[ss],tmp); vstream(ret_v[ss],tmp);
}); });
#else #else
accelerator_loop(ss,x,{ accelerator_loop(ss,x_v,{
ret[ss]=a*x[ss]+b*y[ss]; ret_v[ss]=a*x_v[ss]+b*y_v[ss];
}); });
#endif #endif
} }

View File

@ -47,8 +47,11 @@ template<class vfunctor,class lobj,class robj>
inline Lattice<vInteger> LLComparison(vfunctor op,const Lattice<lobj> &lhs,const Lattice<robj> &rhs) inline Lattice<vInteger> LLComparison(vfunctor op,const Lattice<lobj> &lhs,const Lattice<robj> &rhs)
{ {
Lattice<vInteger> ret(rhs.Grid()); Lattice<vInteger> ret(rhs.Grid());
accelerator_loop( ss, rhs, { auto lhs_v = lhs.View();
ret[ss]=op(lhs[ss],rhs[ss]); auto rhs_v = rhs.View();
auto ret_v = ret.View();
accelerator_loop( ss, rhs_v, {
ret_v[ss]=op(lhs_v[ss],rhs_v[ss]);
}); });
return ret; return ret;
} }
@ -59,8 +62,10 @@ template<class vfunctor,class lobj,class robj>
inline Lattice<vInteger> LSComparison(vfunctor op,const Lattice<lobj> &lhs,const robj &rhs) inline Lattice<vInteger> LSComparison(vfunctor op,const Lattice<lobj> &lhs,const robj &rhs)
{ {
Lattice<vInteger> ret(lhs.Grid()); Lattice<vInteger> ret(lhs.Grid());
accelerator_loop( ss, lhs, { auto lhs_v = lhs.View();
ret[ss]=op(lhs[ss],rhs); auto ret_v = ret.View();
accelerator_loop( ss, lhs_v, {
ret_v[ss]=op(lhs_v[ss],rhs);
}); });
return ret; return ret;
} }
@ -71,8 +76,10 @@ template<class vfunctor,class lobj,class robj>
inline Lattice<vInteger> SLComparison(vfunctor op,const lobj &lhs,const Lattice<robj> &rhs) inline Lattice<vInteger> SLComparison(vfunctor op,const lobj &lhs,const Lattice<robj> &rhs)
{ {
Lattice<vInteger> ret(rhs.Grid()); Lattice<vInteger> ret(rhs.Grid());
accelerator_loop( ss, rhs, { auto rhs_v = rhs.View();
ret[ss]=op(lhs[ss],rhs); auto ret_v = ret.View();
accelerator_loop( ss, rhs_v, {
ret_v[ss]=op(lhs,rhs_v[ss]);
}); });
return ret; return ret;
} }

View File

@ -44,42 +44,42 @@ NAMESPACE_BEGIN(Grid);
// //
template<class lobj,class robj> class veq { template<class lobj,class robj> class veq {
public: public:
vInteger operator()(const lobj &lhs, const robj &rhs) accelerator vInteger operator()(const lobj &lhs, const robj &rhs)
{ {
return (lhs) == (rhs); return (lhs) == (rhs);
} }
}; };
template<class lobj,class robj> class vne { template<class lobj,class robj> class vne {
public: public:
vInteger operator()(const lobj &lhs, const robj &rhs) accelerator vInteger operator()(const lobj &lhs, const robj &rhs)
{ {
return (lhs) != (rhs); return (lhs) != (rhs);
} }
}; };
template<class lobj,class robj> class vlt { template<class lobj,class robj> class vlt {
public: public:
vInteger operator()(const lobj &lhs, const robj &rhs) accelerator vInteger operator()(const lobj &lhs, const robj &rhs)
{ {
return (lhs) < (rhs); return (lhs) < (rhs);
} }
}; };
template<class lobj,class robj> class vle { template<class lobj,class robj> class vle {
public: public:
vInteger operator()(const lobj &lhs, const robj &rhs) accelerator vInteger operator()(const lobj &lhs, const robj &rhs)
{ {
return (lhs) <= (rhs); return (lhs) <= (rhs);
} }
}; };
template<class lobj,class robj> class vgt { template<class lobj,class robj> class vgt {
public: public:
vInteger operator()(const lobj &lhs, const robj &rhs) accelerator vInteger operator()(const lobj &lhs, const robj &rhs)
{ {
return (lhs) > (rhs); return (lhs) > (rhs);
} }
}; };
template<class lobj,class robj> class vge { template<class lobj,class robj> class vge {
public: public:
vInteger operator()(const lobj &lhs, const robj &rhs) accelerator vInteger operator()(const lobj &lhs, const robj &rhs)
{ {
return (lhs) >= (rhs); return (lhs) >= (rhs);
} }
@ -88,42 +88,42 @@ public:
// Generic list of functors // Generic list of functors
template<class lobj,class robj> class seq { template<class lobj,class robj> class seq {
public: public:
Integer operator()(const lobj &lhs, const robj &rhs) accelerator Integer operator()(const lobj &lhs, const robj &rhs)
{ {
return (lhs) == (rhs); return (lhs) == (rhs);
} }
}; };
template<class lobj,class robj> class sne { template<class lobj,class robj> class sne {
public: public:
Integer operator()(const lobj &lhs, const robj &rhs) accelerator Integer operator()(const lobj &lhs, const robj &rhs)
{ {
return (lhs) != (rhs); return (lhs) != (rhs);
} }
}; };
template<class lobj,class robj> class slt { template<class lobj,class robj> class slt {
public: public:
Integer operator()(const lobj &lhs, const robj &rhs) accelerator Integer operator()(const lobj &lhs, const robj &rhs)
{ {
return (lhs) < (rhs); return (lhs) < (rhs);
} }
}; };
template<class lobj,class robj> class sle { template<class lobj,class robj> class sle {
public: public:
Integer operator()(const lobj &lhs, const robj &rhs) accelerator Integer operator()(const lobj &lhs, const robj &rhs)
{ {
return (lhs) <= (rhs); return (lhs) <= (rhs);
} }
}; };
template<class lobj,class robj> class sgt { template<class lobj,class robj> class sgt {
public: public:
Integer operator()(const lobj &lhs, const robj &rhs) accelerator Integer operator()(const lobj &lhs, const robj &rhs)
{ {
return (lhs) > (rhs); return (lhs) > (rhs);
} }
}; };
template<class lobj,class robj> class sge { template<class lobj,class robj> class sge {
public: public:
Integer operator()(const lobj &lhs, const robj &rhs) accelerator Integer operator()(const lobj &lhs, const robj &rhs)
{ {
return (lhs) >= (rhs); return (lhs) >= (rhs);
} }
@ -133,7 +133,7 @@ public:
// Integer and real get extra relational functions. // Integer and real get extra relational functions.
////////////////////////////////////////////////////////////////////////////////////////////////////// //////////////////////////////////////////////////////////////////////////////////////////////////////
template<class sfunctor, class vsimd,IfNotComplex<vsimd> = 0> template<class sfunctor, class vsimd,IfNotComplex<vsimd> = 0>
inline vInteger Comparison(sfunctor sop,const vsimd & lhs, const vsimd & rhs) accelerator_inline vInteger Comparison(sfunctor sop,const vsimd & lhs, const vsimd & rhs)
{ {
typedef typename vsimd::scalar_type scalar; typedef typename vsimd::scalar_type scalar;
ExtractBuffer<scalar> vlhs(vsimd::Nsimd()); // Use functors to reduce this to single implementation ExtractBuffer<scalar> vlhs(vsimd::Nsimd()); // Use functors to reduce this to single implementation
@ -150,7 +150,7 @@ inline vInteger Comparison(sfunctor sop,const vsimd & lhs, const vsimd & rhs)
} }
template<class sfunctor, class vsimd,IfNotComplex<vsimd> = 0> template<class sfunctor, class vsimd,IfNotComplex<vsimd> = 0>
inline vInteger Comparison(sfunctor sop,const vsimd & lhs, const typename vsimd::scalar_type & rhs) accelerator_inline vInteger Comparison(sfunctor sop,const vsimd & lhs, const typename vsimd::scalar_type & rhs)
{ {
typedef typename vsimd::scalar_type scalar; typedef typename vsimd::scalar_type scalar;
ExtractBuffer<scalar> vlhs(vsimd::Nsimd()); // Use functors to reduce this to single implementation ExtractBuffer<scalar> vlhs(vsimd::Nsimd()); // Use functors to reduce this to single implementation
@ -165,7 +165,7 @@ inline vInteger Comparison(sfunctor sop,const vsimd & lhs, const typename vsimd:
} }
template<class sfunctor, class vsimd,IfNotComplex<vsimd> = 0> template<class sfunctor, class vsimd,IfNotComplex<vsimd> = 0>
inline vInteger Comparison(sfunctor sop,const typename vsimd::scalar_type & lhs, const vsimd & rhs) accelerator_inline vInteger Comparison(sfunctor sop,const typename vsimd::scalar_type & lhs, const vsimd & rhs)
{ {
typedef typename vsimd::scalar_type scalar; typedef typename vsimd::scalar_type scalar;
ExtractBuffer<scalar> vrhs(vsimd::Nsimd()); // Use functors to reduce this to single implementation ExtractBuffer<scalar> vrhs(vsimd::Nsimd()); // Use functors to reduce this to single implementation
@ -181,35 +181,35 @@ inline vInteger Comparison(sfunctor sop,const typename vsimd::scalar_type & lhs,
#define DECLARE_RELATIONAL(op,functor) \ #define DECLARE_RELATIONAL(op,functor) \
template<class vsimd,IfSimd<vsimd> = 0> \ template<class vsimd,IfSimd<vsimd> = 0> \
inline vInteger operator op (const vsimd & lhs, const vsimd & rhs) \ accelerator_inline vInteger operator op (const vsimd & lhs, const vsimd & rhs) \
{ \ { \
typedef typename vsimd::scalar_type scalar; \ typedef typename vsimd::scalar_type scalar; \
return Comparison(functor<scalar,scalar>(),lhs,rhs); \ return Comparison(functor<scalar,scalar>(),lhs,rhs); \
} \ } \
template<class vsimd,IfSimd<vsimd> = 0> \ template<class vsimd,IfSimd<vsimd> = 0> \
inline vInteger operator op (const vsimd & lhs, const typename vsimd::scalar_type & rhs) \ accelerator_inline vInteger operator op (const vsimd & lhs, const typename vsimd::scalar_type & rhs) \
{ \ { \
typedef typename vsimd::scalar_type scalar; \ typedef typename vsimd::scalar_type scalar; \
return Comparison(functor<scalar,scalar>(),lhs,rhs); \ return Comparison(functor<scalar,scalar>(),lhs,rhs); \
} \ } \
template<class vsimd,IfSimd<vsimd> = 0> \ template<class vsimd,IfSimd<vsimd> = 0> \
inline vInteger operator op (const typename vsimd::scalar_type & lhs, const vsimd & rhs) \ accelerator_inline vInteger operator op (const typename vsimd::scalar_type & lhs, const vsimd & rhs) \
{ \ { \
typedef typename vsimd::scalar_type scalar; \ typedef typename vsimd::scalar_type scalar; \
return Comparison(functor<scalar,scalar>(),lhs,rhs); \ return Comparison(functor<scalar,scalar>(),lhs,rhs); \
} \ } \
template<class vsimd> \ template<class vsimd> \
inline vInteger operator op(const iScalar<vsimd> &lhs,const iScalar<vsimd> &rhs) \ accelerator_inline vInteger operator op(const iScalar<vsimd> &lhs,const iScalar<vsimd> &rhs) \
{ \ { \
return lhs._internal op rhs._internal; \ return lhs._internal op rhs._internal; \
} \ } \
template<class vsimd> \ template<class vsimd> \
inline vInteger operator op(const iScalar<vsimd> &lhs,const typename vsimd::scalar_type &rhs) \ accelerator_inline vInteger operator op(const iScalar<vsimd> &lhs,const typename vsimd::scalar_type &rhs) \
{ \ { \
return lhs._internal op rhs; \ return lhs._internal op rhs; \
} \ } \
template<class vsimd> \ template<class vsimd> \
inline vInteger operator op(const typename vsimd::scalar_type &lhs,const iScalar<vsimd> &rhs) \ accelerator_inline vInteger operator op(const typename vsimd::scalar_type &lhs,const iScalar<vsimd> &rhs) \
{ \ { \
return lhs op rhs._internal; \ return lhs op rhs._internal; \
} }

View File

@ -42,20 +42,22 @@ template<class iobj> inline void LatticeCoordinate(Lattice<iobj> &l,int mu)
ExtractBuffer<scalar_type> mergebuf(Nsimd); ExtractBuffer<scalar_type> mergebuf(Nsimd);
vector_type vI; vector_type vI;
auto l_v = l.View();
for(int o=0;o<grid->oSites();o++){ for(int o=0;o<grid->oSites();o++){
for(int i=0;i<grid->iSites();i++){ for(int i=0;i<grid->iSites();i++){
grid->RankIndexToGlobalCoor(grid->ThisRank(),o,i,gcoor); grid->RankIndexToGlobalCoor(grid->ThisRank(),o,i,gcoor);
mergebuf[i]=(Integer)gcoor[mu]; mergebuf[i]=(Integer)gcoor[mu];
} }
merge<vector_type,scalar_type>(vI,mergebuf); merge<vector_type,scalar_type>(vI,mergebuf);
l[o]=vI; l_v[o]=vI;
} }
}; };
// LatticeCoordinate(); // LatticeCoordinate();
// FIXME for debug; deprecate this; made obscelete by // FIXME for debug; deprecate this; made obscelete by
template<class vobj> void lex_sites(Lattice<vobj> &l){ template<class vobj> void lex_sites(Lattice<vobj> &l){
Real *v_ptr = (Real *)&l[0]; auto l_v = l.View();
Real *v_ptr = (Real *)&l_v[0];
size_t o_len = l.Grid()->oSites(); size_t o_len = l.Grid()->oSites();
size_t v_len = sizeof(vobj)/sizeof(vRealF); size_t v_len = sizeof(vobj)/sizeof(vRealF);
size_t vec_len = vRealF::Nsimd(); size_t vec_len = vRealF::Nsimd();

View File

@ -43,8 +43,10 @@ template<class vobj>
inline auto localNorm2 (const Lattice<vobj> &rhs)-> Lattice<typename vobj::tensor_reduced> inline auto localNorm2 (const Lattice<vobj> &rhs)-> Lattice<typename vobj::tensor_reduced>
{ {
Lattice<typename vobj::tensor_reduced> ret(rhs.Grid()); Lattice<typename vobj::tensor_reduced> ret(rhs.Grid());
accelerator_loop(ss,rhs,{ auto rhs_v = rhs.View();
ret[ss]=innerProduct(rhs[ss],rhs[ss]); auto ret_v = ret.View();
accelerator_loop(ss,rhs_v,{
ret_v[ss]=innerProduct(rhs_v[ss],rhs_v[ss]);
}); });
return ret; return ret;
} }
@ -54,8 +56,11 @@ template<class vobj>
inline auto localInnerProduct (const Lattice<vobj> &lhs,const Lattice<vobj> &rhs) -> Lattice<typename vobj::tensor_reduced> inline auto localInnerProduct (const Lattice<vobj> &lhs,const Lattice<vobj> &rhs) -> Lattice<typename vobj::tensor_reduced>
{ {
Lattice<typename vobj::tensor_reduced> ret(rhs.Grid()); Lattice<typename vobj::tensor_reduced> ret(rhs.Grid());
accelerator_loop(ss,rhs,{ auto lhs_v = lhs.View();
ret[ss]=innerProduct(lhs[ss],rhs[ss]); auto rhs_v = rhs.View();
auto ret_v = ret.View();
accelerator_loop(ss,rhs_v,{
ret_v[ss]=innerProduct(lhs_v[ss],rhs_v[ss]);
}); });
return ret; return ret;
} }
@ -63,11 +68,14 @@ inline auto localInnerProduct (const Lattice<vobj> &lhs,const Lattice<vobj> &rhs
// outerProduct Scalar x Scalar -> Scalar // outerProduct Scalar x Scalar -> Scalar
// Vector x Vector -> Matrix // Vector x Vector -> Matrix
template<class ll,class rr> template<class ll,class rr>
inline auto outerProduct (const Lattice<ll> &lhs,const Lattice<rr> &rhs) -> Lattice<decltype(outerProduct(lhs[0],rhs[0]))> inline auto outerProduct (const Lattice<ll> &lhs,const Lattice<rr> &rhs) -> Lattice<decltype(outerProduct(ll(),rr()))>
{ {
Lattice<decltype(outerProduct(lhs[0],rhs[0]))> ret(rhs.Grid()); Lattice<decltype(outerProduct(ll(),rr()))> ret(rhs.Grid());
accelerator_loop(ss,rhs,{ auto lhs_v = lhs.View();
ret[ss]=outerProduct(lhs[ss],rhs[ss]); auto rhs_v = rhs.View();
auto ret_v = ret.View();
accelerator_loop(ss,rhs_v,{
ret_v[ss]=outerProduct(lhs_v[ss],rhs_v[ss]);
}); });
return ret; return ret;
} }

View File

@ -51,6 +51,9 @@ static void sliceMaddMatrix (Lattice<vobj> &R,Eigen::MatrixXcd &aa,const Lattice
int block =FullGrid->_slice_block [Orthog]; int block =FullGrid->_slice_block [Orthog];
int nblock=FullGrid->_slice_nblock[Orthog]; int nblock=FullGrid->_slice_nblock[Orthog];
int ostride=FullGrid->_ostride[Orthog]; int ostride=FullGrid->_ostride[Orthog];
auto X_v = X.View();
auto Y_v = Y.View();
auto R_v = R.View();
thread_region thread_region
{ {
std::vector<vobj> s_x(Nblock); std::vector<vobj> s_x(Nblock);
@ -60,16 +63,16 @@ static void sliceMaddMatrix (Lattice<vobj> &R,Eigen::MatrixXcd &aa,const Lattice
int o = n*stride + b; int o = n*stride + b;
for(int i=0;i<Nblock;i++){ for(int i=0;i<Nblock;i++){
s_x[i] = X[o+i*ostride]; s_x[i] = X_v[o+i*ostride];
} }
vobj dot; vobj dot;
for(int i=0;i<Nblock;i++){ for(int i=0;i<Nblock;i++){
dot = Y[o+i*ostride]; dot = Y_v[o+i*ostride];
for(int j=0;j<Nblock;j++){ for(int j=0;j<Nblock;j++){
dot = dot + s_x[j]*(scale*aa(j,i)); dot = dot + s_x[j]*(scale*aa(j,i));
} }
R[o+i*ostride]=dot; R_v[o+i*ostride]=dot;
} }
}}); }});
} }
@ -85,14 +88,7 @@ static void sliceMulMatrix (Lattice<vobj> &R,Eigen::MatrixXcd &aa,const Lattice<
int Nblock = X.Grid()->GlobalDimensions()[Orthog]; int Nblock = X.Grid()->GlobalDimensions()[Orthog];
GridBase *FullGrid = X.Grid(); GridBase *FullGrid = X.Grid();
// GridBase *SliceGrid = makeSubSliceGrid(FullGrid,Orthog);
// Lattice<vobj> Xslice(SliceGrid);
// Lattice<vobj> Rslice(SliceGrid);
assert( FullGrid->_simd_layout[Orthog]==1); assert( FullGrid->_simd_layout[Orthog]==1);
// int nh = FullGrid->_ndimension;
// int nl = SliceGrid->_ndimension;
// int nl=1;
//FIXME package in a convenient iterator //FIXME package in a convenient iterator
//Should loop over a plane orthogonal to direction "Orthog" //Should loop over a plane orthogonal to direction "Orthog"
@ -100,16 +96,20 @@ static void sliceMulMatrix (Lattice<vobj> &R,Eigen::MatrixXcd &aa,const Lattice<
int block =FullGrid->_slice_block [Orthog]; int block =FullGrid->_slice_block [Orthog];
int nblock=FullGrid->_slice_nblock[Orthog]; int nblock=FullGrid->_slice_nblock[Orthog];
int ostride=FullGrid->_ostride[Orthog]; int ostride=FullGrid->_ostride[Orthog];
auto X_v = X.View();
auto R_v = R.View();
thread_region thread_region
{ {
std::vector<vobj> s_x(Nblock); std::vector<vobj> s_x(Nblock);
thread_loop_collapse2( (int n=0;n<nblock;n++),{ thread_loop_collapse2( (int n=0;n<nblock;n++),{
for(int b=0;b<block;b++){ for(int b=0;b<block;b++){
int o = n*stride + b; int o = n*stride + b;
for(int i=0;i<Nblock;i++){ for(int i=0;i<Nblock;i++){
s_x[i] = X[o+i*ostride]; s_x[i] = X_v[o+i*ostride];
} }
vobj dot; vobj dot;
@ -118,7 +118,7 @@ static void sliceMulMatrix (Lattice<vobj> &R,Eigen::MatrixXcd &aa,const Lattice<
for(int j=1;j<Nblock;j++){ for(int j=1;j<Nblock;j++){
dot = dot + s_x[j]*(scale*aa(j,i)); dot = dot + s_x[j]*(scale*aa(j,i));
} }
R[o+i*ostride]=dot; R_v[o+i*ostride]=dot;
} }
}}); }});
} }
@ -156,7 +156,8 @@ static void sliceInnerProductMatrix( Eigen::MatrixXcd &mat, const Lattice<vobj>
int ostride=FullGrid->_ostride[Orthog]; int ostride=FullGrid->_ostride[Orthog];
typedef typename vobj::vector_typeD vector_typeD; typedef typename vobj::vector_typeD vector_typeD;
auto lhs_v = lhs.View();
auto rhs_v = rhs.View();
thread_region { thread_region {
std::vector<vobj> Left(Nblock); std::vector<vobj> Left(Nblock);
std::vector<vobj> Right(Nblock); std::vector<vobj> Right(Nblock);
@ -168,8 +169,8 @@ static void sliceInnerProductMatrix( Eigen::MatrixXcd &mat, const Lattice<vobj>
int o = n*stride + b; int o = n*stride + b;
for(int i=0;i<Nblock;i++){ for(int i=0;i<Nblock;i++){
Left [i] = lhs[o+i*ostride]; Left [i] = lhs_v[o+i*ostride];
Right[i] = rhs[o+i*ostride]; Right[i] = rhs_v[o+i*ostride];
} }
for(int i=0;i<Nblock;i++){ for(int i=0;i<Nblock;i++){

View File

@ -42,22 +42,26 @@ NAMESPACE_BEGIN(Grid);
// Peek internal indices of a Lattice object // Peek internal indices of a Lattice object
//////////////////////////////////////////////////////////////////////////////////////////////////// ////////////////////////////////////////////////////////////////////////////////////////////////////
template<int Index,class vobj> template<int Index,class vobj>
auto PeekIndex(const Lattice<vobj> &lhs,int i) -> Lattice<decltype(peekIndex<Index>(lhs[0],i))> auto PeekIndex(const Lattice<vobj> &lhs,int i) -> Lattice<decltype(peekIndex<Index>(vobj(),i))>
{ {
Lattice<decltype(peekIndex<Index>(lhs[0],i))> ret(lhs.Grid()); Lattice<decltype(peekIndex<Index>(vobj(),i))> ret(lhs.Grid());
ret.Checkerboard()=lhs.Checkerboard(); ret.Checkerboard()=lhs.Checkerboard();
cpu_loop( ss, lhs, { auto ret_v = ret.View();
ret[ss] = peekIndex<Index>(lhs[ss],i); auto lhs_v = lhs.View();
cpu_loop( ss, lhs_v, {
ret_v[ss] = peekIndex<Index>(lhs_v[ss],i);
}); });
return ret; return ret;
}; };
template<int Index,class vobj> template<int Index,class vobj>
auto PeekIndex(const Lattice<vobj> &lhs,int i,int j) -> Lattice<decltype(peekIndex<Index>(lhs[0],i,j))> auto PeekIndex(const Lattice<vobj> &lhs,int i,int j) -> Lattice<decltype(peekIndex<Index>(vobj(),i,j))>
{ {
Lattice<decltype(peekIndex<Index>(lhs[0],i,j))> ret(lhs.Grid()); Lattice<decltype(peekIndex<Index>(vobj(),i,j))> ret(lhs.Grid());
ret.Checkerboard()=lhs.Checkerboard(); ret.Checkerboard()=lhs.Checkerboard();
cpu_loop( ss, lhs, { auto ret_v = ret.View();
ret[ss] = peekIndex<Index>(lhs[ss],i,j); auto lhs_v = lhs.View();
cpu_loop( ss, lhs_v, {
ret_v[ss] = peekIndex<Index>(lhs_v[ss],i,j);
}); });
return ret; return ret;
}; };
@ -66,17 +70,21 @@ auto PeekIndex(const Lattice<vobj> &lhs,int i,int j) -> Lattice<decltype(peekInd
// Poke internal indices of a Lattice object // Poke internal indices of a Lattice object
//////////////////////////////////////////////////////////////////////////////////////////////////// ////////////////////////////////////////////////////////////////////////////////////////////////////
template<int Index,class vobj> template<int Index,class vobj>
void PokeIndex(Lattice<vobj> &lhs,const Lattice<decltype(peekIndex<Index>(lhs[0],0))> & rhs,int i) void PokeIndex(Lattice<vobj> &lhs,const Lattice<decltype(peekIndex<Index>(vobj(),0))> & rhs,int i)
{ {
cpu_loop( ss, lhs, { auto rhs_v = rhs.View();
pokeIndex<Index>(lhs[ss],rhs[ss],i); auto lhs_v = lhs.View();
cpu_loop( ss, lhs_v, {
pokeIndex<Index>(lhs_v[ss],rhs_v[ss],i);
}); });
} }
template<int Index,class vobj> template<int Index,class vobj>
void PokeIndex(Lattice<vobj> &lhs,const Lattice<decltype(peekIndex<Index>(lhs[0],0,0))> & rhs,int i,int j) void PokeIndex(Lattice<vobj> &lhs,const Lattice<decltype(peekIndex<Index>(vobj(),0,0))> & rhs,int i,int j)
{ {
cpu_loop( ss, lhs, { auto rhs_v = rhs.View();
pokeIndex<Index>(lhs[ss],rhs[ss],i,j); auto lhs_v = lhs.View();
cpu_loop( ss, lhs_v, {
pokeIndex<Index>(lhs_v[ss],rhs_v[ss],i,j);
}); });
} }
@ -103,10 +111,11 @@ void pokeSite(const sobj &s,Lattice<vobj> &l,const Coordinate &site){
// extract-modify-merge cycle is easiest way and this is not perf critical // extract-modify-merge cycle is easiest way and this is not perf critical
ExtractBuffer<sobj> buf(Nsimd); ExtractBuffer<sobj> buf(Nsimd);
auto l_v = l.View();
if ( rank == grid->ThisRank() ) { if ( rank == grid->ThisRank() ) {
extract(l[odx],buf); extract(l_v[odx],buf);
buf[idx] = s; buf[idx] = s;
merge(l[odx],buf); merge(l_v[odx],buf);
} }
return; return;
@ -132,7 +141,8 @@ void peekSite(sobj &s,const Lattice<vobj> &l,const Coordinate &site){
grid->GlobalCoorToRankIndex(rank,odx,idx,site); grid->GlobalCoorToRankIndex(rank,odx,idx,site);
ExtractBuffer<sobj> buf(Nsimd); ExtractBuffer<sobj> buf(Nsimd);
extract(l[odx],buf); auto l_v = l.View();
extract(l_v[odx],buf);
s = buf[idx]; s = buf[idx];
@ -162,8 +172,9 @@ void peekLocalSite(sobj &s,const Lattice<vobj> &l,Coordinate &site){
int odx,idx; int odx,idx;
idx= grid->iIndex(site); idx= grid->iIndex(site);
odx= grid->oIndex(site); odx= grid->oIndex(site);
scalar_type * vp = (scalar_type *)&l[odx]; auto l_v = l.View();
scalar_type * vp = (scalar_type *)&l_v[odx];
scalar_type * pt = (scalar_type *)&s; scalar_type * pt = (scalar_type *)&s;
for(int w=0;w<words;w++){ for(int w=0;w<words;w++){
@ -191,9 +202,9 @@ void pokeLocalSite(const sobj &s,Lattice<vobj> &l,Coordinate &site){
idx= grid->iIndex(site); idx= grid->iIndex(site);
odx= grid->oIndex(site); odx= grid->oIndex(site);
scalar_type * vp = (scalar_type *)&l[odx]; auto l_v = l.View();
scalar_type * vp = (scalar_type *)&l_v[odx];
scalar_type * pt = (scalar_type *)&s; scalar_type * pt = (scalar_type *)&s;
for(int w=0;w<words;w++){ for(int w=0;w<words;w++){
vp[idx+w*Nsimd] = pt[w]; vp[idx+w*Nsimd] = pt[w];
} }

View File

@ -40,16 +40,20 @@ NAMESPACE_BEGIN(Grid);
template<class vobj> inline Lattice<vobj> adj(const Lattice<vobj> &lhs){ template<class vobj> inline Lattice<vobj> adj(const Lattice<vobj> &lhs){
Lattice<vobj> ret(lhs.Grid()); Lattice<vobj> ret(lhs.Grid());
accelerator_loop( ss, lhs, { auto lhs_v = lhs.View();
ret[ss] = adj(lhs[ss]); auto ret_v = ret.View();
accelerator_loop( ss, lhs_v, {
ret_v[ss] = adj(lhs_v[ss]);
}); });
return ret; return ret;
}; };
template<class vobj> inline Lattice<vobj> conjugate(const Lattice<vobj> &lhs){ template<class vobj> inline Lattice<vobj> conjugate(const Lattice<vobj> &lhs){
Lattice<vobj> ret(lhs.Grid()); Lattice<vobj> ret(lhs.Grid());
accelerator_loop( ss, lhs, { auto lhs_v = lhs.View();
ret[ss] = conjugate(lhs[ss]); auto ret_v = ret.View();
accelerator_loop( ss, lhs_v, {
ret_v[ss] = conjugate(lhs_v[ss]);
}); });
return ret; return ret;
}; };

View File

@ -47,14 +47,17 @@ inline ComplexD innerProduct(const Lattice<vobj> &left,const Lattice<vobj> &righ
GridBase *grid = left.Grid(); GridBase *grid = left.Grid();
std::vector<vector_type,alignedAllocator<vector_type> > sumarray(grid->SumArraySize()); std::vector<vector_type,alignedAllocator<vector_type> > sumarray(grid->SumArraySize());
auto left_v = left.View();
auto right_v=right.View();
thread_loop( (int thr=0;thr<grid->SumArraySize();thr++),{ thread_loop( (int thr=0;thr<grid->SumArraySize();thr++),{
int mywork, myoff; int mywork, myoff;
GridThread::GetWork(left.Grid()->oSites(),thr,mywork,myoff); GridThread::GetWork(left.Grid()->oSites(),thr,mywork,myoff);
decltype(innerProductD(left[0],right[0])) vnrm=Zero(); // private to thread; sub summation decltype(innerProductD(left_v[0],right_v[0])) vnrm=Zero(); // private to thread; sub summation
for(int ss=myoff;ss<mywork+myoff; ss++){ for(int ss=myoff;ss<mywork+myoff; ss++){
vnrm = vnrm + innerProductD(left[ss],right[ss]); vnrm = vnrm + innerProductD(left_v[ss],right_v[ss]);
} }
sumarray[thr]=TensorRemove(vnrm) ; sumarray[thr]=TensorRemove(vnrm) ;
}); });
@ -70,14 +73,14 @@ inline ComplexD innerProduct(const Lattice<vobj> &left,const Lattice<vobj> &righ
template<class Op,class T1> template<class Op,class T1>
inline auto sum(const LatticeUnaryExpression<Op,T1> & expr) inline auto sum(const LatticeUnaryExpression<Op,T1> & expr)
->typename decltype(expr.first.func(eval(0,std::get<0>(expr.second))))::scalar_object ->typename decltype(expr.op.func(eval(0,expr.arg1)))::scalar_object
{ {
return sum(closure(expr)); return sum(closure(expr));
} }
template<class Op,class T1,class T2> template<class Op,class T1,class T2>
inline auto sum(const LatticeBinaryExpression<Op,T1,T2> & expr) inline auto sum(const LatticeBinaryExpression<Op,T1,T2> & expr)
->typename decltype(expr.first.func(eval(0,std::get<0>(expr.second)),eval(0,std::get<1>(expr.second))))::scalar_object ->typename decltype(expr.op.func(eval(0,expr.arg1,eval(0,expr.arg2))))::scalar_object
{ {
return sum(closure(expr)); return sum(closure(expr));
} }
@ -85,10 +88,10 @@ inline auto sum(const LatticeBinaryExpression<Op,T1,T2> & expr)
template<class Op,class T1,class T2,class T3> template<class Op,class T1,class T2,class T3>
inline auto sum(const LatticeTrinaryExpression<Op,T1,T2,T3> & expr) inline auto sum(const LatticeTrinaryExpression<Op,T1,T2,T3> & expr)
->typename decltype(expr.first.func(eval(0,std::get<0>(expr.second)), ->typename decltype(expr.op.func(eval(0,expr.arg1),
eval(0,std::get<1>(expr.second)), eval(0,expr.arg2),
eval(0,std::get<2>(expr.second)) eval(0,expr.arg3)
))::scalar_object ))::scalar_object
{ {
return sum(closure(expr)); return sum(closure(expr));
} }
@ -103,14 +106,14 @@ inline typename vobj::scalar_object sum(const Lattice<vobj> &arg)
for(int i=0;i<grid->SumArraySize();i++){ for(int i=0;i<grid->SumArraySize();i++){
sumarray[i]=Zero(); sumarray[i]=Zero();
} }
auto arg_v = arg.View();
thread_loop( (int thr=0;thr<grid->SumArraySize();thr++),{ thread_loop( (int thr=0;thr<grid->SumArraySize();thr++),{
int mywork, myoff; int mywork, myoff;
GridThread::GetWork(grid->oSites(),thr,mywork,myoff); GridThread::GetWork(grid->oSites(),thr,mywork,myoff);
vobj vvsum=Zero(); vobj vvsum=Zero();
for(int ss=myoff;ss<mywork+myoff; ss++){ for(int ss=myoff;ss<mywork+myoff; ss++){
vvsum = vvsum + arg[ss]; vvsum = vvsum + arg_v[ss];
} }
sumarray[thr]=vvsum; sumarray[thr]=vvsum;
}); });
@ -172,6 +175,7 @@ template<class vobj> inline void sliceSum(const Lattice<vobj> &Data,std::vector<
int stride=grid->_slice_stride[orthogdim]; int stride=grid->_slice_stride[orthogdim];
// sum over reduced dimension planes, breaking out orthog dir // sum over reduced dimension planes, breaking out orthog dir
auto Data_v = Data.View();
thread_loop( (int r=0;r<rd;r++),{ thread_loop( (int r=0;r<rd;r++),{
int so=r*grid->_ostride[orthogdim]; // base offset for start of plane int so=r*grid->_ostride[orthogdim]; // base offset for start of plane
@ -179,7 +183,7 @@ template<class vobj> inline void sliceSum(const Lattice<vobj> &Data,std::vector<
for(int n=0;n<e1;n++){ for(int n=0;n<e1;n++){
for(int b=0;b<e2;b++){ for(int b=0;b<e2;b++){
int ss= so+n*stride+b; int ss= so+n*stride+b;
lvSum[r]=lvSum[r]+Data[ss]; lvSum[r]=lvSum[r]+Data_v[ss];
} }
} }
}); });
@ -251,6 +255,8 @@ static void sliceInnerProductVector( std::vector<ComplexD> & result, const Latti
int e2= grid->_slice_block [orthogdim]; int e2= grid->_slice_block [orthogdim];
int stride=grid->_slice_stride[orthogdim]; int stride=grid->_slice_stride[orthogdim];
auto lhs_v = lhs.View();
auto rhs_v = rhs.View();
thread_loop( (int r=0;r<rd;r++),{ thread_loop( (int r=0;r<rd;r++),{
int so=r*grid->_ostride[orthogdim]; // base offset for start of plane int so=r*grid->_ostride[orthogdim]; // base offset for start of plane
@ -258,7 +264,7 @@ static void sliceInnerProductVector( std::vector<ComplexD> & result, const Latti
for(int n=0;n<e1;n++){ for(int n=0;n<e1;n++){
for(int b=0;b<e2;b++){ for(int b=0;b<e2;b++){
int ss= so+n*stride+b; int ss= so+n*stride+b;
vector_type vv = TensorRemove(innerProduct(lhs[ss],rhs[ss])); vector_type vv = TensorRemove(innerProduct(lhs_v[ss],rhs_v[ss]));
lvSum[r]=lvSum[r]+vv; lvSum[r]=lvSum[r]+vv;
} }
} }
@ -358,10 +364,13 @@ static void sliceMaddVector(Lattice<vobj> &R,std::vector<RealD> &a,const Lattice
tensor_reduced at; at=av; tensor_reduced at; at=av;
auto X_v = X.View();
auto Y_v = Y.View();
auto R_v = R.View();
thread_loop_collapse2( (int n=0;n<e1;n++),{ thread_loop_collapse2( (int n=0;n<e1;n++),{
for(int b=0;b<e2;b++){ for(int b=0;b<e2;b++){
int ss= so+n*stride+b; int ss= so+n*stride+b;
R[ss] = at*X[ss]+Y[ss]; R_v[ss] = at*X_v[ss]+Y_v[ss];
} }
}); });
} }

View File

@ -346,7 +346,9 @@ public:
int osites = _grid->oSites(); // guaranteed to be <= l.Grid()->oSites() by a factor multiplicity int osites = _grid->oSites(); // guaranteed to be <= l.Grid()->oSites() by a factor multiplicity
int words = sizeof(scalar_object) / sizeof(scalar_type); int words = sizeof(scalar_object) / sizeof(scalar_type);
thread_loop( (int ss=0;ss<osites;ss++), { auto l_v = l.View();
// thread_loop( (int ss=0;ss<osites;ss++), {
for (int ss=0;ss<osites;ss++) {
ExtractBuffer<scalar_object> buf(Nsimd); ExtractBuffer<scalar_object> buf(Nsimd);
for (int m = 0; m < multiplicity; m++) { // Draw from same generator multiplicity times for (int m = 0; m < multiplicity; m++) { // Draw from same generator multiplicity times
@ -361,9 +363,10 @@ public:
fillScalar(pointer[idx], dist[gdx], _generators[gdx]); fillScalar(pointer[idx], dist[gdx], _generators[gdx]);
} }
// merge into SIMD lanes, FIXME suboptimal implementation // merge into SIMD lanes, FIXME suboptimal implementation
merge(l[sm], buf); merge(l_v[sm], buf);
} }
}); }
// });
_time_counter += usecond()- inner_time_counter; _time_counter += usecond()- inner_time_counter;
} }

View File

@ -38,12 +38,13 @@ NAMESPACE_BEGIN(Grid);
// Trace // Trace
//////////////////////////////////////////////////////////////////////////////////////////////////// ////////////////////////////////////////////////////////////////////////////////////////////////////
template<class vobj> template<class vobj>
inline auto trace(const Lattice<vobj> &lhs) inline auto trace(const Lattice<vobj> &lhs) -> Lattice<decltype(trace(vobj()))>
-> Lattice<decltype(trace(lhs[0]))>
{ {
Lattice<decltype(trace(lhs[0]))> ret(lhs.Grid()); Lattice<decltype(trace(vobj()))> ret(lhs.Grid());
accelerator_loop( ss, lhs, { auto ret_v = ret.View();
ret[ss] = trace(lhs[ss]); auto lhs_v = lhs.View();
accelerator_loop( ss, lhs_v, {
ret_v[ss] = trace(lhs_v[ss]);
}); });
return ret; return ret;
}; };
@ -52,11 +53,13 @@ inline auto trace(const Lattice<vobj> &lhs)
// Trace Index level dependent operation // Trace Index level dependent operation
//////////////////////////////////////////////////////////////////////////////////////////////////// ////////////////////////////////////////////////////////////////////////////////////////////////////
template<int Index,class vobj> template<int Index,class vobj>
inline auto TraceIndex(const Lattice<vobj> &lhs) -> Lattice<decltype(traceIndex<Index>(lhs[0]))> inline auto TraceIndex(const Lattice<vobj> &lhs) -> Lattice<decltype(traceIndex<Index>(vobj()))>
{ {
Lattice<decltype(traceIndex<Index>(lhs[0]))> ret(lhs.Grid()); Lattice<decltype(traceIndex<Index>(vobj()))> ret(lhs.Grid());
accelerator_loop( ss, lhs, { auto ret_v = ret.View();
ret[ss] = traceIndex<Index>(lhs[ss]); auto lhs_v = lhs.View();
accelerator_loop( ss, lhs_v, {
ret_v[ss] = traceIndex<Index>(lhs_v[ss]);
}); });
return ret; return ret;
}; };

View File

@ -51,20 +51,24 @@ inline void subdivides(GridBase *coarse,GridBase *fine)
template<class vobj> inline void pickCheckerboard(int cb,Lattice<vobj> &half,const Lattice<vobj> &full){ template<class vobj> inline void pickCheckerboard(int cb,Lattice<vobj> &half,const Lattice<vobj> &full){
half.Checkerboard() = cb; half.Checkerboard() = cb;
auto half_v = half.View();
auto full_v = full.View();
thread_loop( (int ss=0;ss<full.Grid()->oSites();ss++),{ thread_loop( (int ss=0;ss<full.Grid()->oSites();ss++),{
int cbos; int cbos;
Coordinate coor; Coordinate coor;
full.Grid()->oCoorFromOindex(coor,ss); full.Grid()->oCoorFromOindex(coor,ss);
cbos=half.Grid()->CheckerBoard(coor); cbos=half.Grid()->CheckerBoard(coor);
if (cbos==cb) { if (cbos==cb) {
int ssh=half.Grid()->oIndex(coor); int ssh=half.Grid()->oIndex(coor);
half[ssh] = full[ss]; half_v[ssh] = full_v[ss];
} }
}); });
} }
template<class vobj> inline void setCheckerboard(Lattice<vobj> &full,const Lattice<vobj> &half){ template<class vobj> inline void setCheckerboard(Lattice<vobj> &full,const Lattice<vobj> &half){
int cb = half.Checkerboard(); int cb = half.Checkerboard();
auto half_v = half.View();
auto full_v = full.View();
thread_loop( (int ss=0;ss<full.Grid()->oSites();ss++), { thread_loop( (int ss=0;ss<full.Grid()->oSites();ss++), {
Coordinate coor; Coordinate coor;
int cbos; int cbos;
@ -74,7 +78,7 @@ template<class vobj> inline void setCheckerboard(Lattice<vobj> &full,const Latti
if (cbos==cb) { if (cbos==cb) {
int ssh=half.Grid()->oIndex(coor); int ssh=half.Grid()->oIndex(coor);
full[ss]=half[ssh]; full_v[ss]=half_v[ssh];
} }
}); });
} }
@ -105,6 +109,8 @@ inline void blockProject(Lattice<iVector<CComplex,nbasis > > &coarseData,
coarseData=Zero(); coarseData=Zero();
auto fineData_ = fineData.View();
auto coarseData_ = coarseData.View();
// Loop over coars parallel, and then loop over fine associated with coarse. // Loop over coars parallel, and then loop over fine associated with coarse.
thread_loop( (int sf=0;sf<fine->oSites();sf++),{ thread_loop( (int sf=0;sf<fine->oSites();sf++),{
@ -117,9 +123,8 @@ inline void blockProject(Lattice<iVector<CComplex,nbasis > > &coarseData,
thread_critical { thread_critical {
for(int i=0;i<nbasis;i++) { for(int i=0;i<nbasis;i++) {
coarseData[sc](i)=coarseData[sc](i) auto Basis_ = Basis[i].View();
+ innerProduct(Basis[i][sf],fineData[sf]); coarseData_[sc](i)=coarseData_[sc](i) + innerProduct(Basis_[sf],fineData_[sf]);
} }
} }
}); });
@ -151,6 +156,11 @@ inline void blockZAXPY(Lattice<vobj> &fineZ,
assert(block_r[d]*coarse->_rdimensions[d]==fine->_rdimensions[d]); assert(block_r[d]*coarse->_rdimensions[d]==fine->_rdimensions[d]);
} }
auto fineZ_ = fineZ.View();
auto fineX_ = fineX.View();
auto fineY_ = fineY.View();
auto coarseA_= coarseA.View();
thread_loop( (int sf=0;sf<fine->oSites();sf++),{ thread_loop( (int sf=0;sf<fine->oSites();sf++),{
int sc; int sc;
@ -162,7 +172,7 @@ inline void blockZAXPY(Lattice<vobj> &fineZ,
Lexicographic::IndexFromCoor(coor_c,sc,coarse->_rdimensions); Lexicographic::IndexFromCoor(coor_c,sc,coarse->_rdimensions);
// z = A x + y // z = A x + y
fineZ[sf]=coarseA[sc]*fineX[sf]+fineY[sf]; fineZ_[sf]=coarseA_[sc]*fineX_[sf]+fineY_[sf];
}); });
@ -173,7 +183,7 @@ inline void blockInnerProduct(Lattice<CComplex> &CoarseInner,
const Lattice<vobj> &fineX, const Lattice<vobj> &fineX,
const Lattice<vobj> &fineY) const Lattice<vobj> &fineY)
{ {
typedef decltype(innerProduct(fineX[0],fineY[0])) dotp; typedef decltype(innerProduct(vobj(),vobj())) dotp;
GridBase *coarse(CoarseInner.Grid()); GridBase *coarse(CoarseInner.Grid());
GridBase *fine (fineX.Grid()); GridBase *fine (fineX.Grid());
@ -182,10 +192,13 @@ inline void blockInnerProduct(Lattice<CComplex> &CoarseInner,
Lattice<dotp> coarse_inner(coarse); Lattice<dotp> coarse_inner(coarse);
// Precision promotion? // Precision promotion?
auto CoarseInner_ = CoarseInner.View();
auto coarse_inner_ = coarse_inner.View();
fine_inner = localInnerProduct(fineX,fineY); fine_inner = localInnerProduct(fineX,fineY);
blockSum(coarse_inner,fine_inner); blockSum(coarse_inner,fine_inner);
thread_loop( (int ss=0;ss<coarse->oSites();ss++),{ thread_loop( (int ss=0;ss<coarse->oSites();ss++),{
CoarseInner[ss] = coarse_inner[ss]; CoarseInner_[ss] = coarse_inner_[ss];
}); });
} }
template<class vobj,class CComplex> template<class vobj,class CComplex>
@ -218,24 +231,23 @@ inline void blockSum(Lattice<vobj> &coarseData,const Lattice<vobj> &fineData)
// Turn this around to loop threaded over sc and interior loop // Turn this around to loop threaded over sc and interior loop
// over sf would thread better // over sf would thread better
coarseData=Zero(); coarseData=Zero();
thread_region { auto coarseData_ = coarseData.View();
auto fineData_ = fineData.View();
thread_loop( (int sf=0;sf<fine->oSites();sf++),{
int sc; int sc;
Coordinate coor_c(_ndimension); Coordinate coor_c(_ndimension);
Coordinate coor_f(_ndimension); Coordinate coor_f(_ndimension);
thread_loop_in_region( (int sf=0;sf<fine->oSites();sf++),{
Lexicographic::CoorFromIndex(coor_f,sf,fine->_rdimensions); Lexicographic::CoorFromIndex(coor_f,sf,fine->_rdimensions);
for(int d=0;d<_ndimension;d++) coor_c[d]=coor_f[d]/block_r[d]; for(int d=0;d<_ndimension;d++) coor_c[d]=coor_f[d]/block_r[d];
Lexicographic::IndexFromCoor(coor_c,sc,coarse->_rdimensions); Lexicographic::IndexFromCoor(coor_c,sc,coarse->_rdimensions);
thread_critical { thread_critical {
coarseData[sc]=coarseData[sc]+fineData[sf]; coarseData_[sc]=coarseData_[sc]+fineData_[sf];
} }
}); });
}
return; return;
} }
@ -306,25 +318,25 @@ inline void blockPromote(const Lattice<iVector<CComplex,nbasis > > &coarseData,
for(int d=0 ; d<_ndimension;d++){ for(int d=0 ; d<_ndimension;d++){
block_r[d] = fine->_rdimensions[d] / coarse->_rdimensions[d]; block_r[d] = fine->_rdimensions[d] / coarse->_rdimensions[d];
} }
auto fineData_ = fineData.View();
auto coarseData_ = coarseData.View();
// Loop with a cache friendly loop ordering // Loop with a cache friendly loop ordering
thread_region { thread_loop( (int sf=0;sf<fine->oSites();sf++),{
int sc; int sc;
Coordinate coor_c(_ndimension); Coordinate coor_c(_ndimension);
Coordinate coor_f(_ndimension); Coordinate coor_f(_ndimension);
thread_loop_in_region( (int sf=0;sf<fine->oSites();sf++),{ Lexicographic::CoorFromIndex(coor_f,sf,fine->_rdimensions);
for(int d=0;d<_ndimension;d++) coor_c[d]=coor_f[d]/block_r[d];
Lexicographic::IndexFromCoor(coor_c,sc,coarse->_rdimensions);
Lexicographic::CoorFromIndex(coor_f,sf,fine->_rdimensions); for(int i=0;i<nbasis;i++) {
for(int d=0;d<_ndimension;d++) coor_c[d]=coor_f[d]/block_r[d]; auto basis_ = Basis[i].View();
Lexicographic::IndexFromCoor(coor_c,sc,coarse->_rdimensions); if(i==0) fineData_[sf]=coarseData_[sc](i) *basis_[sf];
else fineData_[sf]=fineData_[sf]+coarseData_[sc](i)*basis_[sf];
for(int i=0;i<nbasis;i++) { }
if(i==0) fineData[sf]=coarseData[sc](i) * Basis[i][sf]; });
else fineData[sf]=fineData[sf]+coarseData[sc](i)*Basis[i][sf];
}
});
}
return; return;
} }
@ -577,6 +589,7 @@ unvectorizeToLexOrdArray(std::vector<sobj> &out, const Lattice<vobj> &in)
} }
//loop over outer index //loop over outer index
auto in_v = in.View();
thread_loop( (int in_oidx = 0; in_oidx < in_grid->oSites(); in_oidx++),{ thread_loop( (int in_oidx = 0; in_oidx < in_grid->oSites(); in_oidx++),{
//Assemble vector of pointers to output elements //Assemble vector of pointers to output elements
ExtractPointerArray<sobj> out_ptrs(in_nsimd); ExtractPointerArray<sobj> out_ptrs(in_nsimd);
@ -587,16 +600,19 @@ unvectorizeToLexOrdArray(std::vector<sobj> &out, const Lattice<vobj> &in)
Coordinate lcoor(in_grid->Nd()); Coordinate lcoor(in_grid->Nd());
for(int lane=0; lane < in_nsimd; lane++){ for(int lane=0; lane < in_nsimd; lane++){
for(int mu=0;mu<ndim;mu++)
for(int mu=0;mu<ndim;mu++){
lcoor[mu] = in_ocoor[mu] + in_grid->_rdimensions[mu]*in_icoor[lane][mu]; lcoor[mu] = in_ocoor[mu] + in_grid->_rdimensions[mu]*in_icoor[lane][mu];
}
int lex; int lex;
Lexicographic::IndexFromCoor(lcoor, lex, in_grid->_ldimensions); Lexicographic::IndexFromCoor(lcoor, lex, in_grid->_ldimensions);
assert(lex < out.size());
out_ptrs[lane] = &out[lex]; out_ptrs[lane] = &out[lex];
} }
//Unpack into those ptrs //Unpack into those ptrs
const vobj & in_vobj = in[in_oidx]; const vobj & in_vobj = in_v[in_oidx];
extract(in_vobj, out_ptrs, 0); extract(in_vobj, out_ptrs, 0);
}); });
} }
@ -621,7 +637,7 @@ vectorizeFromLexOrdArray( std::vector<sobj> &in, Lattice<vobj> &out)
icoor[lane].resize(ndim); icoor[lane].resize(ndim);
grid->iCoorFromIindex(icoor[lane],lane); grid->iCoorFromIindex(icoor[lane],lane);
} }
auto out_v = out.View();
thread_loop( (uint64_t oidx = 0; oidx < grid->oSites(); oidx++),{ thread_loop( (uint64_t oidx = 0; oidx < grid->oSites(); oidx++),{
//Assemble vector of pointers to output elements //Assemble vector of pointers to output elements
ExtractPointerArray<sobj> ptrs(nsimd); ExtractPointerArray<sobj> ptrs(nsimd);
@ -644,7 +660,7 @@ vectorizeFromLexOrdArray( std::vector<sobj> &in, Lattice<vobj> &out)
//pack from those ptrs //pack from those ptrs
vobj vecobj; vobj vecobj;
merge(vecobj, ptrs, 0); merge(vecobj, ptrs, 0);
out[oidx] = vecobj; out_v[oidx] = vecobj;
}); });
} }
@ -673,6 +689,7 @@ void precisionChange(Lattice<VobjOut> &out, const Lattice<VobjIn> &in){
std::vector<SobjOut> in_slex_conv(in_grid->lSites()); std::vector<SobjOut> in_slex_conv(in_grid->lSites());
unvectorizeToLexOrdArray(in_slex_conv, in); unvectorizeToLexOrdArray(in_slex_conv, in);
auto out_v = out.View();
thread_loop( (uint64_t out_oidx=0;out_oidx<out_grid->oSites();out_oidx++),{ thread_loop( (uint64_t out_oidx=0;out_oidx<out_grid->oSites();out_oidx++),{
Coordinate out_ocoor(ndim); Coordinate out_ocoor(ndim);
out_grid->oCoorFromOindex(out_ocoor, out_oidx); out_grid->oCoorFromOindex(out_ocoor, out_oidx);
@ -688,7 +705,7 @@ void precisionChange(Lattice<VobjOut> &out, const Lattice<VobjIn> &in){
int llex; Lexicographic::IndexFromCoor(lcoor, llex, out_grid->_ldimensions); int llex; Lexicographic::IndexFromCoor(lcoor, llex, out_grid->_ldimensions);
ptrs[lane] = &in_slex_conv[llex]; ptrs[lane] = &in_slex_conv[llex];
} }
merge(out[out_oidx], ptrs, 0); merge(out_v[out_oidx], ptrs, 0);
}); });
} }

View File

@ -41,8 +41,10 @@ NAMESPACE_BEGIN(Grid);
template<class vobj> template<class vobj>
inline Lattice<vobj> transpose(const Lattice<vobj> &lhs){ inline Lattice<vobj> transpose(const Lattice<vobj> &lhs){
Lattice<vobj> ret(lhs.Grid()); Lattice<vobj> ret(lhs.Grid());
accelerator_loop(ss,lhs,{ auto ret_v = ret.View();
ret[ss] = transpose(lhs[ss]); auto lhs_v = lhs.View();
accelerator_loop(ss,lhs_v,{
ret_v[ss] = transpose(lhs_v[ss]);
}); });
return ret; return ret;
}; };
@ -51,11 +53,13 @@ inline Lattice<vobj> transpose(const Lattice<vobj> &lhs){
// Index level dependent transpose // Index level dependent transpose
//////////////////////////////////////////////////////////////////////////////////////////////////// ////////////////////////////////////////////////////////////////////////////////////////////////////
template<int Index,class vobj> template<int Index,class vobj>
inline auto TransposeIndex(const Lattice<vobj> &lhs) -> Lattice<decltype(transposeIndex<Index>(lhs[0]))> inline auto TransposeIndex(const Lattice<vobj> &lhs) -> Lattice<decltype(transposeIndex<Index>(vobj()))>
{ {
Lattice<decltype(transposeIndex<Index>(lhs[0]))> ret(lhs.Grid()); Lattice<decltype(transposeIndex<Index>(vobj()))> ret(lhs.Grid());
accelerator_loop(ss,lhs,{ auto ret_v = ret.View();
ret[ss] = transposeIndex<Index>(lhs[ss]); auto lhs_v = lhs.View();
accelerator_loop(ss,lhs_v,{
ret_v[ss] = transposeIndex<Index>(lhs_v[ss]);
}); });
return ret; return ret;
}; };

View File

@ -33,43 +33,47 @@ Author: paboyle <paboyle@ph.ed.ac.uk>
NAMESPACE_BEGIN(Grid); NAMESPACE_BEGIN(Grid);
template<class obj> Lattice<obj> pow(const Lattice<obj> &rhs,RealD y){ template<class obj> Lattice<obj> pow(const Lattice<obj> &rhs_i,RealD y){
Lattice<obj> ret(rhs.Grid()); Lattice<obj> ret_i(rhs_i.Grid());
auto rhs = rhs_i.View();
auto ret = ret_i.View();
ret.Checkerboard() = rhs.Checkerboard(); ret.Checkerboard() = rhs.Checkerboard();
conformable(ret,rhs);
accelerator_loop(ss,rhs,{ accelerator_loop(ss,rhs,{
ret[ss]=pow(rhs[ss],y); ret[ss]=pow(rhs[ss],y);
}); });
return ret; return ret_i;
} }
template<class obj> Lattice<obj> mod(const Lattice<obj> &rhs,Integer y){ template<class obj> Lattice<obj> mod(const Lattice<obj> &rhs_i,Integer y){
Lattice<obj> ret(rhs.Grid()); Lattice<obj> ret_i(rhs_i.Grid());
auto rhs = rhs_i.View();
auto ret = ret_i.View();
ret.Checkerboard() = rhs.Checkerboard(); ret.Checkerboard() = rhs.Checkerboard();
conformable(ret,rhs);
accelerator_loop(ss,rhs,{ accelerator_loop(ss,rhs,{
ret[ss]=mod(rhs[ss],y); ret[ss]=mod(rhs[ss],y);
}); });
return ret; return ret_i;
} }
template<class obj> Lattice<obj> div(const Lattice<obj> &rhs,Integer y){ template<class obj> Lattice<obj> div(const Lattice<obj> &rhs_i,Integer y){
Lattice<obj> ret(rhs.Grid()); Lattice<obj> ret_i(rhs_i.Grid());
ret.Checkerboard() = rhs.Checkerboard(); auto ret = ret_i.View();
conformable(ret,rhs); auto rhs = rhs_i.View();
ret.Checkerboard() = rhs_i.Checkerboard();
accelerator_loop(ss,rhs,{ accelerator_loop(ss,rhs,{
ret[ss]=div(rhs[ss],y); ret[ss]=div(rhs[ss],y);
}); });
return ret; return ret_i;
} }
template<class obj> Lattice<obj> expMat(const Lattice<obj> &rhs, RealD alpha, Integer Nexp = DEFAULT_MAT_EXP){ template<class obj> Lattice<obj> expMat(const Lattice<obj> &rhs_i, RealD alpha, Integer Nexp = DEFAULT_MAT_EXP){
Lattice<obj> ret(rhs.Grid()); Lattice<obj> ret_i(rhs_i.Grid());
auto rhs = rhs_i.View();
auto ret = ret_i.View();
ret.Checkerboard() = rhs.Checkerboard(); ret.Checkerboard() = rhs.Checkerboard();
conformable(ret,rhs);
accelerator_loop(ss,rhs,{ accelerator_loop(ss,rhs,{
ret[ss]=Exponentiate(rhs[ss],alpha, Nexp); ret[ss]=Exponentiate(rhs[ss],alpha, Nexp);
}); });
return ret; return ret_i;
} }
NAMESPACE_END(Grid); NAMESPACE_END(Grid);

View File

@ -1,90 +0,0 @@
/*************************************************************************************
Grid physics library, www.github.com/paboyle/Grid
Source file: ./lib/lattice/Lattice_where.h
Copyright (C) 2015
Author: Azusa Yamaguchi <ayamaguc@staffmail.ed.ac.uk>
Author: Peter Boyle <paboyle@ph.ed.ac.uk>
Author: Peter Boyle <peterboyle@Peters-MacBook-Pro-2.local>
This program is free software; you can redistribute it and/or modify
it under the terms of the GNU General Public License as published by
the Free Software Foundation; either version 2 of the License, or
(at your option) any later version.
This program is distributed in the hope that it will be useful,
but WITHOUT ANY WARRANTY; without even the implied warranty of
MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
GNU General Public License for more details.
You should have received a copy of the GNU General Public License along
with this program; if not, write to the Free Software Foundation, Inc.,
51 Franklin Street, Fifth Floor, Boston, MA 02110-1301 USA.
See the full license in the file "LICENSE" in the top level distribution directory
*************************************************************************************/
/* END LEGAL */
#ifndef GRID_LATTICE_WHERE_H
#define GRID_LATTICE_WHERE_H
NAMESPACE_BEGIN(Grid);
// Must implement the predicate gating the
// Must be able to reduce the predicate down to a single vInteger per site.
// Must be able to require the type be iScalar x iScalar x ....
// give a GetVtype method in iScalar
// and blow away the tensor structures.
//
template<class vobj,class iobj>
inline void whereWolf(Lattice<vobj> &ret,const Lattice<iobj> &predicate,Lattice<vobj> &iftrue,Lattice<vobj> &iffalse)
{
conformable(iftrue,iffalse);
conformable(iftrue,predicate);
conformable(iftrue,ret);
GridBase *grid=iftrue.Grid();
typedef typename vobj::scalar_object scalar_object;
typedef typename vobj::scalar_type scalar_type;
typedef typename vobj::vector_type vector_type;
typedef typename iobj::vector_type mask_type;
const int Nsimd = grid->Nsimd();
std::vector<Integer> mask(Nsimd);
std::vector<scalar_object> truevals (Nsimd);
std::vector<scalar_object> falsevals(Nsimd);
thread_loop( (int ss=iftrue.begin(); ss<iftrue.end();ss++) , {
extract(iftrue[ss] ,truevals);
extract(iffalse[ss] ,falsevals);
extract<vInteger,Integer>(TensorRemove(predicate[ss]),mask);
for(int s=0;s<Nsimd;s++){
if (mask[s]) falsevals[s]=truevals[s];
}
merge(ret[ss],falsevals);
}
);
}
template<class vobj,class iobj>
inline Lattice<vobj> whereWolf(const Lattice<iobj> &predicate,Lattice<vobj> &iftrue,Lattice<vobj> &iffalse)
{
conformable(iftrue,iffalse);
conformable(iftrue,predicate);
Lattice<vobj> ret(iftrue.Grid());
where(ret,predicate,iftrue,iffalse);
return ret;
}
NAMESPACE_END(Grid);
#endif