diff --git a/extras/Hadrons/AllToAllVectors.hpp b/extras/Hadrons/AllToAllVectors.hpp index de97051f..516a425f 100644 --- a/extras/Hadrons/AllToAllVectors.hpp +++ b/extras/Hadrons/AllToAllVectors.hpp @@ -13,9 +13,79 @@ BEGIN_HADRONS_NAMESPACE template class A2AModesSchurDiagTwo { + private: + const std::vector *evec; + const std::vector *eval; + Matrix &action; + std::function &Solver; + const int Nl, Nh; + const bool return_5d; + std::vector w_high_5d, v_high_5d, w_high_4d, v_high_4d; + public: - A2AModesSchurDiagTwo(void) = default; - virtual ~A2AModesSchurDiagTwo(void) = default; + A2AModesSchurDiagTwo(const std::vector *_evec, const std::vector *_eval, + Matrix &_action, + std::function &_Solver, + const int _Nl, const int _Nh, + const bool _return_5d) + : evec(_evec), eval(_eval), + action(_action), + Solver(_Solver), + Nl(_Nl), Nh(_Nh), + return_5d(_return_5d) + { + init_resize(1, Nh); + if (return_5d) init_resize(Nh, Nh); + }; + + void init_resize(const size_t size_5d, const size_t size_4d) + { + GridBase *grid_5d = action.Grid(); + GridBase *grid_4d = action.GaugeGrid(); + + w_high_5d.resize(size_5d, grid_5d); + v_high_5d.resize(size_5d, grid_5d); + + w_high_4d.resize(size_4d, grid_4d); + v_high_4d.resize(size_4d, grid_4d); + } + + void high_modes(Field &source_5d, Field &source_4d, int i) + { + int i5d; + LOG(Message) << "A2A high modes for i = " << i << std::endl; + i5d = 0; + if (return_5d) i5d = i; + this->high_mode_v(action, Solver, source_5d, v_high_5d[i5d], v_high_4d[i]); + this->high_mode_w(source_5d, source_4d, w_high_5d[i5d], w_high_4d[i]); + } + + void return_v(int i, Field &vout_5d, Field &vout_4d) + { + if (i < Nl) + { + this->low_mode_v(action, evec->at(i), eval->at(i), vout_5d, vout_4d); + } + else + { + vout_4d = v_high_4d[i - Nl]; + if (!(return_5d)) i = Nl; + vout_5d = v_high_5d[i - Nl]; + } + } + void return_w(int i, Field &wout_5d, Field &wout_4d) + { + if (i < Nl) + { + this->low_mode_w(action, evec->at(i), eval->at(i), wout_5d, wout_5d); + } + else + { + wout_4d = w_high_4d[i - Nl]; + if (!(return_5d)) i = Nl; + wout_5d = w_high_5d[i - Nl]; + } + } void Doo(Matrix &action, const Field &in, Field &out) { @@ -26,10 +96,10 @@ class A2AModesSchurDiagTwo action.MooeeInv(tmp, out); action.Meooe(out, tmp); - axpy(out, -1.0, tmp, in); + axpy(out,-1.0, tmp, in); } - void low_mode_v(Matrix &action, const Field &evec, const RealD &eval, Field &vout, bool return_5d = true) + void low_mode_v(Matrix &action, const Field &evec, const RealD &eval, Field &vout_5d, Field &vout_4d) { GridBase *grid = action.RedBlackGrid(); @@ -38,13 +108,10 @@ class A2AModesSchurDiagTwo Field sol_o(grid); Field tmp(grid); - GridBase *fgrid = action.Grid(); - Field tmp_out(fgrid); - src_o = evec; src_o.checkerboard = Odd; - pickCheckerboard(Even, sol_e, tmp_out); - pickCheckerboard(Odd, sol_o, tmp_out); + pickCheckerboard(Even, sol_e, vout_5d); + pickCheckerboard(Odd, sol_o, vout_5d); ///////////////////////////////////////////////////// // v_ie = -(1/eval_i) * MeeInv Meo MooInv evec_i @@ -66,15 +133,15 @@ class A2AModesSchurDiagTwo sol_o = (1.0 / eval) * tmp; assert(sol_o.checkerboard == Odd); - setCheckerboard(tmp_out, sol_e); + setCheckerboard(vout_5d, sol_e); assert(sol_e.checkerboard == Even); - setCheckerboard(tmp_out, sol_o); + setCheckerboard(vout_5d, sol_o); assert(sol_o.checkerboard == Odd); - this->return_dim(action, tmp_out, vout, return_5d); + action.ExportPhysicalFermionSolution(vout_5d, vout_4d); } - void low_mode_w(Matrix &action, const Field &evec, const RealD &eval, Field &wout, bool return_5d = true) + void low_mode_w(Matrix &action, const Field &evec, const RealD &eval, Field &wout_5d, Field &wout_4d) { GridBase *grid = action.RedBlackGrid(); SchurDiagTwoOperator _HermOpEO(action); @@ -85,7 +152,6 @@ class A2AModesSchurDiagTwo Field tmp(grid); GridBase *fgrid = action.Grid(); - Field tmp_out(fgrid); Field tmp_wout(fgrid); src_o = evec; @@ -115,41 +181,24 @@ class A2AModesSchurDiagTwo setCheckerboard(tmp_wout, sol_o); assert(sol_o.checkerboard == Odd); - action.DminusDag(tmp_wout, tmp_out); - this->return_dim(action, tmp_out, wout, return_5d); + action.DminusDag(tmp_wout, wout_5d); + action.ExportPhysicalFermionSolution(wout_5d, wout_4d); } - void high_mode_v(Matrix &action, std::function &Solver, const Field &source, Field &vout, bool return_5d = true) + void high_mode_v(Matrix &action, std::function &Solver, const Field &source, Field &vout_5d, Field &vout_4d) { GridBase *fgrid = action.Grid(); Field tmp(fgrid); - Field tmp_out(fgrid); action.Dminus(source, tmp); - Solver(tmp_out, source); // Note: Solver is Solver(out, in) - this->return_dim(action, tmp_out, vout, return_5d); + Solver(vout_5d, source); // Note: Solver is Solver(out, in) + action.ExportPhysicalFermionSolution(vout_5d, vout_4d); } - void high_mode_w(Matrix &action, const Field &source4d, Field &wout, bool return_5d = true) + void high_mode_w(const Field &source_5d, const Field &source_4d, Field &wout_5d, Field &wout_4d) { - // GridBase *fgrid = action.Grid(); - // Field tmp_out(fgrid); - - // tmp_out = source; - // this->return_dim(action, tmp_out, wout, return_5d); - wout = source4d; - } - - void return_dim(Matrix &action, const Field &in, Field &out, bool return_5d) - { - if (return_5d) - { - out = in; - } - else - { - action.ExportPhysicalFermionSolution(in, out); - } + wout_5d = source_5d; + wout_4d = source_4d; } }; @@ -211,75 +260,6 @@ class A2AHMSchurDiagTwo : virtual public A2AModesSchurDiagTwo } }; -//////////////////////////////// -// Both Modes -//////////////////////////////// - -template -class A2AVectorsReturn : public A2AModesSchurDiagTwo -{ - private: - const std::vector *evec; - const std::vector *eval; - Matrix &action; - std::function &Solver; - const int Nl, Nh; - const bool return_5d; - std::vector w_high, v_high; - - public: - A2AVectorsReturn(const std::vector *_evec, const std::vector *_eval, - Matrix &_action, - std::function &_Solver, - const int _Nl, const int _Nh, - const bool _return_5d) - : evec(_evec), eval(_eval), - action(_action), - Solver(_Solver), - Nl(_Nl), Nh(_Nh), - return_5d(_return_5d) - { - GridBase *grid; - if (return_5d) - { - grid = action.Grid(); - } - else - { - grid = action.GaugeGrid(); - } - resize(Nh, grid); - }; - - void resize(const size_t size, GridBase *grid) - { - w_high.resize(size, grid); - v_high.resize(size, grid); - } - - void high_modes(Field &source5d, Field &source4d, int i) - { - LOG(Message) << "A2A high modes for i = " << i << std::endl; - this->high_mode_v(action, Solver, source5d, v_high[i], return_5d); - this->high_mode_w(action, source4d, w_high[i], return_5d); - } - - void operator()(int i, Field &vout, Field &wout) - { - if (i < Nl) - { - LOG(Message) << "A2A low modes for i = " << i << std::endl; - this->low_mode_v(action, evec->at(i), eval->at(i), vout, return_5d); - this->low_mode_w(action, evec->at(i), eval->at(i), wout, return_5d); - } - else - { - vout = v_high[i - Nl]; - wout = w_high[i - Nl]; - } - } -}; - END_HADRONS_NAMESPACE #endif // A2A_Vectors_hpp_ \ No newline at end of file