1
0
mirror of https://github.com/paboyle/Grid.git synced 2024-11-14 01:35:36 +00:00

BLAS everywhere

This commit is contained in:
Peter Boyle 2024-07-25 18:09:02 +00:00
parent baac1127d0
commit f8f408e7a9

View File

@ -261,23 +261,25 @@ public:
fprintf(FP,"\n\n"); fprintf(FP,"\n\n");
}; };
template<class CComplex>
static void BLAS(void) static void BLAS(void)
{ {
//int nbasis, int nrhs, int coarseVol //int nbasis, int nrhs, int coarseVol
int basis[] = { 16,32,64 }; int basis[] = { 16,32,64 };
int rhs[] = { 8,16,32 }; int rhs[] = { 8,12,16 };
int vol = 4*4*4*4; int vol = 8*8*8*8;
int blk = 4*4*4*4;
GridBLAS blas; GridBLAS blas;
int fpbits = sizeof(CComplex)*4;
std::cout<<GridLogMessage << "=================================================================================="<<std::endl; std::cout<<GridLogMessage << "=================================================================================="<<std::endl;
std::cout<<GridLogMessage << "= batched GEMM (double precision) "<<std::endl; std::cout<<GridLogMessage << "= batched GEMM fp"<<fpbits<<std::endl;
std::cout<<GridLogMessage << "=================================================================================="<<std::endl; std::cout<<GridLogMessage << "=================================================================================="<<std::endl;
std::cout<<GridLogMessage << " M "<<"\t\t"<<"N"<<"\t\t\t"<<"K"<<"\t\t"<<"Gflop/s / rank (coarse mrhs)"<<std::endl; std::cout<<GridLogMessage << " M "<<"\t\t"<<"N"<<"\t\t\t"<<"K"<<"\t\t"<<"Gflop/s / rank (coarse mrhs)"<<std::endl;
std::cout<<GridLogMessage << "----------------------------------------------------------"<<std::endl; std::cout<<GridLogMessage << "----------------------------------------------------------"<<std::endl;
fprintf(FP,"GEMM\n\n M, N, K, BATCH, GF/s per rank\n"); fprintf(FP,"GEMM\n\n M, N, K, BATCH, GF/s per rank fp%d\n",fpbits);
for(int b=0;b<3;b++){ for(int b=0;b<3;b++){
for(int r=0;r<3;r++){ for(int r=0;r<3;r++){
@ -285,7 +287,7 @@ public:
int N=rhs[r]; int N=rhs[r];
int K=basis[b]; int K=basis[b];
int BATCH=vol; int BATCH=vol;
double p=blas.benchmark(M,N,K,BATCH); double p=blas.benchmark<CComplex>(M,N,K,BATCH);
fprintf(FP,"%d, %d, %d, %d, %f\n", M, N, K, BATCH, p); fprintf(FP,"%d, %d, %d, %d, %f\n", M, N, K, BATCH, p);
@ -299,9 +301,9 @@ public:
for(int r=0;r<3;r++){ for(int r=0;r<3;r++){
int M=basis[b]; int M=basis[b];
int N=rhs[r]; int N=rhs[r];
int K=vol; int K=blk;
int BATCH=vol; int BATCH=vol;
double p=blas.benchmark(M,N,K,BATCH); double p=blas.benchmark<CComplex>(M,N,K,BATCH);
fprintf(FP,"%d, %d, %d, %d, %f\n", M, N, K, BATCH, p); fprintf(FP,"%d, %d, %d, %d, %f\n", M, N, K, BATCH, p);
std::cout<<GridLogMessage<<std::setprecision(3) std::cout<<GridLogMessage<<std::setprecision(3)
@ -313,10 +315,10 @@ public:
for(int b=0;b<3;b++){ for(int b=0;b<3;b++){
for(int r=0;r<3;r++){ for(int r=0;r<3;r++){
int M=rhs[r]; int M=rhs[r];
int N=vol; int N=blk;
int K=basis[b]; int K=basis[b];
int BATCH=vol; int BATCH=vol;
double p=blas.benchmark(M,N,K,BATCH); double p=blas.benchmark<CComplex>(M,N,K,BATCH);
fprintf(FP,"%d, %d, %d, %d, %f\n", M, N, K, BATCH, p); fprintf(FP,"%d, %d, %d, %d, %f\n", M, N, K, BATCH, p);
std::cout<<GridLogMessage<<std::setprecision(3) std::cout<<GridLogMessage<<std::setprecision(3)
@ -867,6 +869,7 @@ int main (int argc, char ** argv)
int do_memory=1; int do_memory=1;
int do_comms =1; int do_comms =1;
int do_blas =1; int do_blas =1;
int do_dslash=1;
int sel=4; int sel=4;
std::vector<int> L_list({8,12,16,24,32}); std::vector<int> L_list({8,12,16,24,32});
@ -877,6 +880,7 @@ int main (int argc, char ** argv)
std::vector<double> staggered; std::vector<double> staggered;
int Ls=1; int Ls=1;
if (do_dslash){
std::cout<<GridLogMessage << "=================================================================================="<<std::endl; std::cout<<GridLogMessage << "=================================================================================="<<std::endl;
std::cout<<GridLogMessage << " Clover dslash 4D vectorised (temporarily Wilson)" <<std::endl; std::cout<<GridLogMessage << " Clover dslash 4D vectorised (temporarily Wilson)" <<std::endl;
std::cout<<GridLogMessage << "=================================================================================="<<std::endl; std::cout<<GridLogMessage << "=================================================================================="<<std::endl;
@ -909,6 +913,7 @@ int main (int argc, char ** argv)
std::cout<<GridLogMessage << L_list[l] <<" \t\t "<< clover[l]<<" \t\t "<<dwf4[l] << " \t\t "<< staggered[l]<<std::endl; std::cout<<GridLogMessage << L_list[l] <<" \t\t "<< clover[l]<<" \t\t "<<dwf4[l] << " \t\t "<< staggered[l]<<std::endl;
} }
std::cout<<GridLogMessage << "=================================================================================="<<std::endl; std::cout<<GridLogMessage << "=================================================================================="<<std::endl;
}
int NN=NN_global; int NN=NN_global;
if ( do_memory ) { if ( do_memory ) {
@ -919,12 +924,11 @@ int main (int argc, char ** argv)
} }
if ( do_blas ) { if ( do_blas ) {
#if defined(GRID_CUDA) || defined(GRID_HIP) || defined(GRID_SYCL)
std::cout<<GridLogMessage << "=================================================================================="<<std::endl; std::cout<<GridLogMessage << "=================================================================================="<<std::endl;
std::cout<<GridLogMessage << " Batched BLAS benchmark " <<std::endl; std::cout<<GridLogMessage << " Batched BLAS benchmark " <<std::endl;
std::cout<<GridLogMessage << "=================================================================================="<<std::endl; std::cout<<GridLogMessage << "=================================================================================="<<std::endl;
Benchmark::BLAS(); Benchmark::BLAS<ComplexD>();
#endif Benchmark::BLAS<ComplexF>();
} }
if ( do_su4 ) { if ( do_su4 ) {
@ -941,6 +945,7 @@ int main (int argc, char ** argv)
Benchmark::Comms(); Benchmark::Comms();
} }
if(do_dslash){
std::cout<<GridLogMessage << "=================================================================================="<<std::endl; std::cout<<GridLogMessage << "=================================================================================="<<std::endl;
std::cout<<GridLogMessage << " Per Node Summary table Ls="<<Ls <<std::endl; std::cout<<GridLogMessage << " Per Node Summary table Ls="<<Ls <<std::endl;
std::cout<<GridLogMessage << "=================================================================================="<<std::endl; std::cout<<GridLogMessage << "=================================================================================="<<std::endl;
@ -954,7 +959,6 @@ int main (int argc, char ** argv)
fprintf(FP,"%d , %.0f, %.0f, %.0f\n",L_list[l],clover[l]/NN/1000.,dwf4[l]/NN/1000.,staggered[l]/NN/1000.); fprintf(FP,"%d , %.0f, %.0f, %.0f\n",L_list[l],clover[l]/NN/1000.,dwf4[l]/NN/1000.,staggered[l]/NN/1000.);
} }
fprintf(FP,"\n"); fprintf(FP,"\n");
std::cout<<GridLogMessage << "=================================================================================="<<std::endl; std::cout<<GridLogMessage << "=================================================================================="<<std::endl;
std::cout<<GridLogMessage << "=================================================================================="<<std::endl; std::cout<<GridLogMessage << "=================================================================================="<<std::endl;
@ -962,6 +966,7 @@ int main (int argc, char ** argv)
std::cout<<GridLogMessage << " Comparison point is 0.5*("<<dwf4[sel]/NN<<"+"<<dwf4[selm1]/NN << ") "<<std::endl; std::cout<<GridLogMessage << " Comparison point is 0.5*("<<dwf4[sel]/NN<<"+"<<dwf4[selm1]/NN << ") "<<std::endl;
std::cout<<std::setprecision(3); std::cout<<std::setprecision(3);
std::cout<<GridLogMessage << "=================================================================================="<<std::endl; std::cout<<GridLogMessage << "=================================================================================="<<std::endl;
}
Grid_finalize(); Grid_finalize();
fclose(FP); fclose(FP);