/************************************************************************************* Grid physics library, www.github.com/paboyle/Grid Source file: ./lib/Cshift.h Copyright (C) 2015 Author: Peter Boyle This program is free software; you can redistribute it and/or modify it under the terms of the GNU General Public License as published by the Free Software Foundation; either version 2 of the License, or (at your option) any later version. This program is distributed in the hope that it will be useful, but WITHOUT ANY WARRANTY; without even the implied warranty of MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the GNU General Public License for more details. You should have received a copy of the GNU General Public License along with this program; if not, write to the Free Software Foundation, Inc., 51 Franklin Street, Fifth Floor, Boston, MA 02110-1301 USA. See the full license in the file "LICENSE" in the top level distribution directory *************************************************************************************/ /* END LEGAL */ #ifndef _GRID_FFT_H_ #define _GRID_FFT_H_ #ifdef HAVE_FFTW #ifdef USE_MKL #include #else #include #endif #endif NAMESPACE_BEGIN(Grid); template struct FFTW { }; #ifdef HAVE_FFTW template<> struct FFTW { public: typedef fftw_complex FFTW_scalar; typedef fftw_plan FFTW_plan; static FFTW_plan fftw_plan_many_dft(int rank, const int *n,int howmany, FFTW_scalar *in, const int *inembed, int istride, int idist, FFTW_scalar *out, const int *onembed, int ostride, int odist, int sign, unsigned flags) { return ::fftw_plan_many_dft(rank,n,howmany,in,inembed,istride,idist,out,onembed,ostride,odist,sign,flags); } static void fftw_flops(const FFTW_plan p,double *add, double *mul, double *fmas){ ::fftw_flops(p,add,mul,fmas); } inline static void fftw_execute_dft(const FFTW_plan p,FFTW_scalar *in,FFTW_scalar *out) { ::fftw_execute_dft(p,in,out); } inline static void fftw_destroy_plan(const FFTW_plan p) { ::fftw_destroy_plan(p); } }; template<> struct FFTW { public: typedef fftwf_complex FFTW_scalar; typedef fftwf_plan FFTW_plan; static FFTW_plan fftw_plan_many_dft(int rank, const int *n,int howmany, FFTW_scalar *in, const int *inembed, int istride, int idist, FFTW_scalar *out, const int *onembed, int ostride, int odist, int sign, unsigned flags) { return ::fftwf_plan_many_dft(rank,n,howmany,in,inembed,istride,idist,out,onembed,ostride,odist,sign,flags); } static void fftw_flops(const FFTW_plan p,double *add, double *mul, double *fmas){ ::fftwf_flops(p,add,mul,fmas); } inline static void fftw_execute_dft(const FFTW_plan p,FFTW_scalar *in,FFTW_scalar *out) { ::fftwf_execute_dft(p,in,out); } inline static void fftw_destroy_plan(const FFTW_plan p) { ::fftwf_destroy_plan(p); } }; #endif #ifndef FFTW_FORWARD #define FFTW_FORWARD (-1) #define FFTW_BACKWARD (+1) #endif class FFT { private: GridCartesian *vgrid; GridCartesian *sgrid; int Nd; double flops; double flops_call; uint64_t usec; Coordinate dimensions; Coordinate processors; Coordinate processor_coor; public: static const int forward=FFTW_FORWARD; static const int backward=FFTW_BACKWARD; double Flops(void) {return flops;} double MFlops(void) {return flops/usec;} double USec(void) {return (double)usec;} FFT ( GridCartesian * grid ) : vgrid(grid), Nd(grid->_ndimension), dimensions(grid->_fdimensions), processors(grid->_processors), processor_coor(grid->_processor_coor) { flops=0; usec =0; Coordinate layout(Nd,1); sgrid = new GridCartesian(dimensions,layout,processors); }; ~FFT ( void) { delete sgrid; } template void FFT_dim_mask(Lattice &result,const Lattice &source,Coordinate mask,int sign){ conformable(result.Grid(),vgrid); conformable(source.Grid(),vgrid); Lattice tmp(vgrid); tmp = source; for(int d=0;d void FFT_all_dim(Lattice &result,const Lattice &source,int sign){ Coordinate mask(Nd,1); FFT_dim_mask(result,source,mask,sign); } template void FFT_dim(Lattice &result,const Lattice &source,int dim, int sign){ #ifndef HAVE_FFTW assert(0); #else conformable(result.Grid(),vgrid); conformable(source.Grid(),vgrid); int L = vgrid->_ldimensions[dim]; int G = vgrid->_fdimensions[dim]; Coordinate layout(Nd,1); Coordinate pencil_gd(vgrid->_fdimensions); pencil_gd[dim] = G*processors[dim]; // Pencil global vol LxLxGxLxL per node GridCartesian pencil_g(pencil_gd,layout,processors); // Construct pencils typedef typename vobj::scalar_object sobj; typedef typename sobj::scalar_type scalar; Lattice pgbuf(&pencil_g); auto pgbuf_v = pgbuf.View(); typedef typename FFTW::FFTW_scalar FFTW_scalar; typedef typename FFTW::FFTW_plan FFTW_plan; int Ncomp = sizeof(sobj)/sizeof(scalar); int Nlow = 1; for(int d=0;d_ldimensions[d]; } int rank = 1; /* 1d transforms */ int n[] = {G}; /* 1d transforms of length G */ int howmany = Ncomp; int odist,idist,istride,ostride; idist = odist = 1; /* Distance between consecutive FT's */ istride = ostride = Ncomp*Nlow; /* distance between two elements in the same FT */ int *inembed = n, *onembed = n; scalar div; if ( sign == backward ) div = 1.0/G; else if ( sign == forward ) div = 1.0; else assert(0); FFTW_plan p; { FFTW_scalar *in = (FFTW_scalar *)&pgbuf_v[0]; FFTW_scalar *out= (FFTW_scalar *)&pgbuf_v[0]; p = FFTW::fftw_plan_many_dft(rank,n,howmany, in,inembed, istride,idist, out,onembed, ostride, odist, sign,FFTW_ESTIMATE); } // Barrel shift and collect global pencil Coordinate lcoor(Nd), gcoor(Nd); result = source; int pc = processor_coor[dim]; for(int p=0;plSites();idx++), { Coordinate cbuf(Nd); sobj s; sgrid->LocalIndexToLocalCoor(idx,cbuf); peekLocalSite(s,result,cbuf); cbuf[dim]+=((pc+p) % processors[dim])*L; // cbuf[dim]+=p*L; pokeLocalSite(s,pgbuf,cbuf); }); if (p != processors[dim] - 1) { result = Cshift(result,dim,L); } } // Loop over orthog coords int NN=pencil_g.lSites(); GridStopWatch timer; timer.Start(); thread_loop( (int idx=0;idx::fftw_execute_dft(p,in,out); } }); timer.Stop(); // performance counting double add,mul,fma; FFTW::fftw_flops(p,&add,&mul,&fma); flops_call = add+mul+2.0*fma; usec += timer.useconds(); flops+= flops_call*NN; // writing out result thread_loop( (int idx=0;idxlSites();idx++), { Coordinate clbuf(Nd), cgbuf(Nd); sobj s; sgrid->LocalIndexToLocalCoor(idx,clbuf); cgbuf = clbuf; cgbuf[dim] = clbuf[dim]+L*pc; peekLocalSite(s,pgbuf,cgbuf); pokeLocalSite(s,result,clbuf); }); result = result*div; // destroying plan FFTW::fftw_destroy_plan(p); #endif } }; NAMESPACE_END(Grid); #endif