/************************************************************************************* Grid physics library, www.github.com/paboyle/Grid Source file: ./lib/lattice/Lattice_ET.h Copyright (C) 2015 Author: Azusa Yamaguchi Author: Peter Boyle Author: neo This program is free software; you can redistribute it and/or modify it under the terms of the GNU General Public License as published by the Free Software Foundation; either version 2 of the License, or (at your option) any later version. This program is distributed in the hope that it will be useful, but WITHOUT ANY WARRANTY; without even the implied warranty of MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the GNU General Public License for more details. You should have received a copy of the GNU General Public License along with this program; if not, write to the Free Software Foundation, Inc., 51 Franklin Street, Fifth Floor, Boston, MA 02110-1301 USA. See the full license in the file "LICENSE" in the top level distribution directory *************************************************************************************/ /* END LEGAL */ #ifndef GRID_LATTICE_ET_H #define GRID_LATTICE_ET_H #include #include #include #include NAMESPACE_BEGIN(Grid); //////////////////////////////////////////////////// // Predicated where support //////////////////////////////////////////////////// template accelerator_inline vobj predicatedWhere(const iobj &predicate, const vobj &iftrue, const robj &iffalse) { typename std::remove_const::type ret; typedef typename vobj::scalar_object scalar_object; typedef typename vobj::scalar_type scalar_type; typedef typename vobj::vector_type vector_type; const int Nsimd = vobj::vector_type::Nsimd(); ExtractBuffer mask(Nsimd); ExtractBuffer truevals(Nsimd); ExtractBuffer falsevals(Nsimd); extract(iftrue, truevals); extract(iffalse, falsevals); extract(TensorRemove(predicate), mask); for (int s = 0; s < Nsimd; s++) { if (mask[s]) falsevals[s] = truevals[s]; } merge(ret, falsevals); return ret; } ///////////////////////////////////////////////////// //Specialization of getVectorType for lattices ///////////////////////////////////////////////////// template struct getVectorType >{ typedef typename Lattice::vector_object type; }; //////////////////////////////////////////// //-- recursive evaluation of expressions; -- // handle leaves of syntax tree /////////////////////////////////////////////////// template accelerator_inline sobj eval(const uint64_t ss, const sobj &arg) { return arg; } template accelerator_inline const lobj & eval(const uint64_t ss, const LatticeView &arg) { return arg[ss]; } template accelerator_inline const lobj & eval(const uint64_t ss, const Lattice &arg) { auto view = arg.View(); return view[ss]; } /////////////////////////////////////////////////// // handle nodes in syntax tree- eval one operand /////////////////////////////////////////////////// template accelerator_inline auto eval(const uint64_t ss, const LatticeUnaryExpression &expr) -> decltype(expr.op.func( eval(ss, expr.arg1))) { return expr.op.func( eval(ss, expr.arg1) ); } /////////////////////// // eval two operands /////////////////////// template accelerator_inline auto eval(const uint64_t ss, const LatticeBinaryExpression &expr) -> decltype(expr.op.func( eval(ss,expr.arg1),eval(ss,expr.arg2))) { return expr.op.func( eval(ss,expr.arg1), eval(ss,expr.arg2) ); } /////////////////////// // eval three operands /////////////////////// template accelerator_inline auto eval(const uint64_t ss, const LatticeTrinaryExpression &expr) -> decltype(expr.op.func(eval(ss, expr.arg1), eval(ss, expr.arg2), eval(ss, expr.arg3))) { return expr.op.func(eval(ss, expr.arg1), eval(ss, expr.arg2), eval(ss, expr.arg3)); } ////////////////////////////////////////////////////////////////////////// // Obtain the grid from an expression, ensuring conformable. This must follow a // tree recursion; must retain grid pointer in the LatticeView class which sucks // Use a different method, and make it void *. // Perhaps a conformable method. ////////////////////////////////////////////////////////////////////////// template ::value, T1>::type * = nullptr> accelerator_inline void GridFromExpression(GridBase *&grid, const T1 &lat) // Lattice leaf { lat.Conformable(grid); } template ::value, T1>::type * = nullptr> accelerator_inline void GridFromExpression(GridBase *&grid,const T1 ¬lat) // non-lattice leaf {} template accelerator_inline void GridFromExpression(GridBase *&grid,const LatticeUnaryExpression &expr) { GridFromExpression(grid, expr.arg1); // recurse } template accelerator_inline void GridFromExpression(GridBase *&grid, const LatticeBinaryExpression &expr) { GridFromExpression(grid, expr.arg1); // recurse GridFromExpression(grid, expr.arg2); } template accelerator_inline void GridFromExpression(GridBase *&grid, const LatticeTrinaryExpression &expr) { GridFromExpression(grid, expr.arg1); // recurse GridFromExpression(grid, expr.arg2); // recurse GridFromExpression(grid, expr.arg3); // recurse } ////////////////////////////////////////////////////////////////////////// // Obtain the CB from an expression, ensuring conformable. This must follow a // tree recursion ////////////////////////////////////////////////////////////////////////// template ::value, T1>::type * = nullptr> inline void CBFromExpression(int &cb, const T1 &lat) // Lattice leaf { if ((cb == Odd) || (cb == Even)) { assert(cb == lat.Checkerboard()); } cb = lat.Checkerboard(); } template ::value, T1>::type * = nullptr> inline void CBFromExpression(int &cb, const T1 ¬lat) // non-lattice leaf { } template inline void CBFromExpression(int &cb,const LatticeUnaryExpression &expr) { CBFromExpression(cb, expr.arg1); // recurse AST } template inline void CBFromExpression(int &cb,const LatticeBinaryExpression &expr) { CBFromExpression(cb, expr.arg1); // recurse AST CBFromExpression(cb, expr.arg2); // recurse AST } template inline void CBFromExpression(int &cb, const LatticeTrinaryExpression &expr) { CBFromExpression(cb, expr.arg1); // recurse AST CBFromExpression(cb, expr.arg2); // recurse AST CBFromExpression(cb, expr.arg3); // recurse AST } //////////////////////////////////////////// // Unary operators and funcs //////////////////////////////////////////// #define GridUnopClass(name, ret) \ template \ struct name { \ static auto accelerator_inline func(const arg a) -> decltype(ret) { return ret; } \ }; GridUnopClass(UnarySub, -a); GridUnopClass(UnaryNot, Not(a)); GridUnopClass(UnaryAdj, adj(a)); GridUnopClass(UnaryConj, conjugate(a)); GridUnopClass(UnaryTrace, trace(a)); GridUnopClass(UnaryTranspose, transpose(a)); GridUnopClass(UnaryTa, Ta(a)); GridUnopClass(UnaryProjectOnGroup, ProjectOnGroup(a)); GridUnopClass(UnaryReal, real(a)); GridUnopClass(UnaryImag, imag(a)); GridUnopClass(UnaryToReal, toReal(a)); GridUnopClass(UnaryToComplex, toComplex(a)); GridUnopClass(UnaryTimesI, timesI(a)); GridUnopClass(UnaryTimesMinusI, timesMinusI(a)); GridUnopClass(UnaryAbs, abs(a)); GridUnopClass(UnarySqrt, sqrt(a)); GridUnopClass(UnaryRsqrt, rsqrt(a)); GridUnopClass(UnarySin, sin(a)); GridUnopClass(UnaryCos, cos(a)); GridUnopClass(UnaryAsin, asin(a)); GridUnopClass(UnaryAcos, acos(a)); GridUnopClass(UnaryLog, log(a)); GridUnopClass(UnaryExp, exp(a)); //////////////////////////////////////////// // Binary operators //////////////////////////////////////////// #define GridBinOpClass(name, combination) \ template \ struct name { \ static auto accelerator_inline \ func(const left &lhs, const right &rhs) \ -> decltype(combination) const \ { \ return combination; \ } \ }; GridBinOpClass(BinaryAdd, lhs + rhs); GridBinOpClass(BinarySub, lhs - rhs); GridBinOpClass(BinaryMul, lhs *rhs); GridBinOpClass(BinaryDiv, lhs /rhs); GridBinOpClass(BinaryAnd, lhs &rhs); GridBinOpClass(BinaryOr, lhs | rhs); GridBinOpClass(BinaryAndAnd, lhs &&rhs); GridBinOpClass(BinaryOrOr, lhs || rhs); //////////////////////////////////////////////////// // Trinary conditional op //////////////////////////////////////////////////// #define GridTrinOpClass(name, combination) \ template \ struct name { \ static auto accelerator_inline \ func(const predicate &pred, const left &lhs, const right &rhs) \ -> decltype(combination) const \ { \ return combination; \ } \ }; GridTrinOpClass(TrinaryWhere, (predicatedWhere::type, typename std::remove_reference::type>(pred, lhs,rhs))); //////////////////////////////////////////// // Operator syntactical glue //////////////////////////////////////////// #define GRID_UNOP(name) name #define GRID_BINOP(name) name #define GRID_TRINOP(name) name #define GRID_DEF_UNOP(op, name) \ template ::value||is_lattice_expr::value,T1>::type * = nullptr> \ inline auto op(const T1 &arg) ->decltype(LatticeUnaryExpression(GRID_UNOP(name)(), arg)) \ { \ return LatticeUnaryExpression(GRID_UNOP(name)(), arg); \ } #define GRID_BINOP_LEFT(op, name) \ template ::value||is_lattice_expr::value,T1>::type * = nullptr> \ inline auto op(const T1 &lhs, const T2 &rhs) \ ->decltype(LatticeBinaryExpression(GRID_BINOP(name)(),lhs,rhs)) \ { \ return LatticeBinaryExpression(GRID_BINOP(name)(),lhs,rhs);\ } #define GRID_BINOP_RIGHT(op, name) \ template ::value&&!is_lattice_expr::value,T1>::type * = nullptr, \ typename std::enable_if< is_lattice::value|| is_lattice_expr::value,T2>::type * = nullptr> \ inline auto op(const T1 &lhs, const T2 &rhs) \ ->decltype(LatticeBinaryExpression(GRID_BINOP(name)(),lhs, rhs)) \ { \ return LatticeBinaryExpression(GRID_BINOP(name)(),lhs, rhs); \ } #define GRID_DEF_BINOP(op, name) \ GRID_BINOP_LEFT(op, name); \ GRID_BINOP_RIGHT(op, name); #define GRID_DEF_TRINOP(op, name) \ template \ inline auto op(const T1 &pred, const T2 &lhs, const T3 &rhs) \ ->decltype(LatticeTrinaryExpression(GRID_TRINOP(name)(),pred, lhs, rhs)) \ { \ return LatticeTrinaryExpression(GRID_TRINOP(name)(),pred, lhs, rhs); \ } //////////////////////// // Operator definitions //////////////////////// GRID_DEF_UNOP(operator-, UnarySub); GRID_DEF_UNOP(Not, UnaryNot); GRID_DEF_UNOP(operator!, UnaryNot); GRID_DEF_UNOP(adj, UnaryAdj); GRID_DEF_UNOP(conjugate, UnaryConj); GRID_DEF_UNOP(trace, UnaryTrace); GRID_DEF_UNOP(transpose, UnaryTranspose); GRID_DEF_UNOP(Ta, UnaryTa); GRID_DEF_UNOP(ProjectOnGroup, UnaryProjectOnGroup); GRID_DEF_UNOP(real, UnaryReal); GRID_DEF_UNOP(imag, UnaryImag); GRID_DEF_UNOP(toReal, UnaryToReal); GRID_DEF_UNOP(toComplex, UnaryToComplex); GRID_DEF_UNOP(timesI, UnaryTimesI); GRID_DEF_UNOP(timesMinusI, UnaryTimesMinusI); GRID_DEF_UNOP(abs, UnaryAbs); // abs overloaded in cmath C++98; DON'T do the // abs-fabs-dabs-labs thing GRID_DEF_UNOP(sqrt, UnarySqrt); GRID_DEF_UNOP(rsqrt, UnaryRsqrt); GRID_DEF_UNOP(sin, UnarySin); GRID_DEF_UNOP(cos, UnaryCos); GRID_DEF_UNOP(asin, UnaryAsin); GRID_DEF_UNOP(acos, UnaryAcos); GRID_DEF_UNOP(log, UnaryLog); GRID_DEF_UNOP(exp, UnaryExp); GRID_DEF_BINOP(operator+, BinaryAdd); GRID_DEF_BINOP(operator-, BinarySub); GRID_DEF_BINOP(operator*, BinaryMul); GRID_DEF_BINOP(operator/, BinaryDiv); GRID_DEF_BINOP(operator&, BinaryAnd); GRID_DEF_BINOP(operator|, BinaryOr); GRID_DEF_BINOP(operator&&, BinaryAndAnd); GRID_DEF_BINOP(operator||, BinaryOrOr); GRID_DEF_TRINOP(where, TrinaryWhere); ///////////////////////////////////////////////////////////// // Closure convenience to force expression to evaluate ///////////////////////////////////////////////////////////// template auto closure(const LatticeUnaryExpression &expr) -> Lattice { Lattice ret(expr); return ret; } template auto closure(const LatticeBinaryExpression &expr) -> Lattice { Lattice ret(expr); return ret; } template auto closure(const LatticeTrinaryExpression &expr) -> Lattice { Lattice ret(expr); return ret; } #undef GRID_UNOP #undef GRID_BINOP #undef GRID_TRINOP #undef GRID_DEF_UNOP #undef GRID_DEF_BINOP #undef GRID_DEF_TRINOP NAMESPACE_END(Grid); #endif