From 7d62f1d6d20790f681f37fe1755e3712d2a4e2b0 Mon Sep 17 00:00:00 2001 From: Makis Kappas Date: Wed, 11 Jan 2023 21:26:25 +0000 Subject: [PATCH] Populate the Cshift_table in the GPU Cshift is allocated in Unified memory and used in the LambdaApply kernels but also populated from the host. This creates a lot of Unified HtoD and DtoH mem operations and has a negative effect in performance. With this commit we populate the Cshift table in the device with the populate_Cshift_table() kernel. --- Grid/cshift/Cshift_common.h | 40 +++++++++++++++++++++++++++++++++++++ 1 file changed, 40 insertions(+) diff --git a/Grid/cshift/Cshift_common.h b/Grid/cshift/Cshift_common.h index cf902b58..742c99da 100644 --- a/Grid/cshift/Cshift_common.h +++ b/Grid/cshift/Cshift_common.h @@ -297,6 +297,30 @@ template void Scatter_plane_merge(Lattice &rhs,ExtractPointerA } } +#if (defined(GRID_CUDA) || defined(GRID_HIP)) && defined(ACCELERATOR_CSHIFT) + +template +T iDivUp(T a, T b) // Round a / b to nearest higher integer value +{ return (a % b != 0) ? (a / b + 1) : (a / b); } + +template +__global__ void populate_Cshift_table(T* vector, T lo, T ro, T e1, T e2, T stride) +{ + int idx = blockIdx.x*blockDim.x + threadIdx.x; + if (idx >= e1*e2) return; + + int n, b, o; + + n = idx / e2; + b = idx % e2; + o = n*stride + b; + + vector[2*idx + 0] = lo + o; + vector[2*idx + 1] = ro + o; +} + +#endif + ////////////////////////////////////////////////////// // local to node block strided copies ////////////////////////////////////////////////////// @@ -321,12 +345,20 @@ template void Copy_plane(Lattice& lhs,const Lattice &rhs int ent=0; if(cbmask == 0x3 ){ +#if (defined(GRID_CUDA) || defined(GRID_HIP)) && defined(ACCELERATOR_CSHIFT) + ent = e1*e2; + dim3 blockSize(acceleratorThreads()); + dim3 gridSize(iDivUp((unsigned int)ent, blockSize.x)); + populate_Cshift_table<<>>(&Cshift_table[0].first, lo, ro, e1, e2, stride); + accelerator_barrier(); +#else for(int n=0;n(lo+o,ro+o); } } +#endif } else { for(int n=0;n void Copy_plane_permute(Lattice& lhs,const Lattice>>(&Cshift_table[0].first, lo, ro, e1, e2, stride); + accelerator_barrier(); +#else for(int n=0;n(lo+o+b,ro+o+b); }} +#endif } else { for(int n=0;n