mirror of
https://github.com/paboyle/Grid.git
synced 2024-11-10 07:55:35 +00:00
more FFT optimisations
This commit is contained in:
parent
33d199a0ad
commit
14ddf2c234
233
lib/FFT.h
233
lib/FFT.h
@ -98,174 +98,157 @@ namespace Grid {
|
||||
#define FFTW_BACKWARD (+1)
|
||||
#endif
|
||||
|
||||
class FFT {
|
||||
class FFT {
|
||||
private:
|
||||
|
||||
|
||||
GridCartesian *vgrid;
|
||||
GridCartesian *sgrid;
|
||||
|
||||
|
||||
int Nd;
|
||||
double flops;
|
||||
double flops_call;
|
||||
uint64_t usec;
|
||||
|
||||
|
||||
std::vector<int> dimensions;
|
||||
std::vector<int> processors;
|
||||
std::vector<int> processor_coor;
|
||||
|
||||
|
||||
public:
|
||||
|
||||
|
||||
static const int forward=FFTW_FORWARD;
|
||||
static const int backward=FFTW_BACKWARD;
|
||||
|
||||
|
||||
double Flops(void) {return flops;}
|
||||
double MFlops(void) {return flops/usec;}
|
||||
|
||||
FFT ( GridCartesian * grid ) :
|
||||
vgrid(grid),
|
||||
Nd(grid->_ndimension),
|
||||
dimensions(grid->_fdimensions),
|
||||
processors(grid->_processors),
|
||||
processor_coor(grid->_processor_coor)
|
||||
|
||||
FFT ( GridCartesian * grid ) :
|
||||
vgrid(grid),
|
||||
Nd(grid->_ndimension),
|
||||
dimensions(grid->_fdimensions),
|
||||
processors(grid->_processors),
|
||||
processor_coor(grid->_processor_coor)
|
||||
{
|
||||
flops=0;
|
||||
usec =0;
|
||||
std::vector<int> layout(Nd,1);
|
||||
sgrid = new GridCartesian(dimensions,layout,processors);
|
||||
};
|
||||
|
||||
~FFT ( void) {
|
||||
delete sgrid;
|
||||
|
||||
~FFT ( void) {
|
||||
delete sgrid;
|
||||
}
|
||||
|
||||
template<class vobj>
|
||||
void FFT_dim(Lattice<vobj> &result,const Lattice<vobj> &source,int dim, int inverse){
|
||||
|
||||
#ifndef HAVE_FFTW
|
||||
assert(0);
|
||||
#else
|
||||
conformable(result._grid,vgrid);
|
||||
conformable(source._grid,vgrid);
|
||||
|
||||
|
||||
int L = vgrid->_ldimensions[dim];
|
||||
int G = vgrid->_fdimensions[dim];
|
||||
|
||||
|
||||
std::vector<int> layout(Nd,1);
|
||||
std::vector<int> pencil_gd(vgrid->_fdimensions);
|
||||
|
||||
pencil_gd[dim] = G*processors[dim];
|
||||
|
||||
|
||||
pencil_gd[dim] = G*processors[dim];
|
||||
|
||||
// Pencil global vol LxLxGxLxL per node
|
||||
GridCartesian pencil_g(pencil_gd,layout,processors);
|
||||
|
||||
|
||||
// Construct pencils
|
||||
typedef typename vobj::scalar_object sobj;
|
||||
typedef typename sobj::scalar_type scalar;
|
||||
|
||||
Lattice<sobj> pgbuf(&pencil_g);
|
||||
|
||||
|
||||
Lattice<vobj> ssource(vgrid); ssource =source;
|
||||
Lattice<sobj> pgsource(&pencil_g);
|
||||
Lattice<sobj> pgresult(&pencil_g); pgresult=zero;
|
||||
|
||||
#ifndef HAVE_FFTW
|
||||
assert(0);
|
||||
#else
|
||||
typedef typename FFTW<scalar>::FFTW_scalar FFTW_scalar;
|
||||
typedef typename FFTW<scalar>::FFTW_plan FFTW_plan;
|
||||
|
||||
{
|
||||
int Ncomp = sizeof(sobj)/sizeof(scalar);
|
||||
int Nlow = 1;
|
||||
for(int d=0;d<dim;d++){
|
||||
Nlow*=vgrid->_ldimensions[d];
|
||||
}
|
||||
|
||||
int rank = 1; /* 1d transforms */
|
||||
int n[] = {G}; /* 1d transforms of length G */
|
||||
int howmany = Ncomp;
|
||||
int odist,idist,istride,ostride;
|
||||
idist = odist = 1; /* Distance between consecutive FT's */
|
||||
istride = ostride = Ncomp*Nlow; /* distance between two elements in the same FT */
|
||||
int *inembed = n, *onembed = n;
|
||||
|
||||
|
||||
int sign = FFTW_FORWARD;
|
||||
if (inverse) sign = FFTW_BACKWARD;
|
||||
|
||||
FFTW_plan p;
|
||||
{
|
||||
FFTW_scalar *in = (FFTW_scalar *)&pgsource._odata[0];
|
||||
FFTW_scalar *out= (FFTW_scalar *)&pgresult._odata[0];
|
||||
p = FFTW<scalar>::fftw_plan_many_dft(rank,n,howmany,
|
||||
in,inembed,
|
||||
istride,idist,
|
||||
out,onembed,
|
||||
ostride, odist,
|
||||
sign,FFTW_ESTIMATE);
|
||||
}
|
||||
|
||||
std::vector<int> lcoor(Nd), gcoor(Nd);
|
||||
|
||||
// Barrel shift and collect global pencil
|
||||
for(int p=0;p<processors[dim];p++) {
|
||||
|
||||
for(int idx=0;idx<sgrid->lSites();idx++) {
|
||||
|
||||
|
||||
sgrid->LocalIndexToLocalCoor(idx,lcoor);
|
||||
|
||||
sobj s;
|
||||
|
||||
peekLocalSite(s,ssource,lcoor);
|
||||
|
||||
lcoor[dim]+=p*L;
|
||||
|
||||
pokeLocalSite(s,pgsource,lcoor);
|
||||
}
|
||||
|
||||
ssource = Cshift(ssource,dim,L);
|
||||
}
|
||||
|
||||
// Loop over orthog coords
|
||||
int NN=pencil_g.lSites();
|
||||
GridStopWatch timer;
|
||||
timer.Start();
|
||||
|
||||
//PARALLEL_FOR_LOOP
|
||||
for(int idx=0;idx<NN;idx++) {
|
||||
pencil_g.LocalIndexToLocalCoor(idx,lcoor);
|
||||
|
||||
if ( lcoor[dim] == 0 ) { // restricts loop to plane at lcoor[dim]==0
|
||||
FFTW_scalar *in = (FFTW_scalar *)&pgsource._odata[idx];
|
||||
FFTW_scalar *out= (FFTW_scalar *)&pgresult._odata[idx];
|
||||
FFTW<scalar>::fftw_execute_dft(p,in,out);
|
||||
}
|
||||
}
|
||||
|
||||
timer.Stop();
|
||||
|
||||
double add,mul,fma;
|
||||
FFTW<scalar>::fftw_flops(p,&add,&mul,&fma);
|
||||
flops_call = add+mul+2.0*fma;
|
||||
usec += timer.useconds();
|
||||
flops+= flops_call*NN;
|
||||
int pc = processor_coor[dim];
|
||||
for(int idx=0;idx<sgrid->lSites();idx++) {
|
||||
sgrid->LocalIndexToLocalCoor(idx,lcoor);
|
||||
gcoor = lcoor;
|
||||
// extract the result
|
||||
sobj s;
|
||||
gcoor[dim] = lcoor[dim]+L*pc;
|
||||
peekLocalSite(s,pgresult,gcoor);
|
||||
pokeLocalSite(s,result,lcoor);
|
||||
}
|
||||
|
||||
FFTW<scalar>::fftw_destroy_plan(p);
|
||||
|
||||
int Ncomp = sizeof(sobj)/sizeof(scalar);
|
||||
int Nlow = 1;
|
||||
for(int d=0;d<dim;d++){
|
||||
Nlow*=vgrid->_ldimensions[d];
|
||||
}
|
||||
|
||||
int rank = 1; /* 1d transforms */
|
||||
int n[] = {G}; /* 1d transforms of length G */
|
||||
int howmany = Ncomp;
|
||||
int odist,idist,istride,ostride;
|
||||
idist = odist = 1; /* Distance between consecutive FT's */
|
||||
istride = ostride = Ncomp*Nlow; /* distance between two elements in the same FT */
|
||||
int *inembed = n, *onembed = n;
|
||||
|
||||
int sign = FFTW_FORWARD;
|
||||
if (inverse) sign = FFTW_BACKWARD;
|
||||
|
||||
FFTW_plan p;
|
||||
{
|
||||
FFTW_scalar *in = (FFTW_scalar *)&pgbuf._odata[0];
|
||||
FFTW_scalar *out= (FFTW_scalar *)&pgbuf._odata[0];
|
||||
p = FFTW<scalar>::fftw_plan_many_dft(rank,n,howmany,
|
||||
in,inembed,
|
||||
istride,idist,
|
||||
out,onembed,
|
||||
ostride, odist,
|
||||
sign,FFTW_ESTIMATE);
|
||||
}
|
||||
|
||||
// Barrel shift and collect global pencil
|
||||
std::vector<int> lcoor(Nd), gcoor(Nd);
|
||||
result = source;
|
||||
for(int p=0;p<processors[dim];p++) {
|
||||
for(int idx=0;idx<sgrid->lSites();idx++) {
|
||||
sgrid->LocalIndexToLocalCoor(idx,lcoor);
|
||||
sobj s;
|
||||
peekLocalSite(s,result,lcoor);
|
||||
lcoor[dim]+=p*L;
|
||||
pokeLocalSite(s,pgbuf,lcoor);
|
||||
}
|
||||
result = Cshift(result,dim,L);
|
||||
}
|
||||
|
||||
// Loop over orthog coords
|
||||
int NN=pencil_g.lSites();
|
||||
GridStopWatch timer;
|
||||
timer.Start();
|
||||
//PARALLEL_FOR_LOOP
|
||||
for(int idx=0;idx<NN;idx++) {
|
||||
pencil_g.LocalIndexToLocalCoor(idx,lcoor);
|
||||
|
||||
if ( lcoor[dim] == 0 ) { // restricts loop to plane at lcoor[dim]==0
|
||||
FFTW_scalar *in = (FFTW_scalar *)&pgbuf._odata[idx];
|
||||
FFTW_scalar *out= (FFTW_scalar *)&pgbuf._odata[idx];
|
||||
FFTW<scalar>::fftw_execute_dft(p,in,out);
|
||||
}
|
||||
}
|
||||
timer.Stop();
|
||||
|
||||
// performance counting
|
||||
double add,mul,fma;
|
||||
FFTW<scalar>::fftw_flops(p,&add,&mul,&fma);
|
||||
flops_call = add+mul+2.0*fma;
|
||||
usec += timer.useconds();
|
||||
flops+= flops_call*NN;
|
||||
|
||||
// writing out result
|
||||
int pc = processor_coor[dim];
|
||||
for(int idx=0;idx<sgrid->lSites();idx++) {
|
||||
sgrid->LocalIndexToLocalCoor(idx,lcoor);
|
||||
gcoor = lcoor;
|
||||
sobj s;
|
||||
gcoor[dim] = lcoor[dim]+L*pc;
|
||||
peekLocalSite(s,pgbuf,gcoor);
|
||||
pokeLocalSite(s,result,lcoor);
|
||||
}
|
||||
|
||||
// destroying plan
|
||||
FFTW<scalar>::fftw_destroy_plan(p);
|
||||
#endif
|
||||
|
||||
|
||||
}
|
||||
|
||||
};
|
||||
|
||||
|
||||
}
|
||||
|
||||
#endif
|
||||
|
@ -76,13 +76,14 @@ int main (int argc, char ** argv)
|
||||
S=zero;
|
||||
S = S+C;
|
||||
|
||||
Ctilde = C;
|
||||
FFT theFFT(&Fine);
|
||||
|
||||
theFFT.FFT_dim(Ctilde,C,0,FFT::forward); C=Ctilde; std::cout << theFFT.MFlops()<<std::endl;
|
||||
theFFT.FFT_dim(Ctilde,C,1,FFT::forward); C=Ctilde; std::cout << theFFT.MFlops()<<std::endl;
|
||||
theFFT.FFT_dim(Ctilde,C,2,FFT::forward); C=Ctilde; std::cout << theFFT.MFlops()<<std::endl;
|
||||
theFFT.FFT_dim(Ctilde,C,3,FFT::forward); std::cout << theFFT.MFlops()<<std::endl;
|
||||
|
||||
theFFT.FFT_dim(Ctilde,Ctilde,0,FFT::forward); std::cout << theFFT.MFlops()<<std::endl;
|
||||
theFFT.FFT_dim(Ctilde,Ctilde,1,FFT::forward); std::cout << theFFT.MFlops()<<std::endl;
|
||||
theFFT.FFT_dim(Ctilde,Ctilde,2,FFT::forward); std::cout << theFFT.MFlops()<<std::endl;
|
||||
theFFT.FFT_dim(Ctilde,Ctilde,3,FFT::forward); std::cout << theFFT.MFlops()<<std::endl;
|
||||
|
||||
// C=zero;
|
||||
// Ctilde = where(abs(Ctilde)<1.0e-10,C,Ctilde);
|
||||
TComplexD cVol;
|
||||
@ -92,12 +93,13 @@ int main (int argc, char ** argv)
|
||||
pokeSite(cVol,C,p);
|
||||
C=C-Ctilde;
|
||||
std::cout << "diff scalar "<<norm2(C) << std::endl;
|
||||
|
||||
theFFT.FFT_dim(Stilde,S,0,FFT::forward); S=Stilde; std::cout << theFFT.MFlops()<<std::endl;
|
||||
theFFT.FFT_dim(Stilde,S,1,FFT::forward); S=Stilde;std::cout << theFFT.MFlops()<<std::endl;
|
||||
theFFT.FFT_dim(Stilde,S,2,FFT::forward); S=Stilde;std::cout << theFFT.MFlops()<<std::endl;
|
||||
theFFT.FFT_dim(Stilde,S,3,FFT::forward);std::cout << theFFT.MFlops()<<std::endl;
|
||||
|
||||
Stilde = S;
|
||||
|
||||
theFFT.FFT_dim(Stilde,Stilde,0,FFT::forward); std::cout << theFFT.MFlops()<<std::endl;
|
||||
theFFT.FFT_dim(Stilde,Stilde,1,FFT::forward); std::cout << theFFT.MFlops()<<std::endl;
|
||||
theFFT.FFT_dim(Stilde,Stilde,2,FFT::forward); std::cout << theFFT.MFlops()<<std::endl;
|
||||
theFFT.FFT_dim(Stilde,Stilde,3,FFT::forward); std::cout << theFFT.MFlops()<<std::endl;
|
||||
|
||||
SpinMatrixD Sp;
|
||||
Sp = zero; Sp = Sp+cVol;
|
||||
|
||||
|
Loading…
Reference in New Issue
Block a user