mirror of
https://github.com/paboyle/Grid.git
synced 2024-11-10 07:55:35 +00:00
fix inner product with thrust reduction
This commit is contained in:
parent
e3c56fd9b3
commit
1a82533d22
@ -23,12 +23,7 @@ Author: paboyle <paboyle@ph.ed.ac.uk>
|
|||||||
|
|
||||||
#include <Grid/Grid_Eigen_Dense.h>
|
#include <Grid/Grid_Eigen_Dense.h>
|
||||||
#ifdef GRID_NVCC
|
#ifdef GRID_NVCC
|
||||||
#include <thrust/host_vector.h>
|
#include <thrust/inner_product.h>
|
||||||
#include <thrust/device_vector.h>
|
|
||||||
#include <thrust/generate.h>
|
|
||||||
#include <thrust/reduce.h>
|
|
||||||
#include <thrust/functional.h>
|
|
||||||
#include <thrust/reduce.h>
|
|
||||||
#endif
|
#endif
|
||||||
|
|
||||||
NAMESPACE_BEGIN(Grid);
|
NAMESPACE_BEGIN(Grid);
|
||||||
@ -41,23 +36,12 @@ template<class vobj> inline RealD norm2(const Lattice<vobj> &arg){
|
|||||||
return real(nrm);
|
return real(nrm);
|
||||||
}
|
}
|
||||||
|
|
||||||
#if 0
|
#ifdef GRID_NVCC
|
||||||
//#warning "ThrustReduce compiled"
|
template<class T, class R>
|
||||||
//#include <thrust/execution_policy.h>
|
struct innerProductFunctor : public thrust::binary_function<T,T,R>
|
||||||
template<class vobj>
|
|
||||||
vobj ThrustNorm(const Lattice<vobj> &lat)
|
|
||||||
{
|
{
|
||||||
typedef typename vobj::scalar_type scalar_type;
|
accelerator R operator()(T x, T y) { return innerProduct(x,y); }
|
||||||
auto lat_v=lat.View();
|
};
|
||||||
Integer s0=0;
|
|
||||||
Integer sN=lat_v.end();
|
|
||||||
scalar_type sum = 0;
|
|
||||||
scalar_type * begin = (scalar_type *)&lat_v[s0];
|
|
||||||
scalar_type * end = (scalar_type *)&lat_v[sN];
|
|
||||||
thrust::reduce(begin,end,sum);
|
|
||||||
std::cout <<" thrust::reduce sum "<< sum << std::endl;
|
|
||||||
return sum;
|
|
||||||
}
|
|
||||||
#endif
|
#endif
|
||||||
|
|
||||||
// Double inner product
|
// Double inner product
|
||||||
@ -75,24 +59,17 @@ inline ComplexD innerProduct(const Lattice<vobj> &left,const Lattice<vobj> &righ
|
|||||||
auto left_v = left.View();
|
auto left_v = left.View();
|
||||||
auto right_v=right.View();
|
auto right_v=right.View();
|
||||||
|
|
||||||
#if 0
|
#ifdef GRID_NVCC
|
||||||
|
|
||||||
typedef decltype(TensorRemove(innerProduct(left_v[0],right_v[0]))) inner_t;
|
typedef decltype(innerProduct(left_v[0],right_v[0])) inner_t;
|
||||||
|
thrust::plus<inner_t> binary_sum;
|
||||||
|
innerProductFunctor<vobj,inner_t> binary_inner_p;
|
||||||
|
Integer sN = left_v.end();
|
||||||
|
inner_t zero = Zero();
|
||||||
|
// is there a way of using the efficient thrust reduction while maintaining memory coalescing?
|
||||||
|
inner_t vnrm = thrust::inner_product(thrust::device, &left_v[0], &left_v[sN], &right_v[0], zero, binary_sum, binary_inner_p);
|
||||||
|
nrm = Reduce(TensorRemove(vnrm));// sum across simd
|
||||||
|
|
||||||
Lattice<inner_t> inner_tmp(grid);
|
|
||||||
|
|
||||||
/////////////////////////
|
|
||||||
// localInnerProduct
|
|
||||||
/////////////////////////
|
|
||||||
auto inner_tmp_v = inner_tmp.View();
|
|
||||||
accelerator_loop(ss,left_v,{
|
|
||||||
inner_tmp_v[ss] = TensorRemove(innerProduct(left_v[ss],right_v[ss]));
|
|
||||||
});
|
|
||||||
/////////////////////////
|
|
||||||
// and site sum the scalars
|
|
||||||
/////////////////////////
|
|
||||||
inner_t vnrm = ThrustNorm(inner_tmp);
|
|
||||||
auto vvnrm = vnrm;
|
|
||||||
#else
|
#else
|
||||||
thread_loop( (int thr=0;thr<grid->SumArraySize();thr++),{
|
thread_loop( (int thr=0;thr<grid->SumArraySize();thr++),{
|
||||||
int mywork, myoff;
|
int mywork, myoff;
|
||||||
@ -109,8 +86,8 @@ inline ComplexD innerProduct(const Lattice<vobj> &left,const Lattice<vobj> &righ
|
|||||||
for(int i=0;i<grid->SumArraySize();i++){
|
for(int i=0;i<grid->SumArraySize();i++){
|
||||||
vvnrm = vvnrm+sumarray[i];
|
vvnrm = vvnrm+sumarray[i];
|
||||||
}
|
}
|
||||||
#endif
|
|
||||||
nrm = Reduce(vvnrm);// sum across simd
|
nrm = Reduce(vvnrm);// sum across simd
|
||||||
|
#endif
|
||||||
right.Grid()->GlobalSum(nrm);
|
right.Grid()->GlobalSum(nrm);
|
||||||
return nrm;
|
return nrm;
|
||||||
}
|
}
|
||||||
|
Loading…
Reference in New Issue
Block a user