mirror of
https://github.com/paboyle/Grid.git
synced 2025-06-21 17:22:03 +01:00
Interface to reduced precision comms
This commit is contained in:
@ -154,6 +154,12 @@ public:
|
|||||||
StencilImpl Stencil;
|
StencilImpl Stencil;
|
||||||
StencilImpl StencilEven;
|
StencilImpl StencilEven;
|
||||||
StencilImpl StencilOdd;
|
StencilImpl StencilOdd;
|
||||||
|
void SloppyComms(int sloppy)
|
||||||
|
{
|
||||||
|
Stencil.SetSloppyComms(sloppy);
|
||||||
|
StencilEven.SetSloppyComms(sloppy);
|
||||||
|
StencilOdd.SetSloppyComms(sloppy);
|
||||||
|
}
|
||||||
|
|
||||||
// Copy of the gauge field , with even and odd subsets
|
// Copy of the gauge field , with even and odd subsets
|
||||||
DoubledGaugeField Umu;
|
DoubledGaugeField Umu;
|
||||||
|
@ -179,6 +179,12 @@ public:
|
|||||||
StencilImpl Stencil;
|
StencilImpl Stencil;
|
||||||
StencilImpl StencilEven;
|
StencilImpl StencilEven;
|
||||||
StencilImpl StencilOdd;
|
StencilImpl StencilOdd;
|
||||||
|
void SloppyComms(int sloppy)
|
||||||
|
{
|
||||||
|
Stencil.SetSloppyComms(sloppy);
|
||||||
|
StencilEven.SetSloppyComms(sloppy);
|
||||||
|
StencilOdd.SetSloppyComms(sloppy);
|
||||||
|
}
|
||||||
|
|
||||||
// Copy of the gauge field , with even and odd subsets
|
// Copy of the gauge field , with even and odd subsets
|
||||||
DoubledGaugeField Umu;
|
DoubledGaugeField Umu;
|
||||||
|
@ -146,6 +146,12 @@ public:
|
|||||||
StencilImpl Stencil;
|
StencilImpl Stencil;
|
||||||
StencilImpl StencilEven;
|
StencilImpl StencilEven;
|
||||||
StencilImpl StencilOdd;
|
StencilImpl StencilOdd;
|
||||||
|
void SloppyComms(int sloppy)
|
||||||
|
{
|
||||||
|
Stencil.SetSloppyComms(sloppy);
|
||||||
|
StencilEven.SetSloppyComms(sloppy);
|
||||||
|
StencilOdd.SetSloppyComms(sloppy);
|
||||||
|
}
|
||||||
|
|
||||||
// Copy of the gauge field , with even and odd subsets
|
// Copy of the gauge field , with even and odd subsets
|
||||||
DoubledGaugeField Umu;
|
DoubledGaugeField Umu;
|
||||||
|
@ -32,209 +32,6 @@ Author: paboyle <paboyle@ph.ed.ac.uk>
|
|||||||
|
|
||||||
NAMESPACE_BEGIN(Grid);
|
NAMESPACE_BEGIN(Grid);
|
||||||
|
|
||||||
///////////////////////////////////////////////////////////////
|
|
||||||
// Wilson compressor will need FaceGather policies for:
|
|
||||||
// Periodic, Dirichlet, and partial Dirichlet for DWF
|
|
||||||
///////////////////////////////////////////////////////////////
|
|
||||||
const int dwf_compressor_depth=2;
|
|
||||||
#define DWF_COMPRESS
|
|
||||||
class FaceGatherPartialDWF
|
|
||||||
{
|
|
||||||
public:
|
|
||||||
#ifdef DWF_COMPRESS
|
|
||||||
static int PartialCompressionFactor(GridBase *grid) {return grid->_fdimensions[0]/(2*dwf_compressor_depth);};
|
|
||||||
#else
|
|
||||||
static int PartialCompressionFactor(GridBase *grid) { return 1;}
|
|
||||||
#endif
|
|
||||||
template<class vobj,class cobj,class compressor>
|
|
||||||
static void Gather_plane_simple (deviceVector<std::pair<int,int> >& table,
|
|
||||||
const Lattice<vobj> &rhs,
|
|
||||||
cobj *buffer,
|
|
||||||
compressor &compress,
|
|
||||||
int off,int so,int partial)
|
|
||||||
{
|
|
||||||
//DWF only hack: If a direction that is OFF node we use Partial Dirichlet
|
|
||||||
// Shrinks local and remote comms buffers
|
|
||||||
GridBase *Grid = rhs.Grid();
|
|
||||||
int Ls = Grid->_rdimensions[0];
|
|
||||||
#ifdef DWF_COMPRESS
|
|
||||||
int depth=dwf_compressor_depth;
|
|
||||||
#else
|
|
||||||
int depth=Ls/2;
|
|
||||||
#endif
|
|
||||||
std::pair<int,int> *table_v = & table[0];
|
|
||||||
auto rhs_v = rhs.View(AcceleratorRead);
|
|
||||||
int vol=table.size()/Ls;
|
|
||||||
accelerator_forNB( idx,table.size(), vobj::Nsimd(), {
|
|
||||||
Integer i=idx/Ls;
|
|
||||||
Integer s=idx%Ls;
|
|
||||||
Integer sc=depth+s-(Ls-depth);
|
|
||||||
if(s<depth) compress.Compress(buffer[off+i+s*vol],rhs_v[so+table_v[idx].second]);
|
|
||||||
if(s>=Ls-depth) compress.Compress(buffer[off+i+sc*vol],rhs_v[so+table_v[idx].second]);
|
|
||||||
});
|
|
||||||
rhs_v.ViewClose();
|
|
||||||
}
|
|
||||||
template<class decompressor,class Decompression>
|
|
||||||
static void DecompressFace(decompressor decompress,Decompression &dd)
|
|
||||||
{
|
|
||||||
auto Ls = dd.dims[0];
|
|
||||||
#ifdef DWF_COMPRESS
|
|
||||||
int depth=dwf_compressor_depth;
|
|
||||||
#else
|
|
||||||
int depth=Ls/2;
|
|
||||||
#endif
|
|
||||||
// Just pass in the Grid
|
|
||||||
auto kp = dd.kernel_p;
|
|
||||||
auto mp = dd.mpi_p;
|
|
||||||
int size= dd.buffer_size;
|
|
||||||
int vol= size/Ls;
|
|
||||||
accelerator_forNB(o,size,1,{
|
|
||||||
int idx=o/Ls;
|
|
||||||
int s=o%Ls;
|
|
||||||
if ( s < depth ) {
|
|
||||||
int oo=s*vol+idx;
|
|
||||||
kp[o]=mp[oo];
|
|
||||||
} else if ( s >= Ls-depth ) {
|
|
||||||
int sc = depth + s - (Ls-depth);
|
|
||||||
int oo=sc*vol+idx;
|
|
||||||
kp[o]=mp[oo];
|
|
||||||
} else {
|
|
||||||
kp[o] = Zero();//fill rest with zero if partial dirichlet
|
|
||||||
}
|
|
||||||
});
|
|
||||||
}
|
|
||||||
////////////////////////////////////////////////////////////////////////////////////////////
|
|
||||||
// Need to gather *interior portions* for ALL s-slices in simd directions
|
|
||||||
// Do the gather as need to treat SIMD lanes differently, and insert zeroes on receive side
|
|
||||||
// Reorder the fifth dim to be s=Ls-1 , s=0, s=1,...,Ls-2.
|
|
||||||
////////////////////////////////////////////////////////////////////////////////////////////
|
|
||||||
template<class vobj,class cobj,class compressor>
|
|
||||||
static void Gather_plane_exchange(deviceVector<std::pair<int,int> >& table,const Lattice<vobj> &rhs,
|
|
||||||
std::vector<cobj *> pointers,int dimension,int plane,int cbmask,
|
|
||||||
compressor &compress,int type,int partial)
|
|
||||||
{
|
|
||||||
GridBase *Grid = rhs.Grid();
|
|
||||||
int Ls = Grid->_rdimensions[0];
|
|
||||||
#ifdef DWF_COMPRESS
|
|
||||||
int depth=dwf_compressor_depth;
|
|
||||||
#else
|
|
||||||
int depth = Ls/2;
|
|
||||||
#endif
|
|
||||||
|
|
||||||
// insertion of zeroes...
|
|
||||||
assert( (table.size()&0x1)==0);
|
|
||||||
int num=table.size()/2;
|
|
||||||
int so = plane*rhs.Grid()->_ostride[dimension]; // base offset for start of plane
|
|
||||||
|
|
||||||
auto rhs_v = rhs.View(AcceleratorRead);
|
|
||||||
auto p0=&pointers[0][0];
|
|
||||||
auto p1=&pointers[1][0];
|
|
||||||
auto tp=&table[0];
|
|
||||||
int nnum=num/Ls;
|
|
||||||
accelerator_forNB(j, num, vobj::Nsimd(), {
|
|
||||||
// Reorders both local and remote comms buffers
|
|
||||||
//
|
|
||||||
int s = j % Ls;
|
|
||||||
int sp1 = (s+depth)%Ls; // peri incremented s slice
|
|
||||||
|
|
||||||
int hxyz= j/Ls;
|
|
||||||
|
|
||||||
int xyz0= hxyz*2; // xyzt part of coor
|
|
||||||
int xyz1= hxyz*2+1;
|
|
||||||
|
|
||||||
int jj= hxyz + sp1*nnum ; // 0,1,2,3 -> Ls-1 slice , 0-slice, 1-slice ....
|
|
||||||
|
|
||||||
int kk0= xyz0*Ls + s ; // s=0 goes to s=1
|
|
||||||
int kk1= xyz1*Ls + s ; // s=Ls-1 -> s=0
|
|
||||||
compress.CompressExchange(p0[jj],p1[jj],
|
|
||||||
rhs_v[so+tp[kk0 ].second], // Same s, consecutive xyz sites
|
|
||||||
rhs_v[so+tp[kk1 ].second],
|
|
||||||
type);
|
|
||||||
});
|
|
||||||
rhs_v.ViewClose();
|
|
||||||
}
|
|
||||||
// Merge routine is for SIMD faces
|
|
||||||
template<class decompressor,class Merger>
|
|
||||||
static void MergeFace(decompressor decompress,Merger &mm)
|
|
||||||
{
|
|
||||||
auto Ls = mm.dims[0];
|
|
||||||
#ifdef DWF_COMPRESS
|
|
||||||
int depth=dwf_compressor_depth;
|
|
||||||
#else
|
|
||||||
int depth = Ls/2;
|
|
||||||
#endif
|
|
||||||
int num= mm.buffer_size/2; // relate vol and Ls to buffer size
|
|
||||||
auto mp = &mm.mpointer[0];
|
|
||||||
auto vp0= &mm.vpointers[0][0]; // First arg is exchange first
|
|
||||||
auto vp1= &mm.vpointers[1][0];
|
|
||||||
auto type= mm.type;
|
|
||||||
int nnum = num/Ls;
|
|
||||||
accelerator_forNB(o,num,Merger::Nsimd,{
|
|
||||||
|
|
||||||
int s=o%Ls;
|
|
||||||
int hxyz=o/Ls; // xyzt related component
|
|
||||||
int xyz0=hxyz*2;
|
|
||||||
int xyz1=hxyz*2+1;
|
|
||||||
|
|
||||||
int sp = (s+depth)%Ls;
|
|
||||||
int jj= hxyz + sp*nnum ; // 0,1,2,3 -> Ls-1 slice , 0-slice, 1-slice ....
|
|
||||||
|
|
||||||
int oo0= s+xyz0*Ls;
|
|
||||||
int oo1= s+xyz1*Ls;
|
|
||||||
|
|
||||||
// same ss0, ss1 pair goes to new layout
|
|
||||||
decompress.Exchange(mp[oo0],mp[oo1],vp0[jj],vp1[jj],type);
|
|
||||||
});
|
|
||||||
}
|
|
||||||
};
|
|
||||||
class FaceGatherDWFMixedBCs
|
|
||||||
{
|
|
||||||
public:
|
|
||||||
#ifdef DWF_COMPRESS
|
|
||||||
static int PartialCompressionFactor(GridBase *grid) {return grid->_fdimensions[0]/(2*dwf_compressor_depth);};
|
|
||||||
#else
|
|
||||||
static int PartialCompressionFactor(GridBase *grid) {return 1;}
|
|
||||||
#endif
|
|
||||||
|
|
||||||
template<class vobj,class cobj,class compressor>
|
|
||||||
static void Gather_plane_simple (deviceVector<std::pair<int,int> >& table,
|
|
||||||
const Lattice<vobj> &rhs,
|
|
||||||
cobj *buffer,
|
|
||||||
compressor &compress,
|
|
||||||
int off,int so,int partial)
|
|
||||||
{
|
|
||||||
// std::cout << " face gather simple DWF partial "<<partial <<std::endl;
|
|
||||||
if(partial) FaceGatherPartialDWF::Gather_plane_simple(table,rhs,buffer,compress,off,so,partial);
|
|
||||||
else FaceGatherSimple::Gather_plane_simple(table,rhs,buffer,compress,off,so,partial);
|
|
||||||
}
|
|
||||||
template<class vobj,class cobj,class compressor>
|
|
||||||
static void Gather_plane_exchange(deviceVector<std::pair<int,int> >& table,const Lattice<vobj> &rhs,
|
|
||||||
std::vector<cobj *> pointers,int dimension,int plane,int cbmask,
|
|
||||||
compressor &compress,int type,int partial)
|
|
||||||
{
|
|
||||||
// std::cout << " face gather exch DWF partial "<<partial <<std::endl;
|
|
||||||
if(partial) FaceGatherPartialDWF::Gather_plane_exchange(table,rhs,pointers,dimension, plane,cbmask,compress,type,partial);
|
|
||||||
else FaceGatherSimple::Gather_plane_exchange (table,rhs,pointers,dimension, plane,cbmask,compress,type,partial);
|
|
||||||
}
|
|
||||||
template<class decompressor,class Merger>
|
|
||||||
static void MergeFace(decompressor decompress,Merger &mm)
|
|
||||||
{
|
|
||||||
int partial = mm.partial;
|
|
||||||
// std::cout << " merge DWF partial "<<partial <<std::endl;
|
|
||||||
if ( partial ) FaceGatherPartialDWF::MergeFace(decompress,mm);
|
|
||||||
else FaceGatherSimple::MergeFace(decompress,mm);
|
|
||||||
}
|
|
||||||
|
|
||||||
template<class decompressor,class Decompression>
|
|
||||||
static void DecompressFace(decompressor decompress,Decompression &dd)
|
|
||||||
{
|
|
||||||
int partial = dd.partial;
|
|
||||||
// std::cout << " decompress DWF partial "<<partial <<std::endl;
|
|
||||||
if ( partial ) FaceGatherPartialDWF::DecompressFace(decompress,dd);
|
|
||||||
else FaceGatherSimple::DecompressFace(decompress,dd);
|
|
||||||
}
|
|
||||||
};
|
|
||||||
|
|
||||||
/////////////////////////////////////////////////////////////////////////////////////////////
|
/////////////////////////////////////////////////////////////////////////////////////////////
|
||||||
// optimised versions supporting half precision too??? Deprecate
|
// optimised versions supporting half precision too??? Deprecate
|
||||||
/////////////////////////////////////////////////////////////////////////////////////////////
|
/////////////////////////////////////////////////////////////////////////////////////////////
|
||||||
@ -242,8 +39,7 @@ public:
|
|||||||
|
|
||||||
//Could make FaceGather a template param, but then behaviour is runtime not compile time
|
//Could make FaceGather a template param, but then behaviour is runtime not compile time
|
||||||
template<class _HCspinor,class _Hspinor,class _Spinor, class projector>
|
template<class _HCspinor,class _Hspinor,class _Spinor, class projector>
|
||||||
class WilsonCompressorTemplate : public FaceGatherDWFMixedBCs
|
class WilsonCompressorTemplate : public FaceGatherSimple
|
||||||
// : public FaceGatherSimple
|
|
||||||
{
|
{
|
||||||
public:
|
public:
|
||||||
|
|
||||||
|
@ -165,6 +165,12 @@ public:
|
|||||||
StencilImpl Stencil;
|
StencilImpl Stencil;
|
||||||
StencilImpl StencilEven;
|
StencilImpl StencilEven;
|
||||||
StencilImpl StencilOdd;
|
StencilImpl StencilOdd;
|
||||||
|
void SloppyComms(int sloppy)
|
||||||
|
{
|
||||||
|
Stencil.SetSloppyComms(sloppy);
|
||||||
|
StencilEven.SetSloppyComms(sloppy);
|
||||||
|
StencilOdd.SetSloppyComms(sloppy);
|
||||||
|
}
|
||||||
|
|
||||||
// Copy of the gauge field , with even and odd subsets
|
// Copy of the gauge field , with even and odd subsets
|
||||||
DoubledGaugeField Umu;
|
DoubledGaugeField Umu;
|
||||||
|
@ -205,6 +205,13 @@ public:
|
|||||||
DoubledGaugeField UmuEven;
|
DoubledGaugeField UmuEven;
|
||||||
DoubledGaugeField UmuOdd;
|
DoubledGaugeField UmuOdd;
|
||||||
|
|
||||||
|
|
||||||
|
void SloppyComms(int sloppy)
|
||||||
|
{
|
||||||
|
Stencil.SetSloppyComms(sloppy);
|
||||||
|
StencilEven.SetSloppyComms(sloppy);
|
||||||
|
StencilOdd.SetSloppyComms(sloppy);
|
||||||
|
}
|
||||||
// Comms buffer
|
// Comms buffer
|
||||||
// std::vector<SiteHalfSpinor,alignedAllocator<SiteHalfSpinor> > comm_buf;
|
// std::vector<SiteHalfSpinor,alignedAllocator<SiteHalfSpinor> > comm_buf;
|
||||||
|
|
||||||
|
Reference in New Issue
Block a user