1
0
mirror of https://github.com/aportelli/LatAnalyze.git synced 2024-11-14 01:45:35 +00:00

code cleaning + derivative support for NLopt

This commit is contained in:
Antonin Portelli 2016-04-04 14:56:21 +01:00
parent 2de5a9440b
commit f82b20dc73
8 changed files with 58 additions and 40 deletions

View File

@ -31,11 +31,12 @@ using namespace Math;
// constructor ///////////////////////////////////////////////////////////////// // constructor /////////////////////////////////////////////////////////////////
Derivative::Derivative(const DoubleFunction &f, const Index dir, Derivative::Derivative(const DoubleFunction &f, const Index dir,
const double step) const double step)
: f_(f) : buffer_(new DVec(f.getNArg()))
, dir_(dir) {
, step_(step) setFunction(f);
, buffer_(new DVec(f.getNArg())) setDir(dir);
{} setStep(step);
}
Derivative::Derivative(const DoubleFunction &f, const Index dir, Derivative::Derivative(const DoubleFunction &f, const Index dir,
const Index order, const DVec &point, const double step) const Index order, const DVec &point, const double step)
@ -45,6 +46,11 @@ Derivative::Derivative(const DoubleFunction &f, const Index dir,
} }
// access ////////////////////////////////////////////////////////////////////// // access //////////////////////////////////////////////////////////////////////
Index Derivative::getDir(void) const
{
return dir_;
}
Index Derivative::getOrder(void) const Index Derivative::getOrder(void) const
{ {
return order_; return order_;
@ -60,6 +66,11 @@ double Derivative::getStep(void) const
return step_; return step_;
} }
void Derivative::setDir(const Index dir)
{
dir_ = dir;
}
void Derivative::setFunction(const DoubleFunction &f) void Derivative::setFunction(const DoubleFunction &f)
{ {
f_ = f; f_ = f;

View File

@ -39,9 +39,11 @@ public:
// destructor // destructor
virtual ~Derivative(void) = default; virtual ~Derivative(void) = default;
// access // access
Index getDir(void) const;
Index getNPoint(void) const; Index getNPoint(void) const;
Index getOrder(void) const; Index getOrder(void) const;
double getStep(void) const; double getStep(void) const;
void setDir(const Index dir);
void setFunction(const DoubleFunction &f); void setFunction(const DoubleFunction &f);
void setOrderAndPoint(const Index order, const DVec &point); void setOrderAndPoint(const Index order, const DVec &point);
void setStep(const double step); void setStep(const double step);
@ -73,16 +75,13 @@ public:
static const Index defaultPrecOrder = 2; static const Index defaultPrecOrder = 2;
public: public:
// constructor // constructor
CentralDerivative(const DoubleFunction &f, const Index dir = 0, CentralDerivative(const DoubleFunction &f = DoubleFunction(),
const Index dir = 0,
const Index order = 1, const Index order = 1,
const Index precOrder = defaultPrecOrder); const Index precOrder = defaultPrecOrder);
// destructor // destructor
virtual ~CentralDerivative(void) = default; virtual ~CentralDerivative(void) = default;
// access // access
using Derivative::getNPoint;
using Derivative::getStep;
using Derivative::getOrder;
using Derivative::setStep;
Index getPrecOrder(void) const; Index getPrecOrder(void) const;
void setOrder(const Index order, const Index precOrder = defaultPrecOrder); void setOrder(const Index order, const Index precOrder = defaultPrecOrder);
// function call // function call

View File

@ -23,12 +23,6 @@
using namespace std; using namespace std;
using namespace Latan; using namespace Latan;
// constructor /////////////////////////////////////////////////////////////////
Minimizer::Minimizer(const Index dim)
{
resize(dim);
}
// access ////////////////////////////////////////////////////////////////////// // access //////////////////////////////////////////////////////////////////////
void Minimizer::resize(const Index dim) void Minimizer::resize(const Index dim)
{ {

View File

@ -36,7 +36,6 @@ class Minimizer: public Solver
public: public:
// constructor // constructor
Minimizer(void) = default; Minimizer(void) = default;
explicit Minimizer(const Index dim);
// destructor // destructor
virtual ~Minimizer(void) = default; virtual ~Minimizer(void) = default;
// access // access

View File

@ -36,12 +36,6 @@ MinuitMinimizer::MinuitMinimizer(const Algorithm algorithm)
setAlgorithm(algorithm); setAlgorithm(algorithm);
} }
MinuitMinimizer::MinuitMinimizer(const Index dim, const Algorithm algorithm)
: Minimizer(dim)
{
setAlgorithm(algorithm);
}
// access ////////////////////////////////////////////////////////////////////// // access //////////////////////////////////////////////////////////////////////
MinuitMinimizer::Algorithm MinuitMinimizer::getAlgorithm(void) const MinuitMinimizer::Algorithm MinuitMinimizer::getAlgorithm(void) const
{ {

View File

@ -41,9 +41,7 @@ public:
}; };
public: public:
// constructor // constructor
MinuitMinimizer(const Algorithm algorithm = defaultAlg_); explicit MinuitMinimizer(const Algorithm algorithm = defaultAlg_);
explicit MinuitMinimizer(const Index dim,
const Algorithm algorithm = defaultAlg_);
// destructor // destructor
virtual ~MinuitMinimizer(void) = default; virtual ~MinuitMinimizer(void) = default;
// access // access

View File

@ -30,12 +30,7 @@ using namespace Latan;
NloptMinimizer::NloptMinimizer(const Algorithm algorithm) NloptMinimizer::NloptMinimizer(const Algorithm algorithm)
{ {
setAlgorithm(algorithm); setAlgorithm(algorithm);
} der_.setOrder(1, 1);
NloptMinimizer::NloptMinimizer(const Index dim, const Algorithm algorithm)
: Minimizer(dim)
{
setAlgorithm(algorithm);
} }
// access ////////////////////////////////////////////////////////////////////// // access //////////////////////////////////////////////////////////////////////
@ -67,7 +62,10 @@ const DVec & NloptMinimizer::operator()(const DoubleFunction &f)
min.set_maxeval(getMaxIteration()); min.set_maxeval(getMaxIteration());
min.set_xtol_rel(getPrecision()); min.set_xtol_rel(getPrecision());
min.set_ftol_rel(-1.);
der_.setFunction(f);
data.f = &f; data.f = &f;
data.d = &der_;
min.set_min_objective(&funcWrapper, &data); min.set_min_objective(&funcWrapper, &data);
for (Index i = 0; i < x.size(); ++i) for (Index i = 0; i < x.size(); ++i)
{ {
@ -126,13 +124,12 @@ const DVec & NloptMinimizer::operator()(const DoubleFunction &f)
x(i) = vx[i]; x(i) = vx[i];
} }
n++; n++;
} while ((status != nlopt::XTOL_REACHED) and (status != nlopt::SUCCESS) } while (!minSuccess(status) and (n < getMaxPass()));
and (n < getMaxPass()));
if (getVerbosity() >= Verbosity::Normal) if (getVerbosity() >= Verbosity::Normal)
{ {
cout << "=================================================" << endl; cout << "=================================================" << endl;
} }
if ((status != nlopt::XTOL_REACHED) and (status != nlopt::SUCCESS)) if (!minSuccess(status))
{ {
LATAN_WARNING("invalid minimum: " + returnMessage(status)); LATAN_WARNING("invalid minimum: " + returnMessage(status));
} }
@ -163,13 +160,37 @@ string NloptMinimizer::returnMessage(const nlopt::result status)
} }
// NLopt function wrapper ////////////////////////////////////////////////////// // NLopt function wrapper //////////////////////////////////////////////////////
double NloptMinimizer::funcWrapper(unsigned int n __dumb, const double *arg, double NloptMinimizer::funcWrapper(unsigned int n, const double *arg,
double *grad , void *vdata) double *grad , void *vdata)
{ {
NloptFuncData &data = *static_cast<NloptFuncData *>(vdata); NloptFuncData &data = *static_cast<NloptFuncData *>(vdata);
assert(grad == nullptr); if (grad)
{
for (unsigned int i = 0; i < n; ++i)
{
data.d->setDir(i);
grad[i] = (*(data.d))(arg);
}
data.evalCount += data.d->getNPoint()*n;
}
data.evalCount++; data.evalCount++;
return (*data.f)(arg); return (*data.f)(arg);
} }
// NLopt return status parser //////////////////////////////////////////////////
bool NloptMinimizer::minSuccess(const nlopt::result status)
{
switch (status)
{
case nlopt::SUCCESS:
case nlopt::FTOL_REACHED:
case nlopt::XTOL_REACHED:
return true;
break;
default:
return false;
break;
}
}

