diff --git a/Grid/qcd/hmc/integrators/Integrator.h b/Grid/qcd/hmc/integrators/Integrator.h index 5d875c6a..1763f504 100644 --- a/Grid/qcd/hmc/integrators/Integrator.h +++ b/Grid/qcd/hmc/integrators/Integrator.h @@ -66,6 +66,7 @@ public: template class Integrator { protected: + typedef typename FieldImplementation::Field MomentaField; //for readability typedef typename FieldImplementation::Field Field; @@ -118,6 +119,7 @@ protected: } } update_P_hireps{}; + void update_P(MomentaField& Mom, Field& U, int level, double ep) { // input U actually not used in the fundamental case // Fundamental updates, include smearing @@ -129,7 +131,9 @@ protected: Field& Us = Smearer.get_U(as[level].actions.at(a)->is_smeared); double start_force = usecond(); + as[level].actions.at(a)->deriv_timer_start(); as[level].actions.at(a)->deriv(Us, force); // deriv should NOT include Ta + as[level].actions.at(a)->deriv_timer_stop(); std::cout << GridLogIntegrator << "Smearing (on/off): " << as[level].actions.at(a)->is_smeared << std::endl; auto name = as[level].actions.at(a)->action_name(); @@ -145,6 +149,8 @@ protected: Real force_max = std::sqrt(maxLocalNorm2(force)); Real impulse_max = force_max * ep * HMC_MOMENTUM_DENOMINATOR; + as[level].actions.at(a)->deriv_log(force_abs,force_max); + std::cout << GridLogIntegrator<< "["<(force,Nd-1); DumpSliceNorm("force_t",pol); pol=Zero(); PokeIndex(force,pol,Nd-1); DumpSliceNorm("force_xyz",force); - + */ } // Force from the other representations @@ -226,6 +233,66 @@ public: const MomentaField & getMomentum() const{ return P; } + void reset_timer(void) + { + for (int level = 0; level < as.size(); ++level) { + for (int actionID = 0; actionID < as[level].actions.size(); ++actionID) { + as[level].actions.at(actionID)->reset_timer(); + } + } + } + void print_timer(void) + { + std::cout << GridLogMessage << ":::::::::::::::::::::::::::::::::::::::::" << std::endl; + std::cout << GridLogMessage << " Refresh cumulative timings "<action_name() + <<"["<refresh_us*1.0e-6<<" s"<< std::endl; + } + } + std::cout << GridLogMessage << "--------------------------- "<action_name() + <<"["<S_us*1.0e-6<<" s"<< std::endl; + } + } + std::cout << GridLogMessage << "--------------------------- "<action_name() + <<"["<deriv_us*1.0e-6<<" s"<< std::endl; + } + } + std::cout << GridLogMessage << "--------------------------- "<action_name() + <<"["<deriv_max_average() + <<" norm " << as[level].actions.at(actionID)->deriv_norm_average() + <<" calls " << as[level].actions.at(actionID)->deriv_num + << std::endl; + } + } + std::cout << GridLogMessage << ":::::::::::::::::::::::::::::::::::::::::"<< std::endl; + } + void print_parameters() { std::cout << GridLogMessage << "[Integrator] Name : "<< integrator_name() << std::endl; @@ -244,7 +311,6 @@ public: } } std::cout << GridLogMessage << ":::::::::::::::::::::::::::::::::::::::::"<< std::endl; - } void reverse_momenta() @@ -288,7 +354,9 @@ public: // get gauge field from the SmearingPolicy and // based on the boolean is_smeared in actionID Field& Us = Smearer.get_U(as[level].actions.at(actionID)->is_smeared); + as[level].actions.at(actionID)->refresh_timer_start(); as[level].actions.at(actionID)->refresh(Us, sRNG, pRNG); + as[level].actions.at(actionID)->refresh_timer_stop(); } // Refresh the higher representation actions @@ -330,7 +398,9 @@ public: // based on the boolean is_smeared in actionID Field& Us = Smearer.get_U(as[level].actions.at(actionID)->is_smeared); std::cout << GridLogMessage << "S [" << level << "][" << actionID << "] action eval " << std::endl; + as[level].actions.at(actionID)->S_timer_start(); Hterm = as[level].actions.at(actionID)->S(Us); + as[level].actions.at(actionID)->S_timer_stop(); std::cout << GridLogMessage << "S [" << level << "][" << actionID << "] H = " << Hterm << std::endl; H += Hterm; }