diff --git a/Grid/lattice/Lattice_slicesum_core.h b/Grid/lattice/Lattice_slicesum_core.h index 63737517..36580f57 100644 --- a/Grid/lattice/Lattice_slicesum_core.h +++ b/Grid/lattice/Lattice_slicesum_core.h @@ -131,9 +131,8 @@ template inline void sliceSumReduction_cub(const Lattice &Data #if defined(GRID_SYCL) -template inline void sliceSumReduction_sycl(const Lattice &Data, Vector &lvSum, const int &rd, const int &e1, const int &e2, const int &stride, const int &ostride, const int &Nsimd) +template inline void sliceSumReduction_sycl_small(const vobj *Data, Vector &lvSum, const int &rd, const int &e1, const int &e2, const int &stride, const int &ostride, const int &Nsimd) { - typedef typename vobj::scalar_object sobj; size_t subvol_size = e1*e2; vobj *mysum = (vobj *) malloc_shared(rd*sizeof(vobj),*theGridAccelerator); @@ -147,7 +146,7 @@ template inline void sliceSumReduction_sycl(const Lattice &Dat auto rb_p = &reduction_buffer[0]; - autoView(Data_v, Data, AcceleratorRead); + // autoView(Data_v, Data, AcceleratorRead); //prepare reduction buffer accelerator_for2d( s,subvol_size, r,rd, (size_t)Nsimd,{ @@ -157,7 +156,7 @@ template inline void sliceSumReduction_sycl(const Lattice &Dat int so=r*ostride; // base offset for start of plane int ss= so+n*stride+b; - coalescedWrite(rb_p[r*subvol_size+s], coalescedRead(Data_v[ss])); + coalescedWrite(rb_p[r*subvol_size+s], coalescedRead(Data[ss])); }); @@ -180,6 +179,45 @@ template inline void sliceSumReduction_sycl(const Lattice &Dat } free(mysum,*theGridAccelerator); } + + +template inline void sliceSumReduction_sycl_large(const vobj *Data, Vector &lvSum, const int rd, const int e1, const int e2, const int stride, const int ostride, const int Nsimd) { + typedef typename vobj::vector_type vector; + const int words = sizeof(vobj)/sizeof(vector); + const int osites = rd*e1*e2; + commVectorbuffer(osites); + vector *dat = (vector *)Data; + vector *buf = &buffer[0]; + Vector lvSum_small(rd); + vector *lvSum_ptr = (vector *)&lvSum[0]; + + for (int w = 0; w < words; w++) { + accelerator_for(ss,osites,1,{ + buf[ss] = dat[ss*words+w]; + }); + + sliceSumReduction_sycl_small(buf,lvSum_small,rd,e1,e2,stride, ostride,Nsimd); + + for (int r = 0; r < rd; r++) { + lvSum_ptr[w+words*r]=lvSum_small[r]; + } + + } +} + + +template inline void sliceSumReduction_sycl(const Lattice &Data, Vector &lvSum, const int rd, const int e1, const int e2, const int stride, const int ostride, const int Nsimd) +{ + autoView(Data_v, Data, AcceleratorRead); //hipcub/cub cannot deal with large vobjs so we split into small/large case. + if constexpr (sizeof(vobj) <= 256) { + sliceSumReduction_sycl_small(&Data_v[0], lvSum, rd, e1, e2, stride, ostride, Nsimd); + } + else { + sliceSumReduction_sycl_large(&Data_v[0], lvSum, rd, e1, e2, stride, ostride, Nsimd); + } +} + + #endif template inline void sliceSumReduction_cpu(const Lattice &Data, Vector &lvSum, const int &rd, const int &e1, const int &e2, const int &stride, const int &ostride, const int &Nsimd)