View File

@ -43,13 +43,12 @@ private:
struct NloptFuncData struct NloptFuncData
{ {
const DoubleFunction *f{nullptr}; const DoubleFunction *f{nullptr};
Derivative *d{nullptr};
unsigned int evalCount{0}; unsigned int evalCount{0};
}; };
public: public:
// constructor // constructor
NloptMinimizer(const Algorithm algorithm = defaultAlg_); explicit NloptMinimizer(const Algorithm algorithm = defaultAlg_);
explicit NloptMinimizer(const Index dim,
const Algorithm algorithm = defaultAlg_);
// destructor // destructor
virtual ~NloptMinimizer(void) = default; virtual ~NloptMinimizer(void) = default;
// access // access
@ -63,9 +62,12 @@ private:
// NLopt function wrapper // NLopt function wrapper
static double funcWrapper(unsigned int n, const double *arg, static double funcWrapper(unsigned int n, const double *arg,
double *grad , void *vdata); double *grad , void *vdata);
// NLopt return status parser
static bool minSuccess(const nlopt::result status);
private: private:
Algorithm algorithm_; Algorithm algorithm_;
static constexpr Algorithm defaultAlg_ = Algorithm::LN_NELDERMEAD; static constexpr Algorithm defaultAlg_ = Algorithm::LN_NELDERMEAD;
CentralDerivative der_;
}; };
END_LATAN_NAMESPACE END_LATAN_NAMESPACE