mirror of
				https://github.com/paboyle/Grid.git
				synced 2025-11-04 14:04:32 +00:00 
			
		
		
		
	Fused innerProduct + norm2 on first argument operation
This commit is contained in:
		@@ -204,8 +204,64 @@ axpby_norm_fast(Lattice<vobj> &z,sobj a,sobj b,const Lattice<vobj> &x,const Latt
 | 
				
			|||||||
  grid->GlobalSum(nrm);
 | 
					  grid->GlobalSum(nrm);
 | 
				
			||||||
  return nrm; 
 | 
					  return nrm; 
 | 
				
			||||||
}
 | 
					}
 | 
				
			||||||
 | 
					 | 
				
			||||||
 
 | 
					 
 | 
				
			||||||
 | 
					template<class vobj> strong_inline void
 | 
				
			||||||
 | 
					innerProduct_norm(ComplexD& ip, RealD &nrm, const Lattice<vobj> &left,const Lattice<vobj> &right)
 | 
				
			||||||
 | 
					{
 | 
				
			||||||
 | 
					  conformable(left,right);
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					  typedef typename vobj::scalar_type scalar_type;
 | 
				
			||||||
 | 
					  typedef typename vobj::vector_typeD vector_type;
 | 
				
			||||||
 | 
					  Vector<ComplexD> tmp(2);
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					  GridBase *grid = left.Grid();
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					  auto left_v=left.View();
 | 
				
			||||||
 | 
					  auto right_v=right.View();
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					  const uint64_t nsimd = grid->Nsimd();
 | 
				
			||||||
 | 
					  const uint64_t sites = grid->oSites();
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					#ifdef GRID_NVCC
 | 
				
			||||||
 | 
					  // GPU
 | 
				
			||||||
 | 
					  typedef decltype(innerProduct(left_v[0],right_v[0])) inner_t;
 | 
				
			||||||
 | 
					  typedef decltype(innerProduct(left_v[0],left_v[0])) norm_t;
 | 
				
			||||||
 | 
					  Vector<inner_t> inner_tmp(sites);
 | 
				
			||||||
 | 
					  Vector<norm_t> norm_tmp(sites);
 | 
				
			||||||
 | 
					  auto inner_tmp_v = &inner_tmp[0];
 | 
				
			||||||
 | 
					  auto norm_tmp_v = &norm_tmp[0];
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					  accelerator_for( ss, sites, nsimd,{
 | 
				
			||||||
 | 
					      auto left_tmp = left_v(ss);
 | 
				
			||||||
 | 
					      coalescedWrite(inner_tmp_v[ss],innerProduct(left_tmp,right_v(ss)));
 | 
				
			||||||
 | 
					      coalescedWrite(norm_tmp_v[ss],innerProduct(left_tmp,left_tmp)));
 | 
				
			||||||
 | 
					  });
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					  tmp[0] = TensorRemove(sumD_gpu(inner_tmp_v,sites));
 | 
				
			||||||
 | 
					  tmp[1] = TensorRemove(sumD_gpu(norm_tmp_v,sites));
 | 
				
			||||||
 | 
					#else
 | 
				
			||||||
 | 
					  // CPU
 | 
				
			||||||
 | 
					  typedef decltype(innerProductD(left_v[0],right_v[0])) inner_t;
 | 
				
			||||||
 | 
					  typedef decltype(innerProductD(left_v[0],left_v[0])) norm_t;
 | 
				
			||||||
 | 
					  Vector<inner_t> inner_tmp(sites);
 | 
				
			||||||
 | 
					  Vector<norm_t> norm_tmp(sites);
 | 
				
			||||||
 | 
					  auto inner_tmp_v = &inner_tmp[0];
 | 
				
			||||||
 | 
					  auto norm_tmp_v = &norm_tmp[0];
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					  accelerator_for( ss, sites, nsimd,{
 | 
				
			||||||
 | 
					      auto left_tmp = left_v(ss);
 | 
				
			||||||
 | 
					      inner_tmp_v[ss] = innerProductD(left_tmp,right_v(ss));
 | 
				
			||||||
 | 
					      norm_tmp_v[ss] = innerProductD(left_tmp,left_tmp);
 | 
				
			||||||
 | 
					  });
 | 
				
			||||||
 | 
					  // Already promoted to double
 | 
				
			||||||
 | 
					  tmp[0] = TensorRemove(sum(inner_tmp_v,sites));
 | 
				
			||||||
 | 
					  tmp[1] = TensorRemove(sum(norm_tmp_v,sites));
 | 
				
			||||||
 | 
					#endif
 | 
				
			||||||
 | 
					  grid->GlobalSumVector(&tmp[0],2); // keep norm Complex -> can use GlobalSumVector
 | 
				
			||||||
 | 
					  ip = tmp[0];
 | 
				
			||||||
 | 
					  nrm = real(tmp[1]);
 | 
				
			||||||
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
template<class Op,class T1>
 | 
					template<class Op,class T1>
 | 
				
			||||||
