diff --git a/lib/lattice/Lattice_comparison_utils.h b/lib/lattice/Lattice_comparison_utils.h index d5e3b3e5..d25c49f2 100644 --- a/lib/lattice/Lattice_comparison_utils.h +++ b/lib/lattice/Lattice_comparison_utils.h @@ -5,140 +5,197 @@ namespace Grid { ///////////////////////////////////////// // This implementation is a bit poor. - // Only support logical operations (== etc) - // on scalar objects. Strip any tensor structures. + // + // Only support relational logical operations (<, > etc) + // on scalar objects. Therefore can strip any tensor structures. + // // Should guard this with isGridTensor<> enable if? ///////////////////////////////////////// - // Generic list of functors - template class veq { + // + // Generic list of functors + // + template class veq { + public: + vInteger operator()(const lobj &lhs, const robj &rhs) + { + return (lhs) == (rhs); + } + }; + template class vne { + public: + vInteger operator()(const lobj &lhs, const robj &rhs) + { + return (lhs) != (rhs); + } + }; + template class vlt { + public: + vInteger operator()(const lobj &lhs, const robj &rhs) + { + return (lhs) < (rhs); + } + }; + template class vle { + public: + vInteger operator()(const lobj &lhs, const robj &rhs) + { + return (lhs) <= (rhs); + } + }; + template class vgt { + public: + vInteger operator()(const lobj &lhs, const robj &rhs) + { + return (lhs) > (rhs); + } + }; + template class vge { public: - vInteger operator()(const lobj &lhs, const robj &rhs) - { - return TensorRemove(lhs) == TensorRemove(rhs); - } - }; - template class vne { - public: - vInteger operator()(const lobj &lhs, const robj &rhs) - { - return TensorRemove(lhs) != TensorRemove(rhs); - } - }; - template class vlt { - public: - vInteger operator()(const lobj &lhs, const robj &rhs) - { - return TensorRemove(lhs) < TensorRemove(rhs); - } - }; - template class vle { - public: - vInteger operator()(const lobj &lhs, const robj &rhs) - { - return TensorRemove(lhs) <= TensorRemove(rhs); - } - }; - template class vgt { - public: - vInteger operator()(const lobj &lhs, const robj &rhs) - { - return TensorRemove(lhs) > TensorRemove(rhs); - } - }; - template class vge { - public: - vInteger operator()(const lobj &lhs, const robj &rhs) - { - return TensorRemove(lhs) >= TensorRemove(rhs); - } - }; + vInteger operator()(const lobj &lhs, const robj &rhs) + { + return (lhs) >= (rhs); + } + }; + + // Generic list of functors + template class seq { + public: + Integer operator()(const lobj &lhs, const robj &rhs) + { + return (lhs) == (rhs); + } + }; + template class sne { + public: + Integer operator()(const lobj &lhs, const robj &rhs) + { + return (lhs) != (rhs); + } + }; + template class slt { + public: + Integer operator()(const lobj &lhs, const robj &rhs) + { + return (lhs) < (rhs); + } + }; + template class sle { + public: + Integer operator()(const lobj &lhs, const robj &rhs) + { + return (lhs) <= (rhs); + } + }; + template class sgt { + public: + Integer operator()(const lobj &lhs, const robj &rhs) + { + return (lhs) > (rhs); + } + }; + template class sge { + public: + Integer operator()(const lobj &lhs, const robj &rhs) + { + return (lhs) >= (rhs); + } + }; - // Generic list of functors - template class seq { - public: - Integer operator()(const lobj &lhs, const robj &rhs) - { - return TensorRemove(lhs) == TensorRemove(rhs); - } - }; - template class sne { - public: - Integer operator()(const lobj &lhs, const robj &rhs) - { - return TensorRemove(lhs) != TensorRemove(rhs); - } - }; - template class slt { - public: - Integer operator()(const lobj &lhs, const robj &rhs) - { - return TensorRemove(lhs) < TensorRemove(rhs); - } - }; - template class sle { - public: - Integer operator()(const lobj &lhs, const robj &rhs) - { - return TensorRemove(lhs) <= TensorRemove(rhs); - } - }; - template class sgt { - public: - Integer operator()(const lobj &lhs, const robj &rhs) - { - return TensorRemove(lhs) > TensorRemove(rhs); - } - }; - template class sge { - public: - Integer operator()(const lobj &lhs, const robj &rhs) - { - return TensorRemove(lhs) >= TensorRemove(rhs); - } - }; - - - ////////////////////////////////////////////////////////////////////////////////////////////////////// - // Integer gets extra relational functions. Could also implement these for RealF, RealD etc.. - ////////////////////////////////////////////////////////////////////////////////////////////////////// - template - inline vInteger Comparison(sfunctor sop,const vInteger & lhs, const vInteger & rhs) + ////////////////////////////////////////////////////////////////////////////////////////////////////// + // Integer and real get extra relational functions. + ////////////////////////////////////////////////////////////////////////////////////////////////////// + template = 0> + inline vInteger Comparison(sfunctor sop,const vsimd & lhs, const vsimd & rhs) { - std::vector vlhs(vInteger::Nsimd()); // Use functors to reduce this to single implementation - std::vector vrhs(vInteger::Nsimd()); + typedef typename vsimd::scalar_type scalar; + std::vector vlhs(vsimd::Nsimd()); // Use functors to reduce this to single implementation + std::vector vrhs(vsimd::Nsimd()); + std::vector vpred(vsimd::Nsimd()); vInteger ret; - extract(lhs,vlhs); - extract(rhs,vrhs); - for(int s=0;s(lhs,vlhs); + extract(rhs,vrhs); + for(int s=0;s(ret,vlhs); + merge(ret,vpred); return ret; } - inline vInteger operator < (const vInteger & lhs, const vInteger & rhs) + + template = 0> + inline vInteger Comparison(sfunctor sop,const vsimd & lhs, const typename vsimd::scalar_type & rhs) { - return Comparison(slt(),lhs,rhs); + typedef typename vsimd::scalar_type scalar; + std::vector vlhs(vsimd::Nsimd()); // Use functors to reduce this to single implementation + std::vector vpred(vsimd::Nsimd()); + vInteger ret; + extract(lhs,vlhs); + for(int s=0;s(ret,vpred); + return ret; } - inline vInteger operator <= (const vInteger & lhs, const vInteger & rhs) + + template = 0> + inline vInteger Comparison(sfunctor sop,const typename vsimd::scalar_type & lhs, const vsimd & rhs) { - return Comparison(sle(),lhs,rhs); - } - inline vInteger operator > (const vInteger & lhs, const vInteger & rhs) - { - return Comparison(sgt(),lhs,rhs); - } - inline vInteger operator >= (const vInteger & lhs, const vInteger & rhs) - { - return Comparison(sge(),lhs,rhs); - } - inline vInteger operator == (const vInteger & lhs, const vInteger & rhs) - { - return Comparison(seq(),lhs,rhs); - } - inline vInteger operator != (const vInteger & lhs, const vInteger & rhs) - { - return Comparison(sne(),lhs,rhs); + typedef typename vsimd::scalar_type scalar; + std::vector vrhs(vsimd::Nsimd()); // Use functors to reduce this to single implementation + std::vector vpred(vsimd::Nsimd()); + vInteger ret; + extract(rhs,vrhs); + for(int s=0;s(ret,vpred); + return ret; } + +#define DECLARE_RELATIONAL(op,functor) \ + template = 0>\ + inline vInteger operator op (const vsimd & lhs, const vsimd & rhs)\ + {\ + typedef typename vsimd::scalar_type scalar;\ + return Comparison(functor(),lhs,rhs);\ + }\ + template = 0>\ + inline vInteger operator op (const vsimd & lhs, const typename vsimd::scalar_type & rhs) \ + {\ + typedef typename vsimd::scalar_type scalar;\ + return Comparison(functor(),lhs,rhs);\ + }\ + template = 0>\ + inline vInteger operator op (const typename vsimd::scalar_type & lhs, const vsimd & rhs) \ + {\ + typedef typename vsimd::scalar_type scalar;\ + return Comparison(functor(),lhs,rhs);\ + }\ + template\ + inline vInteger operator op(const iScalar &lhs,const iScalar &rhs)\ + { \ + return lhs._internal op rhs._internal; \ + } \ + template\ + inline vInteger operator op(const iScalar &lhs,const typename vsimd::scalar_type &rhs) \ + { \ + return lhs._internal op rhs; \ + } \ + template\ + inline vInteger operator op(const typename vsimd::scalar_type &lhs,const iScalar &rhs) \ + { \ + return lhs op rhs._internal; \ + } + + +DECLARE_RELATIONAL(<,slt); +DECLARE_RELATIONAL(<=,sle); +DECLARE_RELATIONAL(>,sgt); +DECLARE_RELATIONAL(>=,sge); +DECLARE_RELATIONAL(==,seq); +DECLARE_RELATIONAL(!=,sne); + +#undef DECLARE_RELATIONAL + }