1
0
mirror of https://github.com/paboyle/Grid.git synced 2024-11-09 23:45:36 +00:00

BLAS everywhere

This commit is contained in:
Peter Boyle 2024-07-25 18:09:02 +00:00
parent baac1127d0
commit f8f408e7a9

View File

@ -261,23 +261,25 @@ public:
fprintf(FP,"\n\n");
};
template<class CComplex>
static void BLAS(void)
{
//int nbasis, int nrhs, int coarseVol
int basis[] = { 16,32,64 };
int rhs[] = { 8,16,32 };
int vol = 4*4*4*4;
int rhs[] = { 8,12,16 };
int vol = 8*8*8*8;
int blk = 4*4*4*4;
GridBLAS blas;
int fpbits = sizeof(CComplex)*4;
std::cout<<GridLogMessage << "=================================================================================="<<std::endl;
std::cout<<GridLogMessage << "= batched GEMM (double precision) "<<std::endl;
std::cout<<GridLogMessage << "= batched GEMM fp"<<fpbits<<std::endl;
std::cout<<GridLogMessage << "=================================================================================="<<std::endl;
std::cout<<GridLogMessage << " M "<<"\t\t"<<"N"<<"\t\t\t"<<"K"<<"\t\t"<<"Gflop/s / rank (coarse mrhs)"<<std::endl;
std::cout<<GridLogMessage << "----------------------------------------------------------"<<std::endl;
fprintf(FP,"GEMM\n\n M, N, K, BATCH, GF/s per rank\n");
fprintf(FP,"GEMM\n\n M, N, K, BATCH, GF/s per rank fp%d\n",fpbits);
for(int b=0;b<3;b++){
for(int r=0;r<3;r++){
@ -285,7 +287,7 @@ public:
int N=rhs[r];
int K=basis[b];
int BATCH=vol;
double p=blas.benchmark(M,N,K,BATCH);
double p=blas.benchmark<CComplex>(M,N,K,BATCH);
fprintf(FP,"%d, %d, %d, %d, %f\n", M, N, K, BATCH, p);
@ -299,9 +301,9 @@ public:
for(int r=0;r<3;r++){
int M=basis[b];
int N=rhs[r];
int K=vol;
int K=blk;
int BATCH=vol;
double p=blas.benchmark(M,N,K,BATCH);
double p=blas.benchmark<CComplex>(M,N,K,BATCH);
fprintf(FP,"%d, %d, %d, %d, %f\n", M, N, K, BATCH, p);
std::cout<<GridLogMessage<<std::setprecision(3)
@ -313,10 +315,10 @@ public:
for(int b=0;b<3;b++){
for(int r=0;r<3;r++){
int M=rhs[r];
int N=vol;
int N=blk;
int K=basis[b];
int BATCH=vol;
double p=blas.benchmark(M,N,K,BATCH);
double p=blas.benchmark<CComplex>(M,N,K,BATCH);
fprintf(FP,"%d, %d, %d, %d, %f\n", M, N, K, BATCH, p);
std::cout<<GridLogMessage<<std::setprecision(3)
@ -867,6 +869,7 @@ int main (int argc, char ** argv)
int do_memory=1;
int do_comms =1;
int do_blas =1;
int do_dslash=1;
int sel=4;
std::vector<int> L_list({8,12,16,24,32});
@ -877,6 +880,7 @@ int main (int argc, char ** argv)
std::vector<double> staggered;
int Ls=1;
if (do_dslash){
std::cout<<GridLogMessage << "=================================================================================="<<std::endl;
std::cout<<GridLogMessage << " Clover dslash 4D vectorised (temporarily Wilson)" <<std::endl;
std::cout<<GridLogMessage << "=================================================================================="<<std::endl;
@ -909,7 +913,8 @@ int main (int argc, char ** argv)
std::cout<<GridLogMessage << L_list[l] <<" \t\t "<< clover[l]<<" \t\t "<<dwf4[l] << " \t\t "<< staggered[l]<<std::endl;
}
std::cout<<GridLogMessage << "=================================================================================="<<std::endl;
}
int NN=NN_global;
if ( do_memory ) {
std::cout<<GridLogMessage << "=================================================================================="<<std::endl;
@ -919,12 +924,11 @@ int main (int argc, char ** argv)
}
if ( do_blas ) {
#if defined(GRID_CUDA) || defined(GRID_HIP) || defined(GRID_SYCL)
std::cout<<GridLogMessage << "=================================================================================="<<std::endl;
std::cout<<GridLogMessage << " Batched BLAS benchmark " <<std::endl;
std::cout<<GridLogMessage << "=================================================================================="<<std::endl;
Benchmark::BLAS();
#endif
Benchmark::BLAS<ComplexD>();
Benchmark::BLAS<ComplexF>();
}
if ( do_su4 ) {
@ -941,6 +945,7 @@ int main (int argc, char ** argv)
Benchmark::Comms();
}
if(do_dslash){
std::cout<<GridLogMessage << "=================================================================================="<<std::endl;
std::cout<<GridLogMessage << " Per Node Summary table Ls="<<Ls <<std::endl;
std::cout<<GridLogMessage << "=================================================================================="<<std::endl;
@ -954,7 +959,6 @@ int main (int argc, char ** argv)
fprintf(FP,"%d , %.0f, %.0f, %.0f\n",L_list[l],clover[l]/NN/1000.,dwf4[l]/NN/1000.,staggered[l]/NN/1000.);
}
fprintf(FP,"\n");
std::cout<<GridLogMessage << "=================================================================================="<<std::endl;
std::cout<<GridLogMessage << "=================================================================================="<<std::endl;
@ -962,7 +966,8 @@ int main (int argc, char ** argv)
std::cout<<GridLogMessage << " Comparison point is 0.5*("<<dwf4[sel]/NN<<"+"<<dwf4[selm1]/NN << ") "<<std::endl;
std::cout<<std::setprecision(3);
std::cout<<GridLogMessage << "=================================================================================="<<std::endl;
}
Grid_finalize();
fclose(FP);
}