From 9c902e4c2d39c9df4edf5f2bd2e82301622642ba Mon Sep 17 00:00:00 2001 From: Peter Boyle Date: Thu, 11 Jul 2024 15:19:49 +0000 Subject: [PATCH] Batched blas, but not working yet on OneAPI --- Grid/algorithms/blas/BatchedBlas.h | 120 +++++++++++++++++++++++------ 1 file changed, 97 insertions(+), 23 deletions(-) diff --git a/Grid/algorithms/blas/BatchedBlas.h b/Grid/algorithms/blas/BatchedBlas.h index a7edb485..22353d49 100644 --- a/Grid/algorithms/blas/BatchedBlas.h +++ b/Grid/algorithms/blas/BatchedBlas.h @@ -89,9 +89,10 @@ public: gridblasHandle = theGridAccelerator; #endif #ifdef GRID_ONE_MKL - cl::sycl::cpu_selector selector; + cl::sycl::gpu_selector selector; cl::sycl::device selectedDevice { selector }; - gridblasHandle =new sycl::queue (selectedDevice); + cl::sycl::property_list q_prop{cl::sycl::property::queue::in_order()}; + gridblasHandle =new sycl::queue (selectedDevice,q_prop); #endif gridblasInit=1; } @@ -266,8 +267,46 @@ public: assert(err==CUBLAS_STATUS_SUCCESS); #endif #ifdef GRID_SYCL - //MKL’s cblas_gemm_batch & OneAPI -#warning "oneMKL implementation not built " + std::cerr << " Calling SYCL batched ZGEMM "<()); + synchronise(); + std::cerr << " Called SYCL batched ZGEMM "< A(m*k); // pointer list to matrices + std::vector B(k*n); + std::vector C(m*n); + int sda = lda*k; + int sdb = ldb*k; + int sdc = ldc*n; + for (int p = 0; p < 1; ++p) { + acceleratorCopyFromDevice((void *)&Amk[p][0],(void *)&A[0],m*k*sizeof(ComplexD)); + acceleratorCopyFromDevice((void *)&Bkn[p][0],(void *)&B[0],k*n*sizeof(ComplexD)); + acceleratorCopyFromDevice((void *)&Cmn[p][0],(void *)&C[0],m*n*sizeof(ComplexD)); + for (int mm = 0; mm < m; ++mm) { + for (int nn = 0; nn < n; ++nn) { + ComplexD c_mn(0.0); + for (int kk = 0; kk < k; ++kk) + c_mn += A[mm + kk*lda ] * B[kk + nn*ldb]; + std::cout << " beta "<gemm_batch & OneAPI -#warning "oneMKL implementation not built " + int64_t m64=m; + int64_t n64=n; + int64_t k64=k; + int64_t lda64=lda; + int64_t ldb64=ldb; + int64_t ldc64=ldc; + int64_t batchCount64=batchCount; + oneapi::mkl::transpose notransp =oneapi::mkl::transpose::N; + oneapi::mkl::blas::column_major::gemm_batch(*gridblasHandle, + ¬ransp, + ¬ransp, + &m64,&n64,&k64, + (ComplexF *) &alpha_p[0], + (const ComplexF **)&Amk[0], (const int64_t *)&lda64, + (const ComplexF **)&Bkn[0], (const int64_t *)&ldb64, + (ComplexF *) &beta_p[0], + (ComplexF **)&Cmn[0], (const int64_t *)&ldc64, + (int64_t)1,&batchCount64,std::vector()); + synchronise(); #endif #if !defined(GRID_SYCL) && !defined(GRID_CUDA) && !defined(GRID_HIP) int sda = lda*k; @@ -467,8 +522,25 @@ public: assert(err==CUBLAS_STATUS_SUCCESS); #endif #ifdef GRID_SYCL - //MKL’s cblas_gemm_batch & OneAPI -#warning "oneMKL implementation not built " + int64_t m64=m; + int64_t n64=n; + int64_t k64=k; + int64_t lda64=lda; + int64_t ldb64=ldb; + int64_t ldc64=ldc; + int64_t batchCount64=batchCount; + oneapi::mkl::transpose notransp =oneapi::mkl::transpose::N; + oneapi::mkl::blas::column_major::gemm_batch(*gridblasHandle, + ¬ransp, + ¬ransp, + &m64,&n64,&k64, + (float *) &alpha_p[0], + (const float **)&Amk[0], (const int64_t *)&lda64, + (const float **)&Bkn[0], (const int64_t *)&ldb64, + (float *) &beta_p[0], + (float **)&Cmn[0], (const int64_t *)&ldc64, + (int64_t)1,&batchCount64,std::vector()); + synchronise(); #endif #if !defined(GRID_SYCL) && !defined(GRID_CUDA) && !defined(GRID_HIP) int sda = lda*k; @@ -568,24 +640,25 @@ public: assert(err==CUBLAS_STATUS_SUCCESS); #endif #ifdef GRID_SYCL - /* int64_t m64=m; int64_t n64=n; int64_t k64=k; + int64_t lda64=lda; + int64_t ldb64=ldb; + int64_t ldc64=ldc; int64_t batchCount64=batchCount; - oneapi::mkl::blas::column_major::gemm_batch(*theGridAccelerator, - onemkl::transpose::N, - onemkl::transpose::N, - &m64,&n64,&k64, - (double *) &alpha_p[0], - (double **)&Amk[0], lda, - (double **)&Bkn[0], ldb, - (double *) &beta_p[0], - (double **)&Cmn[0], ldc, - 1,&batchCount64); - */ - //MKL’s cblas_gemm_batch & OneAPI -#warning "oneMKL implementation not built " + oneapi::mkl::transpose notransp =oneapi::mkl::transpose::N; + oneapi::mkl::blas::column_major::gemm_batch(*gridblasHandle, + ¬ransp, + ¬ransp, + &m64,&n64,&k64, + (double *) &alpha_p[0], + (const double **)&Amk[0], (const int64_t *)&lda64, + (const double **)&Bkn[0], (const int64_t *)&ldb64, + (double *) &beta_p[0], + (double **)&Cmn[0], (const int64_t *)&ldc64, + (int64_t)1,&batchCount64,std::vector()); + synchronise(); #endif #if !defined(GRID_SYCL) && !defined(GRID_CUDA) && !defined(GRID_HIP) int sda = lda*k; @@ -673,6 +746,7 @@ public: beta, (ComplexD *)Cmn,ldc,sdc, batchCount); + synchronise(); #endif #if !defined(GRID_SYCL) && !defined(GRID_CUDA) && !defined(GRID_HIP) && !defined(GRID_ONE_MKL) // Need a default/reference implementation