mirror of
https://github.com/paboyle/Grid.git
synced 2025-06-19 16:27:05 +01:00
Fused innerProduct + norm2 on first argument operation
This commit is contained in:
@ -204,8 +204,64 @@ axpby_norm_fast(Lattice<vobj> &z,sobj a,sobj b,const Lattice<vobj> &x,const Latt
|
||||
grid->GlobalSum(nrm);
|
||||
return nrm;
|
||||
}
|
||||
|
||||
|
||||
template<class vobj> strong_inline void
|
||||
innerProduct_norm(ComplexD& ip, RealD &nrm, const Lattice<vobj> &left,const Lattice<vobj> &right)
|
||||
{
|
||||
conformable(left,right);
|
||||
|
||||
typedef typename vobj::scalar_type scalar_type;
|
||||
typedef typename vobj::vector_typeD vector_type;
|
||||
Vector<ComplexD> tmp(2);
|
||||
|
||||
GridBase *grid = left.Grid();
|
||||
|
||||
auto left_v=left.View();
|
||||
auto right_v=right.View();
|
||||
|
||||
const uint64_t nsimd = grid->Nsimd();
|
||||
const uint64_t sites = grid->oSites();
|
||||
|
||||
#ifdef GRID_NVCC
|
||||
// GPU
|
||||
typedef decltype(innerProduct(left_v[0],right_v[0])) inner_t;
|
||||
typedef decltype(innerProduct(left_v[0],left_v[0])) norm_t;
|
||||
Vector<inner_t> inner_tmp(sites);
|
||||
Vector<norm_t> norm_tmp(sites);
|
||||
auto inner_tmp_v = &inner_tmp[0];
|
||||
auto norm_tmp_v = &norm_tmp[0];
|
||||
|
||||
accelerator_for( ss, sites, nsimd,{
|
||||
auto left_tmp = left_v(ss);
|
||||
coalescedWrite(inner_tmp_v[ss],innerProduct(left_tmp,right_v(ss)));
|
||||
coalescedWrite(norm_tmp_v[ss],innerProduct(left_tmp,left_tmp)));
|
||||
});
|
||||
|
||||
tmp[0] = TensorRemove(sumD_gpu(inner_tmp_v,sites));
|
||||
tmp[1] = TensorRemove(sumD_gpu(norm_tmp_v,sites));
|
||||
#else
|
||||
// CPU
|
||||
typedef decltype(innerProductD(left_v[0],right_v[0])) inner_t;
|
||||
typedef decltype(innerProductD(left_v[0],left_v[0])) norm_t;
|
||||
Vector<inner_t> inner_tmp(sites);
|
||||
Vector<norm_t> norm_tmp(sites);
|
||||
auto inner_tmp_v = &inner_tmp[0];
|
||||
auto norm_tmp_v = &norm_tmp[0];
|
||||
|
||||
accelerator_for( ss, sites, nsimd,{
|
||||
auto left_tmp = left_v(ss);
|
||||
inner_tmp_v[ss] = innerProductD(left_tmp,right_v(ss));
|
||||
norm_tmp_v[ss] = innerProductD(left_tmp,left_tmp);
|
||||
});
|
||||
// Already promoted to double
|
||||
tmp[0] = TensorRemove(sum(inner_tmp_v,sites));
|
||||
tmp[1] = TensorRemove(sum(norm_tmp_v,sites));
|
||||
#endif
|
||||
grid->GlobalSumVector(&tmp[0],2); // keep norm Complex -> can use GlobalSumVector
|
||||
ip = tmp[0];
|
||||
nrm = real(tmp[1]);
|
||||
}
|
||||
|
||||
template<class Op,class T1>
|
||||
inline auto sum(const LatticeUnaryExpression<Op,T1> & expr)
|
||||
->typename decltype(expr.op.func(eval(0,expr.arg1)))::scalar_object
|
||||
|
Reference in New Issue
Block a user