diff --git a/Grid/algorithms/blas/BatchedBlas.h b/Grid/algorithms/blas/BatchedBlas.h index f4245319..3a7bbc44 100644 --- a/Grid/algorithms/blas/BatchedBlas.h +++ b/Grid/algorithms/blas/BatchedBlas.h @@ -208,8 +208,8 @@ public: assert(Bkn.size()==batchCount); assert(Cmn.size()==batchCount); - assert(OpA!=GridBLAS_OP_T); // Complex case expect no transpose - assert(OpB!=GridBLAS_OP_T); + //assert(OpA!=GridBLAS_OP_T); // Complex case expect no transpose + //assert(OpB!=GridBLAS_OP_T); int lda = m; // m x k column major int ldb = k; // k x n column major @@ -367,28 +367,67 @@ public: Eigen::Map eAmk(Amk[p],m,k); Eigen::Map eBkn(Bkn[p],k,n); Eigen::Map eCmn(Cmn[p],m,n); - eCmn = beta * eCmn + alpha * eAmk * eBkn ; + if (std::abs(beta) != 0.0) + eCmn = beta * eCmn + alpha * eAmk * eBkn ; + else + eCmn = alpha * eAmk * eBkn ; }); } else if ( (OpA == GridBLAS_OP_C ) && (OpB == GridBLAS_OP_N) ) { thread_for (p, batchCount, { Eigen::Map eAmk(Amk[p],k,m); Eigen::Map eBkn(Bkn[p],k,n); Eigen::Map eCmn(Cmn[p],m,n); - eCmn = beta * eCmn + alpha * eAmk.adjoint() * eBkn ; + if (std::abs(beta) != 0.0) + eCmn = beta * eCmn + alpha * eAmk.adjoint() * eBkn ; + else + eCmn = alpha * eAmk.adjoint() * eBkn ; + }); + } else if ( (OpA == GridBLAS_OP_T ) && (OpB == GridBLAS_OP_N) ) { + thread_for (p, batchCount, { + Eigen::Map eAmk(Amk[p],k,m); + Eigen::Map eBkn(Bkn[p],k,n); + Eigen::Map eCmn(Cmn[p],m,n); + if (std::abs(beta) != 0.0) + eCmn = beta * eCmn + alpha * eAmk.transpose() * eBkn ; + else + eCmn = alpha * eAmk.transpose() * eBkn ; }); } else if ( (OpA == GridBLAS_OP_N ) && (OpB == GridBLAS_OP_C) ) { thread_for (p, batchCount, { Eigen::Map eAmk(Amk[p],m,k); Eigen::Map eBkn(Bkn[p],n,k); Eigen::Map eCmn(Cmn[p],m,n); - eCmn = beta * eCmn + alpha * eAmk * eBkn.adjoint() ; + if (std::abs(beta) != 0.0) + eCmn = beta * eCmn + alpha * eAmk * eBkn.adjoint() ; + else + eCmn = alpha * eAmk * eBkn.adjoint() ; + }); + } else if ( (OpA == GridBLAS_OP_N ) && (OpB == GridBLAS_OP_T) ) { + thread_for (p, batchCount, { + Eigen::Map eAmk(Amk[p],m,k); + Eigen::Map eBkn(Bkn[p],n,k); + Eigen::Map eCmn(Cmn[p],m,n); + eCmn = beta * eCmn + alpha * eAmk * eBkn.transpose() ; }); } else if ( (OpA == GridBLAS_OP_C ) && (OpB == GridBLAS_OP_C) ) { thread_for (p, batchCount, { Eigen::Map eAmk(Amk[p],k,m); Eigen::Map eBkn(Bkn[p],n,k); Eigen::Map eCmn(Cmn[p],m,n); - eCmn = beta * eCmn + alpha * eAmk.adjoint() * eBkn.adjoint() ; + if (std::abs(beta) != 0.0) + eCmn = beta * eCmn + alpha * eAmk.adjoint() * eBkn.adjoint() ; + else + eCmn = alpha * eAmk.adjoint() * eBkn.adjoint() ; + } ); + } else if ( (OpA == GridBLAS_OP_T ) && (OpB == GridBLAS_OP_T) ) { + thread_for (p, batchCount, { + Eigen::Map eAmk(Amk[p],k,m); + Eigen::Map eBkn(Bkn[p],n,k); + Eigen::Map eCmn(Cmn[p],m,n); + if (std::abs(beta) != 0.0) + eCmn = beta * eCmn + alpha * eAmk.transpose() * eBkn.transpose() ; + else + eCmn = alpha * eAmk.transpose() * eBkn.transpose() ; } ); } else { assert(0); @@ -414,8 +453,8 @@ public: RealD t2=usecond(); int32_t batchCount = Amk.size(); - assert(OpA!=GridBLAS_OP_T); // Complex case expect no transpose - assert(OpB!=GridBLAS_OP_T); + //assert(OpA!=GridBLAS_OP_T); // Complex case expect no transpose + //assert(OpB!=GridBLAS_OP_T); int lda = m; // m x k column major int ldb = k; // k x n column major @@ -514,28 +553,70 @@ public: Eigen::Map eAmk(Amk[p],m,k); Eigen::Map eBkn(Bkn[p],k,n); Eigen::Map eCmn(Cmn[p],m,n); - eCmn = beta * eCmn + alpha * eAmk * eBkn ; + if (std::abs(beta) != 0.0) + eCmn = beta * eCmn + alpha * eAmk * eBkn ; + else + eCmn = alpha * eAmk * eBkn ; }); } else if ( (OpA == GridBLAS_OP_C ) && (OpB == GridBLAS_OP_N) ) { thread_for (p, batchCount, { Eigen::Map eAmk(Amk[p],k,m); Eigen::Map eBkn(Bkn[p],k,n); Eigen::Map eCmn(Cmn[p],m,n); - eCmn = beta * eCmn + alpha * eAmk.adjoint() * eBkn ; + if (std::abs(beta) != 0.0) + eCmn = beta * eCmn + alpha * eAmk.adjoint() * eBkn ; + else + eCmn = alpha * eAmk.adjoint() * eBkn ; + }); + } else if ( (OpA == GridBLAS_OP_T ) && (OpB == GridBLAS_OP_N) ) { + thread_for (p, batchCount, { + Eigen::Map eAmk(Amk[p],k,m); + Eigen::Map eBkn(Bkn[p],k,n); + Eigen::Map eCmn(Cmn[p],m,n); + if (std::abs(beta) != 0.0) + eCmn = beta * eCmn + alpha * eAmk.transpose() * eBkn ; + else + eCmn = alpha * eAmk.transpose() * eBkn ; }); } else if ( (OpA == GridBLAS_OP_N ) && (OpB == GridBLAS_OP_C) ) { thread_for (p, batchCount, { Eigen::Map eAmk(Amk[p],m,k); Eigen::Map eBkn(Bkn[p],n,k); Eigen::Map eCmn(Cmn[p],m,n); - eCmn = beta * eCmn + alpha * eAmk * eBkn.adjoint() ; + if (std::abs(beta) != 0.0) + eCmn = beta * eCmn + alpha * eAmk * eBkn.adjoint() ; + else + eCmn = alpha * eAmk * eBkn.adjoint() ; + }); + } else if ( (OpA == GridBLAS_OP_N ) && (OpB == GridBLAS_OP_T) ) { + thread_for (p, batchCount, { + Eigen::Map eAmk(Amk[p],m,k); + Eigen::Map eBkn(Bkn[p],n,k); + Eigen::Map eCmn(Cmn[p],m,n); + if (std::abs(beta) != 0.0) + eCmn = beta * eCmn + alpha * eAmk * eBkn.transpose() ; + else + eCmn = alpha * eAmk * eBkn.transpose() ; }); } else if ( (OpA == GridBLAS_OP_C ) && (OpB == GridBLAS_OP_C) ) { thread_for (p, batchCount, { Eigen::Map eAmk(Amk[p],k,m); Eigen::Map eBkn(Bkn[p],n,k); Eigen::Map eCmn(Cmn[p],m,n); - eCmn = beta * eCmn + alpha * eAmk.adjoint() * eBkn.adjoint() ; + if (std::abs(beta) != 0.0) + eCmn = beta * eCmn + alpha * eAmk.adjoint() * eBkn.adjoint() ; + else + eCmn = alpha * eAmk.adjoint() * eBkn.adjoint() ; + } ); + } else if ( (OpA == GridBLAS_OP_T ) && (OpB == GridBLAS_OP_T) ) { + thread_for (p, batchCount, { + Eigen::Map eAmk(Amk[p],k,m); + Eigen::Map eBkn(Bkn[p],n,k); + Eigen::Map eCmn(Cmn[p],m,n); + if (std::abs(beta) != 0.0) + eCmn = beta * eCmn + alpha * eAmk.transpose() * eBkn.transpose() ; + else + eCmn = alpha * eAmk.transpose() * eBkn.transpose() ; } ); } else { assert(0); @@ -661,29 +742,41 @@ public: Eigen::Map eAmk(Amk[p],m,k); Eigen::Map eBkn(Bkn[p],k,n); Eigen::Map eCmn(Cmn[p],m,n); - eCmn = beta * eCmn + alpha * eAmk * eBkn ; + if (std::abs(beta) != 0.0) + eCmn = beta * eCmn + alpha * eAmk * eBkn ; + else + eCmn = alpha * eAmk * eBkn ; }); } else if ( (OpA == GridBLAS_OP_T ) && (OpB == GridBLAS_OP_N) ) { thread_for (p, batchCount, { Eigen::Map eAmk(Amk[p],k,m); Eigen::Map eBkn(Bkn[p],k,n); Eigen::Map eCmn(Cmn[p],m,n); - eCmn = beta * eCmn + alpha * eAmk.transpose() * eBkn ; + if (std::abs(beta) != 0.0) + eCmn = beta * eCmn + alpha * eAmk.transpose() * eBkn ; + else + eCmn = alpha * eAmk.transpose() * eBkn ; }); } else if ( (OpA == GridBLAS_OP_N ) && (OpB == GridBLAS_OP_T) ) { thread_for (p, batchCount, { Eigen::Map eAmk(Amk[p],m,k); Eigen::Map eBkn(Bkn[p],n,k); Eigen::Map eCmn(Cmn[p],m,n); - eCmn = beta * eCmn + alpha * eAmk * eBkn.transpose() ; + if (std::abs(beta) != 0.0) + eCmn = beta * eCmn + alpha * eAmk * eBkn.transpose() ; + else + eCmn = alpha * eAmk * eBkn.transpose() ; }); } else if ( (OpA == GridBLAS_OP_T ) && (OpB == GridBLAS_OP_T) ) { thread_for (p, batchCount, { Eigen::Map eAmk(Amk[p],k,m); Eigen::Map eBkn(Bkn[p],n,k); Eigen::Map eCmn(Cmn[p],m,n); - eCmn = beta * eCmn + alpha * eAmk.transpose() * eBkn.transpose() ; - } ); + if (std::abs(beta) != 0.0) + eCmn = beta * eCmn + alpha * eAmk.transpose() * eBkn.transpose() ; + else + eCmn = alpha * eAmk.transpose() * eBkn.transpose() ; + }); } else { assert(0); } @@ -809,28 +902,40 @@ public: Eigen::Map eAmk(Amk[p],m,k); Eigen::Map eBkn(Bkn[p],k,n); Eigen::Map eCmn(Cmn[p],m,n); - eCmn = beta * eCmn + alpha * eAmk * eBkn ; + if (std::abs(beta) != 0.0) + eCmn = beta * eCmn + alpha * eAmk * eBkn ; + else + eCmn = alpha * eAmk * eBkn ; }); } else if ( (OpA == GridBLAS_OP_T ) && (OpB == GridBLAS_OP_N) ) { thread_for (p, batchCount, { Eigen::Map eAmk(Amk[p],k,m); Eigen::Map eBkn(Bkn[p],k,n); Eigen::Map eCmn(Cmn[p],m,n); - eCmn = beta * eCmn + alpha * eAmk.transpose() * eBkn ; + if (std::abs(beta) != 0.0) + eCmn = beta * eCmn + alpha * eAmk.transpose() * eBkn ; + else + eCmn = alpha * eAmk.transpose() * eBkn ; }); } else if ( (OpA == GridBLAS_OP_N ) && (OpB == GridBLAS_OP_T) ) { thread_for (p, batchCount, { Eigen::Map eAmk(Amk[p],m,k); Eigen::Map eBkn(Bkn[p],n,k); Eigen::Map eCmn(Cmn[p],m,n); - eCmn = beta * eCmn + alpha * eAmk * eBkn.transpose() ; + if (std::abs(beta) != 0.0) + eCmn = beta * eCmn + alpha * eAmk * eBkn.transpose() ; + else + eCmn = alpha * eAmk * eBkn.transpose() ; }); } else if ( (OpA == GridBLAS_OP_T ) && (OpB == GridBLAS_OP_T) ) { thread_for (p, batchCount, { Eigen::Map eAmk(Amk[p],k,m); Eigen::Map eBkn(Bkn[p],n,k); Eigen::Map eCmn(Cmn[p],m,n); - eCmn = beta * eCmn + alpha * eAmk.transpose() * eBkn.transpose() ; + if (std::abs(beta) != 0.0) + eCmn = beta * eCmn + alpha * eAmk.transpose() * eBkn.transpose() ; + else + eCmn = alpha * eAmk.transpose() * eBkn.transpose() ; }); } else { assert(0);