From 2ef1fa66a8afe2066b8c1ef191a608bd64bdb3bd Mon Sep 17 00:00:00 2001 From: Christopher Kelly Date: Mon, 7 Dec 2020 11:53:35 -0500 Subject: [PATCH] Improved performance of G-parity kernel for GPUs by simplifying multLink implementation --- Grid/qcd/action/fermion/GparityWilsonImpl.h | 42 ++++++++------------- 1 file changed, 15 insertions(+), 27 deletions(-) diff --git a/Grid/qcd/action/fermion/GparityWilsonImpl.h b/Grid/qcd/action/fermion/GparityWilsonImpl.h index 0b726db9..9dca403b 100644 --- a/Grid/qcd/action/fermion/GparityWilsonImpl.h +++ b/Grid/qcd/action/fermion/GparityWilsonImpl.h @@ -97,42 +97,30 @@ public: Coordinate icoor; #ifdef GRID_SIMT - _Spinor tmp; - const int Nsimd =SiteDoubledGaugeField::Nsimd(); int s = acceleratorSIMTlane(Nsimd); St.iCoorFromIindex(icoor,s); int mmu = mu % Nd; - if ( SE->_around_the_world && St.parameters.twists[mmu] ) { - - int permute_lane = (sl==1) - || ((distance== 1)&&(icoor[direction]==1)) - || ((distance==-1)&&(icoor[direction]==0)); - if ( permute_lane ) { - tmp(0) = chi(1); - tmp(1) = chi(0); - } else { - tmp(0) = chi(0); - tmp(1) = chi(1); - } + auto UU0=coalescedRead(U(0)(mu)); + auto UU1=coalescedRead(U(1)(mu)); + + //Decide whether we do a G-parity flavor twist + //Note: this assumes (but does not check) that sl==1 || sl==2 i.e. max 2 SIMD lanes in G-parity dir + //It also assumes (but does not check) that abs(distance) == 1 + int permute_lane = (sl==1) + || ((distance== 1)&&(icoor[direction]==1)) + || ((distance==-1)&&(icoor[direction]==0)); - auto UU0=coalescedRead(U(0)(mu)); - auto UU1=coalescedRead(U(1)(mu)); + permute_lane = permute_lane && SE->_around_the_world && St.parameters.twists[mmu]; //only if we are going around the world - mult(&phi(0),&UU0,&tmp(0)); - mult(&phi(1),&UU1,&tmp(1)); + //Apply the links + int f_upper = permute_lane ? 1 : 0; + int f_lower = !f_upper; - } else { - - auto UU0=coalescedRead(U(0)(mu)); - auto UU1=coalescedRead(U(1)(mu)); - - mult(&phi(0),&UU0,&chi(0)); - mult(&phi(1),&UU1,&chi(1)); - - } + mult(&phi(0),&UU0,&chi(f_upper)); + mult(&phi(1),&UU1,&chi(f_lower)); #else typedef _Spinor vobj;