1
0
mirror of https://github.com/paboyle/Grid.git synced 2025-04-11 06:30:45 +01:00

more FFT optimisations

This commit is contained in:
Antonin Portelli 2016-10-26 17:36:26 +01:00
parent 33d199a0ad
commit 14ddf2c234
2 changed files with 121 additions and 136 deletions

View File

@ -140,7 +140,9 @@ namespace Grid {
template<class vobj>
void FFT_dim(Lattice<vobj> &result,const Lattice<vobj> &source,int dim, int inverse){
#ifndef HAVE_FFTW
assert(0);
#else
conformable(result._grid,vgrid);
conformable(source._grid,vgrid);
@ -159,17 +161,12 @@ namespace Grid {
typedef typename vobj::scalar_object sobj;
typedef typename sobj::scalar_type scalar;
Lattice<vobj> ssource(vgrid); ssource =source;
Lattice<sobj> pgsource(&pencil_g);
Lattice<sobj> pgresult(&pencil_g); pgresult=zero;
Lattice<sobj> pgbuf(&pencil_g);
#ifndef HAVE_FFTW
assert(0);
#else
typedef typename FFTW<scalar>::FFTW_scalar FFTW_scalar;
typedef typename FFTW<scalar>::FFTW_plan FFTW_plan;
{
int Ncomp = sizeof(sobj)/sizeof(scalar);
int Nlow = 1;
for(int d=0;d<dim;d++){
@ -184,14 +181,13 @@ namespace Grid {
istride = ostride = Ncomp*Nlow; /* distance between two elements in the same FT */
int *inembed = n, *onembed = n;
int sign = FFTW_FORWARD;
if (inverse) sign = FFTW_BACKWARD;
FFTW_plan p;
{
FFTW_scalar *in = (FFTW_scalar *)&pgsource._odata[0];
FFTW_scalar *out= (FFTW_scalar *)&pgresult._odata[0];
FFTW_scalar *in = (FFTW_scalar *)&pgbuf._odata[0];
FFTW_scalar *out= (FFTW_scalar *)&pgbuf._odata[0];
p = FFTW<scalar>::fftw_plan_many_dft(rank,n,howmany,
in,inembed,
istride,idist,
@ -200,72 +196,59 @@ namespace Grid {
sign,FFTW_ESTIMATE);
}
std::vector<int> lcoor(Nd), gcoor(Nd);
// Barrel shift and collect global pencil
std::vector<int> lcoor(Nd), gcoor(Nd);
result = source;
for(int p=0;p<processors[dim];p++) {
for(int idx=0;idx<sgrid->lSites();idx++) {
sgrid->LocalIndexToLocalCoor(idx,lcoor);
sobj s;
peekLocalSite(s,ssource,lcoor);
peekLocalSite(s,result,lcoor);
lcoor[dim]+=p*L;
pokeLocalSite(s,pgsource,lcoor);
pokeLocalSite(s,pgbuf,lcoor);
}
ssource = Cshift(ssource,dim,L);
result = Cshift(result,dim,L);
}
// Loop over orthog coords
int NN=pencil_g.lSites();
GridStopWatch timer;
timer.Start();
//PARALLEL_FOR_LOOP
//PARALLEL_FOR_LOOP
for(int idx=0;idx<NN;idx++) {
pencil_g.LocalIndexToLocalCoor(idx,lcoor);
if ( lcoor[dim] == 0 ) { // restricts loop to plane at lcoor[dim]==0
FFTW_scalar *in = (FFTW_scalar *)&pgsource._odata[idx];
FFTW_scalar *out= (FFTW_scalar *)&pgresult._odata[idx];
FFTW_scalar *in = (FFTW_scalar *)&pgbuf._odata[idx];
FFTW_scalar *out= (FFTW_scalar *)&pgbuf._odata[idx];
FFTW<scalar>::fftw_execute_dft(p,in,out);
}
}
timer.Stop();
// performance counting
double add,mul,fma;
FFTW<scalar>::fftw_flops(p,&add,&mul,&fma);
flops_call = add+mul+2.0*fma;
usec += timer.useconds();
flops+= flops_call*NN;
// writing out result
int pc = processor_coor[dim];
for(int idx=0;idx<sgrid->lSites();idx++) {
sgrid->LocalIndexToLocalCoor(idx,lcoor);
gcoor = lcoor;
// extract the result
sobj s;
gcoor[dim] = lcoor[dim]+L*pc;
peekLocalSite(s,pgresult,gcoor);
peekLocalSite(s,pgbuf,gcoor);
pokeLocalSite(s,result,lcoor);
}
// destroying plan
FFTW<scalar>::fftw_destroy_plan(p);
}
#endif
}
};
}
#endif

View File

@ -76,12 +76,13 @@ int main (int argc, char ** argv)
S=zero;
S = S+C;
Ctilde = C;
FFT theFFT(&Fine);
theFFT.FFT_dim(Ctilde,C,0,FFT::forward); C=Ctilde; std::cout << theFFT.MFlops()<<std::endl;
theFFT.FFT_dim(Ctilde,C,1,FFT::forward); C=Ctilde; std::cout << theFFT.MFlops()<<std::endl;
theFFT.FFT_dim(Ctilde,C,2,FFT::forward); C=Ctilde; std::cout << theFFT.MFlops()<<std::endl;
theFFT.FFT_dim(Ctilde,C,3,FFT::forward); std::cout << theFFT.MFlops()<<std::endl;
theFFT.FFT_dim(Ctilde,Ctilde,0,FFT::forward); std::cout << theFFT.MFlops()<<std::endl;
theFFT.FFT_dim(Ctilde,Ctilde,1,FFT::forward); std::cout << theFFT.MFlops()<<std::endl;
theFFT.FFT_dim(Ctilde,Ctilde,2,FFT::forward); std::cout << theFFT.MFlops()<<std::endl;
theFFT.FFT_dim(Ctilde,Ctilde,3,FFT::forward); std::cout << theFFT.MFlops()<<std::endl;
// C=zero;
// Ctilde = where(abs(Ctilde)<1.0e-10,C,Ctilde);
@ -92,11 +93,12 @@ int main (int argc, char ** argv)
pokeSite(cVol,C,p);
C=C-Ctilde;
std::cout << "diff scalar "<<norm2(C) << std::endl;
Stilde = S;
theFFT.FFT_dim(Stilde,S,0,FFT::forward); S=Stilde; std::cout << theFFT.MFlops()<<std::endl;
theFFT.FFT_dim(Stilde,S,1,FFT::forward); S=Stilde;std::cout << theFFT.MFlops()<<std::endl;
theFFT.FFT_dim(Stilde,S,2,FFT::forward); S=Stilde;std::cout << theFFT.MFlops()<<std::endl;
theFFT.FFT_dim(Stilde,S,3,FFT::forward);std::cout << theFFT.MFlops()<<std::endl;
theFFT.FFT_dim(Stilde,Stilde,0,FFT::forward); std::cout << theFFT.MFlops()<<std::endl;
theFFT.FFT_dim(Stilde,Stilde,1,FFT::forward); std::cout << theFFT.MFlops()<<std::endl;
theFFT.FFT_dim(Stilde,Stilde,2,FFT::forward); std::cout << theFFT.MFlops()<<std::endl;
theFFT.FFT_dim(Stilde,Stilde,3,FFT::forward); std::cout << theFFT.MFlops()<<std::endl;
SpinMatrixD Sp;
Sp = zero; Sp = Sp+cVol;