From 588c2f3cb191d9c438399e5039847285d3f9e4a9 Mon Sep 17 00:00:00 2001 From: Peter Boyle Date: Fri, 1 Jul 2022 09:44:58 -0400 Subject: [PATCH] Faster axpy_norm and innerProduct --- Grid/lattice/Lattice_reduction.h | 51 +++++++++++++++++++++++++------- 1 file changed, 40 insertions(+), 11 deletions(-) diff --git a/Grid/lattice/Lattice_reduction.h b/Grid/lattice/Lattice_reduction.h index 0ddac437..16feb856 100644 --- a/Grid/lattice/Lattice_reduction.h +++ b/Grid/lattice/Lattice_reduction.h @@ -232,6 +232,7 @@ inline ComplexD rankInnerProduct(const Lattice &left,const Lattice & const uint64_t sites = grid->oSites(); // Might make all code paths go this way. +#if 0 typedef decltype(innerProductD(vobj(),vobj())) inner_t; Vector inner_tmp(sites); auto inner_tmp_v = &inner_tmp[0]; @@ -241,15 +242,31 @@ inline ComplexD rankInnerProduct(const Lattice &left,const Lattice & autoView( right_v,right, AcceleratorRead); // GPU - SIMT lane compliance... - accelerator_for( ss, sites, 1,{ - auto x_l = left_v[ss]; - auto y_l = right_v[ss]; - inner_tmp_v[ss]=innerProductD(x_l,y_l); + accelerator_for( ss, sites, nsimd,{ + auto x_l = left_v(ss); + auto y_l = right_v(ss); + coalescedWrite(inner_tmp_v[ss],innerProductD(x_l,y_l)); }); } +#else + typedef decltype(innerProduct(vobj(),vobj())) inner_t; + Vector inner_tmp(sites); + auto inner_tmp_v = &inner_tmp[0]; + + { + autoView( left_v , left, AcceleratorRead); + autoView( right_v,right, AcceleratorRead); + // GPU - SIMT lane compliance... + accelerator_for( ss, sites, nsimd,{ + auto x_l = left_v(ss); + auto y_l = right_v(ss); + coalescedWrite(inner_tmp_v[ss],innerProduct(x_l,y_l)); + }); + } +#endif // This is in single precision and fails some tests - auto anrm = sum(inner_tmp_v,sites); + auto anrm = sumD(inner_tmp_v,sites); nrm = anrm; return nrm; } @@ -283,7 +300,7 @@ axpby_norm_fast(Lattice &z,sobj a,sobj b,const Lattice &x,const Latt conformable(x,y); typedef typename vobj::scalar_type scalar_type; - typedef typename vobj::vector_typeD vector_type; + // typedef typename vobj::vector_typeD vector_type; RealD nrm; GridBase *grid = x.Grid(); @@ -295,17 +312,29 @@ axpby_norm_fast(Lattice &z,sobj a,sobj b,const Lattice &x,const Latt autoView( x_v, x, AcceleratorRead); autoView( y_v, y, AcceleratorRead); autoView( z_v, z, AcceleratorWrite); - +#if 0 typedef decltype(innerProductD(x_v[0],y_v[0])) inner_t; Vector inner_tmp(sites); auto inner_tmp_v = &inner_tmp[0]; - accelerator_for( ss, sites, 1,{ - auto tmp = a*x_v[ss]+b*y_v[ss]; - inner_tmp_v[ss]=innerProductD(tmp,tmp); - z_v[ss]=tmp; + accelerator_for( ss, sites, nsimd,{ + auto tmp = a*x_v(ss)+b*y_v(ss); + coalescedWrite(inner_tmp_v[ss],innerProductD(tmp,tmp)); + coalescedWrite(z_v[ss],tmp); }); nrm = real(TensorRemove(sum(inner_tmp_v,sites))); +#else + typedef decltype(innerProduct(x_v[0],y_v[0])) inner_t; + Vector inner_tmp(sites); + auto inner_tmp_v = &inner_tmp[0]; + + accelerator_for( ss, sites, nsimd,{ + auto tmp = a*x_v(ss)+b*y_v(ss); + coalescedWrite(inner_tmp_v[ss],innerProduct(tmp,tmp)); + coalescedWrite(z_v[ss],tmp); + }); + nrm = real(TensorRemove(sumD(inner_tmp_v,sites))); +#endif grid->GlobalSum(nrm); return nrm; }