mirror of
https://github.com/paboyle/Grid.git
synced 2024-11-14 01:35:36 +00:00
BLAS everywhere
This commit is contained in:
parent
baac1127d0
commit
f8f408e7a9
@ -261,23 +261,25 @@ public:
|
|||||||
fprintf(FP,"\n\n");
|
fprintf(FP,"\n\n");
|
||||||
};
|
};
|
||||||
|
|
||||||
|
template<class CComplex>
|
||||||
static void BLAS(void)
|
static void BLAS(void)
|
||||||
{
|
{
|
||||||
//int nbasis, int nrhs, int coarseVol
|
//int nbasis, int nrhs, int coarseVol
|
||||||
int basis[] = { 16,32,64 };
|
int basis[] = { 16,32,64 };
|
||||||
int rhs[] = { 8,16,32 };
|
int rhs[] = { 8,12,16 };
|
||||||
int vol = 4*4*4*4;
|
int vol = 8*8*8*8;
|
||||||
|
int blk = 4*4*4*4;
|
||||||
|
|
||||||
GridBLAS blas;
|
GridBLAS blas;
|
||||||
|
|
||||||
|
int fpbits = sizeof(CComplex)*4;
|
||||||
std::cout<<GridLogMessage << "=================================================================================="<<std::endl;
|
std::cout<<GridLogMessage << "=================================================================================="<<std::endl;
|
||||||
std::cout<<GridLogMessage << "= batched GEMM (double precision) "<<std::endl;
|
std::cout<<GridLogMessage << "= batched GEMM fp"<<fpbits<<std::endl;
|
||||||
std::cout<<GridLogMessage << "=================================================================================="<<std::endl;
|
std::cout<<GridLogMessage << "=================================================================================="<<std::endl;
|
||||||
std::cout<<GridLogMessage << " M "<<"\t\t"<<"N"<<"\t\t\t"<<"K"<<"\t\t"<<"Gflop/s / rank (coarse mrhs)"<<std::endl;
|
std::cout<<GridLogMessage << " M "<<"\t\t"<<"N"<<"\t\t\t"<<"K"<<"\t\t"<<"Gflop/s / rank (coarse mrhs)"<<std::endl;
|
||||||
std::cout<<GridLogMessage << "----------------------------------------------------------"<<std::endl;
|
std::cout<<GridLogMessage << "----------------------------------------------------------"<<std::endl;
|
||||||
|
|
||||||
fprintf(FP,"GEMM\n\n M, N, K, BATCH, GF/s per rank\n");
|
fprintf(FP,"GEMM\n\n M, N, K, BATCH, GF/s per rank fp%d\n",fpbits);
|
||||||
|
|
||||||
for(int b=0;b<3;b++){
|
for(int b=0;b<3;b++){
|
||||||
for(int r=0;r<3;r++){
|
for(int r=0;r<3;r++){
|
||||||
@ -285,7 +287,7 @@ public:
|
|||||||
int N=rhs[r];
|
int N=rhs[r];
|
||||||
int K=basis[b];
|
int K=basis[b];
|
||||||
int BATCH=vol;
|
int BATCH=vol;
|
||||||
double p=blas.benchmark(M,N,K,BATCH);
|
double p=blas.benchmark<CComplex>(M,N,K,BATCH);
|
||||||
|
|
||||||
fprintf(FP,"%d, %d, %d, %d, %f\n", M, N, K, BATCH, p);
|
fprintf(FP,"%d, %d, %d, %d, %f\n", M, N, K, BATCH, p);
|
||||||
|
|
||||||
@ -299,9 +301,9 @@ public:
|
|||||||
for(int r=0;r<3;r++){
|
for(int r=0;r<3;r++){
|
||||||
int M=basis[b];
|
int M=basis[b];
|
||||||
int N=rhs[r];
|
int N=rhs[r];
|
||||||
int K=vol;
|
int K=blk;
|
||||||
int BATCH=vol;
|
int BATCH=vol;
|
||||||
double p=blas.benchmark(M,N,K,BATCH);
|
double p=blas.benchmark<CComplex>(M,N,K,BATCH);
|
||||||
|
|
||||||
fprintf(FP,"%d, %d, %d, %d, %f\n", M, N, K, BATCH, p);
|
fprintf(FP,"%d, %d, %d, %d, %f\n", M, N, K, BATCH, p);
|
||||||
std::cout<<GridLogMessage<<std::setprecision(3)
|
std::cout<<GridLogMessage<<std::setprecision(3)
|
||||||
@ -313,10 +315,10 @@ public:
|
|||||||
for(int b=0;b<3;b++){
|
for(int b=0;b<3;b++){
|
||||||
for(int r=0;r<3;r++){
|
for(int r=0;r<3;r++){
|
||||||
int M=rhs[r];
|
int M=rhs[r];
|
||||||
int N=vol;
|
int N=blk;
|
||||||
int K=basis[b];
|
int K=basis[b];
|
||||||
int BATCH=vol;
|
int BATCH=vol;
|
||||||
double p=blas.benchmark(M,N,K,BATCH);
|
double p=blas.benchmark<CComplex>(M,N,K,BATCH);
|
||||||
|
|
||||||
fprintf(FP,"%d, %d, %d, %d, %f\n", M, N, K, BATCH, p);
|
fprintf(FP,"%d, %d, %d, %d, %f\n", M, N, K, BATCH, p);
|
||||||
std::cout<<GridLogMessage<<std::setprecision(3)
|
std::cout<<GridLogMessage<<std::setprecision(3)
|
||||||
@ -867,6 +869,7 @@ int main (int argc, char ** argv)
|
|||||||
int do_memory=1;
|
int do_memory=1;
|
||||||
int do_comms =1;
|
int do_comms =1;
|
||||||
int do_blas =1;
|
int do_blas =1;
|
||||||
|
int do_dslash=1;
|
||||||
|
|
||||||
int sel=4;
|
int sel=4;
|
||||||
std::vector<int> L_list({8,12,16,24,32});
|
std::vector<int> L_list({8,12,16,24,32});
|
||||||
@ -877,6 +880,7 @@ int main (int argc, char ** argv)
|
|||||||
std::vector<double> staggered;
|
std::vector<double> staggered;
|
||||||
|
|
||||||
int Ls=1;
|
int Ls=1;
|
||||||
|
if (do_dslash){
|
||||||
std::cout<<GridLogMessage << "=================================================================================="<<std::endl;
|
std::cout<<GridLogMessage << "=================================================================================="<<std::endl;
|
||||||
std::cout<<GridLogMessage << " Clover dslash 4D vectorised (temporarily Wilson)" <<std::endl;
|
std::cout<<GridLogMessage << " Clover dslash 4D vectorised (temporarily Wilson)" <<std::endl;
|
||||||
std::cout<<GridLogMessage << "=================================================================================="<<std::endl;
|
std::cout<<GridLogMessage << "=================================================================================="<<std::endl;
|
||||||
@ -909,6 +913,7 @@ int main (int argc, char ** argv)
|
|||||||
std::cout<<GridLogMessage << L_list[l] <<" \t\t "<< clover[l]<<" \t\t "<<dwf4[l] << " \t\t "<< staggered[l]<<std::endl;
|
std::cout<<GridLogMessage << L_list[l] <<" \t\t "<< clover[l]<<" \t\t "<<dwf4[l] << " \t\t "<< staggered[l]<<std::endl;
|
||||||
}
|
}
|
||||||
std::cout<<GridLogMessage << "=================================================================================="<<std::endl;
|
std::cout<<GridLogMessage << "=================================================================================="<<std::endl;
|
||||||
|
}
|
||||||
|
|
||||||
int NN=NN_global;
|
int NN=NN_global;
|
||||||
if ( do_memory ) {
|
if ( do_memory ) {
|
||||||
@ -919,12 +924,11 @@ int main (int argc, char ** argv)
|
|||||||
}
|
}
|
||||||
|
|
||||||
if ( do_blas ) {
|
if ( do_blas ) {
|
||||||
#if defined(GRID_CUDA) || defined(GRID_HIP) || defined(GRID_SYCL)
|
|
||||||
std::cout<<GridLogMessage << "=================================================================================="<<std::endl;
|
std::cout<<GridLogMessage << "=================================================================================="<<std::endl;
|
||||||
std::cout<<GridLogMessage << " Batched BLAS benchmark " <<std::endl;
|
std::cout<<GridLogMessage << " Batched BLAS benchmark " <<std::endl;
|
||||||
std::cout<<GridLogMessage << "=================================================================================="<<std::endl;
|
std::cout<<GridLogMessage << "=================================================================================="<<std::endl;
|
||||||
Benchmark::BLAS();
|
Benchmark::BLAS<ComplexD>();
|
||||||
#endif
|
Benchmark::BLAS<ComplexF>();
|
||||||
}
|
}
|
||||||
|
|
||||||
if ( do_su4 ) {
|
if ( do_su4 ) {
|
||||||
@ -941,6 +945,7 @@ int main (int argc, char ** argv)
|
|||||||
Benchmark::Comms();
|
Benchmark::Comms();
|
||||||
}
|
}
|
||||||
|
|
||||||
|
if(do_dslash){
|
||||||
std::cout<<GridLogMessage << "=================================================================================="<<std::endl;
|
std::cout<<GridLogMessage << "=================================================================================="<<std::endl;
|
||||||
std::cout<<GridLogMessage << " Per Node Summary table Ls="<<Ls <<std::endl;
|
std::cout<<GridLogMessage << " Per Node Summary table Ls="<<Ls <<std::endl;
|
||||||
std::cout<<GridLogMessage << "=================================================================================="<<std::endl;
|
std::cout<<GridLogMessage << "=================================================================================="<<std::endl;
|
||||||
@ -954,7 +959,6 @@ int main (int argc, char ** argv)
|
|||||||
fprintf(FP,"%d , %.0f, %.0f, %.0f\n",L_list[l],clover[l]/NN/1000.,dwf4[l]/NN/1000.,staggered[l]/NN/1000.);
|
fprintf(FP,"%d , %.0f, %.0f, %.0f\n",L_list[l],clover[l]/NN/1000.,dwf4[l]/NN/1000.,staggered[l]/NN/1000.);
|
||||||
}
|
}
|
||||||
fprintf(FP,"\n");
|
fprintf(FP,"\n");
|
||||||
|
|
||||||
std::cout<<GridLogMessage << "=================================================================================="<<std::endl;
|
std::cout<<GridLogMessage << "=================================================================================="<<std::endl;
|
||||||
|
|
||||||
std::cout<<GridLogMessage << "=================================================================================="<<std::endl;
|
std::cout<<GridLogMessage << "=================================================================================="<<std::endl;
|
||||||
@ -962,6 +966,7 @@ int main (int argc, char ** argv)
|
|||||||
std::cout<<GridLogMessage << " Comparison point is 0.5*("<<dwf4[sel]/NN<<"+"<<dwf4[selm1]/NN << ") "<<std::endl;
|
std::cout<<GridLogMessage << " Comparison point is 0.5*("<<dwf4[sel]/NN<<"+"<<dwf4[selm1]/NN << ") "<<std::endl;
|
||||||
std::cout<<std::setprecision(3);
|
std::cout<<std::setprecision(3);
|
||||||
std::cout<<GridLogMessage << "=================================================================================="<<std::endl;
|
std::cout<<GridLogMessage << "=================================================================================="<<std::endl;
|
||||||
|
}
|
||||||
|
|
||||||
Grid_finalize();
|
Grid_finalize();
|
||||||
fclose(FP);
|
fclose(FP);
|
||||||
|
Loading…
Reference in New Issue
Block a user