From bf4369f72de9bd241a0bde9e5382c3dc57e38f22 Mon Sep 17 00:00:00 2001 From: david clarke Date: Thu, 12 Oct 2023 12:41:06 -0600 Subject: [PATCH] clean up HISQSmear with decltypes --- Grid/qcd/smearing/HISQSmearing.h | 101 ++++++++++++++----------------- 1 file changed, 47 insertions(+), 54 deletions(-) diff --git a/Grid/qcd/smearing/HISQSmearing.h b/Grid/qcd/smearing/HISQSmearing.h index 1ea1b7b9..432184e0 100644 --- a/Grid/qcd/smearing/HISQSmearing.h +++ b/Grid/qcd/smearing/HISQSmearing.h @@ -40,6 +40,8 @@ directory NAMESPACE_BEGIN(Grid); +// TODO: find a way to fold this into the stencil header. need to access grid to get +// Nd, since you don't want to inherit from QCD.h /*! @brief append arbitrary shift path to shifts */ template void appendShift(std::vector& shifts, int dir, Args... args) { @@ -51,16 +53,6 @@ void appendShift(std::vector& shifts, int dir, Args... args) { } -// This is to optimize the SIMD (will also need to be in the class, at least for now) -template void gpermute(vobj & inout,int perm) { - vobj tmp=inout; - if (perm & 0x1) {permute(inout,tmp,0); tmp=inout;} - if (perm & 0x2) {permute(inout,tmp,1); tmp=inout;} - if (perm & 0x4) {permute(inout,tmp,2); tmp=inout;} - if (perm & 0x8) {permute(inout,tmp,3); tmp=inout;} -} - - /*! @brief figure out the stencil index from mu and nu */ inline int stencilIndex(int mu, int nu) { // Nshifts depends on how you built the stencil @@ -163,6 +155,12 @@ public: Ughost_5linkA=Zero(); Ughost_5linkB=Zero(); + // We infer some types that will be needed in the calculation. + typedef decltype(gStencil.GetEntry(0,0)) stencilElement; + typedef decltype(coalescedReadGeneralPermute(U_v[0](0),gStencil.GetEntry(0,0)->_permute,Nd)) U3matrix; + stencilElement SE0, SE1, SE2, SE3, SE4; + U3matrix U0, U1, U2, U3, U4, U5, W; + // 3-link for(int site=0;site_offset; - auto SE1 = gStencil.GetEntry(s+1,site); int x_p_nu = SE1->_offset; - auto SE2 = gStencil.GetEntry(s+2,site); int x = SE2->_offset; - auto SE3 = gStencil.GetEntry(s+3,site); int x_p_mu_m_nu = SE3->_offset; - auto SE4 = gStencil.GetEntry(s+4,site); int x_m_nu = SE4->_offset; + SE0 = gStencil.GetEntry(s+0,site); int x_p_mu = SE0->_offset; + SE1 = gStencil.GetEntry(s+1,site); int x_p_nu = SE1->_offset; + SE2 = gStencil.GetEntry(s+2,site); int x = SE2->_offset; + SE3 = gStencil.GetEntry(s+3,site); int x_p_mu_m_nu = SE3->_offset; + SE4 = gStencil.GetEntry(s+4,site); int x_m_nu = SE4->_offset; // When you're deciding whether to take an adjoint, the question is: how is the // stored link oriented compared to the one you want? If I imagine myself travelling // with the to-be-updated link, I have two possible, alternative 3-link paths I can // take, one starting by going to the left, the other starting by going to the right. - auto U0 = U_v[x_p_mu ](nu); gpermute(U0,SE0->_permute); - auto U1 = U_v[x_p_nu ](mu); gpermute(U1,SE1->_permute); - auto U2 = U_v[x ](nu); gpermute(U2,SE2->_permute); - auto U3 = U_v[x_p_mu_m_nu](nu); gpermute(U3,SE3->_permute); - auto U4 = U_v[x_m_nu ](mu); gpermute(U4,SE4->_permute); - auto U5 = U_v[x_m_nu ](nu); gpermute(U5,SE4->_permute); + U0 = coalescedReadGeneralPermute(U_v[x_p_mu ](nu),SE0->_permute,Nd); + U1 = coalescedReadGeneralPermute(U_v[x_p_nu ](mu),SE1->_permute,Nd); + U2 = coalescedReadGeneralPermute(U_v[x ](nu),SE2->_permute,Nd); + U3 = coalescedReadGeneralPermute(U_v[x_p_mu_m_nu](nu),SE3->_permute,Nd); + U4 = coalescedReadGeneralPermute(U_v[x_m_nu ](mu),SE4->_permute,Nd); + U5 = coalescedReadGeneralPermute(U_v[x_m_nu ](nu),SE4->_permute,Nd); - // "left" "right" - auto W = U2*U1*adj(U0) + adj(U5)*U4*U3; + // "left" "right" + W = U2*U1*adj(U0) + adj(U5)*U4*U3; U_3link_v[site](nu) = W; @@ -197,7 +195,6 @@ public: } } - // 5-link for(int site=0;site_offset; - auto SE1 = gStencil.GetEntry(s+1,site); int x_p_nu = SE1->_offset; - auto SE2 = gStencil.GetEntry(s+2,site); int x = SE2->_offset; - auto SE3 = gStencil.GetEntry(s+3,site); int x_p_mu_m_nu = SE3->_offset; - auto SE4 = gStencil.GetEntry(s+4,site); int x_m_nu = SE4->_offset; + SE0 = gStencil.GetEntry(s+0,site); int x_p_mu = SE0->_offset; + SE1 = gStencil.GetEntry(s+1,site); int x_p_nu = SE1->_offset; + SE2 = gStencil.GetEntry(s+2,site); int x = SE2->_offset; + SE3 = gStencil.GetEntry(s+3,site); int x_p_mu_m_nu = SE3->_offset; + SE4 = gStencil.GetEntry(s+4,site); int x_m_nu = SE4->_offset; // gpermutes will be replaced with single line of code, combines load and permute // into one step. still in pull request stage - auto U0 = U_v[x_p_mu ](nu) ; gpermute(U0,SE0->_permute); - auto U1 = U_3link_v[x_p_nu ](rho); gpermute(U1,SE1->_permute); - auto U2 = U_v[x ](nu) ; gpermute(U2,SE2->_permute); - auto U3 = U_v[x_p_mu_m_nu](nu) ; gpermute(U3,SE3->_permute); - auto U4 = U_3link_v[x_m_nu ](rho); gpermute(U4,SE4->_permute); - auto U5 = U_v[x_m_nu ](nu) ; gpermute(U5,SE4->_permute); + U0 = coalescedReadGeneralPermute( U_v[x_p_mu ](nu ),SE0->_permute,Nd); + U1 = coalescedReadGeneralPermute(U_3link_v[x_p_nu ](rho),SE1->_permute,Nd); + U2 = coalescedReadGeneralPermute( U_v[x ](nu ),SE2->_permute,Nd); + U3 = coalescedReadGeneralPermute( U_v[x_p_mu_m_nu](nu ),SE3->_permute,Nd); + U4 = coalescedReadGeneralPermute(U_3link_v[x_m_nu ](rho),SE4->_permute,Nd); + U5 = coalescedReadGeneralPermute( U_v[x_m_nu ](nu ),SE4->_permute,Nd); - auto W = U2*U1*adj(U0) + adj(U5)*U4*U3; + W = U2*U1*adj(U0) + adj(U5)*U4*U3; if(sigmaIndex<3) { U_5linkA_v[site](rho) = W; @@ -246,33 +243,29 @@ public: for(int rho=0;rho_offset; - auto SE1 = gStencil.GetEntry(s+1,site); int x_p_nu = SE1->_offset; - auto SE2 = gStencil.GetEntry(s+2,site); int x = SE2->_offset; - auto SE3 = gStencil.GetEntry(s+3,site); int x_p_mu_m_nu = SE3->_offset; - auto SE4 = gStencil.GetEntry(s+4,site); int x_m_nu = SE4->_offset; + SE0 = gStencil.GetEntry(s+0,site); int x_p_mu = SE0->_offset; + SE1 = gStencil.GetEntry(s+1,site); int x_p_nu = SE1->_offset; + SE2 = gStencil.GetEntry(s+2,site); int x = SE2->_offset; + SE3 = gStencil.GetEntry(s+3,site); int x_p_mu_m_nu = SE3->_offset; + SE4 = gStencil.GetEntry(s+4,site); int x_m_nu = SE4->_offset; - auto U0 = U_v[x_p_mu ](nu) ; gpermute(U0,SE0->_permute); - // decltype, or auto U1 = { ? ... } - auto U1 = U0; + U0 = coalescedReadGeneralPermute(U_v[x_p_mu](nu),SE0->_permute,Nd); if(sigmaIndex<3) { - U1 = U_5linkB_v[x_p_nu](rho); gpermute(U1,SE1->_permute); + U1 = coalescedReadGeneralPermute(U_5linkB_v[x_p_nu](rho),SE1->_permute,Nd); } else { - U1 = U_5linkA_v[x_p_nu](rho); gpermute(U1,SE1->_permute); + U1 = coalescedReadGeneralPermute(U_5linkA_v[x_p_nu](rho),SE1->_permute,Nd); } - auto U2 = U_v[x ](nu) ; gpermute(U2,SE2->_permute); - auto U3 = U_v[x_p_mu_m_nu](nu) ; gpermute(U3,SE3->_permute); - auto U4 = U0; + U2 = coalescedReadGeneralPermute(U_v[x](nu),SE2->_permute,Nd); + U3 = coalescedReadGeneralPermute(U_v[x_p_mu_m_nu](nu),SE3->_permute,Nd); if(sigmaIndex<3) { - U4 = U_5linkB_v[x_m_nu](rho); gpermute(U4,SE4->_permute); + U4 = coalescedReadGeneralPermute(U_5linkB_v[x_m_nu](rho),SE4->_permute,Nd); } else { - U4 = U_5linkA_v[x_m_nu](rho); gpermute(U4,SE4->_permute); + U4 = coalescedReadGeneralPermute(U_5linkA_v[x_m_nu](rho),SE4->_permute,Nd); } - auto U5 = U_v[x_m_nu ](nu) ; gpermute(U5,SE4->_permute); + U5 = coalescedReadGeneralPermute(U_v[x_m_nu](nu),SE4->_permute,Nd); - auto W = U2*U1*adj(U0) + adj(U5)*U4*U3; + W = U2*U1*adj(U0) + adj(U5)*U4*U3; - // std::vector(3) ? U_fat_v[site](mu) = U_fat_v[site](mu) + lt.c_7*W; sigmaIndex++;