From 5fc8a273e7e8f1a29acebad0237b5e0bef2ecb97 Mon Sep 17 00:00:00 2001 From: Daniel Richtmann Date: Mon, 6 Apr 2020 11:30:50 +0200 Subject: [PATCH] Fused innerProduct + norm2 on first argument operation --- Grid/lattice/Lattice_reduction.h | 58 +++++++++++++- tests/Test_innerproduct_norm.cc | 126 +++++++++++++++++++++++++++++++ 2 files changed, 183 insertions(+), 1 deletion(-) create mode 100644 tests/Test_innerproduct_norm.cc diff --git a/Grid/lattice/Lattice_reduction.h b/Grid/lattice/Lattice_reduction.h index 3c5b03e5..de2efd72 100644 --- a/Grid/lattice/Lattice_reduction.h +++ b/Grid/lattice/Lattice_reduction.h @@ -204,8 +204,64 @@ axpby_norm_fast(Lattice &z,sobj a,sobj b,const Lattice &x,const Latt grid->GlobalSum(nrm); return nrm; } - +template strong_inline void +innerProduct_norm(ComplexD& ip, RealD &nrm, const Lattice &left,const Lattice &right) +{ + conformable(left,right); + + typedef typename vobj::scalar_type scalar_type; + typedef typename vobj::vector_typeD vector_type; + Vector tmp(2); + + GridBase *grid = left.Grid(); + + auto left_v=left.View(); + auto right_v=right.View(); + + const uint64_t nsimd = grid->Nsimd(); + const uint64_t sites = grid->oSites(); + +#ifdef GRID_NVCC + // GPU + typedef decltype(innerProduct(left_v[0],right_v[0])) inner_t; + typedef decltype(innerProduct(left_v[0],left_v[0])) norm_t; + Vector inner_tmp(sites); + Vector norm_tmp(sites); + auto inner_tmp_v = &inner_tmp[0]; + auto norm_tmp_v = &norm_tmp[0]; + + accelerator_for( ss, sites, nsimd,{ + auto left_tmp = left_v(ss); + coalescedWrite(inner_tmp_v[ss],innerProduct(left_tmp,right_v(ss))); + coalescedWrite(norm_tmp_v[ss],innerProduct(left_tmp,left_tmp))); + }); + + tmp[0] = TensorRemove(sumD_gpu(inner_tmp_v,sites)); + tmp[1] = TensorRemove(sumD_gpu(norm_tmp_v,sites)); +#else + // CPU + typedef decltype(innerProductD(left_v[0],right_v[0])) inner_t; + typedef decltype(innerProductD(left_v[0],left_v[0])) norm_t; + Vector inner_tmp(sites); + Vector norm_tmp(sites); + auto inner_tmp_v = &inner_tmp[0]; + auto norm_tmp_v = &norm_tmp[0]; + + accelerator_for( ss, sites, nsimd,{ + auto left_tmp = left_v(ss); + inner_tmp_v[ss] = innerProductD(left_tmp,right_v(ss)); + norm_tmp_v[ss] = innerProductD(left_tmp,left_tmp); + }); + // Already promoted to double + tmp[0] = TensorRemove(sum(inner_tmp_v,sites)); + tmp[1] = TensorRemove(sum(norm_tmp_v,sites)); +#endif + grid->GlobalSumVector(&tmp[0],2); // keep norm Complex -> can use GlobalSumVector + ip = tmp[0]; + nrm = real(tmp[1]); +} + template inline auto sum(const LatticeUnaryExpression & expr) ->typename decltype(expr.op.func(eval(0,expr.arg1)))::scalar_object diff --git a/tests/Test_innerproduct_norm.cc b/tests/Test_innerproduct_norm.cc new file mode 100644 index 00000000..85c98521 --- /dev/null +++ b/tests/Test_innerproduct_norm.cc @@ -0,0 +1,126 @@ +/************************************************************************************* + +Grid physics library, www.github.com/paboyle/Grid + +Source file: ./tests/Test_innerproduct_norm.cc + +Copyright (C) 2015 + +Author: Daniel Richtmann + +This program is free software; you can redistribute it and/or modify +it under the terms of the GNU General Public License as published by +the Free Software Foundation; either version 2 of the License, or +(at your option) any later version. + +This program is distributed in the hope that it will be useful, +but WITHOUT ANY WARRANTY; without even the implied warranty of +MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the +GNU General Public License for more details. + +You should have received a copy of the GNU General Public License along +with this program; if not, write to the Free Software Foundation, Inc., +51 Franklin Street, Fifth Floor, Boston, MA 02110-1301 USA. + +See the full license in the file "LICENSE" in the top level distribution directory +*************************************************************************************/ +/* END LEGAL */ +#include + +using namespace Grid; + +int main(int argc, char** argv) { + Grid_init(&argc, &argv); + + const int nIter = 100; + + // clang-format off + GridCartesian *Grid_d = SpaceTimeGrid::makeFourDimGrid(GridDefaultLatt(), GridDefaultSimd(Nd, vComplexD::Nsimd()), GridDefaultMpi()); + GridCartesian *Grid_f = SpaceTimeGrid::makeFourDimGrid(GridDefaultLatt(), GridDefaultSimd(Nd, vComplexF::Nsimd()), GridDefaultMpi()); + // clang-format on + + GridParallelRNG pRNG_d(Grid_d); + GridParallelRNG pRNG_f(Grid_f); + + std::vector seeds_d({1, 2, 3, 4}); + std::vector seeds_f({5, 6, 7, 8}); + + pRNG_d.SeedFixedIntegers(seeds_d); + pRNG_f.SeedFixedIntegers(seeds_f); + + // clang-format off + LatticeFermionD x_d(Grid_d); random(pRNG_d, x_d); + LatticeFermionD y_d(Grid_d); random(pRNG_d, y_d); + LatticeFermionF x_f(Grid_f); random(pRNG_f, x_f); + LatticeFermionF y_f(Grid_f); random(pRNG_f, y_f); + // clang-format on + + GridStopWatch sw_ref; + GridStopWatch sw_res; + + { // double precision + ComplexD ip_d_ref, ip_d_res, diff_ip_d; + RealD norm2_d_ref, norm2_d_res, diff_norm2_d; + + sw_ref.Reset(); + sw_ref.Start(); + for(int i = 0; i < nIter; ++i) { + ip_d_ref = innerProduct(x_d, y_d); + norm2_d_ref = norm2(x_d); + } + sw_ref.Stop(); + + sw_res.Reset(); + sw_res.Start(); + for(int i = 0; i < nIter; ++i) { innerProduct_norm(ip_d_res, norm2_d_res, x_d, y_d); } + sw_res.Stop(); + + diff_ip_d = ip_d_ref - ip_d_res; + diff_norm2_d = norm2_d_ref - norm2_d_res; + + // clang-format off + std::cout << GridLogMessage << "Double: ip_ref = " << ip_d_ref << " ip_res = " << ip_d_res << " diff = " << diff_ip_d << std::endl; + std::cout << GridLogMessage << "Double: norm2_ref = " << norm2_d_ref << " norm2_res = " << norm2_d_res << " diff = " << diff_norm2_d << std::endl; + std::cout << GridLogMessage << "Double: time_ref = " << sw_ref.Elapsed() << " time_res = " << sw_res.Elapsed() << std::endl; + // clang-format on + + assert(diff_ip_d == 0.); + assert(diff_norm2_d == 0.); + + std::cout << GridLogMessage << "Double: all checks passed" << std::endl; + } + + { // single precision + ComplexD ip_f_ref, ip_f_res, diff_ip_f; + RealD norm2_f_ref, norm2_f_res, diff_norm2_f; + + sw_ref.Reset(); + sw_ref.Start(); + for(int i = 0; i < nIter; ++i) { + ip_f_ref = innerProduct(x_f, y_f); + norm2_f_ref = norm2(x_f); + } + sw_ref.Stop(); + + sw_res.Reset(); + sw_res.Start(); + for(int i = 0; i < nIter; ++i) { innerProduct_norm(ip_f_res, norm2_f_res, x_f, y_f); } + sw_res.Stop(); + + diff_ip_f = ip_f_ref - ip_f_res; + diff_norm2_f = norm2_f_ref - norm2_f_res; + + // clang-format off + std::cout << GridLogMessage << "Single: ip_ref = " << ip_f_ref << " ip_res = " << ip_f_res << " diff = " << diff_ip_f << std::endl; + std::cout << GridLogMessage << "Single: norm2_ref = " << norm2_f_ref << " norm2_res = " << norm2_f_res << " diff = " << diff_norm2_f << std::endl; + std::cout << GridLogMessage << "Single: time_ref = " << sw_ref.Elapsed() << " time_res = " << sw_res.Elapsed() << std::endl; + // clang-format on + + assert(diff_ip_f == 0.); + assert(diff_norm2_f == 0.); + + std::cout << GridLogMessage << "Single: all checks passed" << std::endl; + } + + Grid_finalize(); +}