1
0
mirror of https://github.com/paboyle/Grid.git synced 2024-11-09 23:45:36 +00:00

USQCD benchmark

This commit is contained in:
Peter Boyle 2024-03-01 00:05:04 -05:00
parent 04ca065281
commit c805f86343
4 changed files with 78 additions and 45 deletions

View File

@ -0,0 +1,34 @@
/*************************************************************************************
Grid physics library, www.github.com/paboyle/Grid
Source file: BatchedBlas.h
Copyright (C) 2023
Author: Peter Boyle <pboyle@bnl.gov>
This program is free software; you can redistribute it and/or modify
it under the terms of the GNU General Public License as published by
the Free Software Foundation; either version 2 of the License, or
(at your option) any later version.
This program is distributed in the hope that it will be useful,
but WITHOUT ANY WARRANTY; without even the implied warranty of
MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
GNU General Public License for more details.
You should have received a copy of the GNU General Public License along
with this program; if not, write to the Free Software Foundation, Inc.,
51 Franklin Street, Fifth Floor, Boston, MA 02110-1301 USA.
See the full license in the file "LICENSE" in the top level distribution directory
*************************************************************************************/
/* END LEGAL */
#include <Grid/GridCore.h>
#include <Grid/algorithms/blas/BatchedBlas.h>
NAMESPACE_BEGIN(Grid);
gridblasHandle_t GridBLAS::gridblasHandle;
int GridBLAS::gridblasInit;
NAMESPACE_END(Grid);

View File

@ -615,9 +615,10 @@ public:
deviceVector<ComplexD> beta_p(1); deviceVector<ComplexD> beta_p(1);
acceleratorCopyToDevice((void *)&alpha,(void *)&alpha_p[0],sizeof(ComplexD)); acceleratorCopyToDevice((void *)&alpha,(void *)&alpha_p[0],sizeof(ComplexD));
acceleratorCopyToDevice((void *)&beta ,(void *)&beta_p[0],sizeof(ComplexD)); acceleratorCopyToDevice((void *)&beta ,(void *)&beta_p[0],sizeof(ComplexD));
std::cout << "blasZgemmStridedBatched mnk "<<m<<","<<n<<","<<k<<" count "<<batchCount<<std::endl;
std::cout << "blasZgemmStridedBatched ld "<<lda<<","<<ldb<<","<<ldc<<std::endl; // std::cout << "blasZgemmStridedBatched mnk "<<m<<","<<n<<","<<k<<" count "<<batchCount<<std::endl;
std::cout << "blasZgemmStridedBatched sd "<<sda<<","<<sdb<<","<<sdc<<std::endl; // std::cout << "blasZgemmStridedBatched ld "<<lda<<","<<ldb<<","<<ldc<<std::endl;
// std::cout << "blasZgemmStridedBatched sd "<<sda<<","<<sdb<<","<<sdc<<std::endl;
#ifdef GRID_HIP #ifdef GRID_HIP
auto err = hipblasZgemmStridedBatched(gridblasHandle, auto err = hipblasZgemmStridedBatched(gridblasHandle,
HIPBLAS_OP_N, HIPBLAS_OP_N,
@ -672,8 +673,9 @@ public:
ComplexD alpha(1.0); ComplexD alpha(1.0);
ComplexD beta (1.0); ComplexD beta (1.0);
RealD flops = 8.0*M*N*K*BATCH; RealD flops = 8.0*M*N*K*BATCH;
for(int i=0;i<10;i++){ int ncall=10;
RealD t0 = usecond(); RealD t0 = usecond();
for(int i=0;i<ncall;i++){
gemmStridedBatched(M,N,K, gemmStridedBatched(M,N,K,
alpha, alpha,
&A[0], // m x k &A[0], // m x k
@ -681,12 +683,13 @@ public:
beta, beta,
&C[0], // m x n &C[0], // m x n
BATCH); BATCH);
}
synchronise(); synchronise();
RealD t1 = usecond(); RealD t1 = usecond();
RealD bytes = 1.0*sizeof(ComplexD)*(M*N*2+N*K+M*K)*BATCH; RealD bytes = 1.0*sizeof(ComplexD)*(M*N*2+N*K+M*K)*BATCH;
flops = 8.0*M*N*K*BATCH*ncall;
flops = flops/(t1-t0)/1.e3; flops = flops/(t1-t0)/1.e3;
} return flops; // Returns gigaflops
return flops;
} }

