mirror of
https://github.com/paboyle/Grid.git
synced 2024-11-09 23:45:36 +00:00
Batched SGEMM/DGEMM/ZGEMM/CGEMM
Hip, Cuda version and vanilla CPU One MKL stub in comments, to be tested as different.
This commit is contained in:
parent
48d1f0df89
commit
dfa617c439
@ -73,7 +73,6 @@ public:
|
|||||||
hipblasCreate(&gridblasHandle);
|
hipblasCreate(&gridblasHandle);
|
||||||
#endif
|
#endif
|
||||||
#ifdef GRID_SYCL
|
#ifdef GRID_SYCL
|
||||||
#error
|
|
||||||
#endif
|
#endif
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@ -82,19 +81,19 @@ public:
|
|||||||
GridBLAS() { Init(); };
|
GridBLAS() { Init(); };
|
||||||
~GridBLAS() { };
|
~GridBLAS() { };
|
||||||
|
|
||||||
/////////////////////////////////////////////////////////////////////////////////////
|
/////////////////////////////////////////////////////////////////////////////////////
|
||||||
// BLAS GEMM conventions:
|
// BLAS GEMM conventions:
|
||||||
/////////////////////////////////////////////////////////////////////////////////////
|
/////////////////////////////////////////////////////////////////////////////////////
|
||||||
// - C = alpha A * B + beta C
|
// - C = alpha A * B + beta C
|
||||||
// Dimensions:
|
// Dimensions:
|
||||||
// - C_m.n
|
// - C_m.n
|
||||||
// - A_m.k
|
// - A_m.k
|
||||||
// - B_k.n
|
// - B_k.n
|
||||||
// - Flops = 8 M N K
|
// - Flops = 8 M N K
|
||||||
// - Bytes = 2*sizeof(word) * (MN+MK+KN)
|
// - Bytes = 2*sizeof(word) * (MN+MK+KN)
|
||||||
// M=60, N=12
|
// M=60, N=12
|
||||||
// Flop/Byte = 8 . 60.60.12 / (60.12+60.60+60.12)/16 = 4 so expect about 4 TF/s on a GCD
|
// Flop/Byte = 8 . 60.60.12 / (60.12+60.60+60.12)/16 = 4 so expect about 4 TF/s on a GCD
|
||||||
/////////////////////////////////////////////////////////////////////////////////////
|
/////////////////////////////////////////////////////////////////////////////////////
|
||||||
void synchronise(void)
|
void synchronise(void)
|
||||||
{
|
{
|
||||||
#ifdef GRID_HIP
|
#ifdef GRID_HIP
|
||||||
@ -158,10 +157,10 @@ public:
|
|||||||
acceleratorCopyToDevice((void *)&alpha,(void *)&alpha_p[0],sizeof(ComplexD));
|
acceleratorCopyToDevice((void *)&alpha,(void *)&alpha_p[0],sizeof(ComplexD));
|
||||||
acceleratorCopyToDevice((void *)&beta ,(void *)&beta_p[0],sizeof(ComplexD));
|
acceleratorCopyToDevice((void *)&beta ,(void *)&beta_p[0],sizeof(ComplexD));
|
||||||
RealD t0=usecond();
|
RealD t0=usecond();
|
||||||
#ifdef GRID_HIP
|
// std::cout << "hipblasZgemmBatched mnk "<<m<<","<<n<<","<<k<<" count "<<batchCount<<std::endl;
|
||||||
std::cout << "hipblasZgemmBatched mnk "<<m<<","<<n<<","<<k<<" count "<<batchCount<<std::endl;
|
|
||||||
assert(Bkn.size()==batchCount);
|
assert(Bkn.size()==batchCount);
|
||||||
assert(Cmn.size()==batchCount);
|
assert(Cmn.size()==batchCount);
|
||||||
|
#ifdef GRID_HIP
|
||||||
auto err = hipblasZgemmBatched(gridblasHandle,
|
auto err = hipblasZgemmBatched(gridblasHandle,
|
||||||
HIPBLAS_OP_N,
|
HIPBLAS_OP_N,
|
||||||
HIPBLAS_OP_N,
|
HIPBLAS_OP_N,
|
||||||
@ -174,13 +173,23 @@ public:
|
|||||||
batchCount);
|
batchCount);
|
||||||
// std::cout << " hipblas return code " <<(int)err<<std::endl;
|
// std::cout << " hipblas return code " <<(int)err<<std::endl;
|
||||||
assert(err==HIPBLAS_STATUS_SUCCESS);
|
assert(err==HIPBLAS_STATUS_SUCCESS);
|
||||||
synchronise();
|
|
||||||
#endif
|
#endif
|
||||||
#ifdef GRID_CUDA
|
#ifdef GRID_CUDA
|
||||||
#error "CUDA implemenetation "
|
auto err = cublasZgemmBatched(gridblasHandle,
|
||||||
|
CUBLAS_OP_N,
|
||||||
|
CUBLAS_OP_N,
|
||||||
|
m,n,k,
|
||||||
|
(cuDoubleComplex *) &alpha_p[0],
|
||||||
|
(cuDoubleComplex **)&Amk[0], lda,
|
||||||
|
(cuDoubleComplex **)&Bkn[0], ldb,
|
||||||
|
(cuDoubleComplex *) &beta_p[0],
|
||||||
|
(cuDoubleComplex **)&Cmn[0], ldc,
|
||||||
|
batchCount);
|
||||||
|
assert(err==CUBLAS_STATUS_SUCCESS);
|
||||||
#endif
|
#endif
|
||||||
#ifdef GRID_SYCL
|
#ifdef GRID_SYCL
|
||||||
#error "oneMKL implemenetation "
|
//MKL’s cblas_<T>gemm_batch & OneAPI
|
||||||
|
#warning "oneMKL implementation not built "
|
||||||
#endif
|
#endif
|
||||||
#if !defined(GRID_SYCL) && !defined(GRID_CUDA) && !defined(GRID_HIP)
|
#if !defined(GRID_SYCL) && !defined(GRID_CUDA) && !defined(GRID_HIP)
|
||||||
// Need a default/reference implementation
|
// Need a default/reference implementation
|
||||||
@ -195,16 +204,269 @@ public:
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
#endif
|
#endif
|
||||||
|
synchronise();
|
||||||
RealD t1=usecond();
|
RealD t1=usecond();
|
||||||
// std::cout << " hipblas synchronised " <<std::endl;
|
|
||||||
RealD flops = 8.0*m*n*k*batchCount;
|
RealD flops = 8.0*m*n*k*batchCount;
|
||||||
RealD bytes = 1.0*sizeof(ComplexD)*(m*k+k*n+m*n)*batchCount;
|
RealD bytes = 1.0*sizeof(ComplexD)*(m*k+k*n+m*n)*batchCount;
|
||||||
std::cout << " batched Blas copy "<<(t0-t2)/1.e3 <<" ms "<<std::endl;
|
std::cout <<GridLogPerformance<< " batched Blas copy "<<(t0-t2)/1.e3 <<" ms "<<std::endl;
|
||||||
std::cout << " batched Blas call "<<m<<","<<n<<","<<k<<" "<< flops/(t1-t0)/1.e3 <<" GF/s "<<(t1-t0)/1.e3<<" ms "<<std::endl;
|
std::cout <<GridLogPerformance<< " batched Blas call "<<m<<","<<n<<","<<k<<" "<< flops/(t1-t0)/1.e3 <<" GF/s "<<(t1-t0)/1.e3<<" ms "<<std::endl;
|
||||||
std::cout << " batched Blas call "<<m<<","<<n<<","<<k<<" "<< bytes/(t1-t0)/1.e3 <<" GB/s "<<(t1-t0)/1.e3<<" ms "<<std::endl;
|
std::cout <<GridLogPerformance<< " batched Blas call "<<m<<","<<n<<","<<k<<" "<< bytes/(t1-t0)/1.e3 <<" GB/s "<<(t1-t0)/1.e3<<" ms "<<std::endl;
|
||||||
|
|
||||||
}
|
}
|
||||||
|
|
||||||
|
void gemmBatched(int m,int n, int k,
|
||||||
|
ComplexF alpha,
|
||||||
|
deviceVector<ComplexF*> &Amk, // pointer list to matrices
|
||||||
|
deviceVector<ComplexF*> &Bkn,
|
||||||
|
ComplexF beta,
|
||||||
|
deviceVector<ComplexF*> &Cmn)
|
||||||
|
{
|
||||||
|
RealD t2=usecond();
|
||||||
|
int32_t batchCount = Amk.size();
|
||||||
|
// Use C-row major storage, so transpose calls
|
||||||
|
int lda = m; // m x k column major
|
||||||
|
int ldb = k; // k x n column major
|
||||||
|
int ldc = m; // m x b column major
|
||||||
|
static deviceVector<ComplexF> alpha_p(1);
|
||||||
|
static deviceVector<ComplexF> beta_p(1);
|
||||||
|
// can prestore the 1 and the zero on device
|
||||||
|
acceleratorCopyToDevice((void *)&alpha,(void *)&alpha_p[0],sizeof(ComplexF));
|
||||||
|
acceleratorCopyToDevice((void *)&beta ,(void *)&beta_p[0],sizeof(ComplexF));
|
||||||
|
RealD t0=usecond();
|
||||||
|
// std::cout << "hipblasZgemmBatched mnk "<<m<<","<<n<<","<<k<<" count "<<batchCount<<std::endl;
|
||||||
|
assert(Bkn.size()==batchCount);
|
||||||
|
assert(Cmn.size()==batchCount);
|
||||||
|
#ifdef GRID_HIP
|
||||||
|
auto err = hipblasCgemmBatched(gridblasHandle,
|
||||||
|
HIPBLAS_OP_N,
|
||||||
|
HIPBLAS_OP_N,
|
||||||
|
m,n,k,
|
||||||
|
(hipblasComplex *) &alpha_p[0],
|
||||||
|
(hipblasComplex **)&Amk[0], lda,
|
||||||
|
(hipblasComplex **)&Bkn[0], ldb,
|
||||||
|
(hipblasComplex *) &beta_p[0],
|
||||||
|
(hipblasComplex **)&Cmn[0], ldc,
|
||||||
|
batchCount);
|
||||||
|
// std::cout << " hipblas return code " <<(int)err<<std::endl;
|
||||||
|
assert(err==HIPBLAS_STATUS_SUCCESS);
|
||||||
|
#endif
|
||||||
|
#ifdef GRID_CUDA
|
||||||
|
auto err = cublasCgemmBatched(gridblasHandle,
|
||||||
|
CUBLAS_OP_N,
|
||||||
|
CUBLAS_OP_N,
|
||||||
|
m,n,k,
|
||||||
|
(cuComplex *) &alpha_p[0],
|
||||||
|
(cuComplex **)&Amk[0], lda,
|
||||||
|
(cuComplex **)&Bkn[0], ldb,
|
||||||
|
(cuComplex *) &beta_p[0],
|
||||||
|
(cuComplex **)&Cmn[0], ldc,
|
||||||
|
batchCount);
|
||||||
|
assert(err==CUBLAS_STATUS_SUCCESS);
|
||||||
|
#endif
|
||||||
|
#ifdef GRID_SYCL
|
||||||
|
//MKL’s cblas_<T>gemm_batch & OneAPI
|
||||||
|
#warning "oneMKL implementation not built "
|
||||||
|
#endif
|
||||||
|
#if !defined(GRID_SYCL) && !defined(GRID_CUDA) && !defined(GRID_HIP)
|
||||||
|
// Need a default/reference implementation
|
||||||
|
for (int p = 0; p < batchCount; ++p) {
|
||||||
|
for (int mm = 0; mm < m; ++mm) {
|
||||||
|
for (int nn = 0; nn < n; ++nn) {
|
||||||
|
ComplexD c_mn(0.0);
|
||||||
|
for (int kk = 0; kk < k, ++kk)
|
||||||
|
c_mn += Amk[mm + kk*lda + p*sda] * Bkn[kk + nn*ldb + p*sdb];
|
||||||
|
Cmn[mm + nn*ldc + p*sdc] = (*alpha_p)*c_mn + (*beta_p)*Cmn[mm + nn*ldc + p*sdc];
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
#endif
|
||||||
|
synchronise();
|
||||||
|
RealD t1=usecond();
|
||||||
|
RealD flops = 8.0*m*n*k*batchCount;
|
||||||
|
RealD bytes = 1.0*sizeof(ComplexF)*(m*k+k*n+m*n)*batchCount;
|
||||||
|
std::cout <<GridLogPerformance<< " batched Blas copy "<<(t0-t2)/1.e3 <<" ms "<<std::endl;
|
||||||
|
std::cout <<GridLogPerformance<< " batched Blas call "<<m<<","<<n<<","<<k<<" "<< flops/(t1-t0)/1.e3 <<" GF/s "<<(t1-t0)/1.e3<<" ms "<<std::endl;
|
||||||
|
std::cout <<GridLogPerformance<< " batched Blas call "<<m<<","<<n<<","<<k<<" "<< bytes/(t1-t0)/1.e3 <<" GB/s "<<(t1-t0)/1.e3<<" ms "<<std::endl;
|
||||||
|
}
|
||||||
|
|
||||||
|
///////////////////////////////////////////////////////////////////////////
|
||||||
|
// Single precision real GEMM
|
||||||
|
///////////////////////////////////////////////////////////////////////////
|
||||||
|
|
||||||
|
void gemmBatched(int m,int n, int k,
|
||||||
|
RealF alpha,
|
||||||
|
deviceVector<RealF*> &Amk, // pointer list to matrices
|
||||||
|
deviceVector<RealF*> &Bkn,
|
||||||
|
RealF beta,
|
||||||
|
deviceVector<RealF*> &Cmn)
|
||||||
|
{
|
||||||
|
RealD t2=usecond();
|
||||||
|
int32_t batchCount = Amk.size();
|
||||||
|
// Use C-row major storage, so transpose calls
|
||||||
|
int lda = m; // m x k column major
|
||||||
|
int ldb = k; // k x n column major
|
||||||
|
int ldc = m; // m x b column major
|
||||||
|
static deviceVector<RealF> alpha_p(1);
|
||||||
|
static deviceVector<RealF> beta_p(1);
|
||||||
|
// can prestore the 1 and the zero on device
|
||||||
|
acceleratorCopyToDevice((void *)&alpha,(void *)&alpha_p[0],sizeof(RealF));
|
||||||
|
acceleratorCopyToDevice((void *)&beta ,(void *)&beta_p[0],sizeof(RealF));
|
||||||
|
RealD t0=usecond();
|
||||||
|
// std::cout << "hipblasZgemmBatched mnk "<<m<<","<<n<<","<<k<<" count "<<batchCount<<std::endl;
|
||||||
|
assert(Bkn.size()==batchCount);
|
||||||
|
assert(Cmn.size()==batchCount);
|
||||||
|
#ifdef GRID_HIP
|
||||||
|
auto err = hipblasSgemmBatched(gridblasHandle,
|
||||||
|
HIPBLAS_OP_N,
|
||||||
|
HIPBLAS_OP_N,
|
||||||
|
m,n,k,
|
||||||
|
(float *) &alpha_p[0],
|
||||||
|
(float **)&Amk[0], lda,
|
||||||
|
(float **)&Bkn[0], ldb,
|
||||||
|
(float *) &beta_p[0],
|
||||||
|
(float **)&Cmn[0], ldc,
|
||||||
|
batchCount);
|
||||||
|
assert(err==HIPBLAS_STATUS_SUCCESS);
|
||||||
|
#endif
|
||||||
|
#ifdef GRID_CUDA
|
||||||
|
auto err = cublasSgemmBatched(gridblasHandle,
|
||||||
|
CUBLAS_OP_N,
|
||||||
|
CUBLAS_OP_N,
|
||||||
|
m,n,k,
|
||||||
|
(float *) &alpha_p[0],
|
||||||
|
(float **)&Amk[0], lda,
|
||||||
|
(float **)&Bkn[0], ldb,
|
||||||
|
(float *) &beta_p[0],
|
||||||
|
(float **)&Cmn[0], ldc,
|
||||||
|
batchCount);
|
||||||
|
assert(err==CUBLAS_STATUS_SUCCESS);
|
||||||
|
#endif
|
||||||
|
#ifdef GRID_SYCL
|
||||||
|
//MKL’s cblas_<T>gemm_batch & OneAPI
|
||||||
|
#warning "oneMKL implementation not built "
|
||||||
|
#endif
|
||||||
|
#if !defined(GRID_SYCL) && !defined(GRID_CUDA) && !defined(GRID_HIP)
|
||||||
|
// Need a default/reference implementation
|
||||||
|
for (int p = 0; p < batchCount; ++p) {
|
||||||
|
for (int mm = 0; mm < m; ++mm) {
|
||||||
|
for (int nn = 0; nn < n; ++nn) {
|
||||||
|
RealD c_mn(0.0);
|
||||||
|
for (int kk = 0; kk < k, ++kk)
|
||||||
|
c_mn += Amk[mm + kk*lda + p*sda] * Bkn[kk + nn*ldb + p*sdb];
|
||||||
|
Cmn[mm + nn*ldc + p*sdc] = (*alpha_p)*c_mn + (*beta_p)*Cmn[mm + nn*ldc + p*sdc];
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
#endif
|
||||||
|
synchronise();
|
||||||
|
RealD t1=usecond();
|
||||||
|
RealD flops = 8.0*m*n*k*batchCount;
|
||||||
|
RealD bytes = 1.0*sizeof(RealF)*(m*k+k*n+m*n)*batchCount;
|
||||||
|
std::cout <<GridLogPerformance<< " batched Blas copy "<<(t0-t2)/1.e3 <<" ms "<<std::endl;
|
||||||
|
std::cout <<GridLogPerformance<< " batched Blas call "<<m<<","<<n<<","<<k<<" "<< flops/(t1-t0)/1.e3 <<" GF/s "<<(t1-t0)/1.e3<<" ms "<<std::endl;
|
||||||
|
std::cout <<GridLogPerformance<< " batched Blas call "<<m<<","<<n<<","<<k<<" "<< bytes/(t1-t0)/1.e3 <<" GB/s "<<(t1-t0)/1.e3<<" ms "<<std::endl;
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
///////////////////////////////////////////////////////////////////////////
|
||||||
|
// Double precision real GEMM
|
||||||
|
///////////////////////////////////////////////////////////////////////////
|
||||||
|
|
||||||
|
void gemmBatched(int m,int n, int k,
|
||||||
|
RealD alpha,
|
||||||
|
deviceVector<RealD*> &Amk, // pointer list to matrices
|
||||||
|
deviceVector<RealD*> &Bkn,
|
||||||
|
RealD beta,
|
||||||
|
deviceVector<RealD*> &Cmn)
|
||||||
|
{
|
||||||
|
RealD t2=usecond();
|
||||||
|
int32_t batchCount = Amk.size();
|
||||||
|
// Use C-row major storage, so transpose calls
|
||||||
|
int lda = m; // m x k column major
|
||||||
|
int ldb = k; // k x n column major
|
||||||
|
int ldc = m; // m x b column major
|
||||||
|
static deviceVector<RealD> alpha_p(1);
|
||||||
|
static deviceVector<RealD> beta_p(1);
|
||||||
|
// can prestore the 1 and the zero on device
|
||||||
|
acceleratorCopyToDevice((void *)&alpha,(void *)&alpha_p[0],sizeof(RealD));
|
||||||
|
acceleratorCopyToDevice((void *)&beta ,(void *)&beta_p[0],sizeof(RealD));
|
||||||
|
RealD t0=usecond();
|
||||||
|
// std::cout << "hipblasZgemmBatched mnk "<<m<<","<<n<<","<<k<<" count "<<batchCount<<std::endl;
|
||||||
|
assert(Bkn.size()==batchCount);
|
||||||
|
assert(Cmn.size()==batchCount);
|
||||||
|
#ifdef GRID_HIP
|
||||||
|
auto err = hipblasDgemmBatched(gridblasHandle,
|
||||||
|
HIPBLAS_OP_N,
|
||||||
|
HIPBLAS_OP_N,
|
||||||
|
m,n,k,
|
||||||
|
(double *) &alpha_p[0],
|
||||||
|
(double **)&Amk[0], lda,
|
||||||
|
(double **)&Bkn[0], ldb,
|
||||||
|
(double *) &beta_p[0],
|
||||||
|
(double **)&Cmn[0], ldc,
|
||||||
|
batchCount);
|
||||||
|
assert(err==HIPBLAS_STATUS_SUCCESS);
|
||||||
|
#endif
|
||||||
|
#ifdef GRID_CUDA
|
||||||
|
auto err = cublasDgemmBatched(gridblasHandle,
|
||||||
|
CUBLAS_OP_N,
|
||||||
|
CUBLAS_OP_N,
|
||||||
|
m,n,k,
|
||||||
|
(double *) &alpha_p[0],
|
||||||
|
(double **)&Amk[0], lda,
|
||||||
|
(double **)&Bkn[0], ldb,
|
||||||
|
(double *) &beta_p[0],
|
||||||
|
(double **)&Cmn[0], ldc,
|
||||||
|
batchCount);
|
||||||
|
assert(err==CUBLAS_STATUS_SUCCESS);
|
||||||
|
#endif
|
||||||
|
#ifdef GRID_SYCL
|
||||||
|
/*
|
||||||
|
int64_t m64=m;
|
||||||
|
int64_t n64=n;
|
||||||
|
int64_t k64=k;
|
||||||
|
int64_t batchCount64=batchCount;
|
||||||
|
oneapi::mkl::blas::column_major::gemm_batch(*theGridAccelerator,
|
||||||
|
onemkl::transpose::N,
|
||||||
|
onemkl::transpose::N,
|
||||||
|
&m64,&n64,&k64,
|
||||||
|
(double *) &alpha_p[0],
|
||||||
|
(double **)&Amk[0], lda,
|
||||||
|
(double **)&Bkn[0], ldb,
|
||||||
|
(double *) &beta_p[0],
|
||||||
|
(double **)&Cmn[0], ldc,
|
||||||
|
1,&batchCount64);
|
||||||
|
*/
|
||||||
|
//MKL’s cblas_<T>gemm_batch & OneAPI
|
||||||
|
#warning "oneMKL implementation not built "
|
||||||
|
#endif
|
||||||
|
#if !defined(GRID_SYCL) && !defined(GRID_CUDA) && !defined(GRID_HIP)
|
||||||
|
// Need a default/reference implementation
|
||||||
|
for (int p = 0; p < batchCount; ++p) {
|
||||||
|
for (int mm = 0; mm < m; ++mm) {
|
||||||
|
for (int nn = 0; nn < n; ++nn) {
|
||||||
|
RealD c_mn(0.0);
|
||||||
|
for (int kk = 0; kk < k, ++kk)
|
||||||
|
c_mn += Amk[mm + kk*lda + p*sda] * Bkn[kk + nn*ldb + p*sdb];
|
||||||
|
Cmn[mm + nn*ldc + p*sdc] = (*alpha_p)*c_mn + (*beta_p)*Cmn[mm + nn*ldc + p*sdc];
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
#endif
|
||||||
|
synchronise();
|
||||||
|
RealD t1=usecond();
|
||||||
|
RealD flops = 8.0*m*n*k*batchCount;
|
||||||
|
RealD bytes = 1.0*sizeof(RealD)*(m*k+k*n+m*n)*batchCount;
|
||||||
|
std::cout <<GridLogPerformance<< " batched Blas copy "<<(t0-t2)/1.e3 <<" ms "<<std::endl;
|
||||||
|
std::cout <<GridLogPerformance<< " batched Blas call "<<m<<","<<n<<","<<k<<" "<< flops/(t1-t0)/1.e3 <<" GF/s "<<(t1-t0)/1.e3<<" ms "<<std::endl;
|
||||||
|
std::cout <<GridLogPerformance<< " batched Blas call "<<m<<","<<n<<","<<k<<" "<< bytes/(t1-t0)/1.e3 <<" GB/s "<<(t1-t0)/1.e3<<" ms "<<std::endl;
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
////////////////////////////////////////////////////////////////////////////////////////////////
|
||||||
|
// Strided case used by benchmark, but generally unused in Grid
|
||||||
|
// Keep a code example in double complex, but don't generate the single and real variants for now
|
||||||
|
////////////////////////////////////////////////////////////////////////////////////////////////
|
||||||
|
|
||||||
void gemmStridedBatched(int m,int n, int k,
|
void gemmStridedBatched(int m,int n, int k,
|
||||||
ComplexD alpha,
|
ComplexD alpha,
|
||||||
@ -225,11 +487,10 @@ public:
|
|||||||
deviceVector<ComplexD> beta_p(1);
|
deviceVector<ComplexD> beta_p(1);
|
||||||
acceleratorCopyToDevice((void *)&alpha,(void *)&alpha_p[0],sizeof(ComplexD));
|
acceleratorCopyToDevice((void *)&alpha,(void *)&alpha_p[0],sizeof(ComplexD));
|
||||||
acceleratorCopyToDevice((void *)&beta ,(void *)&beta_p[0],sizeof(ComplexD));
|
acceleratorCopyToDevice((void *)&beta ,(void *)&beta_p[0],sizeof(ComplexD));
|
||||||
|
std::cout << "blasZgemmStridedBatched mnk "<<m<<","<<n<<","<<k<<" count "<<batchCount<<std::endl;
|
||||||
|
std::cout << "blasZgemmStridedBatched ld "<<lda<<","<<ldb<<","<<ldc<<std::endl;
|
||||||
|
std::cout << "blasZgemmStridedBatched sd "<<sda<<","<<sdb<<","<<sdc<<std::endl;
|
||||||
#ifdef GRID_HIP
|
#ifdef GRID_HIP
|
||||||
std::cout << "hipblasZgemmStridedBatched mnk "<<m<<","<<n<<","<<k<<" count "<<batchCount<<std::endl;
|
|
||||||
std::cout << "hipblasZgemmStridedBatched ld "<<lda<<","<<ldb<<","<<ldc<<std::endl;
|
|
||||||
std::cout << "hipblasZgemmStridedBatched sd "<<sda<<","<<sdb<<","<<sdc<<std::endl;
|
|
||||||
{
|
|
||||||
auto err = hipblasZgemmStridedBatched(gridblasHandle,
|
auto err = hipblasZgemmStridedBatched(gridblasHandle,
|
||||||
HIPBLAS_OP_N,
|
HIPBLAS_OP_N,
|
||||||
HIPBLAS_OP_N,
|
HIPBLAS_OP_N,
|
||||||
@ -240,26 +501,24 @@ public:
|
|||||||
(hipblasDoubleComplex *) &beta_p[0],
|
(hipblasDoubleComplex *) &beta_p[0],
|
||||||
(hipblasDoubleComplex *) Cmn, ldc, sdc,
|
(hipblasDoubleComplex *) Cmn, ldc, sdc,
|
||||||
batchCount);
|
batchCount);
|
||||||
std::cout << " hipblas return code " <<(int)err<<std::endl;
|
|
||||||
assert(err==HIPBLAS_STATUS_SUCCESS);
|
assert(err==HIPBLAS_STATUS_SUCCESS);
|
||||||
}
|
|
||||||
#endif
|
#endif
|
||||||
#ifdef GRID_CUDA
|
#ifdef GRID_CUDA
|
||||||
cublasZgemmStridedBatched(gridblasHandle,
|
cublasZgemmStridedBatched(gridblasHandle,
|
||||||
CUBLAS_OP_T,
|
CUBLAS_OP_N,
|
||||||
CUBLAS_OP_T,
|
CUBLAS_OP_N,
|
||||||
m,n,k,
|
m,n,k,
|
||||||
(cuDoubleComplex *)&alpha_p[0],
|
(cuDoubleComplex *) &alpha_p[0],
|
||||||
(cuDoubleComplex *) Amk, lda, sda,
|
(cuDoubleComplex *) Amk, lda, sda,
|
||||||
(cuDoubleComplex *) Bkn, ldb, sdb,
|
(cuDoubleComplex *) Bkn, ldb, sdb,
|
||||||
(cuDoubleComplex *)&beta_p[],
|
(cuDoubleComplex *) &beta_p[0],
|
||||||
(cuDoubleComplex *) Cmn, ldc, sdc,
|
(cuDoubleComplex *) Cmn, ldc, sdc,
|
||||||
batchCount);
|
batchCount);
|
||||||
#endif
|
#endif
|
||||||
#ifdef GRID_SYCL
|
#ifdef GRID_SYCL
|
||||||
#error "oneMKL implemenetation "
|
#warning "oneMKL implementation not made "
|
||||||
#endif
|
#endif
|
||||||
#if !defined(GRID_SYCL) && !defined(GRID_CUDA) && !defined(GRID_HIP)
|
#if !definte(GRID_SYCL) && !defined(GRID_CUDA) && !defined(GRID_HIP)
|
||||||
// Need a default/reference implementation
|
// Need a default/reference implementation
|
||||||
for (int p = 0; p < batchCount; ++p) {
|
for (int p = 0; p < batchCount; ++p) {
|
||||||
for (int mm = 0; mm < m; ++mm) {
|
for (int mm = 0; mm < m; ++mm) {
|
||||||
@ -273,6 +532,10 @@ public:
|
|||||||
}
|
}
|
||||||
#endif
|
#endif
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
};
|
};
|
||||||
|
|
||||||
NAMESPACE_END(Grid);
|
NAMESPACE_END(Grid);
|
||||||
|
Loading…
Reference in New Issue
Block a user