mirror of
https://github.com/paboyle/Grid.git
synced 2025-04-09 21:50:45 +01:00
more FFT optimisations
This commit is contained in:
parent
33d199a0ad
commit
14ddf2c234
233
lib/FFT.h
233
lib/FFT.h
@ -98,174 +98,157 @@ namespace Grid {
|
|||||||
#define FFTW_BACKWARD (+1)
|
#define FFTW_BACKWARD (+1)
|
||||||
#endif
|
#endif
|
||||||
|
|
||||||
class FFT {
|
class FFT {
|
||||||
private:
|
private:
|
||||||
|
|
||||||
GridCartesian *vgrid;
|
GridCartesian *vgrid;
|
||||||
GridCartesian *sgrid;
|
GridCartesian *sgrid;
|
||||||
|
|
||||||
int Nd;
|
int Nd;
|
||||||
double flops;
|
double flops;
|
||||||
double flops_call;
|
double flops_call;
|
||||||
uint64_t usec;
|
uint64_t usec;
|
||||||
|
|
||||||
std::vector<int> dimensions;
|
std::vector<int> dimensions;
|
||||||
std::vector<int> processors;
|
std::vector<int> processors;
|
||||||
std::vector<int> processor_coor;
|
std::vector<int> processor_coor;
|
||||||
|
|
||||||
public:
|
public:
|
||||||
|
|
||||||
static const int forward=FFTW_FORWARD;
|
static const int forward=FFTW_FORWARD;
|
||||||
static const int backward=FFTW_BACKWARD;
|
static const int backward=FFTW_BACKWARD;
|
||||||
|
|
||||||
double Flops(void) {return flops;}
|
double Flops(void) {return flops;}
|
||||||
double MFlops(void) {return flops/usec;}
|
double MFlops(void) {return flops/usec;}
|
||||||
|
|
||||||
FFT ( GridCartesian * grid ) :
|
FFT ( GridCartesian * grid ) :
|
||||||
vgrid(grid),
|
vgrid(grid),
|
||||||
Nd(grid->_ndimension),
|
Nd(grid->_ndimension),
|
||||||
dimensions(grid->_fdimensions),
|
dimensions(grid->_fdimensions),
|
||||||
processors(grid->_processors),
|
processors(grid->_processors),
|
||||||
processor_coor(grid->_processor_coor)
|
processor_coor(grid->_processor_coor)
|
||||||
{
|
{
|
||||||
flops=0;
|
flops=0;
|
||||||
usec =0;
|
usec =0;
|
||||||
std::vector<int> layout(Nd,1);
|
std::vector<int> layout(Nd,1);
|
||||||
sgrid = new GridCartesian(dimensions,layout,processors);
|
sgrid = new GridCartesian(dimensions,layout,processors);
|
||||||
};
|
};
|
||||||
|
|
||||||
~FFT ( void) {
|
~FFT ( void) {
|
||||||
delete sgrid;
|
delete sgrid;
|
||||||
}
|
}
|
||||||
|
|
||||||
template<class vobj>
|
template<class vobj>
|
||||||
void FFT_dim(Lattice<vobj> &result,const Lattice<vobj> &source,int dim, int inverse){
|
void FFT_dim(Lattice<vobj> &result,const Lattice<vobj> &source,int dim, int inverse){
|
||||||
|
#ifndef HAVE_FFTW
|
||||||
|
assert(0);
|
||||||
|
#else
|
||||||
conformable(result._grid,vgrid);
|
conformable(result._grid,vgrid);
|
||||||
conformable(source._grid,vgrid);
|
conformable(source._grid,vgrid);
|
||||||
|
|
||||||
int L = vgrid->_ldimensions[dim];
|
int L = vgrid->_ldimensions[dim];
|
||||||
int G = vgrid->_fdimensions[dim];
|
int G = vgrid->_fdimensions[dim];
|
||||||
|
|
||||||
std::vector<int> layout(Nd,1);
|
std::vector<int> layout(Nd,1);
|
||||||
std::vector<int> pencil_gd(vgrid->_fdimensions);
|
std::vector<int> pencil_gd(vgrid->_fdimensions);
|
||||||
|
|
||||||
pencil_gd[dim] = G*processors[dim];
|
pencil_gd[dim] = G*processors[dim];
|
||||||
|
|
||||||
// Pencil global vol LxLxGxLxL per node
|
// Pencil global vol LxLxGxLxL per node
|
||||||
GridCartesian pencil_g(pencil_gd,layout,processors);
|
GridCartesian pencil_g(pencil_gd,layout,processors);
|
||||||
|
|
||||||
// Construct pencils
|
// Construct pencils
|
||||||
typedef typename vobj::scalar_object sobj;
|
typedef typename vobj::scalar_object sobj;
|
||||||
typedef typename sobj::scalar_type scalar;
|
typedef typename sobj::scalar_type scalar;
|
||||||
|
|
||||||
|
Lattice<sobj> pgbuf(&pencil_g);
|
||||||
|
|
||||||
|
|
||||||
Lattice<vobj> ssource(vgrid); ssource =source;
|
|
||||||
Lattice<sobj> pgsource(&pencil_g);
|
|
||||||
Lattice<sobj> pgresult(&pencil_g); pgresult=zero;
|
|
||||||
|
|
||||||
#ifndef HAVE_FFTW
|
|
||||||
assert(0);
|
|
||||||
#else
|
|
||||||
typedef typename FFTW<scalar>::FFTW_scalar FFTW_scalar;
|
typedef typename FFTW<scalar>::FFTW_scalar FFTW_scalar;
|
||||||
typedef typename FFTW<scalar>::FFTW_plan FFTW_plan;
|
typedef typename FFTW<scalar>::FFTW_plan FFTW_plan;
|
||||||
|
|
||||||
{
|
int Ncomp = sizeof(sobj)/sizeof(scalar);
|
||||||
int Ncomp = sizeof(sobj)/sizeof(scalar);
|
int Nlow = 1;
|
||||||
int Nlow = 1;
|
for(int d=0;d<dim;d++){
|
||||||
for(int d=0;d<dim;d++){
|
Nlow*=vgrid->_ldimensions[d];
|
||||||
Nlow*=vgrid->_ldimensions[d];
|
|
||||||
}
|
|
||||||
|
|
||||||
int rank = 1; /* 1d transforms */
|
|
||||||
int n[] = {G}; /* 1d transforms of length G */
|
|
||||||
int howmany = Ncomp;
|
|
||||||
int odist,idist,istride,ostride;
|
|
||||||
idist = odist = 1; /* Distance between consecutive FT's */
|
|
||||||
istride = ostride = Ncomp*Nlow; /* distance between two elements in the same FT */
|
|
||||||
int *inembed = n, *onembed = n;
|
|
||||||
|
|
||||||
|
|
||||||
int sign = FFTW_FORWARD;
|
|
||||||
if (inverse) sign = FFTW_BACKWARD;
|
|
||||||
|
|
||||||
FFTW_plan p;
|
|
||||||
{
|
|
||||||
FFTW_scalar *in = (FFTW_scalar *)&pgsource._odata[0];
|
|
||||||
FFTW_scalar *out= (FFTW_scalar *)&pgresult._odata[0];
|
|
||||||
p = FFTW<scalar>::fftw_plan_many_dft(rank,n,howmany,
|
|
||||||
in,inembed,
|
|
||||||
istride,idist,
|
|
||||||
out,onembed,
|
|
||||||
ostride, odist,
|
|
||||||
sign,FFTW_ESTIMATE);
|
|
||||||
}
|
|
||||||
|
|
||||||
std::vector<int> lcoor(Nd), gcoor(Nd);
|
|
||||||
|
|
||||||
// Barrel shift and collect global pencil
|
|
||||||
for(int p=0;p<processors[dim];p++) {
|
|
||||||
|
|
||||||
for(int idx=0;idx<sgrid->lSites();idx++) {
|
|
||||||
|
|
||||||
|
|
||||||
sgrid->LocalIndexToLocalCoor(idx,lcoor);
|
|
||||||
|
|
||||||
sobj s;
|
|
||||||
|
|
||||||
peekLocalSite(s,ssource,lcoor);
|
|
||||||
|
|
||||||
lcoor[dim]+=p*L;
|
|
||||||
|
|
||||||
pokeLocalSite(s,pgsource,lcoor);
|
|
||||||
}
|
|
||||||
|
|
||||||
ssource = Cshift(ssource,dim,L);
|
|
||||||
}
|
|
||||||
|
|
||||||
// Loop over orthog coords
|
|
||||||
int NN=pencil_g.lSites();
|
|
||||||
GridStopWatch timer;
|
|
||||||
timer.Start();
|
|
||||||
|
|
||||||
//PARALLEL_FOR_LOOP
|
|
||||||
for(int idx=0;idx<NN;idx++) {
|
|
||||||
pencil_g.LocalIndexToLocalCoor(idx,lcoor);
|
|
||||||
|
|
||||||
if ( lcoor[dim] == 0 ) { // restricts loop to plane at lcoor[dim]==0
|
|
||||||
FFTW_scalar *in = (FFTW_scalar *)&pgsource._odata[idx];
|
|
||||||
FFTW_scalar *out= (FFTW_scalar *)&pgresult._odata[idx];
|
|
||||||
FFTW<scalar>::fftw_execute_dft(p,in,out);
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
timer.Stop();
|
|
||||||
|
|
||||||
double add,mul,fma;
|
|
||||||
FFTW<scalar>::fftw_flops(p,&add,&mul,&fma);
|
|
||||||
flops_call = add+mul+2.0*fma;
|
|
||||||
usec += timer.useconds();
|
|
||||||
flops+= flops_call*NN;
|
|
||||||
int pc = processor_coor[dim];
|
|
||||||
for(int idx=0;idx<sgrid->lSites();idx++) {
|
|
||||||
sgrid->LocalIndexToLocalCoor(idx,lcoor);
|
|
||||||
gcoor = lcoor;
|
|
||||||
// extract the result
|
|
||||||
sobj s;
|
|
||||||
gcoor[dim] = lcoor[dim]+L*pc;
|
|
||||||
peekLocalSite(s,pgresult,gcoor);
|
|
||||||
pokeLocalSite(s,result,lcoor);
|
|
||||||
}
|
|
||||||
|
|
||||||
FFTW<scalar>::fftw_destroy_plan(p);
|
|
||||||
}
|
}
|
||||||
|
|
||||||
|
int rank = 1; /* 1d transforms */
|
||||||
|
int n[] = {G}; /* 1d transforms of length G */
|
||||||
|
int howmany = Ncomp;
|
||||||
|
int odist,idist,istride,ostride;
|
||||||
|
idist = odist = 1; /* Distance between consecutive FT's */
|
||||||
|
istride = ostride = Ncomp*Nlow; /* distance between two elements in the same FT */
|
||||||
|
int *inembed = n, *onembed = n;
|
||||||
|
|
||||||
|
int sign = FFTW_FORWARD;
|
||||||
|
if (inverse) sign = FFTW_BACKWARD;
|
||||||
|
|
||||||
|
FFTW_plan p;
|
||||||
|
{
|
||||||
|
FFTW_scalar *in = (FFTW_scalar *)&pgbuf._odata[0];
|
||||||
|
FFTW_scalar *out= (FFTW_scalar *)&pgbuf._odata[0];
|
||||||
|
p = FFTW<scalar>::fftw_plan_many_dft(rank,n,howmany,
|
||||||
|
in,inembed,
|
||||||
|
istride,idist,
|
||||||
|
out,onembed,
|
||||||
|
ostride, odist,
|
||||||
|
sign,FFTW_ESTIMATE);
|
||||||
|
}
|
||||||
|
|
||||||
|
// Barrel shift and collect global pencil
|
||||||
|
std::vector<int> lcoor(Nd), gcoor(Nd);
|
||||||
|
result = source;
|
||||||
|
for(int p=0;p<processors[dim];p++) {
|
||||||
|
for(int idx=0;idx<sgrid->lSites();idx++) {
|
||||||
|
sgrid->LocalIndexToLocalCoor(idx,lcoor);
|
||||||
|
sobj s;
|
||||||
|
peekLocalSite(s,result,lcoor);
|
||||||
|
lcoor[dim]+=p*L;
|
||||||
|
pokeLocalSite(s,pgbuf,lcoor);
|
||||||
|
}
|
||||||
|
result = Cshift(result,dim,L);
|
||||||
|
}
|
||||||
|
|
||||||
|
// Loop over orthog coords
|
||||||
|
int NN=pencil_g.lSites();
|
||||||
|
GridStopWatch timer;
|
||||||
|
timer.Start();
|
||||||
|
//PARALLEL_FOR_LOOP
|
||||||
|
for(int idx=0;idx<NN;idx++) {
|
||||||
|
pencil_g.LocalIndexToLocalCoor(idx,lcoor);
|
||||||
|
|
||||||
|
if ( lcoor[dim] == 0 ) { // restricts loop to plane at lcoor[dim]==0
|
||||||
|
FFTW_scalar *in = (FFTW_scalar *)&pgbuf._odata[idx];
|
||||||
|
FFTW_scalar *out= (FFTW_scalar *)&pgbuf._odata[idx];
|
||||||
|
FFTW<scalar>::fftw_execute_dft(p,in,out);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
timer.Stop();
|
||||||
|
|
||||||
|
// performance counting
|
||||||
|
double add,mul,fma;
|
||||||
|
FFTW<scalar>::fftw_flops(p,&add,&mul,&fma);
|
||||||
|
flops_call = add+mul+2.0*fma;
|
||||||
|
usec += timer.useconds();
|
||||||
|
flops+= flops_call*NN;
|
||||||
|
|
||||||
|
// writing out result
|
||||||
|
int pc = processor_coor[dim];
|
||||||
|
for(int idx=0;idx<sgrid->lSites();idx++) {
|
||||||
|
sgrid->LocalIndexToLocalCoor(idx,lcoor);
|
||||||
|
gcoor = lcoor;
|
||||||
|
sobj s;
|
||||||
|
gcoor[dim] = lcoor[dim]+L*pc;
|
||||||
|
peekLocalSite(s,pgbuf,gcoor);
|
||||||
|
pokeLocalSite(s,result,lcoor);
|
||||||
|
}
|
||||||
|
|
||||||
|
// destroying plan
|
||||||
|
FFTW<scalar>::fftw_destroy_plan(p);
|
||||||
#endif
|
#endif
|
||||||
|
|
||||||
|
|
||||||
}
|
}
|
||||||
|
|
||||||
};
|
};
|
||||||
|
|
||||||
|
|
||||||
}
|
}
|
||||||
|
|
||||||
#endif
|
#endif
|
||||||
|
@ -76,13 +76,14 @@ int main (int argc, char ** argv)
|
|||||||
S=zero;
|
S=zero;
|
||||||
S = S+C;
|
S = S+C;
|
||||||
|
|
||||||
|
Ctilde = C;
|
||||||
FFT theFFT(&Fine);
|
FFT theFFT(&Fine);
|
||||||
|
|
||||||
theFFT.FFT_dim(Ctilde,C,0,FFT::forward); C=Ctilde; std::cout << theFFT.MFlops()<<std::endl;
|
theFFT.FFT_dim(Ctilde,Ctilde,0,FFT::forward); std::cout << theFFT.MFlops()<<std::endl;
|
||||||
theFFT.FFT_dim(Ctilde,C,1,FFT::forward); C=Ctilde; std::cout << theFFT.MFlops()<<std::endl;
|
theFFT.FFT_dim(Ctilde,Ctilde,1,FFT::forward); std::cout << theFFT.MFlops()<<std::endl;
|
||||||
theFFT.FFT_dim(Ctilde,C,2,FFT::forward); C=Ctilde; std::cout << theFFT.MFlops()<<std::endl;
|
theFFT.FFT_dim(Ctilde,Ctilde,2,FFT::forward); std::cout << theFFT.MFlops()<<std::endl;
|
||||||
theFFT.FFT_dim(Ctilde,C,3,FFT::forward); std::cout << theFFT.MFlops()<<std::endl;
|
theFFT.FFT_dim(Ctilde,Ctilde,3,FFT::forward); std::cout << theFFT.MFlops()<<std::endl;
|
||||||
|
|
||||||
// C=zero;
|
// C=zero;
|
||||||
// Ctilde = where(abs(Ctilde)<1.0e-10,C,Ctilde);
|
// Ctilde = where(abs(Ctilde)<1.0e-10,C,Ctilde);
|
||||||
TComplexD cVol;
|
TComplexD cVol;
|
||||||
@ -92,12 +93,13 @@ int main (int argc, char ** argv)
|
|||||||
pokeSite(cVol,C,p);
|
pokeSite(cVol,C,p);
|
||||||
C=C-Ctilde;
|
C=C-Ctilde;
|
||||||
std::cout << "diff scalar "<<norm2(C) << std::endl;
|
std::cout << "diff scalar "<<norm2(C) << std::endl;
|
||||||
|
Stilde = S;
|
||||||
theFFT.FFT_dim(Stilde,S,0,FFT::forward); S=Stilde; std::cout << theFFT.MFlops()<<std::endl;
|
|
||||||
theFFT.FFT_dim(Stilde,S,1,FFT::forward); S=Stilde;std::cout << theFFT.MFlops()<<std::endl;
|
theFFT.FFT_dim(Stilde,Stilde,0,FFT::forward); std::cout << theFFT.MFlops()<<std::endl;
|
||||||
theFFT.FFT_dim(Stilde,S,2,FFT::forward); S=Stilde;std::cout << theFFT.MFlops()<<std::endl;
|
theFFT.FFT_dim(Stilde,Stilde,1,FFT::forward); std::cout << theFFT.MFlops()<<std::endl;
|
||||||
theFFT.FFT_dim(Stilde,S,3,FFT::forward);std::cout << theFFT.MFlops()<<std::endl;
|
theFFT.FFT_dim(Stilde,Stilde,2,FFT::forward); std::cout << theFFT.MFlops()<<std::endl;
|
||||||
|
theFFT.FFT_dim(Stilde,Stilde,3,FFT::forward); std::cout << theFFT.MFlops()<<std::endl;
|
||||||
|
|
||||||
SpinMatrixD Sp;
|
SpinMatrixD Sp;
|
||||||
Sp = zero; Sp = Sp+cVol;
|
Sp = zero; Sp = Sp+cVol;
|
||||||
|
|
||||||
|
Loading…
x
Reference in New Issue
Block a user