View File

@ -65,7 +65,7 @@ struct time_statistics{
void comms_header(){ void comms_header(){
std::cout <<GridLogMessage << " L "<<"\t"<<" Ls "<<"\t" std::cout <<GridLogMessage << " L "<<"\t"<<" Ls "<<"\t"
<<"bytes\t MB/s uni (err/min/max) \t\t MB/s bidi (err/min/max)"<<std::endl; <<"bytes\t MB/s uni \t\t MB/s bidi "<<std::endl;
}; };
struct controls { struct controls {
@ -180,10 +180,9 @@ public:
std::cout<<GridLogMessage << lat<<"\t"<<Ls<<"\t " std::cout<<GridLogMessage << lat<<"\t"<<Ls<<"\t "
<< bytes << " \t " << bytes << " \t "
<<xbytes/timestat.mean<<" \t "<< xbytes*timestat.err/(timestat.mean*timestat.mean)<< " \t " <<xbytes/timestat.mean
<<xbytes/timestat.max <<" "<< xbytes/timestat.min << "\t\t"
<< "\t\t"<< bidibytes/timestat.mean<< " " << bidibytes*timestat.err/(timestat.mean*timestat.mean) << " " << bidibytes/timestat.mean<< std::endl;
<< bidibytes/timestat.max << " " << bidibytes/timestat.min << std::endl;
fprintf(FP,"%ld, %d, %f\n",(long)bytes,dir,bidibytes/timestat.mean/1000.); fprintf(FP,"%ld, %d, %f\n",(long)bytes,dir,bidibytes/timestat.mean/1000.);
} }
} }
@ -256,7 +255,7 @@ public:
<< lat<<"\t\t"<<bytes<<" \t\t"<<bytes/time<<"\t\t"<<flops/time<<"\t\t"<<(stop-start)/1000./1000. << lat<<"\t\t"<<bytes<<" \t\t"<<bytes/time<<"\t\t"<<flops/time<<"\t\t"<<(stop-start)/1000./1000.
<< "\t\t"<< bytes/time/NN <<std::endl; << "\t\t"<< bytes/time/NN <<std::endl;
fprintf(FP,"%ld, %f\n",(long)bytes,bytes/time/NN/1000.); fprintf(FP,"%ld, %f\n",(long)bytes,bytes/time/NN);
} }
fprintf(FP,"\n\n"); fprintf(FP,"\n\n");
@ -268,64 +267,61 @@ public:
//int nbasis, int nrhs, int coarseVol //int nbasis, int nrhs, int coarseVol
int basis[] = { 16,32,64 }; int basis[] = { 16,32,64 };
int rhs[] = { 8,16,32 }; int rhs[] = { 8,16,32 };
int vols[] = { 4*4*4*4, 8*8*8*8, 8*8*16*16 }; int vol = 4*4*4*4;
GridBLAS blas; GridBLAS blas;
std::cout<<GridLogMessage << "=================================================================================="<<std::endl; std::cout<<GridLogMessage << "=================================================================================="<<std::endl;
std::cout<<GridLogMessage << "= batched GEMM (double precision) "<<std::endl; std::cout<<GridLogMessage << "= batched GEMM (double precision) "<<std::endl;
std::cout<<GridLogMessage << "=================================================================================="<<std::endl; std::cout<<GridLogMessage << "=================================================================================="<<std::endl;
std::cout<<GridLogMessage << " M "<<"\t\t"<<"N"<<"\t\t\t"<<"K"<<"\t\t"<<"Gflop/s / node (coarse mrhs)"<<std::endl; std::cout<<GridLogMessage << " M "<<"\t\t"<<"N"<<"\t\t\t"<<"K"<<"\t\t"<<"Gflop/s / rank (coarse mrhs)"<<std::endl;
std::cout<<GridLogMessage << "----------------------------------------------------------"<<std::endl; std::cout<<GridLogMessage << "----------------------------------------------------------"<<std::endl;
fprintf(FP,"GEMM\n\n M, N, K, BATCH, GF/s per rank\n"); fprintf(FP,"GEMM\n\n M, N, K, BATCH, GF/s per rank\n");
for(int b=0;b<3;b++){ for(int b=0;b<3;b++){
for(int r=0;r<3;r++){ for(int r=0;r<3;r++){
for(int v=0;v<3;v++){
int M=basis[b]; int M=basis[b];
int N=rhs[r]; int N=rhs[r];
int K=basis[b]; int K=basis[b];
int BATCH=vols[v]; int BATCH=vol;
double p=blas.benchmark(M,rhs[r],vols[v],1); double p=blas.benchmark(M,N,K,BATCH);
fprintf(FP,"%d, %d, %d, %d, %f\n", M, N, K, BATCH, p); fprintf(FP,"%d, %d, %d, %d, %f\n", M, N, K, BATCH, p);
std::cout<<GridLogMessage<<std::setprecision(3) std::cout<<GridLogMessage<<std::setprecision(3)
<< M<<"\t\t"<<N<<"\t\t"<<K<<"\t\t"<<BATCH<<"\t\t"<<p<<std::endl; << M<<"\t\t"<<N<<"\t\t"<<K<<"\t\t"<<BATCH<<"\t\t"<<p<<std::endl;
}}} }}
std::cout<<GridLogMessage << "----------------------------------------------------------"<<std::endl; std::cout<<GridLogMessage << "----------------------------------------------------------"<<std::endl;
std::cout<<GridLogMessage << " M "<<"\t\t"<<"N"<<"\t\t\t"<<"K"<<"\t\t"<<"Gflop/s / node (block project)"<<std::endl; std::cout<<GridLogMessage << " M "<<"\t\t"<<"N"<<"\t\t\t"<<"K"<<"\t\t"<<"Gflop/s / rank (block project)"<<std::endl;
std::cout<<GridLogMessage << "----------------------------------------------------------"<<std::endl; std::cout<<GridLogMessage << "----------------------------------------------------------"<<std::endl;
for(int b=0;b<3;b++){ for(int b=0;b<3;b++){
for(int r=0;r<3;r++){ for(int r=0;r<3;r++){
for(int v=0;v<2;v++){
int M=basis[b]; int M=basis[b];
int N=rhs[r]; int N=rhs[r];
int K=vols[2]; int K=vol;
int BATCH=vols[v]; int BATCH=vol;
double p=blas.benchmark(M,rhs[r],vols[v],1); double p=blas.benchmark(M,N,K,BATCH);
fprintf(FP,"%d, %d, %d, %d, %f\n", M, N, K, BATCH, p); fprintf(FP,"%d, %d, %d, %d, %f\n", M, N, K, BATCH, p);
std::cout<<GridLogMessage<<std::setprecision(3) std::cout<<GridLogMessage<<std::setprecision(3)
<< M<<"\t\t"<<N<<"\t\t"<<K<<"\t\t"<<BATCH<<"\t\t"<<p<<std::endl; << M<<"\t\t"<<N<<"\t\t"<<K<<"\t\t"<<BATCH<<"\t\t"<<p<<std::endl;
}}} }}
std::cout<<GridLogMessage << "----------------------------------------------------------"<<std::endl; std::cout<<GridLogMessage << "----------------------------------------------------------"<<std::endl;
std::cout<<GridLogMessage << " M "<<"\t\t"<<"N"<<"\t\t\t"<<"K"<<"\t\t"<<"Gflop/s / node (block promote)"<<std::endl; std::cout<<GridLogMessage << " M "<<"\t\t"<<"N"<<"\t\t\t"<<"K"<<"\t\t"<<"Gflop/s / rank (block promote)"<<std::endl;
std::cout<<GridLogMessage << "----------------------------------------------------------"<<std::endl; std::cout<<GridLogMessage << "----------------------------------------------------------"<<std::endl;
for(int b=0;b<3;b++){ for(int b=0;b<3;b++){
for(int r=0;r<3;r++){ for(int r=0;r<3;r++){
for(int v=0;v<2;v++){
int M=rhs[r]; int M=rhs[r];
int N=vols[2]; int N=vol;
int K=basis[b]; int K=basis[b];
int BATCH=vols[v]; int BATCH=vol;
double p=blas.benchmark(M,rhs[r],vols[v],1); double p=blas.benchmark(M,N,K,BATCH);
fprintf(FP,"%d, %d, %d, %d, %f\n", M, N, K, BATCH, p); fprintf(FP,"%d, %d, %d, %d, %f\n", M, N, K, BATCH, p);
std::cout<<GridLogMessage<<std::setprecision(3) std::cout<<GridLogMessage<<std::setprecision(3)
<< M<<"\t\t"<<N<<"\t\t"<<K<<"\t\t"<<BATCH<<"\t\t"<<p<<std::endl; << M<<"\t\t"<<N<<"\t\t"<<K<<"\t\t"<<BATCH<<"\t\t"<<p<<std::endl;
}}} }}
fprintf(FP,"\n\n\n"); fprintf(FP,"\n\n\n");
std::cout<<GridLogMessage << "=================================================================================="<<std::endl; std::cout<<GridLogMessage << "=================================================================================="<<std::endl;
}; };
@ -873,10 +869,10 @@ int main (int argc, char ** argv)
int Ls=1; int Ls=1;
std::cout<<GridLogMessage << "=================================================================================="<<std::endl; std::cout<<GridLogMessage << "=================================================================================="<<std::endl;
std::cout<<GridLogMessage << " Clover dslash 4D vectorised" <<std::endl; std::cout<<GridLogMessage << " Clover dslash 4D vectorised (temporarily Wilson)" <<std::endl;
std::cout<<GridLogMessage << "=================================================================================="<<std::endl; std::cout<<GridLogMessage << "=================================================================================="<<std::endl;
for(int l=0;l<L_list.size();l++){ for(int l=0;l<L_list.size();l++){
clover.push_back(Benchmark::Clover(L_list[l])); clover.push_back(Benchmark::DWF(1,L_list[l]));
} }
Ls=12; Ls=12;
@ -942,7 +938,7 @@ int main (int argc, char ** argv)
std::cout<<GridLogMessage << " L \t\t Clover\t\t DWF4\t\t Staggered (GF/s per node)" <<std::endl; std::cout<<GridLogMessage << " L \t\t Clover\t\t DWF4\t\t Staggered (GF/s per node)" <<std::endl;
fprintf(FP,"Per node summary table\n"); fprintf(FP,"Per node summary table\n");
fprintf(FP,"\n"); fprintf(FP,"\n");
fprintf(FP,"L , Wilson, DWF4, Staggered\n"); fprintf(FP,"L , Wilson, DWF4, Staggered, GF/s per node\n");
fprintf(FP,"\n"); fprintf(FP,"\n");
for(int l=0;l<L_list.size();l++){ for(int l=0;l<L_list.size();l++){
std::cout<<GridLogMessage << L_list[l] <<" \t\t "<< clover[l]/NN<<" \t "<<dwf4[l]/NN<< " \t "<<staggered[l]/NN<<std::endl; std::cout<<GridLogMessage << L_list[l] <<" \t\t "<< clover[l]/NN<<" \t "<<dwf4[l]/NN<< " \t "<<staggered[l]/NN<<std::endl;

View File

@ -16,7 +16,7 @@ CLIME=`spack find --paths c-lime@2-3-9 | grep c-lime| cut -c 15-`
--disable-fermion-reps \ --disable-fermion-reps \
CXX=hipcc MPICXX=mpicxx \ CXX=hipcc MPICXX=mpicxx \
CXXFLAGS="-fPIC -I{$ROCM_PATH}/include/ -I${MPICH_DIR}/include -L/lib64 -fgpu-sanitize" \ CXXFLAGS="-fPIC -I{$ROCM_PATH}/include/ -I${MPICH_DIR}/include -L/lib64 -fgpu-sanitize" \
LDFLAGS="-L/lib64 -L${MPICH_DIR}/lib -lmpi -L${CRAY_MPICH_ROOTDIR}/gtl/lib -lmpi_gtl_hsa -lamdhip64 " LDFLAGS="-L/lib64 -L${MPICH_DIR}/lib -lmpi -L${CRAY_MPICH_ROOTDIR}/gtl/lib -lmpi_gtl_hsa -lamdhip64 -lhipblas -lrocblas"