1
0
mirror of https://github.com/paboyle/Grid.git synced 2025-08-25 15:37:10 +01:00

FFT offload to GPU and MUCH faster comms.

40x speed up on Frontier
This commit is contained in:
Peter Boyle
2025-08-21 16:44:55 -04:00
parent 76c0ada1e1
commit fe0db53842
8 changed files with 443 additions and 176 deletions

View File

@@ -29,7 +29,6 @@ Author: Peter Boyle <paboyle@ph.ed.ac.uk>
#include <Grid/Grid.h>
using namespace Grid;
;
int main (int argc, char ** argv)
{
@@ -116,10 +115,10 @@ int main (int argc, char ** argv)
Stilde=S;
std::cout<<" Benchmarking FFT of LatticeSpinMatrix "<<std::endl;
theFFT.FFT_dim(Stilde,S,0,FFT::forward); std::cout << theFFT.MFlops()<<" mflops "<<std::endl;
theFFT.FFT_dim(Stilde,S,1,FFT::forward); std::cout << theFFT.MFlops()<<" mflops "<<std::endl;
theFFT.FFT_dim(Stilde,S,2,FFT::forward); std::cout << theFFT.MFlops()<<" mflops "<<std::endl;
theFFT.FFT_dim(Stilde,S,3,FFT::forward); std::cout << theFFT.MFlops()<<" mflops "<<std::endl;
theFFT.FFT_dim(Stilde,Stilde,0,FFT::forward); std::cout << theFFT.MFlops()<<" mflops "<<std::endl;
theFFT.FFT_dim(Stilde,Stilde,1,FFT::forward); std::cout << theFFT.MFlops()<<" mflops "<<std::endl;
theFFT.FFT_dim(Stilde,Stilde,2,FFT::forward); std::cout << theFFT.MFlops()<<" mflops "<<std::endl;
theFFT.FFT_dim(Stilde,Stilde,3,FFT::forward); std::cout << theFFT.MFlops()<<" mflops "<<std::endl;
SpinMatrixD Sp;
Sp = Zero(); Sp = Sp+cVol;
@@ -202,11 +201,16 @@ int main (int argc, char ** argv)
FFT theFFT5(FGrid);
theFFT5.FFT_dim(result5,tmp5,1,FFT::forward); tmp5 = result5;
std::cout<<"Fourier xformed Ddwf 1 "<<norm2(result5)<<std::endl;
theFFT5.FFT_dim(result5,tmp5,2,FFT::forward); tmp5 = result5;
std::cout<<"Fourier xformed Ddwf 2 "<<norm2(result5)<<std::endl;
theFFT5.FFT_dim(result5,tmp5,3,FFT::forward); tmp5 = result5;
theFFT5.FFT_dim(result5,tmp5,4,FFT::forward); result5 = result5*ComplexD(::sqrt(1.0/vol),0.0);
std::cout<<"Fourier xformed Ddwf 3 "<<norm2(result5)<<std::endl;
theFFT5.FFT_dim(result5,tmp5,4,FFT::forward);
std::cout<<"Fourier xformed Ddwf 4 "<<norm2(result5)<<std::endl;
result5 = result5*ComplexD(::sqrt(1.0/vol),0.0);
std::cout<<"Fourier xformed Ddwf"<<std::endl;
std::cout<<"Fourier xformed Ddwf "<<norm2(result5)<<std::endl;
tmp5 = src5;
theFFT5.FFT_dim(src5_p,tmp5,1,FFT::forward); tmp5 = src5_p;
@@ -214,7 +218,7 @@ int main (int argc, char ** argv)
theFFT5.FFT_dim(src5_p,tmp5,3,FFT::forward); tmp5 = src5_p;
theFFT5.FFT_dim(src5_p,tmp5,4,FFT::forward); src5_p = src5_p*ComplexD(::sqrt(1.0/vol),0.0);
std::cout<<"Fourier xformed src5"<<std::endl;
std::cout<<"Fourier xformed src5"<< norm2(src5)<<" -> "<<norm2(src5_p)<<std::endl;
/////////////////////////////////////////////////////////////////
// work out the predicted from Fourier
@@ -251,7 +255,8 @@ int main (int argc, char ** argv)
Kinetic = Kinetic + sin(kmu)*ci*(Gamma(Gmu[mu])*src5_p);
}
std::cout << " src5 "<<norm2(src5_p)<<std::endl;
std::cout << " Kinetic "<<norm2(Kinetic)<<std::endl;
// NB implicit sum over mu
//
// 1-1/2 Dw = 1 - 1/2 ( eip+emip)
@@ -260,18 +265,23 @@ int main (int argc, char ** argv)
// = 2 sink/2 ink/2 = sk2
W = one - M5 + sk2;
std::cout << " W "<<norm2(W)<<std::endl;
Kinetic = Kinetic + W * src5_p;
std::cout << " Kinetic "<<norm2(Kinetic)<<std::endl;
LatticeCoordinate(scoor,sdir);
tmp5 = Cshift(src5_p,sdir,+1);
tmp5 = (tmp5 - G5*tmp5)*0.5;
tmp5 = where(scoor==Integer(Ls-1),mass*tmp5,-tmp5);
std::cout << " tmp5 "<<norm2(tmp5)<<std::endl;
Kinetic = Kinetic + tmp5;
tmp5 = Cshift(src5_p,sdir,-1);
tmp5 = (tmp5 + G5*tmp5)*0.5;
tmp5 = where(scoor==Integer(0),mass*tmp5,-tmp5);
std::cout << " tmp5 "<<norm2(tmp5)<<std::endl;
Kinetic = Kinetic + tmp5;
std::cout<<"Momentum space Ddwf "<< norm2(Kinetic)<<std::endl;
@@ -339,7 +349,7 @@ int main (int argc, char ** argv)
Ddwf.Mdag(src5,tmp5);
src5=tmp5;
MdagMLinearOperator<DomainWallFermionD,LatticeFermionD> HermOp(Ddwf);
ConjugateGradient<LatticeFermionD> CG(1.0e-16,10000);
ConjugateGradient<LatticeFermionD> CG(1.0e-8,10000);
CG(HermOp,src5,result5);
////////////////////////////////////////////////////////////////////////
@@ -423,7 +433,7 @@ int main (int argc, char ** argv)
Dov.Mdag(src5,tmp5);
src5=tmp5;
MdagMLinearOperator<OverlapWilsonCayleyTanhFermionD,LatticeFermionD> HermOp(Dov);
ConjugateGradient<LatticeFermionD> CG(1.0e-16,10000);
ConjugateGradient<LatticeFermionD> CG(1.0e-8,10000);
CG(HermOp,src5,result5);
////////////////////////////////////////////////////////////////////////