1
0
mirror of https://github.com/paboyle/Grid.git synced 2024-11-14 01:35:36 +00:00

OneMKL batched blas starting

This commit is contained in:
Peter Boyle 2024-03-05 23:58:20 +00:00
parent 30228214f7
commit 21bc8c24df

View File

@ -34,9 +34,14 @@ Author: Peter Boyle <pboyle@bnl.gov>
#include <hipblas/hipblas.h> #include <hipblas/hipblas.h>
#endif #endif
#ifdef GRID_SYCL #ifdef GRID_SYCL
#error // need oneMKL version #include <oneapi/mkl.hpp>
#endif
#if 0
#define GRID_ONE_MKL
#endif
#ifdef GRID_ONE_MKL
#include <oneapi/mkl.hpp>
#endif #endif
/////////////////////////////////////////////////////////////////////// ///////////////////////////////////////////////////////////////////////
// Need to rearrange lattice data to be in the right format for a // Need to rearrange lattice data to be in the right format for a
// batched multiply. Might as well make these static, dense packed // batched multiply. Might as well make these static, dense packed
@ -49,9 +54,12 @@ NAMESPACE_BEGIN(Grid);
typedef cudablasHandle_t gridblasHandle_t; typedef cudablasHandle_t gridblasHandle_t;
#endif #endif
#ifdef GRID_SYCL #ifdef GRID_SYCL
typedef int32_t gridblasHandle_t; typedef cl::sycl::queue *gridblasHandle_t;
#endif #endif
#if !defined(GRID_SYCL) && !defined(GRID_CUDA) && !defined(GRID_HIP) #ifdef GRID_ONE_MKL
typedef cl::sycl::queue *gridblasHandle_t;
#endif
#if !defined(GRID_SYCL) && !defined(GRID_CUDA) && !defined(GRID_HIP) && !defined(GRID_ONE_MKL)
typedef int32_t gridblasHandle_t; typedef int32_t gridblasHandle_t;
#endif #endif
@ -76,6 +84,12 @@ public:
hipblasCreate(&gridblasHandle); hipblasCreate(&gridblasHandle);
#endif #endif
#ifdef GRID_SYCL #ifdef GRID_SYCL
gridblasHandle = theGridAccelerator;
#endif
#ifdef GRID_ONE_MKL
cl::sycl::cpu_selector selector;
cl::sycl::device selectedDevice { selector };
gridblasHandle =new sycl::queue (selectedDevice);
#endif #endif
gridblasInit=1; gridblasInit=1;
} }
@ -110,6 +124,9 @@ public:
#endif #endif
#ifdef GRID_SYCL #ifdef GRID_SYCL
accelerator_barrier(); accelerator_barrier();
#endif
#ifdef GRID_ONE_MKL
gridblasHandle->wait();
#endif #endif
} }
@ -644,10 +661,19 @@ public:
(cuDoubleComplex *) Cmn, ldc, sdc, (cuDoubleComplex *) Cmn, ldc, sdc,
batchCount); batchCount);
#endif #endif
#ifdef GRID_SYCL #if defined(GRID_SYCL) || defined(GRID_ONE_MKL)
#warning "oneMKL implementation not made " oneapi::mkl::blas::column_major::gemm_batch(*gridblasHandle,
oneapi::mkl::transpose::N,
oneapi::mkl::transpose::N,
m,n,k,
alpha,
(const ComplexD *)Amk,lda,sda,
(const ComplexD *)Bkn,ldb,sdb,
beta,
(ComplexD *)Cmn,ldc,sdc,
batchCount);
#endif #endif
#if !defined(GRID_SYCL) && !defined(GRID_CUDA) && !defined(GRID_HIP) #if !defined(GRID_SYCL) && !defined(GRID_CUDA) && !defined(GRID_HIP) && !defined(GRID_ONE_MKL)
// Need a default/reference implementation // Need a default/reference implementation
for (int p = 0; p < batchCount; ++p) { for (int p = 0; p < batchCount; ++p) {
for (int mm = 0; mm < m; ++mm) { for (int mm = 0; mm < m; ++mm) {