From 1a82533d22da6a42de27c885d42deb31c28dc8f1 Mon Sep 17 00:00:00 2001 From: gfilaci Date: Tue, 14 May 2019 15:35:54 +0100 Subject: [PATCH] fix inner product with thrust reduction --- Grid/lattice/Lattice_reduction.h | 57 ++++++++++---------------------- 1 file changed, 17 insertions(+), 40 deletions(-) diff --git a/Grid/lattice/Lattice_reduction.h b/Grid/lattice/Lattice_reduction.h index 2d92ead8..ba871d1f 100644 --- a/Grid/lattice/Lattice_reduction.h +++ b/Grid/lattice/Lattice_reduction.h @@ -23,12 +23,7 @@ Author: paboyle #include #ifdef GRID_NVCC -#include -#include -#include -#include -#include -#include +#include #endif NAMESPACE_BEGIN(Grid); @@ -41,23 +36,12 @@ template inline RealD norm2(const Lattice &arg){ return real(nrm); } -#if 0 -//#warning "ThrustReduce compiled" -//#include -template -vobj ThrustNorm(const Lattice &lat) +#ifdef GRID_NVCC +template +struct innerProductFunctor : public thrust::binary_function { - typedef typename vobj::scalar_type scalar_type; - auto lat_v=lat.View(); - Integer s0=0; - Integer sN=lat_v.end(); - scalar_type sum = 0; - scalar_type * begin = (scalar_type *)&lat_v[s0]; - scalar_type * end = (scalar_type *)&lat_v[sN]; - thrust::reduce(begin,end,sum); - std::cout <<" thrust::reduce sum "<< sum << std::endl; - return sum; -} + accelerator R operator()(T x, T y) { return innerProduct(x,y); } +}; #endif // Double inner product @@ -75,24 +59,17 @@ inline ComplexD innerProduct(const Lattice &left,const Lattice &righ auto left_v = left.View(); auto right_v=right.View(); -#if 0 +#ifdef GRID_NVCC - typedef decltype(TensorRemove(innerProduct(left_v[0],right_v[0]))) inner_t; + typedef decltype(innerProduct(left_v[0],right_v[0])) inner_t; + thrust::plus binary_sum; + innerProductFunctor binary_inner_p; + Integer sN = left_v.end(); + inner_t zero = Zero(); + // is there a way of using the efficient thrust reduction while maintaining memory coalescing? + inner_t vnrm = thrust::inner_product(thrust::device, &left_v[0], &left_v[sN], &right_v[0], zero, binary_sum, binary_inner_p); + nrm = Reduce(TensorRemove(vnrm));// sum across simd - Lattice inner_tmp(grid); - - ///////////////////////// - // localInnerProduct - ///////////////////////// - auto inner_tmp_v = inner_tmp.View(); - accelerator_loop(ss,left_v,{ - inner_tmp_v[ss] = TensorRemove(innerProduct(left_v[ss],right_v[ss])); - }); - ///////////////////////// - // and site sum the scalars - ///////////////////////// - inner_t vnrm = ThrustNorm(inner_tmp); - auto vvnrm = vnrm; #else thread_loop( (int thr=0;thrSumArraySize();thr++),{ int mywork, myoff; @@ -108,9 +85,9 @@ inline ComplexD innerProduct(const Lattice &left,const Lattice &righ vector_type vvnrm; vvnrm=Zero(); // sum across threads for(int i=0;iSumArraySize();i++){ vvnrm = vvnrm+sumarray[i]; - } -#endif + } nrm = Reduce(vvnrm);// sum across simd +#endif right.Grid()->GlobalSum(nrm); return nrm; }