inline auto sum(const LatticeUnaryExpression<Op,T1> & expr)
 | 
					inline auto sum(const LatticeUnaryExpression<Op,T1> & expr)
 | 
				
			||||||
  ->typename decltype(expr.op.func(eval(0,expr.arg1)))::scalar_object
 | 
					  ->typename decltype(expr.op.func(eval(0,expr.arg1)))::scalar_object
 | 
				
			||||||
 
 | 
				
			|||||||
							
								
								
									
										126
									
								
								tests/Test_innerproduct_norm.cc
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										126
									
								
								tests/Test_innerproduct_norm.cc
									
									
									
									
									
										Normal file
									
								
							@@ -0,0 +1,126 @@
 | 
				
			|||||||
 | 
					/*************************************************************************************
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					Grid physics library, www.github.com/paboyle/Grid
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					Source file: ./tests/Test_innerproduct_norm.cc
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					Copyright (C) 2015
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					Author: Daniel Richtmann <daniel.richtmann@ur.de>
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					This program is free software; you can redistribute it and/or modify
 | 
				
			||||||
 | 
					it under the terms of the GNU General Public License as published by
 | 
				
			||||||
 | 
					the Free Software Foundation; either version 2 of the License, or
 | 
				
			||||||
 | 
					(at your option) any later version.
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					This program is distributed in the hope that it will be useful,
 | 
				
			||||||
 | 
					but WITHOUT ANY WARRANTY; without even the implied warranty of
 | 
				
			||||||
 | 
					MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
 | 
				
			||||||
 | 
					GNU General Public License for more details.
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					You should have received a copy of the GNU General Public License along
 | 
				
			||||||
 | 
					with this program; if not, write to the Free Software Foundation, Inc.,
 | 
				
			||||||
 | 
					51 Franklin Street, Fifth Floor, Boston, MA 02110-1301 USA.
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					See the full license in the file "LICENSE" in the top level distribution directory
 | 
				
			||||||
 | 
					*************************************************************************************/
 | 
				
			||||||
 | 
					/*  END LEGAL */
 | 
				
			||||||
 | 
					#include <Grid/Grid.h>
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					using namespace Grid;
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					int main(int argc, char** argv) {
 | 
				
			||||||
 | 
					  Grid_init(&argc, &argv);
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					  const int nIter = 100;
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					  // clang-format off
 | 
				
			||||||
 | 
					  GridCartesian *Grid_d = SpaceTimeGrid::makeFourDimGrid(GridDefaultLatt(), GridDefaultSimd(Nd, vComplexD::Nsimd()), GridDefaultMpi());
 | 
				
			||||||
 | 
					  GridCartesian *Grid_f = SpaceTimeGrid::makeFourDimGrid(GridDefaultLatt(), GridDefaultSimd(Nd, vComplexF::Nsimd()), GridDefaultMpi());
 | 
				
			||||||
 | 
					  // clang-format on
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					  GridParallelRNG pRNG_d(Grid_d);
 | 
				
			||||||
 | 
					  GridParallelRNG pRNG_f(Grid_f);
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					  std::vector<int> seeds_d({1, 2, 3, 4});
 | 
				
			||||||
 | 
					  std::vector<int> seeds_f({5, 6, 7, 8});
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					  pRNG_d.SeedFixedIntegers(seeds_d);
 | 
				
			||||||
 | 
					  pRNG_f.SeedFixedIntegers(seeds_f);
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					  // clang-format off
 | 
				
			||||||
 | 
					  LatticeFermionD x_d(Grid_d); random(pRNG_d, x_d);
 | 
				
			||||||
 | 
					  LatticeFermionD y_d(Grid_d); random(pRNG_d, y_d);
 | 
				
			||||||
 | 
					  LatticeFermionF x_f(Grid_f); random(pRNG_f, x_f);
 | 
				
			||||||
 | 
					  LatticeFermionF y_f(Grid_f); random(pRNG_f, y_f);
 | 
				
			||||||
 | 
					  // clang-format on
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					  GridStopWatch sw_ref;
 | 
				
			||||||
 | 
					  GridStopWatch sw_res;
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					  { // double precision
 | 
				
			||||||
 | 
					    ComplexD ip_d_ref, ip_d_res, diff_ip_d;
 | 
				
			||||||
 | 
					    RealD    norm2_d_ref, norm2_d_res, diff_norm2_d;
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    sw_ref.Reset();
 | 
				
			||||||
 | 
					    sw_ref.Start();
 | 
				
			||||||
 | 
					    for(int i = 0; i < nIter; ++i) {
 | 
				
			||||||
 | 
					      ip_d_ref    = innerProduct(x_d, y_d);
 | 
				
			||||||
 | 
					      norm2_d_ref = norm2(x_d);
 | 
				
			||||||
 | 
					    }
 | 
				
			||||||
 | 
					    sw_ref.Stop();
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    sw_res.Reset();
 | 
				
			||||||
 | 
					    sw_res.Start();
 | 
				
			||||||
 | 
					    for(int i = 0; i < nIter; ++i) { innerProduct_norm(ip_d_res, norm2_d_res, x_d, y_d); }
 | 
				
			||||||
 | 
					    sw_res.Stop();
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    diff_ip_d    = ip_d_ref - ip_d_res;
 | 
				
			||||||
 | 
					    diff_norm2_d = norm2_d_ref - norm2_d_res;
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    // clang-format off
 | 
				
			||||||
 | 
					    std::cout << GridLogMessage << "Double: ip_ref = " << ip_d_ref << " ip_res = " << ip_d_res << " diff = " << diff_ip_d << std::endl;
 | 
				
			||||||
 | 
					    std::cout << GridLogMessage << "Double: norm2_ref = " << norm2_d_ref << " norm2_res = " << norm2_d_res << " diff = " << diff_norm2_d << std::endl;
 | 
				
			||||||
 | 
					    std::cout << GridLogMessage << "Double: time_ref = " << sw_ref.Elapsed() << " time_res = " << sw_res.Elapsed() << std::endl;
 | 
				
			||||||
 | 
					    // clang-format on
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    assert(diff_ip_d == 0.);
 | 
				
			||||||
 | 
					    assert(diff_norm2_d == 0.);
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    std::cout << GridLogMessage << "Double: all checks passed" << std::endl;
 | 
				
			||||||
 | 
					  }
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					  { // single precision
 | 
				
			||||||
 | 
					    ComplexD ip_f_ref, ip_f_res, diff_ip_f;
 | 
				
			||||||
 | 
					    RealD    norm2_f_ref, norm2_f_res, diff_norm2_f;
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    sw_ref.Reset();
 | 
				
			||||||
 | 
					    sw_ref.Start();
 | 
				
			||||||
 | 
					    for(int i = 0; i < nIter; ++i) {
 | 
				
			||||||
 | 
					      ip_f_ref    = innerProduct(x_f, y_f);
 | 
				
			||||||
 | 
					      norm2_f_ref = norm2(x_f);
 | 
				
			||||||
 | 
					    }
 | 
				
			||||||
 | 
					    sw_ref.Stop();
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    sw_res.Reset();
 | 
				
			||||||
 | 
					    sw_res.Start();
 | 
				
			||||||
 | 
					    for(int i = 0; i < nIter; ++i) { innerProduct_norm(ip_f_res, norm2_f_res, x_f, y_f); }
 | 
				
			||||||
 | 
					    sw_res.Stop();
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    diff_ip_f    = ip_f_ref - ip_f_res;
 | 
				
			||||||
 | 
					    diff_norm2_f = norm2_f_ref - norm2_f_res;
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    // clang-format off
 | 
				
			||||||
 | 
					    std::cout << GridLogMessage << "Single: ip_ref = " << ip_f_ref << " ip_res = " << ip_f_res << " diff = " << diff_ip_f << std::endl;
 | 
				
			||||||
 | 
					    std::cout << GridLogMessage << "Single: norm2_ref = " << norm2_f_ref << " norm2_res = " << norm2_f_res << " diff = " << diff_norm2_f << std::endl;
 | 
				
			||||||
 | 
					    std::cout << GridLogMessage << "Single: time_ref = " << sw_ref.Elapsed() << " time_res = " << sw_res.Elapsed() << std::endl;
 | 
				
			||||||
 | 
					    // clang-format on
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    assert(diff_ip_f == 0.);
 | 
				
			||||||
 | 
					    assert(diff_norm2_f == 0.);
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    std::cout << GridLogMessage << "Single: all checks passed" << std::endl;
 | 
				
			||||||
 | 
					  }
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					  Grid_finalize();
 | 
				
			||||||
 | 
					}
 | 
				
			||||||
		Reference in New Issue
	
	Block a user