diff --git a/lib/lattice/Grid_lattice_arith.h b/lib/lattice/Grid_lattice_arith.h index 4d859173..63ad3335 100644 --- a/lib/lattice/Grid_lattice_arith.h +++ b/lib/lattice/Grid_lattice_arith.h @@ -12,7 +12,7 @@ namespace Grid { Lattice ret(r._grid); #pragma omp parallel for for(int ss=0;ssoSites();ss++){ - ret._odata[ss]= -r._odata[ss]; + vstream(ret._odata[ss], -r._odata[ss]); } return ret; } @@ -23,20 +23,22 @@ namespace Grid { template void mult(Lattice &ret,const Lattice &lhs,const Lattice &rhs){ conformable(lhs,rhs); - uint32_t vec_len = lhs._grid->oSites(); #pragma omp parallel for - for(int ss=0;ssoSites();ss++){ + obj1 tmp; + mult(&tmp,&lhs._odata[ss],&rhs._odata[ss]); + vstream(ret._odata[ss],tmp); } } template void mac(Lattice &ret,const Lattice &lhs,const Lattice &rhs){ conformable(lhs,rhs); - uint32_t vec_len = lhs._grid->oSites(); #pragma omp parallel for - for(int ss=0;ssoSites();ss++){ + obj1 tmp; + mac(&tmp,&lhs._odata[ss],&rhs._odata[ss]); + vstream(ret._odata[ss],tmp); } } @@ -45,7 +47,9 @@ namespace Grid { conformable(lhs,rhs); #pragma omp parallel for for(int ss=0;ssoSites();ss++){ - sub(&ret._odata[ss],&lhs._odata[ss],&rhs._odata[ss]); + obj1 tmp; + sub(&tmp,&lhs._odata[ss],&rhs._odata[ss]); + vstream(ret._odata[ss],tmp); } } template @@ -53,7 +57,9 @@ namespace Grid { conformable(lhs,rhs); #pragma omp parallel for for(int ss=0;ssoSites();ss++){ - add(&ret._odata[ss],&lhs._odata[ss],&rhs._odata[ss]); + obj1 tmp; + add(&tmp,&lhs._odata[ss],&rhs._odata[ss]); + vstream(ret._odata[ss],tmp); } } @@ -62,88 +68,100 @@ namespace Grid { ////////////////////////////////////////////////////////////////////////////////////////////////////// template void mult(Lattice &ret,const Lattice &lhs,const obj3 &rhs){ - conformable(lhs,rhs); - uint32_t vec_len = lhs._grid->oSites(); + conformable(lhs,ret); #pragma omp parallel for - for(int ss=0;ssoSites();ss++){ + obj1 tmp; + mult(&tmp,&lhs._odata[ss],&rhs); + vstream(ret._odata[ss],tmp); } } template void mac(Lattice &ret,const Lattice &lhs,const obj3 &rhs){ - conformable(lhs,rhs); - uint32_t vec_len = lhs._grid->oSites(); + conformable(lhs,ret); #pragma omp parallel for - for(int ss=0;ssoSites();ss++){ + obj1 tmp; + mac(&tmp,&lhs._odata[ss],&rhs); + vstream(ret._odata[ss],tmp); } } template void sub(Lattice &ret,const Lattice &lhs,const obj3 &rhs){ - conformable(lhs,rhs); + conformable(lhs,ret); #pragma omp parallel for for(int ss=0;ssoSites();ss++){ - sub(&ret._odata[ss],&lhs._odata[ss],&rhs); + obj1 tmp; + sub(&tmp,&lhs._odata[ss],&rhs); + vstream(ret._odata[ss],tmp); } } template void add(Lattice &ret,const Lattice &lhs,const obj3 &rhs){ - conformable(lhs,rhs); + conformable(lhs,ret); #pragma omp parallel for for(int ss=0;ssoSites();ss++){ - add(&ret._odata[ss],&lhs._odata[ss],&rhs); + obj1 tmp; + add(&tmp,&lhs._odata[ss],&rhs); + vstream(ret._odata[ss],tmp); } } ////////////////////////////////////////////////////////////////////////////////////////////////////// // avoid copy back routines for mult, mac, sub, add ////////////////////////////////////////////////////////////////////////////////////////////////////// - template + template void mult(Lattice &ret,const obj2 &lhs,const Lattice &rhs){ - conformable(lhs,rhs); - uint32_t vec_len = lhs._grid->oSites(); + conformable(ret,rhs); #pragma omp parallel for - for(int ss=0;ssoSites();ss++){ + obj1 tmp; + mult(&tmp,&lhs,&rhs._odata[ss]); + vstream(ret._odata[ss],tmp); } } template void mac(Lattice &ret,const obj2 &lhs,const Lattice &rhs){ - conformable(lhs,rhs); - uint32_t vec_len = lhs._grid->oSites(); + conformable(ret,rhs); #pragma omp parallel for - for(int ss=0;ssoSites();ss++){ + obj1 tmp; + mac(&tmp,&lhs,&rhs._odata[ss]); + vstream(ret._odata[ss],tmp); } } template void sub(Lattice &ret,const obj2 &lhs,const Lattice &rhs){ - conformable(lhs,rhs); + conformable(ret,rhs); #pragma omp parallel for - for(int ss=0;ssoSites();ss++){ - sub(&ret._odata[ss],&lhs,&rhs._odata[ss]); + for(int ss=0;ssoSites();ss++){ + obj1 tmp; + sub(&tmp,&lhs,&rhs._odata[ss]); + vstream(ret._odata[ss],tmp); } } template void add(Lattice &ret,const obj2 &lhs,const Lattice &rhs){ - conformable(lhs,rhs); + conformable(ret,rhs); #pragma omp parallel for - for(int ss=0;ssoSites();ss++){ - add(&ret._odata[ss],&lhs,&rhs._odata[ss]); + for(int ss=0;ssoSites();ss++){ + obj1 tmp; + add(&tmp,&lhs,&rhs._odata[ss]); + vstream(ret._odata[ss],tmp); } } ///////////////////////////////////////////////////////////////////////////////////// // Lattice BinOp Lattice, + //NB mult performs conformable check. Do not reapply here for performance. ///////////////////////////////////////////////////////////////////////////////////// template inline auto operator * (const Lattice &lhs,const Lattice &rhs)-> Lattice { - //NB mult performs conformable check. Do not reapply here for performance. Lattice ret(rhs._grid); mult(ret,lhs,rhs); return ret; @@ -151,7 +169,6 @@ namespace Grid { template inline auto operator + (const Lattice &lhs,const Lattice &rhs)-> Lattice { - //NB mult performs conformable check. Do not reapply here for performance. Lattice ret(rhs._grid); add(ret,lhs,rhs); return ret; @@ -159,7 +176,6 @@ namespace Grid { template inline auto operator - (const Lattice &lhs,const Lattice &rhs)-> Lattice { - //NB mult performs conformable check. Do not reapply here for performance. Lattice ret(rhs._grid); sub(ret,lhs,rhs); return ret; @@ -172,9 +188,11 @@ namespace Grid { Lattice ret(rhs._grid); #pragma omp parallel for for(int ss=0;ssoSites(); ss++){ - ret._odata[ss]=lhs*rhs._odata[ss]; + decltype(lhs*rhs._odata[0]) tmp=lhs*rhs._odata[ss]; + vstream(ret._odata[ss],tmp); + // ret._odata[ss]=lhs*rhs._odata[ss]; } - return ret; + return ret; } template inline auto operator + (const left &lhs,const Lattice &rhs) -> Lattice @@ -182,7 +200,9 @@ namespace Grid { Lattice ret(rhs._grid); #pragma omp parallel for for(int ss=0;ssoSites(); ss++){ - ret._odata[ss]=lhs+rhs._odata[ss]; + decltype(lhs+rhs._odata[0]) tmp =lhs-rhs._odata[ss]; + vstream(ret._odata[ss],tmp); + // ret._odata[ss]=lhs+rhs._odata[ss]; } return ret; } @@ -192,7 +212,9 @@ namespace Grid { Lattice ret(rhs._grid); #pragma omp parallel for for(int ss=0;ssoSites(); ss++){ - ret._odata[ss]=lhs-rhs._odata[ss]; + decltype(lhs-rhs._odata[0]) tmp=lhs-rhs._odata[ss]; + vstream(ret._odata[ss],tmp); + // ret._odata[ss]=lhs-rhs._odata[ss]; } return ret; } @@ -202,7 +224,9 @@ namespace Grid { Lattice ret(lhs._grid); #pragma omp parallel for for(int ss=0;ssoSites(); ss++){ - ret._odata[ss]=lhs._odata[ss]*rhs; + decltype(lhs._odata[0]*rhs) tmp =lhs._odata[ss]*rhs; + vstream(ret._odata[ss],tmp); + // ret._odata[ss]=lhs._odata[ss]*rhs; } return ret; } @@ -212,7 +236,9 @@ namespace Grid { Lattice ret(lhs._grid); #pragma omp parallel for for(int ss=0;ssoSites(); ss++){ - ret._odata[ss]=lhs._odata[ss]+rhs; + decltype(lhs._odata[0]+rhs) tmp=lhs._odata[ss]+rhs; + vstream(ret._odata[ss],tmp); + // ret._odata[ss]=lhs._odata[ss]+rhs; } return ret; } @@ -222,7 +248,9 @@ namespace Grid { Lattice ret(lhs._grid); #pragma omp parallel for for(int ss=0;ssoSites(); ss++){ - ret._odata[ss]=lhs._odata[ss]-rhs; + decltype(lhs._odata[0]-rhs) tmp=lhs._odata[ss]-rhs; + vstream(ret._odata[ss],tmp); + // ret._odata[ss]=lhs._odata[ss]-rhs; } return ret; } @@ -230,11 +258,10 @@ namespace Grid { template inline void axpy(Lattice &ret,sobj a,const Lattice &lhs,const Lattice &rhs){ conformable(lhs,rhs); - vobj tmp; #pragma omp parallel for for(int ss=0;ssoSites();ss++){ - tmp = a*lhs._odata[ss]; - ret._odata[ss]= tmp+rhs._odata[ss]; + vobj tmp = a*lhs._odata[ss]; + vstream(ret._odata[ss],tmp+rhs._odata[ss]); } }