/************************************************************************************* Grid physics library, www.github.com/paboyle/Grid Source file: ./lib/lattice/Lattice_ET.h Copyright (C) 2015 Author: Azusa Yamaguchi Author: Peter Boyle Author: neo This program is free software; you can redistribute it and/or modify it under the terms of the GNU General Public License as published by the Free Software Foundation; either version 2 of the License, or (at your option) any later version. This program is distributed in the hope that it will be useful, but WITHOUT ANY WARRANTY; without even the implied warranty of MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the GNU General Public License for more details. You should have received a copy of the GNU General Public License along with this program; if not, write to the Free Software Foundation, Inc., 51 Franklin Street, Fifth Floor, Boston, MA 02110-1301 USA. See the full license in the file "LICENSE" in the top level distribution directory *************************************************************************************/ /* END LEGAL */ #ifndef GRID_LATTICE_ET_H #define GRID_LATTICE_ET_H #include #include #include #include NAMESPACE_BEGIN(Grid); //////////////////////////////////////////////////// // Predicated where support //////////////////////////////////////////////////// template inline vobj predicatedWhere(const iobj &predicate, const vobj &iftrue, const robj &iffalse) { typename std::remove_const::type ret; typedef typename vobj::scalar_object scalar_object; typedef typename vobj::scalar_type scalar_type; typedef typename vobj::vector_type vector_type; const int Nsimd = vobj::vector_type::Nsimd(); std::vector mask(Nsimd); std::vector truevals(Nsimd); std::vector falsevals(Nsimd); extract(iftrue, truevals); extract(iffalse, falsevals); extract(TensorRemove(predicate), mask); for (int s = 0; s < Nsimd; s++) { if (mask[s]) falsevals[s] = truevals[s]; } merge(ret, falsevals); return ret; } //////////////////////////////////////////// // recursive evaluation of expressions; Could // switch to generic approach with variadics, a la // Antonin's Lat Sim but the repack to variadic with popped // from tuple is hideous; C++14 introduces std::make_index_sequence for this //////////////////////////////////////////// // leaf eval of lattice ; should enable if protect using traits template using is_lattice = std::is_base_of; template using is_lattice_expr = std::is_base_of; template using is_lattice_expr = std::is_base_of; //Specialization of getVectorType for lattices template struct getVectorType >{ typedef typename Lattice::vector_object type; }; template inline sobj eval(const unsigned int ss, const sobj &arg) { return arg; } template inline const lobj &eval(const unsigned int ss, const Lattice &arg) { return arg[ss]; } // handle nodes in syntax tree template auto inline eval( const unsigned int ss, const LatticeUnaryExpression &expr) // eval one operand -> decltype(expr.first.func(eval(ss, std::get<0>(expr.second)))) { return expr.first.func(eval(ss, std::get<0>(expr.second))); } template auto inline eval( const unsigned int ss, const LatticeBinaryExpression &expr) // eval two operands -> decltype(expr.first.func(eval(ss, std::get<0>(expr.second)), eval(ss, std::get<1>(expr.second)))) { return expr.first.func(eval(ss, std::get<0>(expr.second)), eval(ss, std::get<1>(expr.second))); } template auto inline eval(const unsigned int ss, const LatticeTrinaryExpression &expr) // eval three operands -> decltype(expr.first.func(eval(ss, std::get<0>(expr.second)), eval(ss, std::get<1>(expr.second)), eval(ss, std::get<2>(expr.second)))) { return expr.first.func(eval(ss, std::get<0>(expr.second)), eval(ss, std::get<1>(expr.second)), eval(ss, std::get<2>(expr.second))); } ////////////////////////////////////////////////////////////////////////// // Obtain the grid from an expression, ensuring conformable. This must follow a // tree recursion ////////////////////////////////////////////////////////////////////////// template ::value, T1>::type * = nullptr> inline void GridFromExpression(GridBase *&grid, const T1 &lat) // Lattice leaf { if (grid) { conformable(grid, lat._grid); } grid = lat._grid; } template ::value, T1>::type * = nullptr> inline void GridFromExpression(GridBase *&grid, const T1 ¬lat) // non-lattice leaf {} template inline void GridFromExpression(GridBase *&grid, const LatticeUnaryExpression &expr) { GridFromExpression(grid, std::get<0>(expr.second)); // recurse } template inline void GridFromExpression( GridBase *&grid, const LatticeBinaryExpression &expr) { GridFromExpression(grid, std::get<0>(expr.second)); // recurse GridFromExpression(grid, std::get<1>(expr.second)); } template inline void GridFromExpression( GridBase *&grid, const LatticeTrinaryExpression &expr) { GridFromExpression(grid, std::get<0>(expr.second)); // recurse GridFromExpression(grid, std::get<1>(expr.second)); GridFromExpression(grid, std::get<2>(expr.second)); } ////////////////////////////////////////////////////////////////////////// // Obtain the CB from an expression, ensuring conformable. This must follow a // tree recursion ////////////////////////////////////////////////////////////////////////// template ::value, T1>::type * = nullptr> inline void CBFromExpression(int &cb, const T1 &lat) // Lattice leaf { if ((cb == Odd) || (cb == Even)) { assert(cb == lat.Checkerboard()); } cb = lat.Checkerboard(); // std::cout<::value, T1>::type * = nullptr> inline void CBFromExpression(int &cb, const T1 ¬lat) // non-lattice leaf { // std::cout< inline void CBFromExpression(int &cb, const LatticeUnaryExpression &expr) { CBFromExpression(cb, std::get<0>(expr.second)); // recurse // std::cout< inline void CBFromExpression(int &cb, const LatticeBinaryExpression &expr) { CBFromExpression(cb, std::get<0>(expr.second)); // recurse CBFromExpression(cb, std::get<1>(expr.second)); // std::cout< inline void CBFromExpression( int &cb, const LatticeTrinaryExpression &expr) { CBFromExpression(cb, std::get<0>(expr.second)); // recurse CBFromExpression(cb, std::get<1>(expr.second)); CBFromExpression(cb, std::get<2>(expr.second)); // std::cout< \ struct name { \ static auto inline func(const arg a) -> decltype(ret) { return ret; } \ }; GridUnopClass(UnarySub, -a); GridUnopClass(UnaryNot, Not(a)); GridUnopClass(UnaryAdj, adj(a)); GridUnopClass(UnaryConj, conjugate(a)); GridUnopClass(UnaryTrace, trace(a)); GridUnopClass(UnaryTranspose, transpose(a)); GridUnopClass(UnaryTa, Ta(a)); GridUnopClass(UnaryProjectOnGroup, ProjectOnGroup(a)); GridUnopClass(UnaryReal, real(a)); GridUnopClass(UnaryImag, imag(a)); GridUnopClass(UnaryToReal, toReal(a)); GridUnopClass(UnaryToComplex, toComplex(a)); GridUnopClass(UnaryTimesI, timesI(a)); GridUnopClass(UnaryTimesMinusI, timesMinusI(a)); GridUnopClass(UnaryAbs, abs(a)); GridUnopClass(UnarySqrt, sqrt(a)); GridUnopClass(UnaryRsqrt, rsqrt(a)); GridUnopClass(UnarySin, sin(a)); GridUnopClass(UnaryCos, cos(a)); GridUnopClass(UnaryAsin, asin(a)); GridUnopClass(UnaryAcos, acos(a)); GridUnopClass(UnaryLog, log(a)); GridUnopClass(UnaryExp, exp(a)); //////////////////////////////////////////// // Binary operators //////////////////////////////////////////// #define GridBinOpClass(name, combination) \ template \ struct name { \ static auto inline func(const left &lhs, const right &rhs) \ -> decltype(combination) const { \ return combination; \ } \ } GridBinOpClass(BinaryAdd, lhs + rhs); GridBinOpClass(BinarySub, lhs - rhs); GridBinOpClass(BinaryMul, lhs *rhs); GridBinOpClass(BinaryDiv, lhs /rhs); GridBinOpClass(BinaryAnd, lhs &rhs); GridBinOpClass(BinaryOr, lhs | rhs); GridBinOpClass(BinaryAndAnd, lhs &&rhs); GridBinOpClass(BinaryOrOr, lhs || rhs); //////////////////////////////////////////////////// // Trinary conditional op //////////////////////////////////////////////////// #define GridTrinOpClass(name, combination) \ template \ struct name { \ static auto inline func(const predicate &pred, const left &lhs, \ const right &rhs) -> decltype(combination) const { \ return combination; \ } \ } GridTrinOpClass( TrinaryWhere, (predicatedWhere::type, typename std::remove_reference::type>(pred, lhs, rhs))); //////////////////////////////////////////// // Operator syntactical glue //////////////////////////////////////////// #define GRID_UNOP(name) name #define GRID_BINOP(name) name #define GRID_TRINOP(name) \ name #define GRID_DEF_UNOP(op, name) \ template ::value || \ is_lattice_expr::value, \ T1>::type * = nullptr> \ inline auto op(const T1 &arg) \ ->decltype(LatticeUnaryExpression( \ std::make_pair(GRID_UNOP(name)(), std::forward_as_tuple(arg)))) { \ return LatticeUnaryExpression( \ std::make_pair(GRID_UNOP(name)(), std::forward_as_tuple(arg))); \ } #define GRID_BINOP_LEFT(op, name) \ template ::value || \ is_lattice_expr::value, \ T1>::type * = nullptr> \ inline auto op(const T1 &lhs, const T2 &rhs) \ ->decltype( \ LatticeBinaryExpression( \ std::make_pair(GRID_BINOP(name)(), \ std::forward_as_tuple(lhs, rhs)))) { \ return LatticeBinaryExpression( \ std::make_pair(GRID_BINOP(name)(), std::forward_as_tuple(lhs, rhs))); \ } #define GRID_BINOP_RIGHT(op, name) \ template ::value && \ !is_lattice_expr::value, \ T1>::type * = nullptr, \ typename std::enable_if::value || \ is_lattice_expr::value, \ T2>::type * = nullptr> \ inline auto op(const T1 &lhs, const T2 &rhs) \ ->decltype( \ LatticeBinaryExpression( \ std::make_pair(GRID_BINOP(name)(), \ std::forward_as_tuple(lhs, rhs)))) { \ return LatticeBinaryExpression( \ std::make_pair(GRID_BINOP(name)(), std::forward_as_tuple(lhs, rhs))); \ } #define GRID_DEF_BINOP(op, name) \ GRID_BINOP_LEFT(op, name); \ GRID_BINOP_RIGHT(op, name); #define GRID_DEF_TRINOP(op, name) \ template \ inline auto op(const T1 &pred, const T2 &lhs, const T3 &rhs) \ ->decltype( \ LatticeTrinaryExpression(std::make_pair( \ GRID_TRINOP(name)(), std::forward_as_tuple(pred, lhs, rhs)))) { \ return LatticeTrinaryExpression(std::make_pair( \ GRID_TRINOP(name)(), std::forward_as_tuple(pred, lhs, rhs))); \ } //////////////////////// // Operator definitions //////////////////////// GRID_DEF_UNOP(operator-, UnarySub); GRID_DEF_UNOP(Not, UnaryNot); GRID_DEF_UNOP(operator!, UnaryNot); GRID_DEF_UNOP(adj, UnaryAdj); GRID_DEF_UNOP(conjugate, UnaryConj); GRID_DEF_UNOP(trace, UnaryTrace); GRID_DEF_UNOP(transpose, UnaryTranspose); GRID_DEF_UNOP(Ta, UnaryTa); GRID_DEF_UNOP(ProjectOnGroup, UnaryProjectOnGroup); GRID_DEF_UNOP(real, UnaryReal); GRID_DEF_UNOP(imag, UnaryImag); GRID_DEF_UNOP(toReal, UnaryToReal); GRID_DEF_UNOP(toComplex, UnaryToComplex); GRID_DEF_UNOP(timesI, UnaryTimesI); GRID_DEF_UNOP(timesMinusI, UnaryTimesMinusI); GRID_DEF_UNOP(abs, UnaryAbs); // abs overloaded in cmath C++98; DON'T do the // abs-fabs-dabs-labs thing GRID_DEF_UNOP(sqrt, UnarySqrt); GRID_DEF_UNOP(rsqrt, UnaryRsqrt); GRID_DEF_UNOP(sin, UnarySin); GRID_DEF_UNOP(cos, UnaryCos); GRID_DEF_UNOP(asin, UnaryAsin); GRID_DEF_UNOP(acos, UnaryAcos); GRID_DEF_UNOP(log, UnaryLog); GRID_DEF_UNOP(exp, UnaryExp); GRID_DEF_BINOP(operator+, BinaryAdd); GRID_DEF_BINOP(operator-, BinarySub); GRID_DEF_BINOP(operator*, BinaryMul); GRID_DEF_BINOP(operator/, BinaryDiv); GRID_DEF_BINOP(operator&, BinaryAnd); GRID_DEF_BINOP(operator|, BinaryOr); GRID_DEF_BINOP(operator&&, BinaryAndAnd); GRID_DEF_BINOP(operator||, BinaryOrOr); GRID_DEF_TRINOP(where, TrinaryWhere); ///////////////////////////////////////////////////////////// // Closure convenience to force expression to evaluate ///////////////////////////////////////////////////////////// template auto closure(const LatticeUnaryExpression &expr) -> Lattice(expr.second))))> { Lattice(expr.second))))> ret( expr); return ret; } template auto closure(const LatticeBinaryExpression &expr) -> Lattice(expr.second)), eval(0, std::get<1>(expr.second))))> { Lattice(expr.second)), eval(0, std::get<1>(expr.second))))> ret(expr); return ret; } template auto closure(const LatticeTrinaryExpression &expr) -> Lattice(expr.second)), eval(0, std::get<1>(expr.second)), eval(0, std::get<2>(expr.second))))> { Lattice(expr.second)), eval(0, std::get<1>(expr.second)), eval(0, std::get<2>(expr.second))))> ret(expr); return ret; } #undef GRID_UNOP #undef GRID_BINOP #undef GRID_TRINOP #undef GRID_DEF_UNOP #undef GRID_DEF_BINOP #undef GRID_DEF_TRINOP NAMESPACE_END(Grid); #if 0 using namespace Grid; int main(int argc,char **argv){ Lattice v1(16); Lattice v2(16); Lattice v3(16); BinaryAdd tmp; LatticeBinaryExpression,Lattice &,Lattice &> expr(std::make_pair(tmp, std::forward_as_tuple(v1,v2))); tmp.func(eval(0,v1),eval(0,v2)); auto var = v1+v2; std::cout< &v1,Lattice &v2,Lattice &v3) { v3=v1+v2+v1*v2; } #endif #endif