mirror of
https://github.com/paboyle/Grid.git
synced 2024-11-16 02:35:36 +00:00
307 lines
8.7 KiB
C++
307 lines
8.7 KiB
C++
|
|
/*************************************************************************************
|
|
|
|
Grid physics library, www.github.com/paboyle/Grid
|
|
|
|
Source file: ./lib/Cshift.h
|
|
|
|
Copyright (C) 2015
|
|
|
|
Author: Peter Boyle <paboyle@ph.ed.ac.uk>
|
|
|
|
This program is free software; you can redistribute it and/or modify
|
|
it under the terms of the GNU General Public License as published by
|
|
the Free Software Foundation; either version 2 of the License, or
|
|
(at your option) any later version.
|
|
|
|
This program is distributed in the hope that it will be useful,
|
|
but WITHOUT ANY WARRANTY; without even the implied warranty of
|
|
MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
|
|
GNU General Public License for more details.
|
|
|
|
You should have received a copy of the GNU General Public License along
|
|
with this program; if not, write to the Free Software Foundation, Inc.,
|
|
51 Franklin Street, Fifth Floor, Boston, MA 02110-1301 USA.
|
|
|
|
See the full license in the file "LICENSE" in the top level distribution directory
|
|
*************************************************************************************/
|
|
/* END LEGAL */
|
|
#ifndef _GRID_FFT_H_
|
|
#define _GRID_FFT_H_
|
|
|
|
#ifdef HAVE_FFTW
|
|
#ifdef USE_MKL
|
|
#include <fftw/fftw3.h>
|
|
#else
|
|
#include <fftw3.h>
|
|
#endif
|
|
#endif
|
|
|
|
|
|
namespace Grid {
|
|
|
|
template<class scalar> struct FFTW { };
|
|
|
|
#ifdef HAVE_FFTW
|
|
template<> struct FFTW<ComplexD> {
|
|
public:
|
|
|
|
typedef fftw_complex FFTW_scalar;
|
|
typedef fftw_plan FFTW_plan;
|
|
|
|
static FFTW_plan fftw_plan_many_dft(int rank, const int *n,int howmany,
|
|
FFTW_scalar *in, const int *inembed,
|
|
int istride, int idist,
|
|
FFTW_scalar *out, const int *onembed,
|
|
int ostride, int odist,
|
|
int sign, unsigned flags) {
|
|
return ::fftw_plan_many_dft(rank,n,howmany,in,inembed,istride,idist,out,onembed,ostride,odist,sign,flags);
|
|
}
|
|
|
|
static void fftw_flops(const FFTW_plan p,double *add, double *mul, double *fmas){
|
|
::fftw_flops(p,add,mul,fmas);
|
|
}
|
|
|
|
inline static void fftw_execute_dft(const FFTW_plan p,FFTW_scalar *in,FFTW_scalar *out) {
|
|
::fftw_execute_dft(p,in,out);
|
|
}
|
|
inline static void fftw_destroy_plan(const FFTW_plan p) {
|
|
::fftw_destroy_plan(p);
|
|
}
|
|
};
|
|
|
|
template<> struct FFTW<ComplexF> {
|
|
public:
|
|
|
|
typedef fftwf_complex FFTW_scalar;
|
|
typedef fftwf_plan FFTW_plan;
|
|
|
|
static FFTW_plan fftw_plan_many_dft(int rank, const int *n,int howmany,
|
|
FFTW_scalar *in, const int *inembed,
|
|
int istride, int idist,
|
|
FFTW_scalar *out, const int *onembed,
|
|
int ostride, int odist,
|
|
int sign, unsigned flags) {
|
|
return ::fftwf_plan_many_dft(rank,n,howmany,in,inembed,istride,idist,out,onembed,ostride,odist,sign,flags);
|
|
}
|
|
|
|
static void fftw_flops(const FFTW_plan p,double *add, double *mul, double *fmas){
|
|
::fftwf_flops(p,add,mul,fmas);
|
|
}
|
|
|
|
inline static void fftw_execute_dft(const FFTW_plan p,FFTW_scalar *in,FFTW_scalar *out) {
|
|
::fftwf_execute_dft(p,in,out);
|
|
}
|
|
inline static void fftw_destroy_plan(const FFTW_plan p) {
|
|
::fftwf_destroy_plan(p);
|
|
}
|
|
};
|
|
|
|
#endif
|
|
|
|
#ifndef FFTW_FORWARD
|
|
#define FFTW_FORWARD (-1)
|
|
#define FFTW_BACKWARD (+1)
|
|
#endif
|
|
|
|
class FFT {
|
|
private:
|
|
|
|
GridCartesian *vgrid;
|
|
GridCartesian *sgrid;
|
|
|
|
int Nd;
|
|
double flops;
|
|
double flops_call;
|
|
uint64_t usec;
|
|
|
|
std::vector<int> dimensions;
|
|
std::vector<int> processors;
|
|
std::vector<int> processor_coor;
|
|
|
|
public:
|
|
|
|
static const int forward=FFTW_FORWARD;
|
|
static const int backward=FFTW_BACKWARD;
|
|
|
|
double Flops(void) {return flops;}
|
|
double MFlops(void) {return flops/usec;}
|
|
double USec(void) {return (double)usec;}
|
|
|
|
FFT ( GridCartesian * grid ) :
|
|
vgrid(grid),
|
|
Nd(grid->_ndimension),
|
|
dimensions(grid->_fdimensions),
|
|
processors(grid->_processors),
|
|
processor_coor(grid->_processor_coor)
|
|
{
|
|
flops=0;
|
|
usec =0;
|
|
std::vector<int> layout(Nd,1);
|
|
sgrid = new GridCartesian(dimensions,layout,processors);
|
|
};
|
|
|
|
~FFT ( void) {
|
|
delete sgrid;
|
|
}
|
|
|
|
template<class vobj>
|
|
void FFT_dim_mask(Lattice<vobj> &result,const Lattice<vobj> &source,std::vector<int> mask,int sign){
|
|
|
|
conformable(result._grid,vgrid);
|
|
conformable(source._grid,vgrid);
|
|
Lattice<vobj> tmp(vgrid);
|
|
tmp = source;
|
|
for(int d=0;d<Nd;d++){
|
|
if( mask[d] ) {
|
|
FFT_dim(result,tmp,d,sign);
|
|
tmp=result;
|
|
}
|
|
}
|
|
}
|
|
|
|
template<class vobj>
|
|
void FFT_all_dim(Lattice<vobj> &result,const Lattice<vobj> &source,int sign){
|
|
std::vector<int> mask(Nd,1);
|
|
FFT_dim_mask(result,source,mask,sign);
|
|
}
|
|
|
|
|
|
template<class vobj>
|
|
void FFT_dim(Lattice<vobj> &result,const Lattice<vobj> &source,int dim, int sign){
|
|
#ifndef HAVE_FFTW
|
|
assert(0);
|
|
#else
|
|
conformable(result._grid,vgrid);
|
|
conformable(source._grid,vgrid);
|
|
|
|
int L = vgrid->_ldimensions[dim];
|
|
int G = vgrid->_fdimensions[dim];
|
|
|
|
std::vector<int> layout(Nd,1);
|
|
std::vector<int> pencil_gd(vgrid->_fdimensions);
|
|
|
|
pencil_gd[dim] = G*processors[dim];
|
|
|
|
// Pencil global vol LxLxGxLxL per node
|
|
GridCartesian pencil_g(pencil_gd,layout,processors);
|
|
|
|
// Construct pencils
|
|
typedef typename vobj::scalar_object sobj;
|
|
typedef typename sobj::scalar_type scalar;
|
|
|
|
Lattice<sobj> pgbuf(&pencil_g);
|
|
|
|
|
|
typedef typename FFTW<scalar>::FFTW_scalar FFTW_scalar;
|
|
typedef typename FFTW<scalar>::FFTW_plan FFTW_plan;
|
|
|
|
int Ncomp = sizeof(sobj)/sizeof(scalar);
|
|
int Nlow = 1;
|
|
for(int d=0;d<dim;d++){
|
|
Nlow*=vgrid->_ldimensions[d];
|
|
}
|
|
|
|
int rank = 1; /* 1d transforms */
|
|
int n[] = {G}; /* 1d transforms of length G */
|
|
int howmany = Ncomp;
|
|
int odist,idist,istride,ostride;
|
|
idist = odist = 1; /* Distance between consecutive FT's */
|
|
istride = ostride = Ncomp*Nlow; /* distance between two elements in the same FT */
|
|
int *inembed = n, *onembed = n;
|
|
|
|
scalar div;
|
|
if ( sign == backward ) div = 1.0/G;
|
|
else if ( sign == forward ) div = 1.0;
|
|
else assert(0);
|
|
|
|
FFTW_plan p;
|
|
{
|
|
FFTW_scalar *in = (FFTW_scalar *)&pgbuf._odata[0];
|
|
FFTW_scalar *out= (FFTW_scalar *)&pgbuf._odata[0];
|
|
p = FFTW<scalar>::fftw_plan_many_dft(rank,n,howmany,
|
|
in,inembed,
|
|
istride,idist,
|
|
out,onembed,
|
|
ostride, odist,
|
|
sign,FFTW_ESTIMATE);
|
|
}
|
|
|
|
// Barrel shift and collect global pencil
|
|
std::vector<int> lcoor(Nd), gcoor(Nd);
|
|
result = source;
|
|
int pc = processor_coor[dim];
|
|
for(int p=0;p<processors[dim];p++) {
|
|
PARALLEL_REGION
|
|
{
|
|
std::vector<int> cbuf(Nd);
|
|
sobj s;
|
|
|
|
PARALLEL_FOR_LOOP_INTERN
|
|
for(int idx=0;idx<sgrid->lSites();idx++) {
|
|
sgrid->LocalIndexToLocalCoor(idx,cbuf);
|
|
peekLocalSite(s,result,cbuf);
|
|
cbuf[dim]+=((pc+p) % processors[dim])*L;
|
|
// cbuf[dim]+=p*L;
|
|
pokeLocalSite(s,pgbuf,cbuf);
|
|
}
|
|
}
|
|
if (p != processors[dim] - 1)
|
|
{
|
|
result = Cshift(result,dim,L);
|
|
}
|
|
}
|
|
|
|
// Loop over orthog coords
|
|
int NN=pencil_g.lSites();
|
|
GridStopWatch timer;
|
|
timer.Start();
|
|
PARALLEL_REGION
|
|
{
|
|
std::vector<int> cbuf(Nd);
|
|
|
|
PARALLEL_FOR_LOOP_INTERN
|
|
for(int idx=0;idx<NN;idx++) {
|
|
pencil_g.LocalIndexToLocalCoor(idx, cbuf);
|
|
if ( cbuf[dim] == 0 ) { // restricts loop to plane at lcoor[dim]==0
|
|
FFTW_scalar *in = (FFTW_scalar *)&pgbuf._odata[idx];
|
|
FFTW_scalar *out= (FFTW_scalar *)&pgbuf._odata[idx];
|
|
FFTW<scalar>::fftw_execute_dft(p,in,out);
|
|
}
|
|
}
|
|
}
|
|
timer.Stop();
|
|
|
|
// performance counting
|
|
double add,mul,fma;
|
|
FFTW<scalar>::fftw_flops(p,&add,&mul,&fma);
|
|
flops_call = add+mul+2.0*fma;
|
|
usec += timer.useconds();
|
|
flops+= flops_call*NN;
|
|
|
|
// writing out result
|
|
PARALLEL_REGION
|
|
{
|
|
std::vector<int> clbuf(Nd), cgbuf(Nd);
|
|
sobj s;
|
|
|
|
PARALLEL_FOR_LOOP_INTERN
|
|
for(int idx=0;idx<sgrid->lSites();idx++) {
|
|
sgrid->LocalIndexToLocalCoor(idx,clbuf);
|
|
cgbuf = clbuf;
|
|
cgbuf[dim] = clbuf[dim]+L*pc;
|
|
peekLocalSite(s,pgbuf,cgbuf);
|
|
pokeLocalSite(s,result,clbuf);
|
|
}
|
|
}
|
|
result = result*div;
|
|
|
|
// destroying plan
|
|
FFTW<scalar>::fftw_destroy_plan(p);
|
|
#endif
|
|
}
|
|
};
|
|
}
|
|
|
|
#endif
|