diff --git a/Grid/lattice/Lattice_comparison.h b/Grid/lattice/Lattice_comparison.h index 3a54c114..bbed2ef5 100644 --- a/Grid/lattice/Lattice_comparison.h +++ b/Grid/lattice/Lattice_comparison.h @@ -40,18 +40,49 @@ NAMESPACE_BEGIN(Grid); //Query supporting logical &&, ||, ////////////////////////////////////////////////////////////////////////// +typedef iScalar vPredicate ; + +/* +template accelerator_inline +vobj predicatedWhere(const iobj &predicate, const vobj &iftrue, const robj &iffalse) +{ + typename std::remove_const::type ret; + + typedef typename vobj::scalar_object scalar_object; + typedef typename vobj::scalar_type scalar_type; + typedef typename vobj::vector_type vector_type; + + const int Nsimd = vobj::vector_type::Nsimd(); + + ExtractBuffer mask(Nsimd); + ExtractBuffer truevals(Nsimd); + ExtractBuffer falsevals(Nsimd); + + extract(iftrue, truevals); + extract(iffalse, falsevals); + extract(TensorRemove(predicate), mask); + + for (int s = 0; s < Nsimd; s++) { + if (mask[s]) falsevals[s] = truevals[s]; + } + + merge(ret, falsevals); + return ret; +} +*/ ////////////////////////////////////////////////////////////////////////// // compare lattice to lattice ////////////////////////////////////////////////////////////////////////// + template -inline Lattice LLComparison(vfunctor op,const Lattice &lhs,const Lattice &rhs) +inline Lattice LLComparison(vfunctor op,const Lattice &lhs,const Lattice &rhs) { - Lattice ret(rhs.Grid()); + Lattice ret(rhs.Grid()); auto lhs_v = lhs.View(); auto rhs_v = rhs.View(); auto ret_v = ret.View(); - accelerator_loop( ss, rhs_v, { - ret_v[ss]=op(lhs_v[ss],rhs_v[ss]); + thread_for( ss, rhs_v.size(), { + ret_v[ss]=op(lhs_v[ss],rhs_v[ss]); }); return ret; } @@ -59,12 +90,12 @@ inline Lattice LLComparison(vfunctor op,const Lattice &lhs,const // compare lattice to scalar ////////////////////////////////////////////////////////////////////////// template -inline Lattice LSComparison(vfunctor op,const Lattice &lhs,const robj &rhs) +inline Lattice LSComparison(vfunctor op,const Lattice &lhs,const robj &rhs) { - Lattice ret(lhs.Grid()); + Lattice ret(lhs.Grid()); auto lhs_v = lhs.View(); auto ret_v = ret.View(); - accelerator_loop( ss, lhs_v, { + thread_for( ss, lhs_v.size(), { ret_v[ss]=op(lhs_v[ss],rhs); }); return ret; @@ -73,12 +104,12 @@ inline Lattice LSComparison(vfunctor op,const Lattice &lhs,const // compare scalar to lattice ////////////////////////////////////////////////////////////////////////// template -inline Lattice SLComparison(vfunctor op,const lobj &lhs,const Lattice &rhs) +inline Lattice SLComparison(vfunctor op,const lobj &lhs,const Lattice &rhs) { - Lattice ret(rhs.Grid()); + Lattice ret(rhs.Grid()); auto rhs_v = rhs.View(); auto ret_v = ret.View(); - accelerator_loop( ss, rhs_v, { + thread_for( ss, rhs_v.size(), { ret_v[ss]=op(lhs,rhs_v[ss]); }); return ret; @@ -89,87 +120,87 @@ inline Lattice SLComparison(vfunctor op,const lobj &lhs,const Lattice< ////////////////////////////////////////////////////////////////////////// // Less than template -inline Lattice operator < (const Lattice & lhs, const Lattice & rhs) { +inline Lattice operator < (const Lattice & lhs, const Lattice & rhs) { return LLComparison(vlt(),lhs,rhs); } template -inline Lattice operator < (const Lattice & lhs, const robj & rhs) { +inline Lattice operator < (const Lattice & lhs, const robj & rhs) { return LSComparison(vlt(),lhs,rhs); } template -inline Lattice operator < (const lobj & lhs, const Lattice & rhs) { +inline Lattice operator < (const lobj & lhs, const Lattice & rhs) { return SLComparison(vlt(),lhs,rhs); } // Less than equal template -inline Lattice operator <= (const Lattice & lhs, const Lattice & rhs) { +inline Lattice operator <= (const Lattice & lhs, const Lattice & rhs) { return LLComparison(vle(),lhs,rhs); } template -inline Lattice operator <= (const Lattice & lhs, const robj & rhs) { +inline Lattice operator <= (const Lattice & lhs, const robj & rhs) { return LSComparison(vle(),lhs,rhs); } template -inline Lattice operator <= (const lobj & lhs, const Lattice & rhs) { +inline Lattice operator <= (const lobj & lhs, const Lattice & rhs) { return SLComparison(vle(),lhs,rhs); } // Greater than template -inline Lattice operator > (const Lattice & lhs, const Lattice & rhs) { +inline Lattice operator > (const Lattice & lhs, const Lattice & rhs) { return LLComparison(vgt(),lhs,rhs); } template -inline Lattice operator > (const Lattice & lhs, const robj & rhs) { +inline Lattice operator > (const Lattice & lhs, const robj & rhs) { return LSComparison(vgt(),lhs,rhs); } template -inline Lattice operator > (const lobj & lhs, const Lattice & rhs) { +inline Lattice operator > (const lobj & lhs, const Lattice & rhs) { return SLComparison(vgt(),lhs,rhs); } // Greater than equal template -inline Lattice operator >= (const Lattice & lhs, const Lattice & rhs) { +inline Lattice operator >= (const Lattice & lhs, const Lattice & rhs) { return LLComparison(vge(),lhs,rhs); } template -inline Lattice operator >= (const Lattice & lhs, const robj & rhs) { +inline Lattice operator >= (const Lattice & lhs, const robj & rhs) { return LSComparison(vge(),lhs,rhs); } template -inline Lattice operator >= (const lobj & lhs, const Lattice & rhs) { +inline Lattice operator >= (const lobj & lhs, const Lattice & rhs) { return SLComparison(vge(),lhs,rhs); } // equal template -inline Lattice operator == (const Lattice & lhs, const Lattice & rhs) { +inline Lattice operator == (const Lattice & lhs, const Lattice & rhs) { return LLComparison(veq(),lhs,rhs); } template -inline Lattice operator == (const Lattice & lhs, const robj & rhs) { +inline Lattice operator == (const Lattice & lhs, const robj & rhs) { return LSComparison(veq(),lhs,rhs); } template -inline Lattice operator == (const lobj & lhs, const Lattice & rhs) { +inline Lattice operator == (const lobj & lhs, const Lattice & rhs) { return SLComparison(veq(),lhs,rhs); } // not equal template -inline Lattice operator != (const Lattice & lhs, const Lattice & rhs) { +inline Lattice operator != (const Lattice & lhs, const Lattice & rhs) { return LLComparison(vne(),lhs,rhs); } template -inline Lattice operator != (const Lattice & lhs, const robj & rhs) { +inline Lattice operator != (const Lattice & lhs, const robj & rhs) { return LSComparison(vne(),lhs,rhs); } template -inline Lattice operator != (const lobj & lhs, const Lattice & rhs) { +inline Lattice operator != (const lobj & lhs, const Lattice & rhs) { return SLComparison(vne(),lhs,rhs); } NAMESPACE_END(Grid);