1
0
mirror of https://github.com/paboyle/Grid.git synced 2025-06-17 15:27:06 +01:00

FFT double and single precision gives good performance now in multithreaded code.

This commit is contained in:
paboyle
2016-08-24 15:05:00 +01:00
parent 88be3b39bb
commit ff6da364e8
6 changed files with 298 additions and 492 deletions

199
lib/FFT.h
View File

@ -1,3 +1,4 @@
/*************************************************************************************
Grid physics library, www.github.com/paboyle/Grid
@ -28,11 +29,70 @@ Author: Peter Boyle <paboyle@ph.ed.ac.uk>
#ifndef _GRID_FFT_H_
#define _GRID_FFT_H_
#include <Grid/fftw/fftw3.h>
#ifdef HAVE_FFTW
#include <fftw3.h>
#endif
namespace Grid {
template<class scalar> struct FFTW { };
#ifdef HAVE_FFTW
template<> struct FFTW<ComplexD> {
public:
typedef fftw_complex FFTW_scalar;
typedef fftw_plan FFTW_plan;
static FFTW_plan fftw_plan_many_dft(int rank, const int *n,int howmany,
FFTW_scalar *in, const int *inembed,
int istride, int idist,
FFTW_scalar *out, const int *onembed,
int ostride, int odist,
int sign, unsigned flags) {
return ::fftw_plan_many_dft(rank,n,howmany,in,inembed,istride,idist,out,onembed,ostride,odist,sign,flags);
}
static void fftw_flops(const FFTW_plan p,double *add, double *mul, double *fmas){
::fftw_flops(p,add,mul,fmas);
}
inline static void fftw_execute_dft(const FFTW_plan p,FFTW_scalar *in,FFTW_scalar *out) {
::fftw_execute_dft(p,in,out);
}
inline static void fftw_destroy_plan(const FFTW_plan p) {
::fftw_destroy_plan(p);
}
};
template<> struct FFTW<ComplexF> {
public:
typedef fftwf_complex FFTW_scalar;
typedef fftwf_plan FFTW_plan;
static FFTW_plan fftw_plan_many_dft(int rank, const int *n,int howmany,
FFTW_scalar *in, const int *inembed,
int istride, int idist,
FFTW_scalar *out, const int *onembed,
int ostride, int odist,
int sign, unsigned flags) {
return ::fftwf_plan_many_dft(rank,n,howmany,in,inembed,istride,idist,out,onembed,ostride,odist,sign,flags);
}
static void fftw_flops(const FFTW_plan p,double *add, double *mul, double *fmas){
::fftwf_flops(p,add,mul,fmas);
}
inline static void fftw_execute_dft(const FFTW_plan p,FFTW_scalar *in,FFTW_scalar *out) {
::fftwf_execute_dft(p,in,out);
}
inline static void fftw_destroy_plan(const FFTW_plan p) {
::fftwf_destroy_plan(p);
}
};
#endif
class FFT {
private:
@ -40,6 +100,10 @@ namespace Grid {
GridCartesian *sgrid;
int Nd;
double flops;
double flops_call;
uint64_t usec;
std::vector<int> dimensions;
std::vector<int> processors;
std::vector<int> processor_coor;
@ -49,6 +113,9 @@ namespace Grid {
static const int forward=FFTW_FORWARD;
static const int backward=FFTW_BACKWARD;
double Flops(void) {return flops;}
double MFlops(void) {return flops/usec;}
FFT ( GridCartesian * grid ) :
vgrid(grid),
Nd(grid->_ndimension),
@ -56,6 +123,8 @@ namespace Grid {
processors(grid->_processors),
processor_coor(grid->_processor_coor)
{
flops=0;
usec =0;
std::vector<int> layout(Nd,1);
sgrid = new GridCartesian(dimensions,layout,processors);
};
@ -75,55 +144,62 @@ namespace Grid {
std::vector<int> layout(Nd,1);
std::vector<int> pencil_gd(vgrid->_fdimensions);
std::vector<int> pencil_ld(processors);
pencil_gd[dim] = G*processors[dim];
pencil_ld[dim] = G*processors[dim];
// Pencil global vol LxLxGxLxL per node
GridCartesian pencil_g(pencil_gd,layout,processors);
GridCartesian pencil_l(pencil_ld,layout,processors);
// Construct pencils
typedef typename vobj::scalar_object sobj;
typedef typename sobj::scalar_type scalar;
Lattice<vobj> ssource(vgrid); ssource =source;
Lattice<sobj> pgsource(&pencil_g);
Lattice<sobj> pgresult(&pencil_g);
Lattice<sobj> plsource(&pencil_l);
Lattice<sobj> plresult(&pencil_l);
Lattice<sobj> pgresult(&pencil_g); pgresult=zero;
#ifndef HAVE_FFTW
assert(0);
#else
typedef typename FFTW<scalar>::FFTW_scalar FFTW_scalar;
typedef typename FFTW<scalar>::FFTW_plan FFTW_plan;
{
assert(sizeof(typename sobj::scalar_type)==sizeof(ComplexD));
assert(sizeof(fftw_complex)==sizeof(ComplexD));
assert(sizeof(fftw_complex)==sizeof(ComplexD));
int Ncomp = sizeof(sobj)/sizeof(scalar);
int Nlow = 1;
for(int d=0;d<dim;d++){
Nlow*=vgrid->_ldimensions[d];
}
int Ncomp = sizeof(sobj)/sizeof(fftw_complex);
int rank = 1; /* not 2: we are computing 1d transforms */
int rank = 1; /* 1d transforms */
int n[] = {G}; /* 1d transforms of length G */
int howmany = Ncomp;
int odist,idist,istride,ostride;
idist = odist = 1;
istride = ostride = Ncomp; /* distance between two elements in the same column */
idist = odist = 1; /* Distance between consecutive FT's */
istride = ostride = Ncomp*Nlow; /* distance between two elements in the same FT */
int *inembed = n, *onembed = n;
fftw_complex *in = (fftw_complex *)&plsource._odata[0];
fftw_complex *out= (fftw_complex *)&plresult._odata[0];
int sign = FFTW_FORWARD;
if (inverse) sign = FFTW_BACKWARD;
#ifdef HAVE_FFTW
fftw_plan p = fftw_plan_many_dft(rank,n,howmany,
in,inembed,
istride,idist,
out,onembed,
ostride, odist,
sign,FFTW_ESTIMATE);
#else
fftw_plan p ;
assert(0);
#endif
FFTW_plan p;
{
FFTW_scalar *in = (FFTW_scalar *)&pgsource._odata[0];
FFTW_scalar *out= (FFTW_scalar *)&pgresult._odata[0];
p = FFTW<scalar>::fftw_plan_many_dft(rank,n,howmany,
in,inembed,
istride,idist,
out,onembed,
ostride, odist,
sign,FFTW_ESTIMATE);
}
double add,mul,fma;
FFTW<scalar>::fftw_flops(p,&add,&mul,&fma);
flops_call = add+mul+2.0*fma;
std::cout << "FFT Flops per call "<<flops_call<<std::endl;
GridStopWatch timer;
// Barrel shift and collect global pencil
for(int p=0;p<processors[dim];p++) {
@ -146,43 +222,46 @@ namespace Grid {
}
// Loop over orthog coords
for(int idx=0;idx<sgrid->lSites();idx++) {
int NN=pencil_g.lSites();
GridStopWatch Timer;
Timer.Start();
PARALLEL_FOR_LOOP
for(int idx=0;idx<NN;idx++) {
std::vector<int> pcoor(Nd,0);
std::vector<int> lcoor(Nd);
sgrid->LocalIndexToLocalCoor(idx,lcoor);
pencil_g.LocalIndexToLocalCoor(idx,lcoor);
if ( lcoor[dim] == 0 ) { // restricts loop to plane at lcoor[dim]==0
// Project to local pencil array
for(int l=0;l<G;l++){
sobj s;
pcoor[dim]=l;
lcoor[dim]=l;
peekLocalSite(s,pgsource,lcoor);
pokeLocalSite(s,plsource,pcoor);
}
// FFT the pencil
#ifdef HAVE_FFTW
fftw_execute(p);
#endif
// Extract the result
for(int l=0;l<L;l++){
sobj s;
int p = processor_coor[dim];
lcoor[dim] = l;
pcoor[dim] = l+L*p;
peekLocalSite(s,plresult,pcoor);
pokeLocalSite(s,result,lcoor);
}
FFTW_scalar *in = (FFTW_scalar *)&pgsource._odata[idx];
FFTW_scalar *out= (FFTW_scalar *)&pgresult._odata[idx];
FFTW<scalar>::fftw_execute_dft(p,in,out);
}
}
fftw_destroy_plan(p);
Timer.Stop();
usec += Timer.useconds();
flops+= flops_call*NN;
int pc = processor_coor[dim];
for(int idx=0;idx<sgrid->lSites();idx++) {
std::vector<int> lcoor(Nd);
sgrid->LocalIndexToLocalCoor(idx,lcoor);
std::vector<int> gcoor = lcoor;
// extract the result
sobj s;
gcoor[dim] = lcoor[dim]+L*pc;
peekLocalSite(s,pgresult,gcoor);
pokeLocalSite(s,result,lcoor);
}
FFTW<scalar>::fftw_destroy_plan(p);
}
#endif
}
};