diff --git a/Grid/qcd/utils/BaryonUtils.h b/Grid/qcd/utils/BaryonUtils.h index 2a7fc9d0..5b7404c4 100644 --- a/Grid/qcd/utils/BaryonUtils.h +++ b/Grid/qcd/utils/BaryonUtils.h @@ -43,74 +43,67 @@ public: typedef typename FImpl::PropagatorField PropagatorField; typedef typename FImpl::SitePropagator pobj; - typedef typename FImpl::SiteSpinor vobj; + typedef typename ComplexField::vector_object vobj; + /* typedef typename FImpl::SiteSpinor vobj; typedef typename vobj::scalar_object sobj; typedef typename vobj::scalar_type scalar_type; typedef typename vobj::vector_type vector_type; - +*/ + private: + static void baryon_site(const pobj &D1, + const pobj &D2, + const pobj &D3, + const Gamma GammaA, + const Gamma GammaB, + const int parity, + const std::vector wick_contractions, + vobj &result); + public: static void ContractBaryons(const PropagatorField &q1_src, - const PropagatorField &q2_src, - const PropagatorField &q3_src, - const Gamma GammaA, - const Gamma GammaB, - const char * quarks_snk, - const char * quarks_src, - const int parity, - ComplexField &baryon_corr); - + const PropagatorField &q2_src, + const PropagatorField &q3_src, + const Gamma GammaA, + const Gamma GammaB, + const char * quarks_snk, + const char * quarks_src, + const int parity, + ComplexField &baryon_corr); + /* template + static void ContractBaryons_Sliced(const T1 &D1, + const T2 &D2, + const T3 &D3,*/ + static void ContractBaryons_Sliced(const pobj &D1, + const pobj &D2, + const pobj &D3, + const Gamma GammaA, + const Gamma GammaB, + const char * quarks_snk, + const char * quarks_src, + const int parity, + vobj &result); }; - template -void BaryonUtils::ContractBaryons(const PropagatorField &q1_src, - const PropagatorField &q2_src, - const PropagatorField &q3_src, +void BaryonUtils::baryon_site(const pobj &D1, + const pobj &D2, + const pobj &D3, const Gamma GammaA, const Gamma GammaB, - const char * quarks_snk, - const char * quarks_src, const int parity, - ComplexField &baryon_corr) + const std::vector wick_contraction, + vobj &result) { - std::cout << "quarks_snk " << quarks_snk[0] << quarks_snk[1] << quarks_snk[2] << std::endl; - std::cout << "GammaA " << (GammaA.g) << std::endl; - std::cout << "GammaB " << (GammaB.g) << std::endl; - - assert(parity==1 || parity == -1 && "Parity must be +1 or -1"); - - GridBase *grid = q1_src.Grid(); Gamma g4(Gamma::Algebra::GammaT); //needed for parity P_\pm = 0.5*(1 \pm \gamma_4) std::vector> epsilon = {{0,1,2},{1,2,0},{2,0,1},{0,2,1},{2,1,0},{1,0,2}}; std::vector epsilon_sgn = {1,1,1,-1,-1,-1}; - std::vector wick_contraction = {0,0,0,0,0,0}; - - for (int ie=0; ie < 6 ; ie++) - if (quarks_src[0] == quarks_snk[epsilon[ie][0]] && quarks_src[1] == quarks_snk[epsilon[ie][1]] && quarks_src[2] == quarks_snk[epsilon[ie][2]]) - wick_contraction[ie]=1; - - typedef typename ComplexField::vector_object vobj; - auto vbaryon_corr= baryon_corr.View(); - auto v1 = q1_src.View(); - auto v2 = q2_src.View(); - auto v3 = q3_src.View(); - - // accelerator_for(ss, grid->oSites(), grid->Nsimd(), { - //thread_for(ss,grid->oSites(),{ - for(int ss=0; ss < grid->oSites(); ss++){ - - auto D1 = v1[ss]; - auto D2 = v2[ss]; - auto D3 = v3[ss]; auto gD1a = GammaA * GammaA * D1; auto gD1b = GammaA * g4 * GammaA * D1; auto pD1 = 0.5* (gD1a + (double)parity * gD1b); auto gD3 = GammaB * D3; - vobj result=Zero(); - for (int ie_src=0; ie_src < 6 ; ie_src++){ int a_src = epsilon[ie_src][0]; //a int b_src = epsilon[ie_src][1]; //b @@ -175,8 +168,85 @@ void BaryonUtils::ContractBaryons(const PropagatorField &q1_src, } } } - vbaryon_corr[ss] = result; +} +template +void BaryonUtils::ContractBaryons(const PropagatorField &q1_src, + const PropagatorField &q2_src, + const PropagatorField &q3_src, + const Gamma GammaA, + const Gamma GammaB, + const char * quarks_snk, + const char * quarks_src, + const int parity, + ComplexField &baryon_corr) +{ + std::cout << "quarks_snk " << quarks_snk[0] << quarks_snk[1] << quarks_snk[2] << std::endl; + std::cout << "GammaA " << (GammaA.g) << std::endl; + std::cout << "GammaB " << (GammaB.g) << std::endl; + + assert(parity==1 || parity == -1 && "Parity must be +1 or -1"); + + GridBase *grid = q1_src.Grid(); + + Gamma g4(Gamma::Algebra::GammaT); //needed for parity P_\pm = 0.5*(1 \pm \gamma_4) + + std::vector> epsilon = {{0,1,2},{1,2,0},{2,0,1},{0,2,1},{2,1,0},{1,0,2}}; + std::vector epsilon_sgn = {1,1,1,-1,-1,-1}; + std::vector wick_contraction = {0,0,0,0,0,0}; + + for (int ie=0; ie < 6 ; ie++) + if (quarks_src[0] == quarks_snk[epsilon[ie][0]] && quarks_src[1] == quarks_snk[epsilon[ie][1]] && quarks_src[2] == quarks_snk[epsilon[ie][2]]) + wick_contraction[ie]=1; + +// typedef typename ComplexField::vector_object vobj; + auto vbaryon_corr= baryon_corr.View(); + auto v1 = q1_src.View(); + auto v2 = q2_src.View(); + auto v3 = q3_src.View(); + + // accelerator_for(ss, grid->oSites(), grid->Nsimd(), { + //thread_for(ss,grid->oSites(),{ + for(int ss=0; ss < grid->oSites(); ss++){ + + auto D1 = v1[ss]; + auto D2 = v2[ss]; + auto D3 = v3[ss]; + + vobj result=Zero(); + baryon_site(D1,D2,D3,GammaA,GammaB,parity,wick_contraction,result); + vbaryon_corr[ss] = result; } // );//end loop over lattice sites } +/*template +void BaryonUtils::ContractBaryons_Sliced(const T1 &D1, + const T2 &D2, + const T3 &D3,*/ +template +void BaryonUtils::ContractBaryons_Sliced(const pobj &D1, + const pobj &D2, + const pobj &D3, + const Gamma GammaA, + const Gamma GammaB, + const char * quarks_snk, + const char * quarks_src, + const int parity, + vobj &result) +{ + + assert(parity==1 || parity == -1 && "Parity must be +1 or -1"); + + Gamma g4(Gamma::Algebra::GammaT); //needed for parity P_\pm = 0.5*(1 \pm \gamma_4) + + std::vector> epsilon = {{0,1,2},{1,2,0},{2,0,1},{0,2,1},{2,1,0},{1,0,2}}; + std::vector epsilon_sgn = {1,1,1,-1,-1,-1}; + std::vector wick_contraction = {0,0,0,0,0,0}; + + for (int ie=0; ie < 6 ; ie++) + if (quarks_src[0] == quarks_snk[epsilon[ie][0]] && quarks_src[1] == quarks_snk[epsilon[ie][1]] && quarks_src[2] == quarks_snk[epsilon[ie][2]]) + wick_contraction[ie]=1; + + result=Zero(); + baryon_site(D1,D2,D3,GammaA,GammaB,parity,wick_contraction,result); +} NAMESPACE_END(Grid); diff --git a/Hadrons/Modules/MContraction/Baryon.hpp b/Hadrons/Modules/MContraction/Baryon.hpp index 4a3a39e4..51751f0a 100644 --- a/Hadrons/Modules/MContraction/Baryon.hpp +++ b/Hadrons/Modules/MContraction/Baryon.hpp @@ -34,7 +34,6 @@ See the full license in the file "LICENSE" in the top level distribution directo #include #include #include -#include BEGIN_HADRONS_NAMESPACE @@ -147,19 +146,58 @@ void TBaryon::execute(void) std::vector ggB = strToVec(par().GammaB); Gamma GammaB(ggB[0]); std::vector buf; + vTComplex cs; const int parity {par().parity}; const char * quarks_snk{par().quarks_snk.c_str()}; const char * quarks_src{par().quarks_src.c_str()}; - BaryonUtils::ContractBaryons(q1_src,q2_src,q3_src,GammaA,GammaB,quarks_snk,quarks_src,parity,c); - - //sliceSum(c,buf,Tp); - SinkFnScalar &sink = envGet(SinkFnScalar, par().sink); - buf = sink(c); - - for (unsigned int t = 0; t < buf.size(); ++t) + if (envHasType(SlicedPropagator1, par().q1_src) and + envHasType(SlicedPropagator2, par().q2_src) and + envHasType(SlicedPropagator3, par().q3_src)) { - result.corr[t] = TensorRemove(buf[t]); + auto &q1_src = envGet(SlicedPropagator1, par().q1_src); + auto &q2_src = envGet(SlicedPropagator2, par().q2_src); + auto &q3_src = envGet(SlicedPropagator3, par().q3_src); + + LOG(Message) << "(propagator already sinked)" << std::endl; + for (unsigned int t = 0; t < buf.size(); ++t) + { + //TODO: Get this to compile without the casts. Templates? + //BaryonUtils::ContractBaryons_Sliced(*reinterpret_cast, 4>>*>(&q1_src[t]),*reinterpret_cast, 4>>*>(&q2_src[t]),*reinterpret_cast, 4>>*>(&q3_src[t]),GammaA,GammaB,quarks_snk,quarks_src,parity,cs); + //result.corr[t] = TensorRemove(*reinterpret_cast(&cs)); + // BaryonUtils::ContractBaryons_Sliced(q1_src[t],q2_src[t],q3_src[t],GammaA,GammaB,quarks_snk,quarks_src,parity,cs); + // result.corr[t] = TensorRemove(cs); + } + } + else + { + std::string ns; + + ns = vm().getModuleNamespace(env().getObjectModule(par().sink)); + if (ns == "MSource") + { + //TODO: Understand what this is and then get it to compile. Hopefully no new function needed. The following lines are from the Meson.hpp module. + /* PropagatorField1 &sink = envGet(PropagatorField1, par().sink); + + c = trace(mesonConnected(q1, q2, gSnk, gSrc)*sink); + sliceSum(c, buf, Tp); */ +// My attempt at some code, which doesn't work. I also don't know whether anything like this is what we want here. + /* BaryonUtils::ContractBaryons(q1_src,q2_src,q3_src,GammaA,GammaB,quarks_snk,quarks_src,parity,c); + PropagatorField1 &sink = envGet(PropagatorField1, par().sink); + auto test = trace(c*sink); + sliceSum(test, buf, Tp); */ + } + else if (ns == "MSink") + { + BaryonUtils::ContractBaryons(q1_src,q2_src,q3_src,GammaA,GammaB,quarks_snk,quarks_src,parity,c); + + SinkFnScalar &sink = envGet(SinkFnScalar, par().sink); + buf = sink(c); + } + for (unsigned int t = 0; t < buf.size(); ++t) + { + result.corr[t] = TensorRemove(buf[t]); + } } saveResult(par().output, "baryon", result);