mirror of
https://github.com/paboyle/Grid.git
synced 2024-11-10 07:55:35 +00:00
Multi dim FFT, and normalisation fix
This commit is contained in:
parent
3475f45ce7
commit
9005b82c6d
42
lib/FFT.h
42
lib/FFT.h
@ -32,6 +32,8 @@ Author: Peter Boyle <paboyle@ph.ed.ac.uk>
|
||||
#ifdef HAVE_FFTW
|
||||
#include <fftw3.h>
|
||||
#endif
|
||||
|
||||
|
||||
namespace Grid {
|
||||
|
||||
template<class scalar> struct FFTW { };
|
||||
@ -115,7 +117,7 @@ namespace Grid {
|
||||
|
||||
public:
|
||||
|
||||
static const int forward=FFTW_FORWARD;
|
||||
static const int forward =FFTW_FORWARD;
|
||||
static const int backward=FFTW_BACKWARD;
|
||||
|
||||
double Flops(void) {return flops;}
|
||||
@ -139,7 +141,29 @@ namespace Grid {
|
||||
}
|
||||
|
||||
template<class vobj>
|
||||
void FFT_dim(Lattice<vobj> &result,const Lattice<vobj> &source,int dim, int inverse){
|
||||
void FFT_dim_mask(Lattice<vobj> &result,const Lattice<vobj> &source,std::vector<int> mask,int sign){
|
||||
|
||||
conformable(result._grid,vgrid);
|
||||
conformable(source._grid,vgrid);
|
||||
Lattice<vobj> tmp(vgrid);
|
||||
tmp = source;
|
||||
for(int d=0;d<Nd;d++){
|
||||
if( mask[d] ) {
|
||||
FFT_dim(result,tmp,d,sign);
|
||||
tmp=result;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
template<class vobj>
|
||||
void FFT_all_dim(Lattice<vobj> &result,const Lattice<vobj> &source,int sign){
|
||||
std::vector<int> mask(Nd,1);
|
||||
FFT_dim_mask(result,source,mask,sign);
|
||||
}
|
||||
|
||||
|
||||
template<class vobj>
|
||||
void FFT_dim(Lattice<vobj> &result,const Lattice<vobj> &source,int dim, int sign){
|
||||
|
||||
conformable(result._grid,vgrid);
|
||||
conformable(source._grid,vgrid);
|
||||
@ -158,6 +182,12 @@ namespace Grid {
|
||||
// Construct pencils
|
||||
typedef typename vobj::scalar_object sobj;
|
||||
typedef typename sobj::scalar_type scalar;
|
||||
|
||||
/*
|
||||
std::cout << "FFT : vobj "<<demangle(typeid(vobj).name()) <<std::endl;
|
||||
std::cout << "FFT : sobj "<<demangle(typeid(sobj).name()) <<std::endl;
|
||||
std::cout << "FFT : scalar "<<demangle(typeid(scalar).name()) <<std::endl;
|
||||
*/
|
||||
|
||||
Lattice<vobj> ssource(vgrid); ssource =source;
|
||||
Lattice<sobj> pgsource(&pencil_g);
|
||||
@ -184,9 +214,10 @@ namespace Grid {
|
||||
istride = ostride = Ncomp*Nlow; /* distance between two elements in the same FT */
|
||||
int *inembed = n, *onembed = n;
|
||||
|
||||
|
||||
int sign = FFTW_FORWARD;
|
||||
if (inverse) sign = FFTW_BACKWARD;
|
||||
scalar div;
|
||||
if ( sign == backward ) div = 1.0/G;
|
||||
else if ( sign == forward ) div = 1.0;
|
||||
else assert(0);
|
||||
|
||||
FFTW_plan p;
|
||||
{
|
||||
@ -258,6 +289,7 @@ PARALLEL_FOR_LOOP
|
||||
sobj s;
|
||||
gcoor[dim] = lcoor[dim]+L*pc;
|
||||
peekLocalSite(s,pgresult,gcoor);
|
||||
s = s * div;
|
||||
pokeLocalSite(s,result,lcoor);
|
||||
}
|
||||
|
||||
|
Loading…
Reference in New Issue
Block a user