1
0
mirror of https://github.com/paboyle/Grid.git synced 2025-04-09 21:50:45 +01:00

Multi dim FFT, and normalisation fix

This commit is contained in:
paboyle 2016-08-31 00:24:52 +01:00
parent 3475f45ce7
commit 9005b82c6d

View File

@ -32,6 +32,8 @@ Author: Peter Boyle <paboyle@ph.ed.ac.uk>
#ifdef HAVE_FFTW #ifdef HAVE_FFTW
#include <fftw3.h> #include <fftw3.h>
#endif #endif
namespace Grid { namespace Grid {
template<class scalar> struct FFTW { }; template<class scalar> struct FFTW { };
@ -115,7 +117,7 @@ namespace Grid {
public: public:
static const int forward=FFTW_FORWARD; static const int forward =FFTW_FORWARD;
static const int backward=FFTW_BACKWARD; static const int backward=FFTW_BACKWARD;
double Flops(void) {return flops;} double Flops(void) {return flops;}
@ -139,7 +141,29 @@ namespace Grid {
} }
template<class vobj> template<class vobj>
void FFT_dim(Lattice<vobj> &result,const Lattice<vobj> &source,int dim, int inverse){ void FFT_dim_mask(Lattice<vobj> &result,const Lattice<vobj> &source,std::vector<int> mask,int sign){
conformable(result._grid,vgrid);
conformable(source._grid,vgrid);
Lattice<vobj> tmp(vgrid);
tmp = source;
for(int d=0;d<Nd;d++){
if( mask[d] ) {
FFT_dim(result,tmp,d,sign);
tmp=result;
}
}
}
template<class vobj>
void FFT_all_dim(Lattice<vobj> &result,const Lattice<vobj> &source,int sign){
std::vector<int> mask(Nd,1);
FFT_dim_mask(result,source,mask,sign);
}
template<class vobj>
void FFT_dim(Lattice<vobj> &result,const Lattice<vobj> &source,int dim, int sign){
conformable(result._grid,vgrid); conformable(result._grid,vgrid);
conformable(source._grid,vgrid); conformable(source._grid,vgrid);
@ -158,6 +182,12 @@ namespace Grid {
// Construct pencils // Construct pencils
typedef typename vobj::scalar_object sobj; typedef typename vobj::scalar_object sobj;
typedef typename sobj::scalar_type scalar; typedef typename sobj::scalar_type scalar;
/*
std::cout << "FFT : vobj "<<demangle(typeid(vobj).name()) <<std::endl;
std::cout << "FFT : sobj "<<demangle(typeid(sobj).name()) <<std::endl;
std::cout << "FFT : scalar "<<demangle(typeid(scalar).name()) <<std::endl;
*/
Lattice<vobj> ssource(vgrid); ssource =source; Lattice<vobj> ssource(vgrid); ssource =source;
Lattice<sobj> pgsource(&pencil_g); Lattice<sobj> pgsource(&pencil_g);
@ -184,9 +214,10 @@ namespace Grid {
istride = ostride = Ncomp*Nlow; /* distance between two elements in the same FT */ istride = ostride = Ncomp*Nlow; /* distance between two elements in the same FT */
int *inembed = n, *onembed = n; int *inembed = n, *onembed = n;
scalar div;
int sign = FFTW_FORWARD; if ( sign == backward ) div = 1.0/G;
if (inverse) sign = FFTW_BACKWARD; else if ( sign == forward ) div = 1.0;
else assert(0);
FFTW_plan p; FFTW_plan p;
{ {
@ -258,6 +289,7 @@ PARALLEL_FOR_LOOP
sobj s; sobj s;
gcoor[dim] = lcoor[dim]+L*pc; gcoor[dim] = lcoor[dim]+L*pc;
peekLocalSite(s,pgresult,gcoor); peekLocalSite(s,pgresult,gcoor);
s = s * div;
pokeLocalSite(s,result,lcoor); pokeLocalSite(s,result,lcoor);
} }