From d55212c998bba55b78bb676eb8a33f86fe6825cc Mon Sep 17 00:00:00 2001 From: Vera Guelpers Date: Wed, 14 Feb 2018 10:45:18 +0000 Subject: [PATCH] restructure SeqConservedCurrent for DWF to need less memory --- lib/qcd/action/fermion/WilsonFermion5D.cc | 96 ++++++++++------------- 1 file changed, 43 insertions(+), 53 deletions(-) diff --git a/lib/qcd/action/fermion/WilsonFermion5D.cc b/lib/qcd/action/fermion/WilsonFermion5D.cc index 92280c13..6b830a03 100644 --- a/lib/qcd/action/fermion/WilsonFermion5D.cc +++ b/lib/qcd/action/fermion/WilsonFermion5D.cc @@ -883,8 +883,7 @@ void WilsonFermion5D::SeqConservedCurrent(PropagatorField &q_in, { conformable(q_in._grid, FermionGrid()); conformable(q_in._grid, q_out._grid); - PropagatorField tmpFwd(FermionGrid()), tmpBwd(FermionGrid()), - tmp(FermionGrid()); + PropagatorField tmp(GaugeGrid()),tmp2(GaugeGrid()); Complex i(0.0, 1.0); unsigned int tshift = (mu == Tp) ? 1 : 0; unsigned int LLs = q_in._grid->_rdimensions[0]; @@ -895,69 +894,60 @@ void WilsonFermion5D::SeqConservedCurrent(PropagatorField &q_in, LatticeCoordinate(coords, Tp); - //QED: photon field is 4dim, but need a 5dim object to multiply to - // DWF PropagatorField - Lattice> lattice_cmplx_5d(FermionGrid()); for (unsigned int s = 0; s < LLs; ++s) - { - InsertSlice(lattice_cmplx,lattice_cmplx_5d, s, 0); - } - - - - // Need q(x + mu, s) and q(x - mu, s). 5D lattice so shift 4D coordinate mu - // by one. - tmp = Cshift(q_in, mu + 1, 1); - tmpFwd = tmp*lattice_cmplx_5d; - tmp = lattice_cmplx_5d*q_in; - tmpBwd = Cshift(tmp, mu + 1, -1); - - parallel_for (unsigned int sU = 0; sU < Umu._grid->oSites(); ++sU) { - // Compute the sequential conserved current insertion only if our simd - // object contains a timeslice we need. - vInteger t_mask = ((coords._odata[sU] >= tmin) && - (coords._odata[sU] <= tmax)); - Integer timeSlices = Reduce(t_mask); + bool axial_sign = ((curr_type == Current::Axial) && (s < (LLs / 2))); + bool tadpole_sign = (curr_type == Current::Tadpole); + bool switch_sgn = tadpole_sign || axial_sign; - if (timeSlices > 0) - { - unsigned int sF = sU * LLs; - for (unsigned int s = 0; s < LLs; ++s) + + //forward direction: Need q(x + mu, s)*A(x) + ExtractSlice(tmp2, q_in, s, 0); //q(x,s) + tmp = Cshift(tmp2, mu, 1); //q(x+mu,s) + tmp2 = tmp*lattice_cmplx; //q(x+mu,s)*A(x) + + parallel_for (unsigned int sU = 0; sU < Umu._grid->oSites(); ++sU) + { + // Compute the sequential conserved current insertion only if our simd + // object contains a timeslice we need. + vInteger t_mask = ((coords._odata[sU] >= tmin) && + (coords._odata[sU] <= tmax)); + Integer timeSlices = Reduce(t_mask); + + if (timeSlices > 0) { - bool axial_sign = ((curr_type == Current::Axial) && (s < (LLs / 2))); - bool tadpole_sign = (curr_type == Current::Tadpole); - bool switch_sgn = tadpole_sign || axial_sign; - - Kernels::SeqConservedCurrentSiteFwd(tmpFwd._odata[sF], - q_out._odata[sF], Umu, sU, - mu, t_mask, switch_sgn); - ++sF; + unsigned int sF = sU * LLs + s; + Kernels::SeqConservedCurrentSiteFwd(tmp2._odata[sU], + q_out._odata[sF], Umu, sU, + mu, t_mask, switch_sgn); } + } - // Repeat for backward direction. - t_mask = ((coords._odata[sU] >= (tmin + tshift)) && - (coords._odata[sU] <= (tmax + tshift))); + //backward direction: Need q(x - mu, s)*A(x-mu) + ExtractSlice(tmp2, q_in, s, 0); //q(x,s) + tmp = lattice_cmplx*tmp2; //q(x,s)*A(x) + tmp2 = Cshift(tmp, mu, -1); //q(x-mu,s)*A(x-mu,s) - //if tmax = LLt-1 (last timeslice) include timeslice 0 if the time is shifted (mu=3) - unsigned int t0 = 0; - if((tmax==LLt-1) && (tshift==1)) t_mask = (t_mask || (coords._odata[sU] == t0 )); + parallel_for (unsigned int sU = 0; sU < Umu._grid->oSites(); ++sU) + { + vInteger t_mask = ((coords._odata[sU] >= (tmin + tshift)) && + (coords._odata[sU] <= (tmax + tshift))); - timeSlices = Reduce(t_mask); + //if tmax = LLt-1 (last timeslice) include timeslice 0 if the time is shifted (mu=3) + unsigned int t0 = 0; + if((tmax==LLt-1) && (tshift==1)) t_mask = (t_mask || (coords._odata[sU] == t0 )); - if (timeSlices > 0) - { - unsigned int sF = sU * LLs; - for (unsigned int s = 0; s < LLs; ++s) + Integer timeSlices = Reduce(t_mask); + + if (timeSlices > 0) { - bool axial_sign = ((curr_type == Current::Axial) && (s < (LLs / 2))); - Kernels::SeqConservedCurrentSiteBwd(tmpBwd._odata[sF], - q_out._odata[sF], Umu, sU, - mu, t_mask, axial_sign); - ++sF; + unsigned int sF = sU * LLs + s; + Kernels::SeqConservedCurrentSiteBwd(tmp2._odata[sU], + q_out._odata[sF], Umu, sU, + mu, t_mask, axial_sign); } - } + } } }