diff --git a/Grid/threads/Threads.h b/Grid/threads/Threads.h index 6403bff9..580278d6 100644 --- a/Grid/threads/Threads.h +++ b/Grid/threads/Threads.h @@ -39,12 +39,7 @@ Author: paboyle #include // complex reductions -#pragma omp declare reduction(ComplexPlus: Grid::Complex: omp_out += omp_in) -#pragma omp declare reduction(GridVComplexPlus: Grid::vComplex: omp_out += omp_in) -#pragma omp declare reduction(ComplexDPlus: Grid::ComplexD: omp_out += omp_in) -#pragma omp declare reduction(GridVComplexDPlus: Grid::vComplexD: omp_out += omp_in) -#pragma omp declare reduction(ComplexFPlus: Grid::ComplexF: omp_out += omp_in) -#pragma omp declare reduction(GridVComplexFPlus: Grid::vComplexF: omp_out += omp_in) +#pragma omp declare reduction(ComplexPlus:Grid::ComplexD, Grid::vComplexD, Grid::ComplexF, Grid::vComplexF: omp_out += omp_in) #define PARALLEL_FOR_LOOP _Pragma("omp parallel for schedule(static)") #define PARALLEL_FOR_LOOP_INTERN _Pragma("omp for schedule(static)") diff --git a/Hadrons/A2AMatrix.hpp b/Hadrons/A2AMatrix.hpp index 9842e15b..5266d050 100644 --- a/Hadrons/A2AMatrix.hpp +++ b/Hadrons/A2AMatrix.hpp @@ -32,6 +32,10 @@ See the full license in the file "LICENSE" in the top level distribution directo #include #include #include +#ifdef USE_MKL +#include "mkl.h" +#include "mkl_cblas.h" +#endif #ifndef HADRONS_A2AM_NAME #define HADRONS_A2AM_NAME "a2aMatrix" @@ -58,6 +62,9 @@ using A2AMatrixSet = Eigen::TensorMap>; template using A2AMatrix = Eigen::Matrix; +template +using A2AMatrixTr = Eigen::Matrix; + /****************************************************************************** * Abstract class for A2A kernels * ******************************************************************************/ @@ -150,6 +157,198 @@ private: std::vector nodeIo_; }; +/****************************************************************************** + * A2A matrix contraction kernels * + ******************************************************************************/ +class A2AContraction +{ +public: + // accTrMul(acc, a, b): acc += tr(a*b) + template + static inline void accTrMul(C &acc, const MatLeft &a, const MatRight &b) + { + if ((MatLeft::Options == Eigen::RowMajor) and + (MatRight::Options == Eigen::ColMajor)) + { + parallel_for_reduce(ComplexPlus, acc) (unsigned int r = 0; r < a.rows(); ++r) + { +#ifdef USE_MKL + ComplexD tmp; + dotuRow(tmp, r, a, b); + acc += tmp; +#else + acc += a.row(r).conjugate().dot(b.col(r)); +#endif + } + } + else + { + parallel_for_reduce(ComplexPlus, acc) (unsigned int c = 0; c < a.cols(); ++c) + { +#ifdef USE_MKL + ComplexD tmp; + dotuCol(tmp, c, a, b); + acc += tmp; +#else + acc += a.col(c).conjugate().dot(b.row(c)); +#endif + } + } + } + + // mul(res, a, b): res = a*b +#ifdef USE_MKL + template