mirror of
https://github.com/paboyle/Grid.git
synced 2024-11-09 23:45:36 +00:00
BLAS everywhere
This commit is contained in:
parent
baac1127d0
commit
f8f408e7a9
@ -261,23 +261,25 @@ public:
|
||||
fprintf(FP,"\n\n");
|
||||
};
|
||||
|
||||
|
||||
template<class CComplex>
|
||||
static void BLAS(void)
|
||||
{
|
||||
//int nbasis, int nrhs, int coarseVol
|
||||
int basis[] = { 16,32,64 };
|
||||
int rhs[] = { 8,16,32 };
|
||||
int vol = 4*4*4*4;
|
||||
int rhs[] = { 8,12,16 };
|
||||
int vol = 8*8*8*8;
|
||||
int blk = 4*4*4*4;
|
||||
|
||||
GridBLAS blas;
|
||||
|
||||
|
||||
int fpbits = sizeof(CComplex)*4;
|
||||
std::cout<<GridLogMessage << "=================================================================================="<<std::endl;
|
||||
std::cout<<GridLogMessage << "= batched GEMM (double precision) "<<std::endl;
|
||||
std::cout<<GridLogMessage << "= batched GEMM fp"<<fpbits<<std::endl;
|
||||
std::cout<<GridLogMessage << "=================================================================================="<<std::endl;
|
||||
std::cout<<GridLogMessage << " M "<<"\t\t"<<"N"<<"\t\t\t"<<"K"<<"\t\t"<<"Gflop/s / rank (coarse mrhs)"<<std::endl;
|
||||
std::cout<<GridLogMessage << "----------------------------------------------------------"<<std::endl;
|
||||
|
||||
fprintf(FP,"GEMM\n\n M, N, K, BATCH, GF/s per rank\n");
|
||||
fprintf(FP,"GEMM\n\n M, N, K, BATCH, GF/s per rank fp%d\n",fpbits);
|
||||
|
||||
for(int b=0;b<3;b++){
|
||||
for(int r=0;r<3;r++){
|
||||
@ -285,7 +287,7 @@ public:
|
||||
int N=rhs[r];
|
||||
int K=basis[b];
|
||||
int BATCH=vol;
|
||||
double p=blas.benchmark(M,N,K,BATCH);
|
||||
double p=blas.benchmark<CComplex>(M,N,K,BATCH);
|
||||
|
||||
fprintf(FP,"%d, %d, %d, %d, %f\n", M, N, K, BATCH, p);
|
||||
|
||||
@ -299,9 +301,9 @@ public:
|
||||
for(int r=0;r<3;r++){
|
||||
int M=basis[b];
|
||||
int N=rhs[r];
|
||||
int K=vol;
|
||||
int K=blk;
|
||||
int BATCH=vol;
|
||||
double p=blas.benchmark(M,N,K,BATCH);
|
||||
double p=blas.benchmark<CComplex>(M,N,K,BATCH);
|
||||
|
||||
fprintf(FP,"%d, %d, %d, %d, %f\n", M, N, K, BATCH, p);
|
||||
std::cout<<GridLogMessage<<std::setprecision(3)
|
||||
@ -313,10 +315,10 @@ public:
|
||||
for(int b=0;b<3;b++){
|
||||
for(int r=0;r<3;r++){
|
||||
int M=rhs[r];
|
||||
int N=vol;
|
||||
int N=blk;
|
||||
int K=basis[b];
|
||||
int BATCH=vol;
|
||||
double p=blas.benchmark(M,N,K,BATCH);
|
||||
double p=blas.benchmark<CComplex>(M,N,K,BATCH);
|
||||
|
||||
fprintf(FP,"%d, %d, %d, %d, %f\n", M, N, K, BATCH, p);
|
||||
std::cout<<GridLogMessage<<std::setprecision(3)
|
||||
@ -867,6 +869,7 @@ int main (int argc, char ** argv)
|
||||
int do_memory=1;
|
||||
int do_comms =1;
|
||||
int do_blas =1;
|
||||
int do_dslash=1;
|
||||
|
||||
int sel=4;
|
||||
std::vector<int> L_list({8,12,16,24,32});
|
||||
@ -877,6 +880,7 @@ int main (int argc, char ** argv)
|
||||
std::vector<double> staggered;
|
||||
|
||||
int Ls=1;
|
||||
if (do_dslash){
|
||||
std::cout<<GridLogMessage << "=================================================================================="<<std::endl;
|
||||
std::cout<<GridLogMessage << " Clover dslash 4D vectorised (temporarily Wilson)" <<std::endl;
|
||||
std::cout<<GridLogMessage << "=================================================================================="<<std::endl;
|
||||
@ -909,7 +913,8 @@ int main (int argc, char ** argv)
|
||||
std::cout<<GridLogMessage << L_list[l] <<" \t\t "<< clover[l]<<" \t\t "<<dwf4[l] << " \t\t "<< staggered[l]<<std::endl;
|
||||
}
|
||||
std::cout<<GridLogMessage << "=================================================================================="<<std::endl;
|
||||
|
||||
}
|
||||
|
||||
int NN=NN_global;
|
||||
if ( do_memory ) {
|
||||
std::cout<<GridLogMessage << "=================================================================================="<<std::endl;
|
||||
@ -919,12 +924,11 @@ int main (int argc, char ** argv)
|
||||
}
|
||||
|
||||
if ( do_blas ) {
|
||||
#if defined(GRID_CUDA) || defined(GRID_HIP) || defined(GRID_SYCL)
|
||||
std::cout<<GridLogMessage << "=================================================================================="<<std::endl;
|
||||
std::cout<<GridLogMessage << " Batched BLAS benchmark " <<std::endl;
|
||||
std::cout<<GridLogMessage << "=================================================================================="<<std::endl;
|
||||
Benchmark::BLAS();
|
||||
#endif
|
||||
Benchmark::BLAS<ComplexD>();
|
||||
Benchmark::BLAS<ComplexF>();
|
||||
}
|
||||
|
||||
if ( do_su4 ) {
|
||||
@ -941,6 +945,7 @@ int main (int argc, char ** argv)
|
||||
Benchmark::Comms();
|
||||
}
|
||||
|
||||
if(do_dslash){
|
||||
std::cout<<GridLogMessage << "=================================================================================="<<std::endl;
|
||||
std::cout<<GridLogMessage << " Per Node Summary table Ls="<<Ls <<std::endl;
|
||||
std::cout<<GridLogMessage << "=================================================================================="<<std::endl;
|
||||
@ -954,7 +959,6 @@ int main (int argc, char ** argv)
|
||||
fprintf(FP,"%d , %.0f, %.0f, %.0f\n",L_list[l],clover[l]/NN/1000.,dwf4[l]/NN/1000.,staggered[l]/NN/1000.);
|
||||
}
|
||||
fprintf(FP,"\n");
|
||||
|
||||
std::cout<<GridLogMessage << "=================================================================================="<<std::endl;
|
||||
|
||||
std::cout<<GridLogMessage << "=================================================================================="<<std::endl;
|
||||
@ -962,7 +966,8 @@ int main (int argc, char ** argv)
|
||||
std::cout<<GridLogMessage << " Comparison point is 0.5*("<<dwf4[sel]/NN<<"+"<<dwf4[selm1]/NN << ") "<<std::endl;
|
||||
std::cout<<std::setprecision(3);
|
||||
std::cout<<GridLogMessage << "=================================================================================="<<std::endl;
|
||||
|
||||
}
|
||||
|
||||
Grid_finalize();
|
||||
fclose(FP);
|
||||
}
|
||||
|
Loading…
Reference in New Issue
Block a user