mirror of
https://github.com/paboyle/Grid.git
synced 2026-05-15 22:54:30 +01:00
adding a version check to handle rocblas type change
This commit is contained in:
@@ -28,6 +28,7 @@ Author: Peter Boyle <pboyle@bnl.gov>
|
|||||||
#pragma once
|
#pragma once
|
||||||
|
|
||||||
#ifdef GRID_HIP
|
#ifdef GRID_HIP
|
||||||
|
#include <hip/hip_version.h>
|
||||||
#include <hipblas/hipblas.h>
|
#include <hipblas/hipblas.h>
|
||||||
#endif
|
#endif
|
||||||
#ifdef GRID_CUDA
|
#ifdef GRID_CUDA
|
||||||
@@ -255,16 +256,29 @@ public:
|
|||||||
if ( OpB == GridBLAS_OP_N ) hOpB = HIPBLAS_OP_N;
|
if ( OpB == GridBLAS_OP_N ) hOpB = HIPBLAS_OP_N;
|
||||||
if ( OpB == GridBLAS_OP_T ) hOpB = HIPBLAS_OP_T;
|
if ( OpB == GridBLAS_OP_T ) hOpB = HIPBLAS_OP_T;
|
||||||
if ( OpB == GridBLAS_OP_C ) hOpB = HIPBLAS_OP_C;
|
if ( OpB == GridBLAS_OP_C ) hOpB = HIPBLAS_OP_C;
|
||||||
|
#if defined(HIP_VERSION_MAJOR) && (HIP_VERSION_MAJOR >=7)
|
||||||
auto err = hipblasZgemmBatched(gridblasHandle,
|
auto err = hipblasZgemmBatched(gridblasHandle,
|
||||||
hOpA,
|
hOpA,
|
||||||
hOpB,
|
hOpB,
|
||||||
m,n,k,
|
m,n,k,
|
||||||
(hipblasDoubleComplex *) &alpha_p[0],
|
(hipDoubleComplex *) &alpha_p[0],
|
||||||
(hipblasDoubleComplex **)&Amk[0], lda,
|
(hipDoubleComplex **)&Amk[0], lda,
|
||||||
(hipblasDoubleComplex **)&Bkn[0], ldb,
|
(hipDoubleComplex **)&Bkn[0], ldb,
|
||||||
(hipblasDoubleComplex *) &beta_p[0],
|
(hipDoubleComplex *) &beta_p[0],
|
||||||
(hipblasDoubleComplex **)&Cmn[0], ldc,
|
(hipDoubleComplex **)&Cmn[0], ldc,
|
||||||
batchCount);
|
batchCount);
|
||||||
|
#else
|
||||||
|
auto err = hipblasZgemmBatched(gridblasHandle,
|
||||||
|
hOpA,
|
||||||
|
hOpB,
|
||||||
|
m,n,k,
|
||||||
|
(hipblasDoubleComplex *) &alpha_p[0],
|
||||||
|
(hipblasDoubleComplex **)&Amk[0], lda,
|
||||||
|
(hipblasDoubleComplex **)&Bkn[0], ldb,
|
||||||
|
(hipblasDoubleComplex *) &beta_p[0],
|
||||||
|
(hipblasDoubleComplex **)&Cmn[0], ldc,
|
||||||
|
batchCount);
|
||||||
|
#endif
|
||||||
// std::cout << " hipblas return code " <<(int)err<<std::endl;
|
// std::cout << " hipblas return code " <<(int)err<<std::endl;
|
||||||
GRID_ASSERT(err==HIPBLAS_STATUS_SUCCESS);
|
GRID_ASSERT(err==HIPBLAS_STATUS_SUCCESS);
|
||||||
#endif
|
#endif
|
||||||
@@ -503,17 +517,30 @@ public:
|
|||||||
if ( OpB == GridBLAS_OP_N ) hOpB = HIPBLAS_OP_N;
|
if ( OpB == GridBLAS_OP_N ) hOpB = HIPBLAS_OP_N;
|
||||||
if ( OpB == GridBLAS_OP_T ) hOpB = HIPBLAS_OP_T;
|
if ( OpB == GridBLAS_OP_T ) hOpB = HIPBLAS_OP_T;
|
||||||
if ( OpB == GridBLAS_OP_C ) hOpB = HIPBLAS_OP_C;
|
if ( OpB == GridBLAS_OP_C ) hOpB = HIPBLAS_OP_C;
|
||||||
|
#if defined(HIP_VERSION_MAJOR) && (HIP_VERSION_MAJOR >=7)
|
||||||
auto err = hipblasCgemmBatched(gridblasHandle,
|
auto err = hipblasCgemmBatched(gridblasHandle,
|
||||||
hOpA,
|
hOpA,
|
||||||
hOpB,
|
hOpB,
|
||||||
m,n,k,
|
m,n,k,
|
||||||
(hipblasComplex *) &alpha_p[0],
|
(hipComplex *) &alpha_p[0],
|
||||||
(hipblasComplex **)&Amk[0], lda,
|
(hipComplex **)&Amk[0], lda,
|
||||||
(hipblasComplex **)&Bkn[0], ldb,
|
(hipComplex **)&Bkn[0], ldb,
|
||||||
(hipblasComplex *) &beta_p[0],
|
(hipComplex *) &beta_p[0],
|
||||||
(hipblasComplex **)&Cmn[0], ldc,
|
(hipComplex **)&Cmn[0], ldc,
|
||||||
batchCount);
|
batchCount);
|
||||||
|
#else
|
||||||
|
auto err = hipblasCgemmBatched(gridblasHandle,
|
||||||
|
hOpA,
|
||||||
|
hOpB,
|
||||||
|
m,n,k,
|
||||||
|
(hipblasComplex *) &alpha_p[0],
|
||||||
|
(hipblasComplex **)&Amk[0], lda,
|
||||||
|
(hipblasComplex **)&Bkn[0], ldb,
|
||||||
|
(hipblasComplex *) &beta_p[0],
|
||||||
|
(hipblasComplex **)&Cmn[0], ldc,
|
||||||
|
batchCount);
|
||||||
|
|
||||||
|
#endif
|
||||||
GRID_ASSERT(err==HIPBLAS_STATUS_SUCCESS);
|
GRID_ASSERT(err==HIPBLAS_STATUS_SUCCESS);
|
||||||
#endif
|
#endif
|
||||||
#ifdef GRID_CUDA
|
#ifdef GRID_CUDA
|
||||||
@@ -1094,11 +1121,19 @@ public:
|
|||||||
GRID_ASSERT(info.size()==batchCount);
|
GRID_ASSERT(info.size()==batchCount);
|
||||||
|
|
||||||
#ifdef GRID_HIP
|
#ifdef GRID_HIP
|
||||||
|
#if defined(HIP_VERSION_MAJOR) && (HIP_VERSION_MAJOR >=7)
|
||||||
auto err = hipblasZgetrfBatched(gridblasHandle,(int)n,
|
auto err = hipblasZgetrfBatched(gridblasHandle,(int)n,
|
||||||
(hipblasDoubleComplex **)&Ann[0], (int)n,
|
(hipDoubleComplex **)&Ann[0], (int)n,
|
||||||
(int*) &ipiv[0],
|
(int*) &ipiv[0],
|
||||||
(int*) &info[0],
|
(int*) &info[0],
|
||||||
(int)batchCount);
|
(int)batchCount);
|
||||||
|
#else
|
||||||
|
auto err = hipblasZgetrfBatched(gridblasHandle,(int)n,
|
||||||
|
(hipblasDoubleComplex **)&Ann[0], (int)n,
|
||||||
|
(int*) &ipiv[0],
|
||||||
|
(int*) &info[0],
|
||||||
|
(int)batchCount);
|
||||||
|
#endif
|
||||||
GRID_ASSERT(err==HIPBLAS_STATUS_SUCCESS);
|
GRID_ASSERT(err==HIPBLAS_STATUS_SUCCESS);
|
||||||
#endif
|
#endif
|
||||||
#ifdef GRID_CUDA
|
#ifdef GRID_CUDA
|
||||||
@@ -1124,11 +1159,20 @@ public:
|
|||||||
GRID_ASSERT(info.size()==batchCount);
|
GRID_ASSERT(info.size()==batchCount);
|
||||||
|
|
||||||
#ifdef GRID_HIP
|
#ifdef GRID_HIP
|
||||||
|
#if defined(HIP_VERSION_MAJOR) && (HIP_VERSION_MAJOR >=7)
|
||||||
auto err = hipblasCgetrfBatched(gridblasHandle,(int)n,
|
auto err = hipblasCgetrfBatched(gridblasHandle,(int)n,
|
||||||
(hipblasComplex **)&Ann[0], (int)n,
|
(hipComplex **)&Ann[0], (int)n,
|
||||||
(int*) &ipiv[0],
|
(int*) &ipiv[0],
|
||||||
(int*) &info[0],
|
(int*) &info[0],
|
||||||
(int)batchCount);
|
(int)batchCount);
|
||||||
|
#else
|
||||||
|
auto err = hipblasCgetrfBatched(gridblasHandle,(int)n,
|
||||||
|
(hipblasComplex **)&Ann[0], (int)n,
|
||||||
|
(int*) &ipiv[0],
|
||||||
|
(int*) &info[0],
|
||||||
|
(int)batchCount);
|
||||||
|
#endif
|
||||||
|
|
||||||
GRID_ASSERT(err==HIPBLAS_STATUS_SUCCESS);
|
GRID_ASSERT(err==HIPBLAS_STATUS_SUCCESS);
|
||||||
#endif
|
#endif
|
||||||
#ifdef GRID_CUDA
|
#ifdef GRID_CUDA
|
||||||
@@ -1201,12 +1245,22 @@ public:
|
|||||||
GRID_ASSERT(Cnn.size()==batchCount);
|
GRID_ASSERT(Cnn.size()==batchCount);
|
||||||
|
|
||||||
#ifdef GRID_HIP
|
#ifdef GRID_HIP
|
||||||
|
#if defined(HIP_VERSION_MAJOR) && (HIP_VERSION_MAJOR >=7)
|
||||||
auto err = hipblasZgetriBatched(gridblasHandle,(int)n,
|
auto err = hipblasZgetriBatched(gridblasHandle,(int)n,
|
||||||
(hipblasDoubleComplex **)&Ann[0], (int)n,
|
(hipDoubleComplex **)&Ann[0], (int)n,
|
||||||
(int*) &ipiv[0],
|
(int*) &ipiv[0],
|
||||||
(hipblasDoubleComplex **)&Cnn[0], (int)n,
|
(hipDoubleComplex **)&Cnn[0], (int)n,
|
||||||
(int*) &info[0],
|
(int*) &info[0],
|
||||||
(int)batchCount);
|
(int)batchCount);
|
||||||
|
#else
|
||||||
|
auto err = hipblasZgetriBatched(gridblasHandle,(int)n,
|
||||||
|
(hipblasDoubleComplex **)&Ann[0], (int)n,
|
||||||
|
(int*) &ipiv[0],
|
||||||
|
(hipblasDoubleComplex **)&Cnn[0], (int)n,
|
||||||
|
(int*) &info[0],
|
||||||
|
(int)batchCount);
|
||||||
|
|
||||||
|
#endif
|
||||||
GRID_ASSERT(err==HIPBLAS_STATUS_SUCCESS);
|
GRID_ASSERT(err==HIPBLAS_STATUS_SUCCESS);
|
||||||
#endif
|
#endif
|
||||||
#ifdef GRID_CUDA
|
#ifdef GRID_CUDA
|
||||||
@@ -1235,12 +1289,21 @@ public:
|
|||||||
GRID_ASSERT(Cnn.size()==batchCount);
|
GRID_ASSERT(Cnn.size()==batchCount);
|
||||||
|
|
||||||
#ifdef GRID_HIP
|
#ifdef GRID_HIP
|
||||||
|
#if defined(HIP_VERSION_MAJOR) && (HIP_VERSION_MAJOR >=7)
|
||||||
auto err = hipblasCgetriBatched(gridblasHandle,(int)n,
|
auto err = hipblasCgetriBatched(gridblasHandle,(int)n,
|
||||||
(hipblasComplex **)&Ann[0], (int)n,
|
(hipComplex **)&Ann[0], (int)n,
|
||||||
(int*) &ipiv[0],
|
(int*) &ipiv[0],
|
||||||
(hipblasComplex **)&Cnn[0], (int)n,
|
(hipComplex **)&Cnn[0], (int)n,
|
||||||
(int*) &info[0],
|
(int*) &info[0],
|
||||||
(int)batchCount);
|
(int)batchCount);
|
||||||
|
#else
|
||||||
|
auto err = hipblasCgetriBatched(gridblasHandle,(int)n,
|
||||||
|
(hipblasComplex **)&Ann[0], (int)n,
|
||||||
|
(int*) &ipiv[0],
|
||||||
|
(hipblasComplex **)&Cnn[0], (int)n,
|
||||||
|
(int*) &info[0],
|
||||||
|
(int)batchCount);
|
||||||
|
#endif
|
||||||
GRID_ASSERT(err==HIPBLAS_STATUS_SUCCESS);
|
GRID_ASSERT(err==HIPBLAS_STATUS_SUCCESS);
|
||||||
#endif
|
#endif
|
||||||
#ifdef GRID_CUDA
|
#ifdef GRID_CUDA
|
||||||
|
|||||||
Reference in New Issue
Block a user