diff --git a/lib/lattice/Lattice_arith.h b/lib/lattice/Lattice_arith.h index c3093167..d1cbc84a 100644 --- a/lib/lattice/Lattice_arith.h +++ b/lib/lattice/Lattice_arith.h @@ -244,19 +244,11 @@ namespace Grid { template strong_inline RealD axpy_norm(Lattice &ret,sobj a,const Lattice &x,const Lattice &y){ - ret.checkerboard = x.checkerboard; - conformable(ret,x); - conformable(x,y); - axpy(ret,a,x,y); - return norm2(ret); + return axpy_norm_fast(ret,a,x,y); } template strong_inline RealD axpby_norm(Lattice &ret,sobj a,sobj b,const Lattice &x,const Lattice &y){ - ret.checkerboard = x.checkerboard; - conformable(ret,x); - conformable(x,y); - axpby(ret,a,b,x,y); - return norm2(ret); // FIXME implement parallel norm in ss loop + return axpby_norm_fast(ret,a,b,x,y); } } diff --git a/lib/lattice/Lattice_reduction.h b/lib/lattice/Lattice_reduction.h index 8a3fbece..7e169baf 100644 --- a/lib/lattice/Lattice_reduction.h +++ b/lib/lattice/Lattice_reduction.h @@ -33,7 +33,7 @@ namespace Grid { // Deterministic Reduction operations //////////////////////////////////////////////////////////////////////////////////////////////////// template inline RealD norm2(const Lattice &arg){ - ComplexD nrm = innerProduct(arg,arg); + auto nrm = innerProduct(arg,arg); return std::real(nrm); } @@ -43,12 +43,12 @@ inline ComplexD innerProduct(const Lattice &left,const Lattice &righ { typedef typename vobj::scalar_type scalar_type; typedef typename vobj::vector_typeD vector_type; - scalar_type nrm; - GridBase *grid = left._grid; - - std::vector > sumarray(grid->SumArraySize()); - + const int pad = 8; + + scalar_type nrm; + std::vector > sumarray(grid->SumArraySize()*pad); + parallel_for(int thr=0;thrSumArraySize();thr++){ int nwork, mywork, myoff; GridThread::GetWork(left._grid->oSites(),thr,mywork,myoff); @@ -57,17 +57,69 @@ inline ComplexD innerProduct(const Lattice &left,const Lattice &righ for(int ss=myoff;ssSumArraySize();i++){ - vvnrm = vvnrm+sumarray[i]; + nrm = nrm+sumarray[i*pad]; } - nrm = Reduce(vvnrm);// sum across simd right._grid->GlobalSum(nrm); return nrm; } + +///////////////////////// +// Fast axpby_norm +// z = a x + b y +// return norm z +///////////////////////// +template strong_inline RealD +axpy_norm_fast(Lattice &z,sobj a,const Lattice &x,const Lattice &y) +{ + sobj one(1.0); + return axpby_norm_fast(z,a,one,x,y); +} + +template strong_inline RealD +axpby_norm_fast(Lattice &z,sobj a,sobj b,const Lattice &x,const Lattice &y) +{ + const int pad = 8; + z.checkerboard = x.checkerboard; + conformable(z,x); + conformable(x,y); + + typedef typename vobj::scalar_type scalar_type; + typedef typename vobj::vector_typeD vector_type; + RealD nrm; + + GridBase *grid = x._grid; + + Vector sumarray(grid->SumArraySize()*pad); + + parallel_for(int thr=0;thrSumArraySize();thr++){ + int nwork, mywork, myoff; + GridThread::GetWork(x._grid->oSites(),thr,mywork,myoff); + + // private to thread; sub summation + decltype(innerProductD(z._odata[0],z._odata[0])) vnrm=zero; + for(int ss=myoff;ssSumArraySize();i++){ + nrm = nrm+sumarray[i*pad]; + } + z._grid->GlobalSum(nrm); + return nrm; +} + template inline auto sum(const LatticeUnaryExpression & expr)