diff --git a/Grid/lattice/Lattice_reduction.h b/Grid/lattice/Lattice_reduction.h index 96654cef..a5366435 100644 --- a/Grid/lattice/Lattice_reduction.h +++ b/Grid/lattice/Lattice_reduction.h @@ -22,15 +22,15 @@ Author: paboyle #pragma once #include + + #ifdef GRID_NVCC #include #endif NAMESPACE_BEGIN(Grid); - -#ifndef GRID_NVCC template -inline typename vobj::scalar_object sum(const Lattice &arg) +inline typename vobj::scalar_object sum_cpu(const Lattice &arg) { GridBase *grid=arg.Grid(); int Nsimd = grid->Nsimd(); @@ -69,8 +69,16 @@ inline typename vobj::scalar_object sum(const Lattice &arg) return ssum; } -#endif +template +inline typename vobj::scalar_object sum(const Lattice &arg) +{ +#ifdef GRID_NVCC + return sum_gpu(arg); +#else + return sum_cpu(arg); +#endif +} //////////////////////////////////////////////////////////////////////////////////////////////////// // Deterministic Reduction operations @@ -109,7 +117,7 @@ inline ComplexD innerProduct(const Lattice &left,const Lattice &righ nrm = TensorRemove(sum(inner_tmp)); - right.Grid()->GlobalSum(nrm); + // right.Grid()->GlobalSum(nrm); return nrm; } @@ -157,7 +165,7 @@ axpby_norm_fast(Lattice &z,sobj a,sobj b,const Lattice &x,const Latt nrm = real(TensorRemove(sum(inner_tmp))); - z.Grid()->GlobalSum(nrm); + // z.Grid()->GlobalSum(nrm); return nrm; } diff --git a/Grid/lattice/Lattice_reduction_gpu.h b/Grid/lattice/Lattice_reduction_gpu.h index bdd3566f..d2906492 100644 --- a/Grid/lattice/Lattice_reduction_gpu.h +++ b/Grid/lattice/Lattice_reduction_gpu.h @@ -180,7 +180,7 @@ __global__ void reduceKernel(const LatticeView lat, typename vobj::scalar_ } template -inline typename vobj::scalar_object sum(const Lattice &lat) +inline typename vobj::scalar_object sum_gpu(const Lattice &lat) { LatticeView lat_v = lat.View();