1
0
mirror of https://github.com/paboyle/Grid.git synced 2024-11-10 07:55:35 +00:00

Multi dim FFT, and normalisation fix

This commit is contained in:
paboyle 2016-08-31 00:24:52 +01:00
parent 3475f45ce7
commit 9005b82c6d

View File

@ -32,6 +32,8 @@ Author: Peter Boyle <paboyle@ph.ed.ac.uk>
#ifdef HAVE_FFTW
#include <fftw3.h>
#endif
namespace Grid {
template<class scalar> struct FFTW { };
@ -115,7 +117,7 @@ namespace Grid {
public:
static const int forward=FFTW_FORWARD;
static const int forward =FFTW_FORWARD;
static const int backward=FFTW_BACKWARD;
double Flops(void) {return flops;}
@ -139,7 +141,29 @@ namespace Grid {
}
template<class vobj>
void FFT_dim(Lattice<vobj> &result,const Lattice<vobj> &source,int dim, int inverse){
void FFT_dim_mask(Lattice<vobj> &result,const Lattice<vobj> &source,std::vector<int> mask,int sign){
conformable(result._grid,vgrid);
conformable(source._grid,vgrid);
Lattice<vobj> tmp(vgrid);
tmp = source;
for(int d=0;d<Nd;d++){
if( mask[d] ) {
FFT_dim(result,tmp,d,sign);
tmp=result;
}
}
}
template<class vobj>
void FFT_all_dim(Lattice<vobj> &result,const Lattice<vobj> &source,int sign){
std::vector<int> mask(Nd,1);
FFT_dim_mask(result,source,mask,sign);
}
template<class vobj>
void FFT_dim(Lattice<vobj> &result,const Lattice<vobj> &source,int dim, int sign){
conformable(result._grid,vgrid);
conformable(source._grid,vgrid);
@ -158,6 +182,12 @@ namespace Grid {
// Construct pencils
typedef typename vobj::scalar_object sobj;
typedef typename sobj::scalar_type scalar;
/*
std::cout << "FFT : vobj "<<demangle(typeid(vobj).name()) <<std::endl;
std::cout << "FFT : sobj "<<demangle(typeid(sobj).name()) <<std::endl;
std::cout << "FFT : scalar "<<demangle(typeid(scalar).name()) <<std::endl;
*/
Lattice<vobj> ssource(vgrid); ssource =source;
Lattice<sobj> pgsource(&pencil_g);
@ -184,9 +214,10 @@ namespace Grid {
istride = ostride = Ncomp*Nlow; /* distance between two elements in the same FT */
int *inembed = n, *onembed = n;
int sign = FFTW_FORWARD;
if (inverse) sign = FFTW_BACKWARD;
scalar div;
if ( sign == backward ) div = 1.0/G;
else if ( sign == forward ) div = 1.0;
else assert(0);
FFTW_plan p;
{
@ -258,6 +289,7 @@ PARALLEL_FOR_LOOP
sobj s;
gcoor[dim] = lcoor[dim]+L*pc;
peekLocalSite(s,pgresult,gcoor);
s = s * div;
pokeLocalSite(s,result,lcoor);
}