1
0
mirror of https://github.com/aportelli/LatAnalyze.git synced 2025-07-14 17:17:05 +01:00

first cmake draft, source relocation, not working

This commit is contained in:
2024-01-28 22:13:07 -03:00
parent 0b5da3866e
commit 0b5c6e851c
103 changed files with 177 additions and 8 deletions

View File

@ -0,0 +1,123 @@
/*
* Eigen.hpp, part of LatAnalyze 3
*
* Copyright (C) 2013 - 2020 Antonin Portelli
*
* LatAnalyze 3 is free software: you can redistribute it and/or modify
* it under the terms of the GNU General Public License as published by
* the Free Software Foundation, either version 3 of the License, or
* (at your option) any later version.
*
* LatAnalyze 3 is distributed in the hope that it will be useful,
* but WITHOUT ANY WARRANTY; without even the implied warranty of
* MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
* GNU General Public License for more details.
*
* You should have received a copy of the GNU General Public License
* along with LatAnalyze 3. If not, see <http://www.gnu.org/licenses/>.
*/
// Eigen inclusion
#define EIGEN_DONT_PARALLELIZE
#define EIGEN_MATRIXBASE_PLUGIN <LatAnalyze/Core/EigenPlugin.hpp>
#include <LatAnalyze/Eigen/Dense>
// copy/assignement from Eigen expression
#define EIGEN_EXPR_CTOR(ctorName, Class, Base, ExprType) \
template <typename Derived>\
ctorName(const ExprType<Derived> &m): Base(m) {}\
template<typename Derived>\
Class & operator=(const ExprType<Derived> &m)\
{\
this->Base::operator=(m);\
return *this;\
}
#define FOR_MAT(mat, i, j) \
for (Latan::Index j = 0; j < mat.cols(); ++j)\
for (Latan::Index i = 0; i < mat.rows(); ++i)
BEGIN_LATAN_NAMESPACE
const int dynamic = Eigen::Dynamic;
// array types
template <typename Derived>
using ArrayExpr = Eigen::ArrayBase<Derived>;
template <typename T, int nRow = dynamic, int nCol = dynamic>
using Array = Eigen::Array<T, nRow, nCol>;
// matrix types
template <typename Derived>
using MatExpr = Eigen::MatrixBase<Derived>;
template <typename T, int nRow = dynamic, int nCol = dynamic>
using MatBase = Eigen::Matrix<T, nRow, nCol>;
template <int nRow, int nCol>
using SFMat = Eigen::Matrix<float, nRow, nCol>;
template <int nRow, int nCol>
using SDMat = Eigen::Matrix<double, nRow, nCol>;
template <int nRow, int nCol>
using SCMat = Eigen::Matrix<std::complex<double>, nRow, nCol>;
// vector types
template <typename T, int size = dynamic>
using Vec = MatBase<T, size, 1>;
template <int size>
using SIVec = Vec<int, size>;
template <int size>
using SUVec = Vec<unsigned int, size>;
template <int size>
using SFVec = Vec<float, size>;
template <int size>
using SDVec = Vec<double, size>;
template <int size>
using SCVec = Vec<std::complex<double>, size>;
typedef SIVec<dynamic> IVec;
typedef SUVec<dynamic> UVec;
typedef SDVec<dynamic> DVec;
typedef SCVec<dynamic> CVec;
// block types
template <typename Derived>
using Block = Eigen::Block<Derived>;
template <typename Derived>
using ConstBlock = const Eigen::Block<const Derived>;
template <typename Derived>
using Row = typename Derived::RowXpr;
template <typename Derived>
using ConstRow = typename Derived::ConstRowXpr;
template <typename Derived>
using Col = typename Derived::ColXpr;
template <typename Derived>
using ConstCol = typename Derived::ConstColXpr;
// map type
template <int stride>
using InnerStride = Eigen::InnerStride<stride>;
template <int rowStride, int colStride>
using Stride = Eigen::Stride<rowStride, colStride>;
template <typename Derived, typename StrideType = Stride<0, 0>>
using Map = Eigen::Map<Derived, Eigen::Unaligned, StrideType>;
template <typename Derived, typename StrideType = Stride<0, 0>>
using ConstMap = Eigen::Map<const Derived, Eigen::Unaligned, StrideType>;
// Index type //////////////////////////////////////////////////////////////////
typedef MatBase<int>::Index Index;
#define FOR_VEC(vec, i) for (Latan::Index i = 0; i < (vec).size(); ++i)
#define FOR_ARRAY(ar, i) FOR_VEC(ar, i)
END_LATAN_NAMESPACE

View File

@ -0,0 +1,60 @@
/*
* EigenPlugin.hpp, part of LatAnalyze
*
* Copyright (C) 2015 Antonin Portelli
*
* LatAnalyze is free software: you can redistribute it and/or modify
* it under the terms of the GNU General Public License as published by
* the Free Software Foundation, either version 3 of the License, or
* (at your option) any later version.
*
* LatAnalyze is distributed in the hope that it will be useful,
* but WITHOUT ANY WARRANTY; without even the implied warranty of
* MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
* GNU General Public License for more details.
*
* You should have received a copy of the GNU General Public License
* along with LatAnalyze. If not, see <http://www.gnu.org/licenses/>.
*/
Derived pInverse(const double tolerance = 1.0e-10) const
{
auto svd = jacobiSvd(Eigen::ComputeThinU|Eigen::ComputeThinV);
const auto u = svd.matrixU();
const auto v = svd.matrixV();
auto s = svd.singularValues();
double maxsv = 0.;
unsigned int elim = 0;
for (Index i = 0; i < s.rows(); ++i)
{
if (fabs(s(i)) > maxsv) maxsv = fabs(s(i));
}
for (Index i = 0; i < s.rows(); ++i)
{
if (fabs(s(i)) > maxsv*tolerance)
{
s(i) = 1./s(i);
}
else
{
elim++;
s(i) = 0.;
}
}
if (elim)
{
std::cerr << "warning: pseudoinverse: " << elim << "/";
std::cerr << s.rows() << " singular value(s) eliminated (tolerance= ";
std::cerr << tolerance << ")" << std::endl;
}
return v*s.asDiagonal()*u.transpose();
}
Derived singularValues(void) const
{
auto svd = jacobiSvd();
return svd.singularValues();
}

View File

@ -0,0 +1,51 @@
/*
* Exceptions.cpp, part of LatAnalyze 3
*
* Copyright (C) 2013 - 2020 Antonin Portelli
*
* LatAnalyze 3 is free software: you can redistribute it and/or modify
* it under the terms of the GNU General Public License as published by
* the Free Software Foundation, either version 3 of the License, or
* (at your option) any later version.
*
* LatAnalyze 3 is distributed in the hope that it will be useful,
* but WITHOUT ANY WARRANTY; without even the implied warranty of
* MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
* GNU General Public License for more details.
*
* You should have received a copy of the GNU General Public License
* along with LatAnalyze 3. If not, see <http://www.gnu.org/licenses/>.
*/
#include <LatAnalyze/Core/Exceptions.hpp>
#include <LatAnalyze/includes.hpp>
#ifndef ERR_SUFF
#define ERR_SUFF " (" + loc + ")"
#endif
#define CONST_EXC(name, init) \
name::name(string msg, string loc)\
:init\
{}
using namespace std;
using namespace Latan;
using namespace Exceptions;
// logic errors
CONST_EXC(Logic, logic_error(Env::msgPrefix + msg + ERR_SUFF))
CONST_EXC(Definition, Logic("definition error: " + msg, loc))
CONST_EXC(Implementation, Logic("implementation error: " + msg, loc))
CONST_EXC(Range, Logic("range error: " + msg, loc))
CONST_EXC(Size, Logic("size error: " + msg, loc))
// runtime errors
CONST_EXC(Runtime, runtime_error(Env::msgPrefix + msg + ERR_SUFF))
CONST_EXC(Argument, Runtime("argument error: " + msg, loc))
CONST_EXC(Compilation, Runtime("compilation error: " + msg, loc))
CONST_EXC(Io, Runtime("IO error: " + msg, loc))
CONST_EXC(Memory, Runtime("memory error: " + msg, loc))
CONST_EXC(Parsing, Runtime(msg, loc))
CONST_EXC(Program, Runtime(msg, loc))
CONST_EXC(Syntax, Runtime("syntax error: " + msg, loc))
CONST_EXC(System, Runtime("system error: " + msg, loc))

View File

@ -0,0 +1,66 @@
/*
* Exceptions.hpp, part of LatAnalyze 3
*
* Copyright (C) 2013 - 2020 Antonin Portelli
*
* LatAnalyze 3 is free software: you can redistribute it and/or modify
* it under the terms of the GNU General Public License as published by
* the Free Software Foundation, either version 3 of the License, or
* (at your option) any later version.
*
* LatAnalyze 3 is distributed in the hope that it will be useful,
* but WITHOUT ANY WARRANTY; without even the implied warranty of
* MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
* GNU General Public License for more details.
*
* You should have received a copy of the GNU General Public License
* along with LatAnalyze 3. If not, see <http://www.gnu.org/licenses/>.
*/
#ifndef Latan_Exceptions_hpp_
#define Latan_Exceptions_hpp_
#include <stdexcept>
#ifndef LATAN_GLOBAL_HPP_
#include <LatAnalyze/Global.hpp>
#endif
#define SRC_LOC strFrom(__FUNCTION__) + " at " + strFrom(__FILE__) + ":"\
+ strFrom(__LINE__)
#define LATAN_ERROR(exc,msg) throw(Exceptions::exc(msg, SRC_LOC))
#define LATAN_WARNING(msg) \
std::cerr << Env::msgPrefix << "warning: " << msg\
<< " (" << SRC_LOC << ")" << std::endl
#define DECL_EXC(name, base) \
class name: public base\
{\
public:\
name(std::string msg, std::string loc);\
}
BEGIN_LATAN_NAMESPACE
namespace Exceptions
{
// logic errors
DECL_EXC(Logic, std::logic_error);
DECL_EXC(Definition, Logic);
DECL_EXC(Implementation, Logic);
DECL_EXC(Range, Logic);
DECL_EXC(Size, Logic);
// runtime errors
DECL_EXC(Runtime, std::runtime_error);
DECL_EXC(Argument, Runtime);
DECL_EXC(Compilation, Runtime);
DECL_EXC(Io, Runtime);
DECL_EXC(Memory, Runtime);
DECL_EXC(Parsing, Runtime);
DECL_EXC(Program, Runtime);
DECL_EXC(Syntax, Runtime);
DECL_EXC(System, Runtime);
}
END_LATAN_NAMESPACE
#endif // Latan_Exceptions_hpp_

View File

@ -0,0 +1,36 @@
/*
* Mat.cpp, part of LatAnalyze 3
*
* Copyright (C) 2013 - 2020 Antonin Portelli
*
* LatAnalyze 3 is free software: you can redistribute it and/or modify
* it under the terms of the GNU General Public License as published by
* the Free Software Foundation, either version 3 of the License, or
* (at your option) any later version.
*
* LatAnalyze 3 is distributed in the hope that it will be useful,
* but WITHOUT ANY WARRANTY; without even the implied warranty of
* MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
* GNU General Public License for more details.
*
* You should have received a copy of the GNU General Public License
* along with LatAnalyze 3. If not, see <http://www.gnu.org/licenses/>.
*/
#include <LatAnalyze/Core/Mat.hpp>
#include <LatAnalyze/includes.hpp>
using namespace std;
/******************************************************************************
* DMat implementation *
******************************************************************************/
// IO //////////////////////////////////////////////////////////////////////////
namespace Latan
{
template <>
IoObject::IoType Mat<double>::getType(void) const
{
return IoType::dMat;
}
}

View File

@ -0,0 +1,69 @@
/*
* Mat.hpp, part of LatAnalyze 3
*
* Copyright (C) 2013 - 2020 Antonin Portelli
*
* LatAnalyze 3 is free software: you can redistribute it and/or modify
* it under the terms of the GNU General Public License as published by
* the Free Software Foundation, either version 3 of the License, or
* (at your option) any later version.
*
* LatAnalyze 3 is distributed in the hope that it will be useful,
* but WITHOUT ANY WARRANTY; without even the implied warranty of
* MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
* GNU General Public License for more details.
*
* You should have received a copy of the GNU General Public License
* along with LatAnalyze 3. If not, see <http://www.gnu.org/licenses/>.
*/
#ifndef Latan_Mat_hpp_
#define Latan_Mat_hpp_
#include <LatAnalyze/Global.hpp>
#include <LatAnalyze/Io/IoObject.hpp>
BEGIN_LATAN_NAMESPACE
/******************************************************************************
* matrix type *
******************************************************************************/
template <typename T>
class Mat: public MatBase<T>, public IoObject
{
public:
// constructors
Mat(void) = default;
Mat(const Index nRow, const Index nCol);
EIGEN_EXPR_CTOR(Mat, Mat<T>, MatBase<T>, MatExpr)
// destructor
virtual ~Mat(void) = default;
// IO
virtual IoType getType(void) const;
};
// type aliases
typedef Mat<int> IMat;
typedef Mat<long int> LMat;
typedef Mat<double> DMat;
typedef Mat<std::complex<double>> CMat;
/******************************************************************************
* Mat template implementation *
******************************************************************************/
// constructors ////////////////////////////////////////////////////////////////
template <typename T>
Mat<T>::Mat(const Index nRow, const Index nCol)
: MatBase<T>(nRow, nCol)
{}
// IO //////////////////////////////////////////////////////////////////////////
template <typename T>
IoObject::IoType Mat<T>::getType(void) const
{
return IoType::noType;
}
END_LATAN_NAMESPACE
#endif // Latan_Mat_hpp_

View File

@ -0,0 +1,170 @@
/*
* Math.cpp, part of LatAnalyze 3
*
* Copyright (C) 2013 - 2020 Antonin Portelli
*
* LatAnalyze 3 is free software: you can redistribute it and/or modify
* it under the terms of the GNU General Public License as published by
* the Free Software Foundation, either version 3 of the License, or
* (at your option) any later version.
*
* LatAnalyze 3 is distributed in the hope that it will be useful,
* but WITHOUT ANY WARRANTY; without even the implied warranty of
* MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
* GNU General Public License for more details.
*
* You should have received a copy of the GNU General Public License
* along with LatAnalyze 3. If not, see <http://www.gnu.org/licenses/>.
*/
#include <LatAnalyze/Core/Math.hpp>
#include <LatAnalyze/Numerical/GslFFT.hpp>
#include <LatAnalyze/includes.hpp>
#include <gsl/gsl_cdf.h>
using namespace std;
using namespace Latan;
/******************************************************************************
* Custom math functions *
******************************************************************************/
DMat MATH_NAMESPACE::varToCorr(const DMat &var)
{
DMat res = var;
DVec invDiag = res.diagonal();
invDiag = invDiag.cwiseInverse().cwiseSqrt();
res = (invDiag*invDiag.transpose()).cwiseProduct(res);
return res;
}
DMat MATH_NAMESPACE::corrToVar(const DMat &corr, const DVec &varDiag)
{
DMat res = corr;
DVec varSqrtDiag = varDiag.cwiseSqrt();
res = (varSqrtDiag*varSqrtDiag.transpose()).cwiseProduct(res);
return res;
}
double MATH_NAMESPACE::conditionNumber(const DMat &mat)
{
DVec s = mat.singularValues();
return s.maxCoeff()/s.minCoeff();
}
double MATH_NAMESPACE::cdr(const DMat &mat)
{
return 10.*log10(conditionNumber(mat));
}
template <typename FFT>
double nsdr(const DMat &m)
{
Index n = m.rows();
FFT fft(n);
CMat buf(n, 1);
FOR_VEC(buf, i)
{
buf(i) = 0.;
for (Index j = 0; j < n; ++j)
{
buf(i) += m(j, (i+j) % n);
}
buf(i) /= n;
}
fft(buf, FFT::Forward);
return 10.*log10(buf.real().maxCoeff()/buf.real().minCoeff());
}
double MATH_NAMESPACE::nsdr(const DMat &mat)
{
return ::nsdr<GslFFT>(mat);
}
/******************************************************************************
* Standard C functions *
******************************************************************************/
#define DEF_STD_FUNC_1ARG(name) \
auto name##VecFunc = [](const double arg[1]){return (name)(arg[0]);};\
DoubleFunction STDMATH_NAMESPACE::name(name##VecFunc, 1);
#define DEF_STD_FUNC_2ARG(name) \
auto name##VecFunc = [](const double arg[2]){return (name)(arg[0], arg[1]);};\
DoubleFunction STDMATH_NAMESPACE::name(name##VecFunc, 2);
// Trigonometric functions
DEF_STD_FUNC_1ARG(cos)
DEF_STD_FUNC_1ARG(sin)
DEF_STD_FUNC_1ARG(tan)
DEF_STD_FUNC_1ARG(acos)
DEF_STD_FUNC_1ARG(asin)
DEF_STD_FUNC_1ARG(atan)
DEF_STD_FUNC_2ARG(atan2)
// Hyperbolic functions
DEF_STD_FUNC_1ARG(cosh)
DEF_STD_FUNC_1ARG(sinh)
DEF_STD_FUNC_1ARG(tanh)
DEF_STD_FUNC_1ARG(acosh)
DEF_STD_FUNC_1ARG(asinh)
DEF_STD_FUNC_1ARG(atanh)
// Exponential and logarithmic functions
DEF_STD_FUNC_1ARG(exp)
DEF_STD_FUNC_1ARG(log)
DEF_STD_FUNC_1ARG(log10)
DEF_STD_FUNC_1ARG(exp2)
DEF_STD_FUNC_1ARG(expm1)
DEF_STD_FUNC_1ARG(log1p)
DEF_STD_FUNC_1ARG(log2)
// Power functions
DEF_STD_FUNC_2ARG(pow)
DEF_STD_FUNC_1ARG(sqrt)
DEF_STD_FUNC_1ARG(cbrt)
DEF_STD_FUNC_2ARG(hypot)
// Error and gamma functions
DEF_STD_FUNC_1ARG(erf)
DEF_STD_FUNC_1ARG(erfc)
DEF_STD_FUNC_1ARG(tgamma)
DEF_STD_FUNC_1ARG(lgamma)
// Rounding and remainder functions
DEF_STD_FUNC_1ARG(ceil)
DEF_STD_FUNC_1ARG(floor)
DEF_STD_FUNC_2ARG(fmod)
DEF_STD_FUNC_1ARG(trunc)
DEF_STD_FUNC_1ARG(round)
DEF_STD_FUNC_1ARG(rint)
DEF_STD_FUNC_1ARG(nearbyint)
DEF_STD_FUNC_2ARG(remainder)
// Minimum, maximum, difference functions
DEF_STD_FUNC_2ARG(fdim)
DEF_STD_FUNC_2ARG(fmax)
DEF_STD_FUNC_2ARG(fmin)
// Absolute value
DEF_STD_FUNC_1ARG(fabs)
// p-value
auto chi2PValueVecFunc = [](const double arg[2])
{
return 2.*min(gsl_cdf_chisq_P(arg[0], arg[1]), gsl_cdf_chisq_Q(arg[0], arg[1]));
};
auto chi2CcdfVecFunc = [](const double arg[2])
{
return gsl_cdf_chisq_Q(arg[0], arg[1]);
};
DoubleFunction MATH_NAMESPACE::chi2PValue(chi2PValueVecFunc, 2);
DoubleFunction MATH_NAMESPACE::chi2Ccdf(chi2CcdfVecFunc, 2);

View File

@ -0,0 +1,167 @@
/*
* Math.hpp, part of LatAnalyze 3
*
* Copyright (C) 2013 - 2020 Antonin Portelli
*
* LatAnalyze 3 is free software: you can redistribute it and/or modify
* it under the terms of the GNU General Public License as published by
* the Free Software Foundation, either version 3 of the License, or
* (at your option) any later version.
*
* LatAnalyze 3 is distributed in the hope that it will be useful,
* but WITHOUT ANY WARRANTY; without even the implied warranty of
* MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
* GNU General Public License for more details.
*
* You should have received a copy of the GNU General Public License
* along with LatAnalyze 3. If not, see <http://www.gnu.org/licenses/>.
*/
#ifndef Latan_Math_hpp_
#define Latan_Math_hpp_
#include <LatAnalyze/Global.hpp>
#include <LatAnalyze/Functional/Function.hpp>
#include <LatAnalyze/Core/MathInterpreter.hpp>
BEGIN_LATAN_NAMESPACE
/******************************************************************************
* Custom math functions *
******************************************************************************/
#define MATH_NAMESPACE Math
namespace MATH_NAMESPACE
{
// integer power function
template <unsigned int n, typename T>
typename std::enable_if<(n == 0), T>::type pow(const T x __dumb)
{
return 1;
}
template <unsigned int n, typename T>
typename std::enable_if<(n == 1), T>::type pow(const T x)
{
return x;
}
template <unsigned int n, typename T>
typename std::enable_if<(n > 1), T>::type pow(const T x)
{
return x*pow<n-1>(x);
}
// integral factorial function
template <typename T>
T factorial(const T n)
{
static_assert(std::is_integral<T>::value,
"factorial must me used with an integral argument");
T res = n;
for (T i = n - 1; i != 0; --i)
{
res *= i;
}
return res;
}
// convert variance matrix to correlation matrix
DMat varToCorr(const DMat &var);
DMat corrToVar(const DMat &corr, const DVec &varDiag);
// matrix SVD dynamic range
double conditionNumber(const DMat &mat);
double cdr(const DMat &mat);
double nsdr(const DMat &mat);
// Constants
constexpr double pi = 3.1415926535897932384626433832795028841970;
constexpr double e = 2.7182818284590452353602874713526624977572;
constexpr double inf = std::numeric_limits<double>::infinity();
constexpr double nan = std::numeric_limits<double>::quiet_NaN();
}
/******************************************************************************
* Standard C functions *
******************************************************************************/
#define STDMATH_NAMESPACE StdMath
#define DECL_STD_FUNC(name) \
namespace STDMATH_NAMESPACE\
{\
extern DoubleFunction name;\
}
// Trigonometric functions
DECL_STD_FUNC(cos)
DECL_STD_FUNC(sin)
DECL_STD_FUNC(tan)
DECL_STD_FUNC(acos)
DECL_STD_FUNC(asin)
DECL_STD_FUNC(atan)
DECL_STD_FUNC(atan2)
// Hyperbolic functions
DECL_STD_FUNC(cosh)
DECL_STD_FUNC(sinh)
DECL_STD_FUNC(tanh)
DECL_STD_FUNC(acosh)
DECL_STD_FUNC(asinh)
DECL_STD_FUNC(atanh)
// Exponential and logarithmic functions
DECL_STD_FUNC(exp)
DECL_STD_FUNC(log)
DECL_STD_FUNC(log10)
DECL_STD_FUNC(exp2)
DECL_STD_FUNC(expm1)
DECL_STD_FUNC(log1p)
DECL_STD_FUNC(log2)
// Power functions
DECL_STD_FUNC(pow)
DECL_STD_FUNC(sqrt)
DECL_STD_FUNC(cbrt)
DECL_STD_FUNC(hypot)
// Error and gamma functions
DECL_STD_FUNC(erf)
DECL_STD_FUNC(erfc)
DECL_STD_FUNC(tgamma)
DECL_STD_FUNC(lgamma)
// Rounding and remainder functions
DECL_STD_FUNC(ceil)
DECL_STD_FUNC(floor)
DECL_STD_FUNC(fmod)
DECL_STD_FUNC(trunc)
DECL_STD_FUNC(round)
DECL_STD_FUNC(rint)
DECL_STD_FUNC(nearbyint)
DECL_STD_FUNC(remainder)
// Minimum, maximum, difference functions
DECL_STD_FUNC(fdim)
DECL_STD_FUNC(fmax)
DECL_STD_FUNC(fmin)
// Absolute value
DECL_STD_FUNC(fabs)
/******************************************************************************
* Other functions *
******************************************************************************/
// p-value
namespace MATH_NAMESPACE
{
extern DoubleFunction chi2PValue;
extern DoubleFunction chi2Ccdf;
}
END_LATAN_NAMESPACE
#endif // Latan_Math_hpp_

View File

@ -0,0 +1,745 @@
/*
* MathInterpreter.cpp, part of LatAnalyze 3
*
* Copyright (C) 2013 - 2020 Antonin Portelli
*
* LatAnalyze 3 is free software: you can redistribute it and/or modify
* it under the terms of the GNU General Public License as published by
* the Free Software Foundation, either version 3 of the License, or
* (at your option) any later version.
*
* LatAnalyze 3 is distributed in the hope that it will be useful,
* but WITHOUT ANY WARRANTY; without even the implied warranty of
* MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
* GNU General Public License for more details.
*
* You should have received a copy of the GNU General Public License
* along with LatAnalyze 3. If not, see <http://www.gnu.org/licenses/>.
*/
#include <LatAnalyze/Core/MathInterpreter.hpp>
#include <LatAnalyze/includes.hpp>
#include <LatAnalyze/Core/Math.hpp>
using namespace std;
using namespace Latan;
/******************************************************************************
* RunContext implementation *
******************************************************************************/
// access //////////////////////////////////////////////////////////////////////
unsigned int RunContext::addFunction(const string &name, DoubleFunction *init)
{
try
{
setFunction(name, init);
return getFunctionAddress(name);
}
catch (Exceptions::Definition)
{
unsigned int address = fTable_.size();
fMem_.push_back(init);
fTable_[name] = address;
return address;
}
}
unsigned int RunContext::addVariable(const string &name, double init)
{
try
{
setVariable(name, init);
return getVariableAddress(name);
}
catch (Exceptions::Definition)
{
unsigned int address = vTable_.size();
vMem_.push_back(init);
vTable_[name] = address;
return address;
}
}
DoubleFunction * RunContext::getFunction(const string &name) const
{
return getFunction(getFunctionAddress(name));
}
DoubleFunction * RunContext::getFunction(const unsigned int address) const
{
if (address >= fTable_.size())
{
LATAN_ERROR(Range, "function address " + strFrom(address)
+ " out of range");
return nullptr;
}
else
{
return fMem_[address];
}
}
unsigned int RunContext::getFunctionAddress(const string &name) const
{
try
{
return fTable_.at(name);
}
catch (out_of_range)
{
LATAN_ERROR(Definition, "undefined function '" + name + "'");
}
}
const RunContext::AddressTable & RunContext::getFunctionTable(void) const
{
return fTable_;
}
unsigned int RunContext::getInsIndex(void) const
{
return insIndex_;
}
double RunContext::getVariable(const string &name) const
{
return getVariable(getVariableAddress(name));
}
double RunContext::getVariable(const unsigned int address) const
{
if (address >= vTable_.size())
{
LATAN_ERROR(Range, "variable address " + strFrom(address)
+ " out of range");
return 0.;
}
else
{
return vMem_[address];
}
}
const RunContext::AddressTable & RunContext::getVariableTable(void) const
{
return vTable_;
}
unsigned int RunContext::getVariableAddress(const string &name) const
{
try
{
return vTable_.at(name);
}
catch (out_of_range)
{
LATAN_ERROR(Definition, "undefined variable '" + name + "'");
}
}
void RunContext::incrementInsIndex(const unsigned int inc)
{
setInsIndex(getInsIndex() + inc);
}
void RunContext::setFunction(const string &name, DoubleFunction *f)
{
setFunction(getFunctionAddress(name), f);
}
void RunContext::setFunction(const unsigned int address, DoubleFunction *f)
{
if (address >= fTable_.size())
{
LATAN_ERROR(Range, "function address " + strFrom(address)
+ " out of range");
}
else
{
fMem_[address] = f;
}
}
void RunContext::setInsIndex(const unsigned index)
{
insIndex_ = index;
}
void RunContext::setVariable(const string &name, const double value)
{
setVariable(getVariableAddress(name), value);
}
void RunContext::setVariable(const unsigned int address, const double value)
{
if (address >= vTable_.size())
{
LATAN_ERROR(Range, "variable address " + strFrom(address)
+ " out of range");
}
else
{
vMem_[address] = value;
}
}
stack<double> & RunContext::stack(void)
{
return dStack_;
}
// reset ///////////////////////////////////////////////////////////////////////
void RunContext::reset(void)
{
insIndex_ = 0;
while (!dStack_.empty())
{
dStack_.pop();
}
vMem_.clear();
fMem_.clear();
vTable_.clear();
fTable_.clear();
}
/******************************************************************************
* Instruction set *
******************************************************************************/
#define CODE_WIDTH 6
#define CODE_MOD setw(CODE_WIDTH) << left
// Instruction operator ////////////////////////////////////////////////////////
ostream &Latan::operator<<(ostream& out, const Instruction& ins)
{
ins.print(out);
return out;
}
// Push constructors ///////////////////////////////////////////////////////////
Push::Push(const double val)
: type_(ArgType::Constant)
, val_(val)
, address_(0)
, name_("")
{}
Push::Push(const unsigned int address, const string &name)
: type_(ArgType::Variable)
, val_(0.0)
, address_(address)
, name_(name)
{}
// Push execution //////////////////////////////////////////////////////////////
void Push::operator()(RunContext &context) const
{
if (type_ == ArgType::Constant)
{
context.stack().push(val_);
}
else
{
context.stack().push(context.getVariable(address_));
}
context.incrementInsIndex();
}
// Push print //////////////////////////////////////////////////////////////////
void Push::print(ostream &out) const
{
out << CODE_MOD << "push";
if (type_ == ArgType::Constant)
{
out << CODE_MOD << val_;
}
else
{
out << CODE_MOD << name_ << " @v" << address_;
}
}
// Pop constructor /////////////////////////////////////////////////////////////
Pop::Pop(const unsigned int address, const string &name)
: address_(address)
, name_(name)
{}
// Pop execution ///////////////////////////////////////////////////////////////
void Pop::operator()(RunContext &context) const
{
if (!name_.empty())
{
context.setVariable(address_, context.stack().top());
}
context.stack().pop();
context.incrementInsIndex();
}
// Pop print ///////////////////////////////////////////////////////////////////
void Pop::print(ostream &out) const
{
out << CODE_MOD << "pop" << CODE_MOD << name_ << " @v" << address_;
}
// Store constructor ///////////////////////////////////////////////////////////
Store::Store(const unsigned int address, const string &name)
: address_(address)
, name_(name)
{}
// Store execution /////////////////////////////////////////////////////////////
void Store::operator()(RunContext &context) const
{
if (!name_.empty())
{
context.setVariable(address_, context.stack().top());
}
context.incrementInsIndex();
}
// Store print /////////////////////////////////////////////////////////////////
void Store::print(ostream &out) const
{
out << CODE_MOD << "store" << CODE_MOD << name_ << " @v" << address_;
}
// Call constructor ////////////////////////////////////////////////////////////
Call::Call(const unsigned int address, const string &name)
: address_(address)
, name_(name)
{}
// Call execution //////////////////////////////////////////////////////////////
void Call::operator()(RunContext &context) const
{
context.stack().push((*context.getFunction(address_))(context.stack()));
context.incrementInsIndex();
}
// Call print //////////////////////////////////////////////////////////////////
void Call::print(ostream &out) const
{
out << CODE_MOD << "call" << CODE_MOD << name_ << " @f" << address_;
}
// Math operations /////////////////////////////////////////////////////////////
#define DEF_OP(name, nArg, exp, insName)\
void name::operator()(RunContext &context) const\
{\
double x[nArg];\
for (int i = 0; i < nArg; ++i)\
{\
x[nArg-1-i] = context.stack().top();\
context.stack().pop();\
}\
context.stack().push(exp);\
context.incrementInsIndex();\
}\
void name::print(ostream &out) const\
{\
out << CODE_MOD << insName;\
}
DEF_OP(Neg, 1, -x[0], "neg")
DEF_OP(Add, 2, x[0] + x[1], "add")
DEF_OP(Sub, 2, x[0] - x[1], "sub")
DEF_OP(Mul, 2, x[0]*x[1], "mul")
DEF_OP(Div, 2, x[0]/x[1], "div")
DEF_OP(Pow, 2, pow(x[0],x[1]), "pow")
/******************************************************************************
* ExprNode implementation *
******************************************************************************/
// ExprNode constructors ///////////////////////////////////////////////////////
ExprNode::ExprNode(const string &name)
: name_(name)
, parent_(nullptr)
{}
// ExprNode access /////////////////////////////////////////////////////////////
const string &ExprNode::getName(void) const
{
return name_;
}
Index ExprNode::getNArg(void) const
{
return static_cast<Index>(arg_.size());
}
const ExprNode * ExprNode::getParent(void) const
{
return parent_;
}
Index ExprNode::getLevel(void) const
{
if (getParent())
{
return getParent()->getLevel() + 1;
}
else
{
return 0;
}
}
void ExprNode::setName(const std::string &name)
{
name_ = name;
}
void ExprNode::pushArg(ExprNode *node)
{
if (node)
{
node->parent_ = this;
arg_.push_back(unique_ptr<ExprNode>(node));
}
}
// ExprNode operators //////////////////////////////////////////////////////////
const ExprNode &ExprNode::operator[](const Index i) const
{
return *arg_[i];
}
ostream &Latan::operator<<(ostream &out, const ExprNode &n)
{
Index level = n.getLevel();
for (Index i = 0; i <= level; ++i)
{
if (i == level)
{
out << "_";
}
else if (i == level - 1)
{
out << "|";
}
else
{
out << " ";
}
}
out << " " << n.getName() << endl;
for (Index i = 0; i < n.getNArg(); ++i)
{
out << n[i];
}
return out;
}
#define PUSH_INS(program, type, ...)\
program.push_back(unique_ptr<type>(new type(__VA_ARGS__)))
#define GET_ADDRESS(address, table, name)\
try\
{\
address = (table).at(name);\
}\
catch (out_of_range)\
{\
address = (table).size();\
(table)[(name)] = address;\
}\
// VarNode compile /////////////////////////////////////////////////////////////
void VarNode::compile(Program &program, RunContext &context) const
{
PUSH_INS(program, Push, context.getVariableAddress(getName()), getName());
}
// CstNode compile /////////////////////////////////////////////////////////////
void CstNode::compile(Program &program, RunContext &context __dumb) const
{
PUSH_INS(program, Push, strTo<double>(getName()));
}
// SemicolonNode compile ///////////////////////////////////////////////////////
void SemicolonNode::compile(Program &program, RunContext &context) const
{
auto &n = *this;
for (Index i = 0; i < getNArg(); ++i)
{
bool isAssign = isDerivedFrom<AssignNode>(&n[i]);
bool isSemiColumn = isDerivedFrom<SemicolonNode>(&n[i]);
bool isKeyword = isDerivedFrom<KeywordNode>(&n[i]);
if (isAssign or isSemiColumn or isKeyword)
{
n[i].compile(program, context);
}
}
}
// AssignNode compile //////////////////////////////////////////////////////////
void AssignNode::compile(Program &program, RunContext &context) const
{
auto &n = *this;
if (isDerivedFrom<VarNode>(&n[0]))
{
bool hasSemicolonParent = isDerivedFrom<SemicolonNode>(getParent());
unsigned int address;
n[1].compile(program, context);
address = context.addVariable(n[0].getName());
if (hasSemicolonParent)
{
PUSH_INS(program, Pop, address, n[0].getName());
}
else
{
PUSH_INS(program, Store, address, n[0].getName());
}
}
else
{
LATAN_ERROR(Compilation, "invalid LHS for '='");
}
}
// MathOpNode compile //////////////////////////////////////////////////////////
#define IFNODE(name, nArg) if ((n.getName() == (name)) and (n.getNArg() == nArg))
#define ELIFNODE(name, nArg) else IFNODE(name, nArg)
#define ELSE else
void MathOpNode::compile(Program &program, RunContext &context) const
{
auto &n = *this;
for (Index i = 0; i < n.getNArg(); ++i)
{
n[i].compile(program, context);
}
IFNODE("-", 1) PUSH_INS(program, Neg,);
ELIFNODE("+", 2) PUSH_INS(program, Add,);
ELIFNODE("-", 2) PUSH_INS(program, Sub,);
ELIFNODE("*", 2) PUSH_INS(program, Mul,);
ELIFNODE("/", 2) PUSH_INS(program, Div,);
ELIFNODE("^", 2) PUSH_INS(program, Pow,);
ELSE LATAN_ERROR(Compilation, "unknown operator '" + getName() + "'");
}
// FuncNode compile ////////////////////////////////////////////////////////////
void FuncNode::compile(Program &program, RunContext &context) const
{
auto &n = *this;
for (Index i = 0; i < n.getNArg(); ++i)
{
n[i].compile(program, context);
}
PUSH_INS(program, Call, context.getFunctionAddress(getName()), getName());
}
// ReturnNode compile ////////////////////////////////////////////////////////////
void ReturnNode::compile(Program &program, RunContext &context) const
{
auto &n = *this;
n[0].compile(program, context);
program.push_back(nullptr);
}
/******************************************************************************
* MathInterpreter implementation *
******************************************************************************/
// MathParserState constructor /////////////////////////////////////////////////
MathInterpreter::MathParserState::MathParserState
(istream *stream, string *name, std::unique_ptr<ExprNode> *data)
: ParserState<std::unique_ptr<ExprNode>>(stream, name, data)
{
initScanner();
}
// MathParserState destructor //////////////////////////////////////////////////
MathInterpreter::MathParserState::~MathParserState(void)
{
destroyScanner();
}
// constructors ////////////////////////////////////////////////////////////////
MathInterpreter::MathInterpreter(const std::string &code)
: codeName_("<string>")
{
setCode(code);
}
// access //////////////////////////////////////////////////////////////////////
const Instruction * MathInterpreter::operator[](const Index i) const
{
return program_[i].get();
}
const ExprNode * MathInterpreter::getAST(void) const
{
return root_.get();
}
void MathInterpreter::push(const Instruction *i)
{
program_.push_back(unique_ptr<const Instruction>(i));
}
// initialization //////////////////////////////////////////////////////////////
void MathInterpreter::setCode(const std::string &code)
{
if (status_)
{
reset();
}
code_.reset(new stringstream(code));
codeName_ = "<string>";
state_.reset(new MathParserState(code_.get(), &codeName_, &root_));
program_.clear();
status_ = Status::initialised;
}
void MathInterpreter::reset(void)
{
code_.reset();
codeName_ = "<no_code>";
state_.reset();
root_.reset();
program_.clear();
status_ = 0;
}
// parser //////////////////////////////////////////////////////////////////////
// Bison/Flex parser declaration
int _math_parse(MathInterpreter::MathParserState *state);
void MathInterpreter::parse(void)
{
_math_parse(state_.get());
}
// interpreter /////////////////////////////////////////////////////////////////
#define ADD_FUNC(context, func)\
(context).addFunction(#func, &STDMATH_NAMESPACE::func);\
#define ADD_STDMATH_FUNCS(context)\
ADD_FUNC(context, cos);\
ADD_FUNC(context, sin);\
ADD_FUNC(context, tan);\
ADD_FUNC(context, acos);\
ADD_FUNC(context, asin);\
ADD_FUNC(context, atan);\
ADD_FUNC(context, atan2);\
ADD_FUNC(context, cosh);\
ADD_FUNC(context, sinh);\
ADD_FUNC(context, tanh);\
ADD_FUNC(context, acosh);\
ADD_FUNC(context, asinh);\
ADD_FUNC(context, atanh);\
ADD_FUNC(context, exp);\
ADD_FUNC(context, log);\
ADD_FUNC(context, log10);\
ADD_FUNC(context, exp2);\
ADD_FUNC(context, expm1);\
ADD_FUNC(context, log1p);\
ADD_FUNC(context, log2);\
ADD_FUNC(context, pow);\
ADD_FUNC(context, sqrt);\
ADD_FUNC(context, cbrt);\
ADD_FUNC(context, hypot);\
ADD_FUNC(context, erf);\
ADD_FUNC(context, erfc);\
ADD_FUNC(context, tgamma);\
ADD_FUNC(context, lgamma);\
ADD_FUNC(context, ceil);\
ADD_FUNC(context, floor);\
ADD_FUNC(context, fmod);\
ADD_FUNC(context, trunc);\
ADD_FUNC(context, round);\
ADD_FUNC(context, rint);\
ADD_FUNC(context, nearbyint);\
ADD_FUNC(context, remainder);\
ADD_FUNC(context, fdim);\
ADD_FUNC(context, fmax);\
ADD_FUNC(context, fmin);\
ADD_FUNC(context, fabs);
void MathInterpreter::compile(RunContext &context)
{
bool gotReturn = false;
if (!(status_ & Status::parsed))
{
parse();
status_ |= Status::parsed;
status_ -= status_ & Status::compiled;
}
if (!(status_ & Status::compiled))
{
if (root_)
{
context.addVariable("pi", Math::pi);
context.addVariable("inf", Math::inf);
ADD_STDMATH_FUNCS(context);
root_->compile(program_, context);
for (unsigned int i = 0; i < program_.size(); ++i)
{
if (!program_[i])
{
gotReturn = true;
program_.resize(i);
program_.shrink_to_fit();
break;
}
}
}
if (!root_ or !gotReturn)
{
LATAN_ERROR(Syntax, "expected 'return' in program '" + codeName_
+ "'");
}
status_ |= Status::compiled;
}
}
// execution ///////////////////////////////////////////////////////////////////
void MathInterpreter::operator()(RunContext &context)
{
if (!(status_ & Status::compiled))
{
compile(context);
}
execute(context);
}
void MathInterpreter::execute(RunContext &context) const
{
context.setInsIndex(0);
while (context.getInsIndex() != program_.size())
{
(*(program_[context.getInsIndex()]))(context);
}
}
// IO //////////////////////////////////////////////////////////////////////////
ostream &Latan::operator<<(ostream &out, const MathInterpreter &program)
{
for (unsigned int i = 0; i < program.program_.size(); ++i)
{
out << *(program.program_[i]) << endl;
}
return out;
}

View File

@ -0,0 +1,313 @@
/*
* MathInterpreter.hpp, part of LatAnalyze 3
*
* Copyright (C) 2013 - 2020 Antonin Portelli
*
* LatAnalyze 3 is free software: you can redistribute it and/or modify
* it under the terms of the GNU General Public License as published by
* the Free Software Foundation, either version 3 of the License, or
* (at your option) any later version.
*
* LatAnalyze 3 is distributed in the hope that it will be useful,
* but WITHOUT ANY WARRANTY; without even the implied warranty of
* MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
* GNU General Public License for more details.
*
* You should have received a copy of the GNU General Public License
* along with LatAnalyze 3. If not, see <http://www.gnu.org/licenses/>.
*/
#ifndef Latan_MathInterpreter_hpp_
#define Latan_MathInterpreter_hpp_
#include <LatAnalyze/Functional/Function.hpp>
#include <LatAnalyze/Global.hpp>
#include <LatAnalyze/Core/ParserState.hpp>
#define MAXIDLENGTH 256
BEGIN_LATAN_NAMESPACE
/******************************************************************************
* Class for runtime context *
******************************************************************************/
class RunContext
{
public:
typedef std::map<std::string, unsigned int> AddressTable;
public:
// constructor
RunContext(void) = default;
// destructor
~RunContext(void) = default;
// access
unsigned int addFunction(const std::string &name,
DoubleFunction *init = nullptr);
unsigned int addVariable(const std::string &name,
const double init = 0.);
DoubleFunction * getFunction(const std::string &name) const;
DoubleFunction * getFunction(const unsigned int address) const;
unsigned int getFunctionAddress(const std::string &name) const;
const AddressTable & getFunctionTable(void) const;
unsigned int getInsIndex(void) const;
double getVariable(const std::string &name) const;
double getVariable(const unsigned int address) const;
unsigned int getVariableAddress(const std::string &name) const;
const AddressTable & getVariableTable(void) const;
void incrementInsIndex(const unsigned int inc = 1);
void setFunction(const std::string &name,
DoubleFunction *f);
void setFunction(const unsigned int address,
DoubleFunction *f);
void setInsIndex(const unsigned index);
void setVariable(const std::string &name,
const double value);
void setVariable(const unsigned int address,
const double value);
std::stack<double> & stack(void);
// reset
void reset(void);
private:
unsigned int insIndex_;
std::stack<double> dStack_;
std::vector<double> vMem_;
std::vector<DoubleFunction *> fMem_;
AddressTable vTable_, fTable_;
};
/******************************************************************************
* Instruction classes *
******************************************************************************/
// Abstract base
class Instruction
{
public:
// destructor
virtual ~Instruction(void) = default;
// instruction execution
virtual void operator()(RunContext &context) const = 0;
friend std::ostream & operator<<(std::ostream &out, const Instruction &ins);
private:
virtual void print(std::ostream &out) const = 0;
};
std::ostream & operator<<(std::ostream &out, const Instruction &ins);
// Instruction container
typedef std::vector<std::unique_ptr<const Instruction>> Program;
// Push
class Push: public Instruction
{
private:
enum class ArgType
{
Constant = 0,
Variable = 1
};
public:
//constructors
explicit Push(const double val);
explicit Push(const unsigned int address, const std::string &name);
// instruction execution
virtual void operator()(RunContext &context) const;
private:
virtual void print(std::ostream& out) const;
private:
ArgType type_;
double val_;
unsigned int address_;
std::string name_;
};
// Pop
class Pop: public Instruction
{
public:
//constructor
explicit Pop(const unsigned int address, const std::string &name);
// instruction execution
virtual void operator()(RunContext &context) const;
private:
virtual void print(std::ostream& out) const;
private:
unsigned int address_;
std::string name_;
};
// Store
class Store: public Instruction
{
public:
//constructor
explicit Store(const unsigned int address, const std::string &name);
// instruction execution
virtual void operator()(RunContext &context) const;
private:
virtual void print(std::ostream& out) const;
private:
unsigned int address_;
std::string name_;
};
// Call function
class Call: public Instruction
{
public:
//constructor
explicit Call(const unsigned int address, const std::string &name);
// instruction execution
virtual void operator()(RunContext &context) const;
private:
virtual void print(std::ostream& out) const;
private:
unsigned int address_;
std::string name_;
};
// Floating point operations
#define DECL_OP(name)\
class name: public Instruction\
{\
public:\
virtual void operator()(RunContext &context) const;\
private:\
virtual void print(std::ostream &out) const;\
}
DECL_OP(Neg);
DECL_OP(Add);
DECL_OP(Sub);
DECL_OP(Mul);
DECL_OP(Div);
DECL_OP(Pow);
/******************************************************************************
* Expression node classes *
******************************************************************************/
class ExprNode
{
public:
// constructors
explicit ExprNode(const std::string &name);
// destructor
virtual ~ExprNode() = default;
// access
const std::string& getName(void) const;
Index getNArg(void) const;
const ExprNode * getParent(void) const;
Index getLevel(void) const;
void setName(const std::string &name);
void pushArg(ExprNode *node);
// operator
const ExprNode &operator[](const Index i) const;
// compile
virtual void compile(Program &program, RunContext &context) const = 0;
private:
std::string name_;
std::vector<std::unique_ptr<ExprNode>> arg_;
const ExprNode * parent_;
};
std::ostream &operator<<(std::ostream &out, const ExprNode &n);
#define DECL_NODE(base, name) \
class name: public base\
{\
public:\
using base::base;\
virtual void compile(Program &program, RunContext &context) const;\
}
DECL_NODE(ExprNode, VarNode);
DECL_NODE(ExprNode, CstNode);
DECL_NODE(ExprNode, SemicolonNode);
DECL_NODE(ExprNode, AssignNode);
DECL_NODE(ExprNode, MathOpNode);
DECL_NODE(ExprNode, FuncNode);
class KeywordNode: public ExprNode
{
public:
using ExprNode::ExprNode;
virtual void compile(Program &program, RunContext &context) const = 0;
};
DECL_NODE(KeywordNode, ReturnNode);
/******************************************************************************
* Interpreter class *
******************************************************************************/
class MathInterpreter
{
public:
// parser state
class MathParserState: public ParserState<std::unique_ptr<ExprNode>>
{
public:
// constructor
explicit MathParserState(std::istream *stream, std::string *name,
std::unique_ptr<ExprNode> *data);
// destructor
virtual ~MathParserState(void);
private:
// allocation/deallocation functions defined in MathLexer.lpp
virtual void initScanner(void);
virtual void destroyScanner(void);
};
private:
// status flags
class Status
{
public:
enum
{
none = 0,
initialised = 1 << 0,
parsed = 1 << 1,
compiled = 1 << 2
};
};
public:
// constructors
MathInterpreter(void) = default;
MathInterpreter(const std::string &code);
// destructor
~MathInterpreter(void) = default;
// access
const Instruction * operator[](const Index i) const;
const ExprNode * getAST(void) const;
// initialization
void setCode(const std::string &code);
// interpreter
void compile(RunContext &context);
// execution
void operator()(RunContext &context);
// IO
friend std::ostream & operator<<(std::ostream &out,
const MathInterpreter &program);
private:
// initialization
void reset(void);
// access
void push(const Instruction *i);
// parser
void parse(void);
// interpreter
void compileNode(const ExprNode &node);
// execution
void execute(RunContext &context) const;
private:
std::unique_ptr<std::istream> code_{nullptr};
std::string codeName_{"<no_code>"};
std::unique_ptr<MathParserState> state_{nullptr};
std::unique_ptr<ExprNode> root_{nullptr};
Program program_;
unsigned int status_{Status::none};
};
std::ostream & operator<<(std::ostream &out, const MathInterpreter &program);
END_LATAN_NAMESPACE
#endif // Latan_MathInterpreter_hpp_

View File

@ -0,0 +1,89 @@
/*
* MathLexer.lpp, part of LatAnalyze 3
*
* Copyright (C) 2013 - 2020 Antonin Portelli
*
* LatAnalyze 3 is free software: you can redistribute it and/or modify
* it under the terms of the GNU General Public License as published by
* the Free Software Foundation, either version 3 of the License, or
* (at your option) any later version.
*
* LatAnalyze 3 is distributed in the hope that it will be useful,
* but WITHOUT ANY WARRANTY; without even the implied warranty of
* MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
* GNU General Public License for more details.
*
* You should have received a copy of the GNU General Public License
* along with LatAnalyze 3. If not, see <http://www.gnu.org/licenses/>.
*/
%option reentrant
%option prefix="_math_"
%option bison-bridge
%option bison-locations
%option noyywrap
%option yylineno
%{
#include <LatAnalyze/Core/MathInterpreter.hpp>
#include "MathParser.hpp"
using namespace std;
using namespace Latan;
#define YY_EXTRA_TYPE MathInterpreter::MathParserState *
#define YY_USER_ACTION \
yylloc->first_line = yylloc->last_line = yylineno;\
yylloc->first_column = yylloc->last_column + 1;\
yylloc->last_column = yylloc->first_column + yyleng - 1;
#define YY_INPUT(buf, result, max_size) \
{ \
(*yyextra->stream).read(buf, max_size);\
result = (*yyextra->stream).gcount();\
}
#define YY_DEBUG 0
#if (YY_DEBUG == 1)
#define RET(var) cout << (var) << "(" << yytext << ")" << endl; return (var)
#define RETTOK(tok) cout << #tok << "(" << yytext << ")" << endl; return tok
#else
#define RET(var) return (var)
#define RETTOK(tok) return tok
#endif
%}
DIGIT [0-9]
ALPHA [a-zA-Z_]
FLOAT (({DIGIT}+(\.{DIGIT}*)?)|({DIGIT}*\.{DIGIT}+))([eE][+-]?{DIGIT}+)?
SPECIAL [;,()+\-*/^={}]
BLANK [ \t]
%%
{FLOAT} {
strncpy(yylval->val_str,yytext,MAXIDLENGTH);
RETTOK(FLOAT);
}
{SPECIAL} {RET(*yytext);}
return {RETTOK(RETURN);}
{ALPHA}({ALPHA}|{DIGIT})* {
strncpy(yylval->val_str,yytext,MAXIDLENGTH);
RETTOK(ID);
}
<*>\n {yylloc->last_column = 0;}
<*>{BLANK}
<*>. {yylval->val_char = yytext[0]; RETTOK(ERR);}
%%
void MathInterpreter::MathParserState::initScanner()
{
yylex_init(&scanner);
yyset_extra(this, scanner);
}
void MathInterpreter::MathParserState::destroyScanner()
{
yylex_destroy(scanner);
}

View File

@ -0,0 +1,142 @@
/*
* MathParser.ypp, part of LatAnalyze 3
*
* Copyright (C) 2013 - 2020 Antonin Portelli
*
* LatAnalyze 3 is free software: you can redistribute it and/or modify
* it under the terms of the GNU General Public License as published by
* the Free Software Foundation, either version 3 of the License, or
* (at your option) any later version.
*
* LatAnalyze 3 is distributed in the hope that it will be useful,
* but WITHOUT ANY WARRANTY; without even the implied warranty of
* MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
* GNU General Public License for more details.
*
* You should have received a copy of the GNU General Public License
* along with LatAnalyze 3. If not, see <http://www.gnu.org/licenses/>.
*/
%{
#include <LatAnalyze/Global.hpp>
#include <LatAnalyze/Core/MathInterpreter.hpp>
using namespace std;
using namespace Latan;
%}
%pure-parser
%name-prefix "_math_"
%locations
%defines
%error-verbose
%parse-param { Latan::MathInterpreter::MathParserState *state }
%initial-action {yylloc.last_column = 0;}
%lex-param { void* scanner }
%union
{
double val_double;
char val_char;
char val_str[MAXIDLENGTH];
Latan::ExprNode *val_node;
}
%token <val_char> ERR
%token <val_str> FLOAT
%token <val_str> ID
%token RETURN
%left '='
%left '+' '-'
%left '*' '/'
%nonassoc UMINUS
%left '^'
%type <val_node> stmt stmt_list expr func_args
%{
int _math_lex(YYSTYPE *lvalp, YYLTYPE *llocp, void *scanner);
void _math_error(YYLTYPE *locp, MathInterpreter::MathParserState *state,
const char *err)
{
stringstream buf;
buf << *(state->streamName) << ":" << locp->first_line << ":"\
<< locp->first_column << ": " << err;
LATAN_ERROR(Parsing, buf.str());
}
void _math_warning(YYLTYPE *locp, MathInterpreter::MathParserState *state,
const char *err)
{
stringstream buf;
buf << *(state->streamName) << ":" << locp->first_line << ":"\
<< locp->first_column << ": " << err;
LATAN_WARNING(buf.str());
}
#define scanner state->scanner
%}
%%
program:
/* empty string */
| stmt_list {(*(state->data)).reset($1);}
;
stmt:
';'
{$$ = new SemicolonNode(";");}
| expr ';'
{$$ = nullptr; _math_warning(&yylloc, state, "useless statement removed");}
| ID '=' expr ';'
{$$ = new AssignNode("="); $$->pushArg(new VarNode($1)); $$->pushArg($3);}
| RETURN expr ';'
{$$ = new ReturnNode("return"); $$->pushArg($2);}
| '{' stmt_list '}'
{$$ = $2;}
;
stmt_list:
stmt
{$$ = $1;}
| stmt_list stmt
{$$ = new SemicolonNode(";"); $$->pushArg($1); $$->pushArg($2);}
;
expr:
FLOAT
{$$ = new CstNode($1);}
| ID
{$$ = new VarNode($1);}
| '-' expr %prec UMINUS
{$$ = new MathOpNode("-"); $$->pushArg($2);}
| expr '+' expr
{$$ = new MathOpNode("+"); $$->pushArg($1); $$->pushArg($3);}
| expr '-' expr
{$$ = new MathOpNode("-"); $$->pushArg($1); $$->pushArg($3);}
| expr '*' expr
{$$ = new MathOpNode("*"); $$->pushArg($1); $$->pushArg($3);}
| expr '/' expr
{$$ = new MathOpNode("/"); $$->pushArg($1); $$->pushArg($3);}
| expr '^' expr
{$$ = new MathOpNode("^"); $$->pushArg($1); $$->pushArg($3);}
| '(' expr ')'
{$$ = $2;}
| ID '(' func_args ')'
{$$ = $3; $$->setName($1);}
;
func_args:
/* empty string */
{$$ = new FuncNode("");}
| expr
{$$ = new FuncNode(""); $$->pushArg($1);}
| func_args ',' expr
{$$ = $1; $$->pushArg($3);}
;

View File

@ -0,0 +1,299 @@
/*
* OptParser.cpp, part of LatAnalyze
*
* Copyright (C) 2016 Antonin Portelli
*
* LatAnalyze is free software: you can redistribute it and/or modify
* it under the terms of the GNU General Public License as published by
* the Free Software Foundation, either version 3 of the License, or
* (at your option) any later version.
*
* LatAnalyze is distributed in the hope that it will be useful,
* but WITHOUT ANY WARRANTY; without even the implied warranty of
* MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
* GNU General Public License for more details.
*
* You should have received a copy of the GNU General Public License
* along with LatAnalyze. If not, see <http://www.gnu.org/licenses/>.
*/
#include <LatAnalyze/Core/OptParser.hpp>
#include <LatAnalyze/includes.hpp>
using namespace std;
using namespace Latan;
static char optRegex[] = "(-([a-zA-Z])(.+)?)|(--([a-zA-Z_-]+)=?(.+)?)";
/******************************************************************************
* OptParser implementation *
******************************************************************************/
// regular expressions /////////////////////////////////////////////////////////
const regex OptParser::optRegex_(optRegex);
// access //////////////////////////////////////////////////////////////////////
void OptParser::addOption(const std::string shortName,
const std::string longName,
const OptType type, const bool optional,
const std::string helpMessage,
const std::string defaultVal)
{
OptPar par;
par.shortName = shortName;
par.longName = longName;
par.defaultVal = defaultVal;
par.helpMessage = helpMessage;
par.type = type;
par.optional = optional;
auto it = std::find_if(opt_.begin(), opt_.end(), [&par](const OptPar & p)
{
bool match = false;
match |= (par.shortName == p.shortName) and !par.shortName.empty();
match |= (par.longName == p.longName) and !par.longName.empty();
return match;
});
if (it != opt_.end())
{
string opt;
if (!it->shortName.empty())
{
opt += "-" + it->shortName;
}
if (!opt.empty())
{
opt += "/";
}
if (!it->longName.empty())
{
opt += "--" + it->longName;
}
throw(logic_error("duplicate option " + opt + " (in the code, not in the command line)"));
}
opt_.push_back(par);
}
bool OptParser::gotOption(const std::string name) const
{
int i = optIndex(name);
if (result_.size() != opt_.size())
{
throw(runtime_error("options not parsed"));
}
if (i >= 0)
{
return result_[i].present;
}
else
{
throw(out_of_range("no option with name '" + name + "'"));
}
}
const vector<string> & OptParser::getArgs(void) const
{
return arg_;
}
// parse ///////////////////////////////////////////////////////////////////////
bool OptParser::parse(int argc, char *argv[])
{
smatch sm;
queue<string> arg;
int expectVal = -1;
bool isCorrect = true;
for (int i = 1; i < argc; ++i)
{
arg.push(argv[i]);
}
result_.clear();
result_.resize(opt_.size());
arg_.clear();
for (unsigned int i = 0; i < opt_.size(); ++i)
{
result_[i].value = opt_[i].defaultVal;
}
while (!arg.empty())
{
// option
if (regex_match(arg.front(), sm, optRegex_))
{
// should it be a value?
if (expectVal >= 0)
{
cerr << "warning: expected value for option ";
cerr << optName(opt_[expectVal]);
cerr << ", got option '" << arg.front() << "' instead" << endl;
expectVal = -1;
isCorrect = false;
}
// short option
if (sm[1].matched)
{
string optName = sm[2].str();
// find option
auto it = find_if(opt_.begin(), opt_.end(),
[&optName](const OptPar &p)
{
return (p.shortName == optName);
});
// parse if found
if (it != opt_.end())
{
unsigned int i = it - opt_.begin();
result_[i].present = true;
if (opt_[i].type == OptType::value)
{
if (sm[3].matched)
{
result_[i].value = sm[3].str();
}
else
{
expectVal = i;
}
}
}
// warning if not found
else
{
cerr << "warning: unknown option '" << arg.front() << "'";
cerr << endl;
}
}
// long option
else if (sm[4].matched)
{
string optName = sm[5].str();
// find option
auto it = find_if(opt_.begin(), opt_.end(),
[&optName](const OptPar &p)
{
return (p.longName == optName);
});
// parse if found
if (it != opt_.end())
{
unsigned int i = it - opt_.begin();
result_[i].present = true;
if (opt_[i].type == OptType::value)
{
if (sm[6].matched)
{
result_[i].value = sm[6].str();
}
else
{
expectVal = i;
}
}
}
// warning if not found
else
{
cerr << "warning: unknown option '" << arg.front() << "'";
cerr << endl;
}
}
}
else if (expectVal >= 0)
{
result_[expectVal].value = arg.front();
expectVal = -1;
}
else
{
arg_.push_back(arg.front());
}
arg.pop();
}
if (expectVal >= 0)
{
cerr << "warning: expected value for option ";
cerr << optName(opt_[expectVal]) << endl;
expectVal = -1;
isCorrect = false;
}
for (unsigned int i = 0; i < opt_.size(); ++i)
{
if (!opt_[i].optional and !result_[i].present)
{
cerr << "warning: mandatory option " << optName(opt_[i]);
cerr << " is missing" << endl;
isCorrect = false;
}
}
return isCorrect;
}
// find option index ///////////////////////////////////////////////////////////
int OptParser::optIndex(const string name) const
{
auto it = find_if(opt_.begin(), opt_.end(), [&name](const OptPar &p)
{
return (p.shortName == name) or (p.longName == name);
});
if (it != opt_.end())
{
return static_cast<int>(it - opt_.begin());
}
else
{
return -1;
}
}
// option name for messages ////////////////////////////////////////////////////
std::string OptParser::optName(const OptPar &opt)
{
std::string res = "";
if (!opt.shortName.empty())
{
res += "-" + opt.shortName;
if (!opt.longName.empty())
{
res += "/";
}
}
if (!opt.longName.empty())
{
res += "--" + opt.longName;
if (opt.type == OptParser::OptType::value)
{
res += "=";
}
}
return res;
}
// print option list ///////////////////////////////////////////////////////////
std::ostream & Latan::operator<<(std::ostream &out, const OptParser &parser)
{
for (auto &o: parser.opt_)
{
out << setw(20) << OptParser::optName(o);
out << ": " << o.helpMessage;
if (!o.defaultVal.empty())
{
out << " (default: " << o.defaultVal << ")";
}
out << endl;
}
return out;
}

View File

@ -0,0 +1,103 @@
/*
* OptParser.hpp, part of LatAnalyze
*
* Copyright (C) 2016 Antonin Portelli
*
* LatAnalyze is free software: you can redistribute it and/or modify
* it under the terms of the GNU General Public License as published by
* the Free Software Foundation, either version 3 of the License, or
* (at your option) any later version.
*
* LatAnalyze is distributed in the hope that it will be useful,
* but WITHOUT ANY WARRANTY; without even the implied warranty of
* MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
* GNU General Public License for more details.
*
* You should have received a copy of the GNU General Public License
* along with LatAnalyze. If not, see <http://www.gnu.org/licenses/>.
*/
#ifndef LatAnalyze_OptParser_hpp_
#define LatAnalyze_OptParser_hpp_
#include <LatAnalyze/Global.hpp>
BEGIN_LATAN_NAMESPACE
/******************************************************************************
* command-line option parser *
******************************************************************************/
class OptParser
{
public:
enum class OptType {value, trigger};
private:
struct OptPar
{
std::string shortName, longName, defaultVal, helpMessage;
OptType type;
bool optional;
};
struct OptRes
{
std::string value;
bool present;
};
public:
// constructor
OptParser(void) = default;
// destructor
virtual ~OptParser(void) = default;
// access
void addOption(const std::string shortName, const std::string longName,
const OptType type, const bool optional = false,
const std::string helpMessage = "",
const std::string defaultVal = "");
bool gotOption(const std::string name) const;
template <typename T = std::string>
T optionValue(const std::string name) const;
const std::vector<std::string> & getArgs(void) const;
// parse
bool parse(int argc, char *argv[]);
// print option list
friend std::ostream & operator<<(std::ostream &out,
const OptParser &parser);
private:
// find option index
int optIndex(const std::string name) const;
// option name for messages
static std::string optName(const OptPar &opt);
private:
std::vector<OptPar> opt_;
std::vector<OptRes> result_;
std::vector<std::string> arg_;
static const std::regex optRegex_;
};
std::ostream & operator<<(std::ostream &out, const OptParser &parser);
/******************************************************************************
* OptParser template implementation *
******************************************************************************/
template <typename T>
T OptParser::optionValue(const std::string name) const
{
int i = optIndex(name);
if (result_.size() != opt_.size())
{
throw(std::runtime_error("options not parsed"));
}
if (i >= 0)
{
return strTo<T>(result_[i].value);
}
else
{
throw(std::out_of_range("no option with name '" + name + "'"));
}
}
END_LATAN_NAMESPACE
#endif // LatAnalyze_OptParser_hpp_

View File

@ -0,0 +1,58 @@
/*
* ParserState.hpp, part of LatAnalyze 3
*
* Copyright (C) 2013 - 2020 Antonin Portelli
*
* LatAnalyze 3 is free software: you can redistribute it and/or modify
* it under the terms of the GNU General Public License as published by
* the Free Software Foundation, either version 3 of the License, or
* (at your option) any later version.
*
* LatAnalyze 3 is distributed in the hope that it will be useful,
* but WITHOUT ANY WARRANTY; without even the implied warranty of
* MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
* GNU General Public License for more details.
*
* You should have received a copy of the GNU General Public License
* along with LatAnalyze 3. If not, see <http://www.gnu.org/licenses/>.
*/
#ifndef Latan_ParserState_hpp_
#define Latan_ParserState_hpp_
#include <LatAnalyze/Global.hpp>
BEGIN_LATAN_NAMESPACE
template <typename DataObj>
class ParserState
{
public:
// constructor
ParserState(std::istream *streamPt, std::string *namePt, DataObj *dataPt);
// destructor
virtual ~ParserState(void) = default;
private:
// scanner allocation/deallocation
virtual void initScanner(void) = 0;
virtual void destroyScanner(void) = 0;
public:
DataObj *data;
void *scanner;
std::istream *stream;
std::string *streamName;
};
template <typename DataObj>
ParserState<DataObj>::ParserState(std::istream *streamPt, std::string *namePt,
DataObj *dataPt)
: data(dataPt)
, scanner(nullptr)
, stream(streamPt)
, streamName(namePt)
{}
END_LATAN_NAMESPACE
#endif // Latan_ParserState_hpp_

View File

@ -0,0 +1,951 @@
/*
* Plot.cpp, part of LatAnalyze 3
*
* Copyright (C) 2013 - 2020 Antonin Portelli
*
* LatAnalyze 3 is free software: you can redistribute it and/or modify
* it under the terms of the GNU General Public License as published by
* the Free Software Foundation, either version 3 of the License, or
* (at your option) any later version.
*
* LatAnalyze 3 is distributed in the hope that it will be useful,
* but WITHOUT ANY WARRANTY; without even the implied warranty of
* MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
* GNU General Public License for more details.
*
* You should have received a copy of the GNU General Public License
* along with LatAnalyze 3. If not, see <http://www.gnu.org/licenses/>.
*/
#include <LatAnalyze/Core/Plot.hpp>
#include <LatAnalyze/includes.hpp>
#include <LatAnalyze/Core/Mat.hpp>
using namespace std;
using namespace Latan;
/******************************************************************************
* Plot objects *
******************************************************************************/
// PlotObject access ///////////////////////////////////////////////////////////
const string & PlotObject::getCommand(void) const
{
return command_;
}
const string & PlotObject::getHeadCommand(void) const
{
return headCommand_;
}
void PlotObject::setCommand(const string &command)
{
command_ = command;
}
void PlotObject::setHeadCommand(const string &command)
{
headCommand_ = command;
}
string PlotObject::popTmpFile(void)
{
string res = tmpFileName_.top();
tmpFileName_.pop();
return res;
}
void PlotObject::pushTmpFile(const std::string &fileName)
{
tmpFileName_.push(fileName);
}
// PlotObject dump a matrix to a temporary file ////////////////////////////////
string PlotObject::dumpToTmpFile(const DMat &m)
{
char tmpFileName[MAX_PATH_LENGTH];
int fd;
FILE *tmpFile;
for (Index j = 0; j < m.cols(); ++j)
{
}
sprintf(tmpFileName, "%s/latan_plot_tmp.XXXXXX.dat", P_tmpdir);
fd = mkstemps(tmpFileName, 4);
if (fd == -1)
{
LATAN_ERROR(System, "impossible to create a temporary file from template "
+ strFrom(tmpFileName));
}
tmpFile = fdopen(fd, "w");
for (Index i = 0; i < m.rows(); ++i)
{
for (Index j = 0; j < m.cols(); ++j)
{
fprintf(tmpFile, "%e ", m(i, j));
}
fprintf(tmpFile, "\n");
}
fclose(tmpFile);
return string(tmpFileName);
}
// PlotObject test /////////////////////////////////////////////////////////////
bool PlotObject::gotTmpFile(void) const
{
return !tmpFileName_.empty();
}
// PlotCommand constructor /////////////////////////////////////////////////////
PlotCommand::PlotCommand(const string &command)
{
setCommand(command);
}
// PlotHeadCommand constructor /////////////////////////////////////////////////
PlotHeadCommand::PlotHeadCommand(const string &command)
{
setHeadCommand(command);
}
// PlotData constructor ////////////////////////////////////////////////////////
PlotData::PlotData(const DMatSample &x, const DMatSample &y, const bool abs)
{
if (x[central].rows() != y[central].rows())
{
LATAN_ERROR(Size, "x and y vectors do not have the same size");
}
DMat d(x[central].rows(), 4);
string usingCmd, tmpFileName;
d.col(0) = x[central].col(0);
d.col(2) = y[central].col(0);
d.col(1) = x.variance().cwiseSqrt().col(0);
d.col(3) = y.variance().cwiseSqrt().col(0);
tmpFileName = dumpToTmpFile(d);
pushTmpFile(tmpFileName);
if (!abs)
{
setCommand("'" + tmpFileName + "' u 1:3:2:4 w xyerr");
}
else
{
setCommand("'" + tmpFileName + "' u 1:(abs($3)):2:4 w xyerr");
}
}
PlotData::PlotData(const DVec &x, const DMatSample &y, const bool abs)
{
if (x.rows() != y[central].rows())
{
LATAN_ERROR(Size, "x and y vector does not have the same size");
}
DMat d(x.rows(), 3);
string usingCmd, tmpFileName;
d.col(0) = x;
d.col(1) = y[central].col(0);
d.col(2) = y.variance().cwiseSqrt().col(0);
tmpFileName = dumpToTmpFile(d);
pushTmpFile(tmpFileName);
if (!abs)
{
setCommand("'" + tmpFileName + "' u 1:2:3 w yerr");
}
else
{
setCommand("'" + tmpFileName + "' u 1:(abs($2)):3 w yerr");
}
}
PlotData::PlotData(const DMatSample &x, const DVec &y, const bool abs)
{
if (x[central].rows() != y.rows())
{
LATAN_ERROR(Size, "x and y vectors do not have the same size");
}
DMat d(x[central].rows(), 3), xerr, yerr;
string usingCmd, tmpFileName;
d.col(0) = x[central].col(0);
d.col(2) = y;
d.col(1) = x.variance().cwiseSqrt().col(0);
tmpFileName = dumpToTmpFile(d);
pushTmpFile(tmpFileName);
if (!abs)
{
setCommand("'" + tmpFileName + "' u 1:3:2 w xerr");
}
else
{
setCommand("'" + tmpFileName + "' u 1:(abs($3)):2 w xerr");
}
}
PlotData::PlotData(const XYStatData &data, const Index i, const Index j, const bool abs)
{
string usingCmd, tmpFileName;
if (!abs)
{
usingCmd = (data.isXExact(i)) ? "u 1:3:4 w yerr" : "u 1:3:2:4 w xyerr";
}
else
{
usingCmd = (data.isXExact(i)) ? "u 1:(abs($3)):4 w yerr" : "u 1:(abs($3)):2:4 w xyerr";
}
tmpFileName = dumpToTmpFile(data.getTable(i, j));
pushTmpFile(tmpFileName);
setCommand("'" + tmpFileName + "' " + usingCmd);
}
// PlotPoint constructor ///////////////////////////////////////////////////////
PlotPoint::PlotPoint(const double x, const double y)
{
DMat d(1, 2);
string usingCmd, tmpFileName;
d(0, 0) = x;
d(0, 1) = y;
tmpFileName = dumpToTmpFile(d);
pushTmpFile(tmpFileName);
setCommand("'" + tmpFileName + "' u 1:2");
}
PlotPoint::PlotPoint(const DSample &x, const double y)
{
DMat d(1, 3);
string usingCmd, tmpFileName;
d(0, 0) = x[central];
d(0, 2) = y;
d(0, 1) = sqrt(x.variance());
tmpFileName = dumpToTmpFile(d);
pushTmpFile(tmpFileName);
setCommand("'" + tmpFileName + "' u 1:3:2 w xerr");
}
PlotPoint::PlotPoint(const double x, const DSample &y)
{
DMat d(1, 3);
string usingCmd, tmpFileName;
d(0, 0) = x;
d(0, 1) = y[central];
d(0, 2) = sqrt(y.variance());
tmpFileName = dumpToTmpFile(d);
pushTmpFile(tmpFileName);
setCommand("'" + tmpFileName + "' u 1:2:3 w yerr");
}
PlotPoint::PlotPoint(const DSample &x, const DSample &y)
{
DMat d(1, 4);
string usingCmd, tmpFileName;
d(0, 0) = x[central];
d(0, 2) = y[central];
d(0, 1) = sqrt(x.variance());
d(0, 3) = sqrt(y.variance());
tmpFileName = dumpToTmpFile(d);
pushTmpFile(tmpFileName);
setCommand("'" + tmpFileName + "' u 1:3:2:4 w xyerr");
}
// PlotLine constructor ////////////////////////////////////////////////////////
PlotLine::PlotLine(const DVec &x, const DVec &y)
{
if (x.size() != y.size())
{
LATAN_ERROR(Size, "x and y vectors do not have the same size");
}
DMat d(x.size(), 2);
string usingCmd, tmpFileName;
d.col(0) = x;
d.col(1) = y;
tmpFileName = dumpToTmpFile(d);
pushTmpFile(tmpFileName);
setCommand("'" + tmpFileName + "' u 1:2 w lines");
}
// PlotPoints constructor ////////////////////////////////////////////////////////
PlotPoints::PlotPoints(const DVec &x, const DVec &y)
{
if (x.size() != y.size())
{
LATAN_ERROR(Size, "x and y vectors do not have the same size");
}
DMat d(x.size(), 2);
string usingCmd, tmpFileName;
d.col(0) = x;
d.col(1) = y;
tmpFileName = dumpToTmpFile(d);
pushTmpFile(tmpFileName);
setCommand("'" + tmpFileName + "' u 1:2");
}
// PlotHLine constructor ///////////////////////////////////////////////////////
PlotHLine::PlotHLine(const double y)
{
setCommand(strFrom(y));
}
// PlotHBand constructor ///////////////////////////////////////////////////////
PlotBand::PlotBand(const double xMin, const double xMax, const double yMin,
const double yMax, const double opacity)
{
setCommand("'< printf \"%e %e\\n%e %e\\n%e %e\\n%e %e\\n%e %e\\n\" "
+ strFrom(xMin) + " " + strFrom(yMin) + " "
+ strFrom(xMax) + " " + strFrom(yMin) + " "
+ strFrom(xMax) + " " + strFrom(yMax) + " "
+ strFrom(xMin) + " " + strFrom(yMax) + " "
+ strFrom(xMin) + " " + strFrom(yMin)
+ "' u 1:2 w filledcurves closed fs solid " + strFrom(opacity)
+ " noborder");
}
// PlotFunction constructor ////////////////////////////////////////////////////
PlotFunction::PlotFunction(const DoubleFunction &function, const double xMin,
const double xMax, const unsigned int nPoint,
const bool abs)
{
DMat d(nPoint, 2);
string tmpFileName;
double dx = (xMax - xMin)/static_cast<double>(nPoint - 1);
for (Index i = 0; i < nPoint; ++i)
{
d(i, 0) = xMin + i*dx;
d(i, 1) = function(d(i, 0));
}
tmpFileName = dumpToTmpFile(d);
pushTmpFile(tmpFileName);
if (!abs)
{
setCommand("'" + tmpFileName + "' u 1:2 w lines");
}
else
{
setCommand("'" + tmpFileName + "' u 1:(abs($2)) w lines");
}
}
// PlotPredBand constructor ////////////////////////////////////////////////////
void PlotPredBand::makePredBand(const DMat &low, const DMat &high, const double opacity)
{
string lowFileName, highFileName, contFileName;
DMat contour(low.rows() + high.rows() + 1, 2);
FOR_MAT(low, i, j)
{
contour(i, j) = low(i, j);
}
FOR_MAT(high, i, j)
{
contour(low.rows() + i, j) = high(high.rows() - i - 1, j);
}
contour.row(low.rows() + high.rows()) = low.row(0);
contFileName = dumpToTmpFile(contour);
pushTmpFile(contFileName);
setCommand("'" + contFileName + "' u 1:2 w filledcurves closed" +
" fs solid " + strFrom(opacity) + " noborder");
}
PlotPredBand::PlotPredBand(const DVec &x, const DVec &y, const DVec &yerr,
const double opacity)
{
if (x.size() != y.size())
{
LATAN_ERROR(Size, "x and y vectors do not have the same size");
}
if (y.size() != yerr.size())
{
LATAN_ERROR(Size, "y and y error vectors do not have the same size");
}
Index nPoint = x.size();
DMat dLow(nPoint, 2), dHigh(nPoint, 2);
dLow.col(0) = x;
dLow.col(1) = y - yerr;
dHigh.col(0) = x;
dHigh.col(1) = y + yerr;
makePredBand(dLow, dHigh, opacity);
}
PlotPredBand::PlotPredBand(const DoubleFunctionSample &function,
const double xMin, const double xMax,
const unsigned int nPoint, const double opacity)
{
DMat dLow(nPoint, 2), dHigh(nPoint, 2);
DSample pred(function.size());
double dx = (xMax - xMin)/static_cast<double>(nPoint - 1);
string lowFileName, highFileName;
for (Index i = 0; i < nPoint; ++i)
{
double x = xMin + i*dx, err;
pred = function(x);
err = sqrt(pred.variance());
dLow(i, 0) = x;
dLow(i, 1) = pred[central] - err;
dHigh(i, 0) = x;
dHigh(i, 1) = pred[central] + err;
}
makePredBand(dLow, dHigh, opacity);
}
// PlotHistogram constructor ///////////////////////////////////////////////////
PlotHistogram::PlotHistogram(const Histogram &h)
{
DMat d(h.size(), 2);
string tmpFileName;
for (Index i = 0; i < h.size(); ++i)
{
d(i, 0) = h.getX(i);
d(i, 1) = h[i];
}
tmpFileName = dumpToTmpFile(d);
pushTmpFile(tmpFileName);
setCommand("'" + tmpFileName + "' u 1:2 w steps");
}
// PlotImpulses constructor ////////////////////////////////////////////////////
PlotImpulses::PlotImpulses(const DVec &x, const DVec &y)
{
if (x.rows() != y.rows())
{
LATAN_ERROR(Size, "x and y vector does not have the same size");
}
DMat d(x.rows(), 2);
string tmpFileName;
for (Index i = 0; i < x.rows(); ++i)
{
d(i, 0) = x(i);
d(i, 1) = y(i);
}
tmpFileName = dumpToTmpFile(d);
pushTmpFile(tmpFileName);
setCommand("'" + tmpFileName + "' u 1:2 w impulses");
}
// PlotMatrixNoRange constructor ///////////////////////////////////////////////
PlotMatrixNoRange::PlotMatrixNoRange(const DMat &m)
{
string tmpFileName = dumpToTmpFile(m);
pushTmpFile(tmpFileName);
setCommand("'" + tmpFileName + "' matrix w image");
}
/******************************************************************************
* Plot modifiers *
******************************************************************************/
// Caption constructor /////////////////////////////////////////////////////////
Caption::Caption(const string &caption)
: caption_(caption)
{}
// Caption modifier ////////////////////////////////////////////////////////////
void Caption::operator()(PlotOptions &option) const
{
option.caption = caption_;
}
// Label constructor ///////////////////////////////////////////////////////////
Label::Label(const string &label, const Axis axis)
: label_(label)
, axis_(axis)
{}
// Label modifier //////////////////////////////////////////////////////////////
void Label::operator()(PlotOptions &option) const
{
option.label[static_cast<int>(axis_)] = label_;
}
// Color constructor ///////////////////////////////////////////////////////////
Color::Color(const string &color)
: color_(color)
{}
// Color modifier //////////////////////////////////////////////////////////////
void Color::operator()(PlotOptions &option) const
{
option.lineColor = color_;
}
// LineWidth constructor ///////////////////////////////////////////////////////
LineWidth::LineWidth(const unsigned int width)
: width_(width)
{}
// LineWidth modifier //////////////////////////////////////////////////////////
void LineWidth::operator()(PlotOptions &option) const
{
option.lineWidth = static_cast<int>(width_);
}
// Dash constructor ///////////////////////////////////////////////////////////
Dash::Dash(const string &dash)
: dash_(dash)
{}
// Dash modifier //////////////////////////////////////////////////////////////
void Dash::operator()(PlotOptions &option) const
{
option.dashType = dash_;
}
// LogScale constructor ////////////////////////////////////////////////////////
LogScale::LogScale(const Axis axis, const double basis)
: axis_(axis)
, basis_(basis)
{}
// Logscale modifier ///////////////////////////////////////////////////////////
void LogScale::operator()(PlotOptions &option) const
{
option.scaleMode[static_cast<int>(axis_)] |= Plot::Scale::log;
option.logScaleBasis[static_cast<int>(axis_)] = basis_;
}
// PlotRange constructors //////////////////////////////////////////////////////
PlotRange::PlotRange(const Axis axis)
: axis_(axis)
, reset_(true)
, min_(0.)
, max_(0.)
{}
PlotRange::PlotRange(const Axis axis, const double min, const double max)
: axis_(axis)
, reset_(false)
, min_(min)
, max_(max)
{}
// PlotRange modifier ///////////////////////////////////////////////////////////
void PlotRange::operator()(PlotOptions &option) const
{
int a = static_cast<int>(axis_);
if (!reset_)
{
option.scaleMode[a] |= Plot::Scale::manual;
option.scale[a].min = min_;
option.scale[a].max = max_;
}
else
{
option.scaleMode[a] = Plot::Scale::reset;
}
}
// Terminal constructor ////////////////////////////////////////////////////////
Terminal::Terminal(const string &terminal, const std::string &options)
: terminalCmd_(terminal + " " + options)
{}
// Terminal modifier ///////////////////////////////////////////////////////////
void Terminal::operator()(PlotOptions &option) const
{
option.terminal = terminalCmd_;
}
// Title constructor ///////////////////////////////////////////////////////////
Title::Title(const string &title)
: title_(title)
{}
// Title modifier //////////////////////////////////////////////////////////////
void Title::operator()(PlotOptions &option) const
{
option.title = title_;
}
// Palette constructor /////////////////////////////////////////////////////////
Palette::Palette(const std::vector<std::string> &palette)
: palette_(palette)
{}
// Palette modifier ////////////////////////////////////////////////////////////
void Palette::operator()(PlotOptions &option) const
{
option.palette = palette_;
}
// category10 palette //////////////////////////////////////////////////////////
const std::vector<std::string> Palette::category10 =
{
"rgb '#1f77b4'",
"rgb '#ff7f0e'",
"rgb '#2ca02c'",
"rgb '#d62728'",
"rgb '#9467bd'",
"rgb '#8c564b'",
"rgb '#e377c2'",
"rgb '#7f7f7f'",
"rgb '#bcbd22'"
};
/******************************************************************************
* Plot implementation *
******************************************************************************/
// constructor /////////////////////////////////////////////////////////////////
Plot::Plot(void)
{
initOptions();
}
// default options /////////////////////////////////////////////////////////////
void Plot::initOptions(void)
{
options_.terminal = "qt noenhanced font 'Arial,12'";
options_.output = "";
options_.caption = "";
options_.title = "";
options_.scaleMode[0] = Plot::Scale::reset;
options_.scaleMode[1] = Plot::Scale::reset;
options_.scale[0] = {0.0, 0.0};
options_.scale[1] = {0.0, 0.0};
options_.label[0] = "";
options_.label[1] = "";
options_.lineColor = "";
options_.lineWidth = -1;
options_.dashType = "";
options_.palette = Palette::category10;
}
// plot reset //////////////////////////////////////////////////////////////////
void Plot::reset(void)
{
headCommand_.clear();
plotCommand_.clear();
tmpFileName_.clear();
initOptions();
}
// plot objects ////////////////////////////////////////////////////////////////
Plot & Plot::operator<<(PlotObject &&command)
{
string commandStr;
while (command.gotTmpFile())
{
tmpFileName_.push_back(command.popTmpFile());
commandStr += "'" + tmpFileName_.back() + "' ";
}
commandStr = command.getCommand();
if (!commandStr.empty())
{
if (!options_.lineColor.empty())
{
commandStr += " lc " + options_.lineColor;
options_.lineColor = "";
}
if (options_.lineWidth > 0)
{
commandStr += " lw " + strFrom(options_.lineWidth);
options_.lineWidth = -1;
}
if (!options_.dashType.empty())
{
commandStr += " dt " + options_.dashType;
options_.dashType = "";
}
if (options_.title.empty())
{
commandStr += " notitle";
}
else
{
commandStr += " t '" + options_.title + "'";
options_.title = "";
}
plotCommand_.push_back(commandStr);
}
if (!command.getHeadCommand().empty())
{
headCommand_.push_back(command.getHeadCommand());
}
return *this;
}
Plot & Plot::operator<<(PlotModifier &&modifier)
{
modifier(options_);
return *this;
}
// find gnuplot ////////////////////////////////////////////////////////////////
#define SEARCH_DIR(dir) \
sprintf(buf, "%s/%s", dir, gnuplotBin_.c_str());\
if (access(buf, X_OK) == 0)\
{\
return dir;\
}
std::string Plot::getProgramPath(void)
{
int i, j, lg;
char *path;
static char buf[MAX_PATH_LENGTH];
// try out in all paths given in the PATH variable
buf[0] = 0;
path = getenv("PATH") ;
if (path)
{
for (i=0;path[i];)
{
for (j=i;(path[j]) and (path[j]!=':');j++);
lg = j - i;
strncpy(buf,path + i,(size_t)(lg));
if (lg == 0)
{
buf[lg++] = '.';
}
buf[lg++] = '/';
strcpy(buf + lg, gnuplotBin_.c_str());
if (access(buf, X_OK) == 0)
{
// found it!
break ;
}
buf[0] = 0;
i = j;
if (path[i] == ':') i++ ;
}
}
// if the buffer is still empty, the command was not found
if (buf[0] != 0)
{
lg = (int)(strlen(buf) - 1);
while (buf[lg]!='/')
{
buf[lg] = 0;
lg--;
}
buf[lg] = 0;
gnuplotPath_ = buf;
return gnuplotPath_;
}
// try in CWD, /usr/bin & /usr/local/bin
SEARCH_DIR(".");
SEARCH_DIR("/usr/bin");
SEARCH_DIR("/usr/local/bin");
// if this code is reached nothing was found
LATAN_ERROR(System, "cannot find gnuplot");
return "";
}
// plot parsing and output /////////////////////////////////////////////////////
void Plot::display(void)
{
int pid = fork();
if (pid == 0)
{
FILE *gnuplotPipe;
string command;
ostringstream scriptBuf;
getProgramPath();
command = gnuplotPath_ + "/" + gnuplotBin_ + " 2>/dev/null";
gnuplotPipe = popen(command.c_str(), "w");
if (!gnuplotPipe)
{
LATAN_ERROR(System, "error starting gnuplot (command was '"
+ command + "')");
}
commandBuffer_.str("");
commandBuffer_ << *this << endl;
commandBuffer_ << "pause mouse close" << endl;
fprintf(gnuplotPipe, "%s", commandBuffer_.str().c_str());
if (pclose(gnuplotPipe) == -1)
{
LATAN_ERROR(System, "problem closing communication to gnuplot");
}
exit(EXIT_SUCCESS);
}
else if (pid == -1)
{
perror("fork error");
LATAN_ERROR(System, "problem forking to the process handling gnuplot");
}
}
void Plot::save(string dirName, bool savePdf)
{
vector<string> commandBack;
string path, terminalBack, outputBack, gpCommand, scriptName;
mode_t mode755;
ofstream script;
mode755 = S_IRWXU|S_IRGRP|S_IXGRP|S_IROTH|S_IXOTH;
// generate directory
if (mkdir(dirName))
{
LATAN_ERROR(Io, "impossible to create directory '" + dirName + "'");
}
// backup I/O parameters
terminalBack = options_.terminal;
outputBack = options_.output;
commandBack = plotCommand_;
// save PDF
if (savePdf)
{
options_.terminal = "pdf";
options_.output = dirName + "/plot.pdf";
display();
options_.terminal = terminalBack;
options_.output = outputBack;
}
// save script and datafiles
for (unsigned int i = 0; i < tmpFileName_.size(); ++i)
{
ofstream dataFile;
ifstream tmpFile;
string dataFileName = "points_" + strFrom(i) + ".dat";
dataFile.open(dirName + "/" + dataFileName);
tmpFile.open(tmpFileName_[i]);
dataFile << tmpFile.rdbuf();
dataFile.close();
tmpFile.close();
for (string &command: plotCommand_)
{
auto pos = command.find(tmpFileName_[i]);
while (pos != string::npos)
{
command.replace(pos, tmpFileName_[i].size(), dataFileName);
pos = command.find(tmpFileName_[i], pos + 1);
}
}
}
scriptName = dirName + "/source.plt";
script.open(scriptName);
getProgramPath();
gpCommand = gnuplotPath_ + "/" + gnuplotBin_ + " " + gnuplotArgs_;
script << "#!/usr/bin/env " << gpCommand << "\n" << endl;
script << "# script generated by " << Env::fullName << "\n" << endl;
script << *this;
script.close();
if (chmod(scriptName.c_str(), mode755))
{
LATAN_ERROR(Io, "impossible to set file '" + scriptName +
"' in mode 755");
}
plotCommand_ = commandBack;
}
ostream & Latan::operator<<(ostream &out, const Plot &plot)
{
std::string begin, end;
int x = static_cast<int>(Axis::x), y = static_cast<int>(Axis::y);
if (!plot.options_.terminal.empty())
{
out << "set term " << plot.options_.terminal << endl;
}
if (!plot.options_.output.empty())
{
out << "set output '" << plot.options_.output << "'" << endl;
}
if (!plot.options_.caption.empty())
{
out << "set title '" << plot.options_.caption << "'" << endl;
}
if (plot.options_.scaleMode[x] & Plot::Scale::manual)
{
out << "xMin = " << plot.options_.scale[x].min << endl;
out << "xMax = " << plot.options_.scale[x].max << endl;
}
if (plot.options_.scaleMode[y] & Plot::Scale::manual)
{
out << "yMin = " << plot.options_.scale[y].min << endl;
out << "yMax = " << plot.options_.scale[y].max << endl;
}
out << "unset xrange" << endl;
if (plot.options_.scaleMode[x] & Plot::Scale::manual)
{
out << "set xrange [xMin:xMax]" << endl;
}
else
{
out << "set xrange [:]" << endl;
}
out << "unset yrange" << endl;
if (plot.options_.scaleMode[y] & Plot::Scale::manual)
{
out << "set yrange [yMin:yMax]" << endl;
}
else
{
out << "set yrange [:]" << endl;
}
out << "unset log" << endl;
if (plot.options_.scaleMode[x] & Plot::Scale::log)
{
out << "set log x " << plot.options_.logScaleBasis[x] << endl;;
}
if (plot.options_.scaleMode[y] & Plot::Scale::log)
{
out << "set log y " << plot.options_.logScaleBasis[y] << endl;
}
if (!plot.options_.label[x].empty())
{
out << "set xlabel '" << plot.options_.label[x] << "'" << endl;
}
if (!plot.options_.label[y].empty())
{
out << "set ylabel '" << plot.options_.label[y] << "'" << endl;
}
for (unsigned int i = 0; i < plot.options_.palette.size(); ++i)
{
out << "set linetype " << i + 1 << " lc "
<< plot.options_.palette[i] << endl;
}
for (unsigned int i = 0; i < plot.headCommand_.size(); ++i)
{
out << plot.headCommand_[i] << endl;
}
for (unsigned int i = 0; i < plot.plotCommand_.size(); ++i)
{
begin = (i == 0) ? "plot " : " ";
end = (i == plot.plotCommand_.size() - 1) ? "" : ",\\";
out << begin << plot.plotCommand_[i] << end << endl;
}
return out;
}

View File

@ -0,0 +1,438 @@
/*
* Plot.hpp, part of LatAnalyze 3
*
* Copyright (C) 2013 - 2020 Antonin Portelli
*
* LatAnalyze 3 is free software: you can redistribute it and/or modify
* it under the terms of the GNU General Public License as published by
* the Free Software Foundation, either version 3 of the License, or
* (at your option) any later version.
*
* LatAnalyze 3 is distributed in the hope that it will be useful,
* but WITHOUT ANY WARRANTY; without even the implied warranty of
* MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
* GNU General Public License for more details.
*
* You should have received a copy of the GNU General Public License
* along with LatAnalyze 3. If not, see <http://www.gnu.org/licenses/>.
*/
#ifndef Latan_Plot_hpp_
#define Latan_Plot_hpp_
#include <LatAnalyze/Global.hpp>
#include <LatAnalyze/Functional/Function.hpp>
#include <LatAnalyze/Core/Mat.hpp>
#include <LatAnalyze/Statistics/MatSample.hpp>
#include <LatAnalyze/Statistics/Histogram.hpp>
#include <LatAnalyze/Statistics/XYStatData.hpp>
// gnuplot default parameters
#ifndef GNUPLOT_BIN
#define GNUPLOT_BIN "gnuplot"
#endif
#ifndef GNUPLOT_ARGS
#define GNUPLOT_ARGS "-p"
#endif
BEGIN_LATAN_NAMESPACE
/******************************************************************************
* Plot objects *
******************************************************************************/
class PlotObject
{
public:
// destructor
virtual ~PlotObject(void) = default;
// access
std::string popTmpFile(void);
const std::string & getCommand(void) const;
const std::string & getHeadCommand(void) const;
// test
bool gotTmpFile(void) const;
protected:
// access
void pushTmpFile(const std::string &fileName);
void setCommand(const std::string &command);
void setHeadCommand(const std::string &command);
// dump a matrix to a temporary file
std::string dumpToTmpFile(const DMat &m);
private:
// plot command
std::string command_;
std::string headCommand_;
// stack of created temporary files
std::stack<std::string> tmpFileName_;
};
class PlotCommand: public PlotObject
{
public:
// constructor
explicit PlotCommand(const std::string &command);
// destructor
virtual ~PlotCommand(void) = default;
};
class PlotHeadCommand: public PlotObject
{
public:
// constructor
explicit PlotHeadCommand(const std::string &command);
// destructor
virtual ~PlotHeadCommand(void) = default;
};
class PlotData: public PlotObject
{
public:
// constructor
PlotData(const DMatSample &x, const DMatSample &y, const bool abs = false);
PlotData(const DVec &x, const DMatSample &y, const bool abs = false);
PlotData(const DMatSample &x, const DVec &y, const bool abs = false);
PlotData(const XYStatData &data, const Index i = 0, const Index j = 0,
const bool abs = false);
// destructor
virtual ~PlotData(void) = default;
};
class PlotPoint: public PlotObject
{
public:
// constructor
PlotPoint(const double x, const double y);
PlotPoint(const DSample &x, const double y);
PlotPoint(const double x, const DSample &y);
PlotPoint(const DSample &x, const DSample &y);
// destructor
virtual ~PlotPoint(void) = default;
};
class PlotHLine: public PlotObject
{
public:
// constructor
PlotHLine(const double y);
// destructor
virtual ~PlotHLine(void) = default;
};
class PlotLine: public PlotObject
{
public:
// constructor
PlotLine(const DVec &x, const DVec &y);
// destructor
virtual ~PlotLine(void) = default;
};
class PlotPoints: public PlotObject
{
public:
// constructor
PlotPoints(const DVec &x, const DVec &y);
// destructor
virtual ~PlotPoints(void) = default;
};
class PlotBand: public PlotObject
{
public:
// constructor
PlotBand(const double xMin, const double xMax, const double yMin,
const double yMax, const double opacity = 0.15);
// destructor
virtual ~PlotBand(void) = default;
};
class PlotFunction: public PlotObject
{
public:
// constructor
PlotFunction(const DoubleFunction &function, const double xMin,
const double xMax, const unsigned int nPoint = 1000,
const bool abs = false);
// destructor
virtual ~PlotFunction(void) = default;
};
class PlotPredBand: public PlotObject
{
public:
// constructor
PlotPredBand(const DVec &x, const DVec &y, const DVec &yerr,
const double opacity = 0.15);
PlotPredBand(const DoubleFunctionSample &function, const double xMin,
const double xMax, const unsigned int nPoint = 1000,
const double opacity = 0.15);
// destructor
virtual ~PlotPredBand(void) = default;
private:
void makePredBand(const DMat &low, const DMat &high, const double opacity);
};
class PlotHistogram: public PlotObject
{
public:
// constructor
PlotHistogram(const Histogram &h);
// destructor
virtual ~PlotHistogram(void) = default;
};
class PlotImpulses: public PlotObject
{
public:
// constructor
PlotImpulses(const DVec &x, const DVec &y);
// destructor
virtual ~PlotImpulses(void) = default;
};
class PlotMatrixNoRange: public PlotObject
{
public:
// constructor
PlotMatrixNoRange(const DMat &m);
// destructor
virtual ~PlotMatrixNoRange(void) = default;
};
#define PlotMatrix(m)\
PlotRange(Axis::x, -.5, (m).cols() - .5) <<\
PlotRange(Axis::y, (m).rows() - .5, -.5) <<\
PlotMatrixNoRange(m)
#define PlotCorrMatrix(m)\
PlotHeadCommand("set cbrange [-1:1]") <<\
PlotHeadCommand("set palette defined (0 'blue', 1 'white', 2 'red')") <<\
PlotMatrix(m)
/******************************************************************************
* Plot modifiers *
******************************************************************************/
enum class Axis {x = 0, y = 1};
struct Range
{
double min, max;
};
struct PlotOptions
{
std::string terminal;
std::string output;
std::string caption;
std::string title;
unsigned int scaleMode[2];
double logScaleBasis[2];
Range scale[2];
std::string label[2];
std::string lineColor;
int lineWidth;
std::string dashType;
std::vector<std::string> palette;
};
class PlotModifier
{
public:
// destructor
virtual ~PlotModifier(void) = default;
// modifier
virtual void operator()(PlotOptions &option) const = 0;
};
class Caption: public PlotModifier
{
public:
// constructor
explicit Caption(const std::string &title);
// destructor
virtual ~Caption(void) = default;
// modifier
virtual void operator()(PlotOptions &option) const;
private:
const std::string caption_;
};
class Label: public PlotModifier
{
public:
// constructor
explicit Label(const std::string &label, const Axis axis);
// destructor
virtual ~Label(void) = default;
// modifier
virtual void operator()(PlotOptions &option) const;
private:
const std::string label_;
const Axis axis_;
};
class Color: public PlotModifier
{
public:
// constructor
explicit Color(const std::string &color);
// destructor
virtual ~Color(void) = default;
// modifier
virtual void operator()(PlotOptions &option) const;
private:
const std::string color_;
};
class LineWidth: public PlotModifier
{
public:
// constructor
explicit LineWidth(const unsigned int width);
// destructor
virtual ~LineWidth(void) = default;
// modifier
virtual void operator()(PlotOptions &option) const;
private:
const unsigned width_;
};
class Dash: public PlotModifier
{
public:
// constructor
explicit Dash(const std::string &dash);
// destructor
virtual ~Dash(void) = default;
// modifier
virtual void operator()(PlotOptions &option) const;
private:
const std::string dash_;
};
class LogScale: public PlotModifier
{
public:
// constructor
explicit LogScale(const Axis axis, const double basis = 10);
// destructor
virtual ~LogScale(void) = default;
// modifier
virtual void operator()(PlotOptions &option) const;
private:
const Axis axis_;
const double basis_;
};
class PlotRange: public PlotModifier
{
public:
// constructors
PlotRange(const Axis axis);
PlotRange(const Axis axis, const double min, const double max);
// destructor
virtual ~PlotRange(void) = default;
// modifier
virtual void operator()(PlotOptions &option) const;
private:
const Axis axis_;
const bool reset_;
const double min_, max_;
};
class Terminal: public PlotModifier
{
public:
// constructor
Terminal(const std::string &terminal, const std::string &options = "");
// destructor
virtual ~Terminal(void) = default;
// modifier
virtual void operator()(PlotOptions &option) const;
private:
const std::string terminalCmd_;
};
class Title: public PlotModifier
{
public:
// constructor
explicit Title(const std::string &title);
// destructor
virtual ~Title(void) = default;
// modifier
virtual void operator()(PlotOptions &option) const;
private:
const std::string title_;
};
class Palette: public PlotModifier
{
public:
static const std::vector<std::string> category10;
public:
// constructor
explicit Palette(const std::vector<std::string> &palette);
// destructor
virtual ~Palette(void) = default;
// modifier
virtual void operator()(PlotOptions &option) const;
private:
const std::vector<std::string> palette_;
};
/******************************************************************************
* Plot class *
******************************************************************************/
class Plot
{
public:
class Scale
{
public:
enum
{
reset = 0,
manual = 1 << 0,
log = 1 << 1
};
};
public:
// constructor/destructor
Plot(void);
virtual ~Plot(void) = default;
// plot operations
Plot & operator<<(PlotObject &&command);
Plot & operator<<(PlotModifier &&modifier);
// plot parsing and output
void display(void);
void save(std::string dirName, bool savePdf = true);
friend std::ostream & operator<<(std::ostream &out, const Plot &plot);
// plot reset
void reset(void);
// find gnuplot
std::string getProgramPath(void);
private:
// default options
void initOptions(void);
private:
// gnuplot execution parameters
std::string gnuplotBin_ {GNUPLOT_BIN};
std::string gnuplotArgs_ {GNUPLOT_ARGS};
std::string gnuplotPath_ {""};
// string buffer for commands
std::ostringstream commandBuffer_;
// stack of created temporary files
std::vector<std::string> tmpFileName_;
// plot content
PlotOptions options_;
std::vector<std::string> headCommand_;
std::vector<std::string> plotCommand_;
};
std::ostream & operator<<(std::ostream &out, const Plot &plot);
END_LATAN_NAMESPACE
#endif // Latan_Plot_hpp_

View File

@ -0,0 +1,117 @@
/*
* ThreadPool.cpp, part of LatAnalyze 3
*
* Copyright (C) 2013 - 2021 Antonin Portelli
*
* LatAnalyze 3 is free software: you can redistribute it and/or modify
* it under the terms of the GNU General Public License as published by
* the Free Software Foundation, either version 3 of the License, or
* (at your option) any later version.
*
* LatAnalyze 3 is distributed in the hope that it will be useful,
* but WITHOUT ANY WARRANTY; without even the implied warranty of
* MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
* GNU General Public License for more details.
*
* You should have received a copy of the GNU General Public License
* along with LatAnalyze 3. If not, see <http://www.gnu.org/licenses/>.
*/
#include <LatAnalyze/Core/ThreadPool.hpp>
#include <LatAnalyze/includes.hpp>
using namespace std;
using namespace Latan;
/******************************************************************************
* ThreadPool implementation *
******************************************************************************/
// constructors ////////////////////////////////////////////////////////////////
ThreadPool::ThreadPool(void)
: ThreadPool(std::thread::hardware_concurrency())
{}
ThreadPool::ThreadPool(const unsigned int nThreads)
: nThreads_(nThreads)
{
for (unsigned int t = 0; t < nThreads_; ++t)
{
threads_.push_back(thread(&ThreadPool::workerLoop, this));
}
}
// destructor //////////////////////////////////////////////////////////////////
ThreadPool::~ThreadPool(void)
{
terminate();
}
// get the number of threads ///////////////////////////////////////////////////
unsigned int ThreadPool::getThreadNum(void) const
{
return nThreads_;
}
// get the pool mutex for synchronisation //////////////////////////////////////
std::mutex & ThreadPool::getMutex(void)
{
return mutex_;
}
// worker loop /////////////////////////////////////////////////////////////////
void ThreadPool::workerLoop(void)
{
while (true)
{
Job job;
{
unique_lock<mutex> lock(mutex_);
condition_.wait(lock, [this](){
return !queue_.empty() || terminatePool_;
});
if (terminatePool_ and queue_.empty())
{
return;
}
job = queue_.front();
queue_.pop();
}
job();
}
}
// add jobs ////////////////////////////////////////////////////////////////////
void ThreadPool::addJob(Job newJob)
{
{
unique_lock<mutex> lock(mutex_);
queue_.push(newJob);
}
condition_.notify_one();
}
// critical section ////////////////////////////////////////////////////////////
void ThreadPool::critical(Job fn)
{
unique_lock<mutex> lock(mutex_);
fn();
}
// wait for completion /////////////////////////////////////////////////////////
void ThreadPool::terminate(void)
{
{
unique_lock<mutex> lock(mutex_);
terminatePool_ = true;
}
condition_.notify_all();
for (auto &thread: threads_)
{
thread.join();
}
threads_.clear();
}

View File

@ -0,0 +1,56 @@
/*
* ThreadPool.hpp, part of LatAnalyze 3
*
* Copyright (C) 2013 - 2021 Antonin Portelli
*
* LatAnalyze 3 is free software: you can redistribute it and/or modify
* it under the terms of the GNU General Public License as published by
* the Free Software Foundation, either version 3 of the License, or
* (at your option) any later version.
*
* LatAnalyze 3 is distributed in the hope that it will be useful,
* but WITHOUT ANY WARRANTY; without even the implied warranty of
* MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
* GNU General Public License for more details.
*
* You should have received a copy of the GNU General Public License
* along with LatAnalyze 3. If not, see <http://www.gnu.org/licenses/>.
*/
#ifndef Latan_ThreadPool_hpp_
#define Latan_ThreadPool_hpp_
#include <LatAnalyze/Global.hpp>
class ThreadPool
{
public:
typedef std::function<void(void)> Job;
public:
// constructors/destructor
ThreadPool(void);
ThreadPool(const unsigned int nThreads);
virtual ~ThreadPool(void);
// get the number of threads
unsigned int getThreadNum(void) const;
// get the pool mutex for synchronisation
std::mutex & getMutex(void);
// add jobs
void addJob(Job newJob);
// critical section
void critical(Job fn);
// wait for completion and terminate
void terminate(void);
private:
// worker loop
void workerLoop(void);
private:
unsigned int nThreads_;
std::condition_variable condition_;
std::vector<std::thread> threads_;
bool terminatePool_{false};
std::queue<Job> queue_;
std::mutex mutex_;
};
#endif

View File

@ -0,0 +1,137 @@
/*
* Utilities.cpp, part of LatAnalyze
*
* Copyright (C) 2013 - 2020 Antonin Portelli
*
* LatAnalyze is free software: you can redistribute it and/or modify
* it under the terms of the GNU General Public License as published by
* the Free Software Foundation, either version 3 of the License, or
* (at your option) any later version.
*
* LatAnalyze is distributed in the hope that it will be useful,
* but WITHOUT ANY WARRANTY; without even the implied warranty of
* MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
* GNU General Public License for more details.
*
* You should have received a copy of the GNU General Public License
* along with LatAnalyze. If not, see <http://www.gnu.org/licenses/>.
*/
#include <LatAnalyze/Core/Utilities.hpp>
#include <LatAnalyze/includes.hpp>
using namespace std;
using namespace Latan;
void Latan::testFunction(void)
{}
ostream & Latan::operator<<(ostream &out, const ProgressBar &&bar)
{
const Index nTick = bar.nCol_*bar.current_/bar.total_;
out << "[";
for (Index i = 0; i < nTick; ++i)
{
out << "=";
}
for (Index i = nTick; i < bar.nCol_; ++i)
{
out << " ";
}
out << "] " << bar.current_ << "/" << bar.total_;
out.flush();
return out;
}
int Latan::mkdir(const std::string dirName)
{
if (access(dirName.c_str(), R_OK|W_OK|X_OK))
{
mode_t mode755;
char tmp[MAX_PATH_LENGTH];
char *p = NULL;
size_t len;
mode755 = S_IRWXU|S_IRGRP|S_IXGRP|S_IROTH|S_IXOTH;
snprintf(tmp, sizeof(tmp), "%s", dirName.c_str());
len = strlen(tmp);
if(tmp[len - 1] == '/')
{
tmp[len - 1] = 0;
}
for(p = tmp + 1; *p; p++)
{
if(*p == '/')
{
*p = 0;
::mkdir(tmp, mode755);
*p = '/';
}
}
return ::mkdir(tmp, mode755);
}
else
{
return 0;
}
}
string Latan::basename(const string &s)
{
constexpr char sep = '/';
size_t i = s.rfind(sep, s.length());
if (i != string::npos)
{
return s.substr(i+1, s.length() - i);
}
else
{
return s;
}
}
std::string Latan::dirname(const std::string &s)
{
constexpr char sep = '/';
size_t i = s.rfind(sep, s.length());
if (i != std::string::npos)
{
return s.substr(0, i);
}
else
{
return "";
}
}
VarName::VarName(const string defName)
: defName_(defName)
{}
string VarName::getName(const Index i) const
{
if (hasName(i))
{
return name_.at(i);
}
else
{
return defName_ + "_" + strFrom(i);
}
}
void VarName::setName(const Index i, const string name)
{
name_[i] = name;
}
bool VarName::hasName(const Index i) const
{
return (name_.find(i) != name_.end());
}

View File

@ -0,0 +1,250 @@
/*
* Utilities.hpp, part of LatAnalyze
*
* Copyright (C) 2013 - 2020 Antonin Portelli
*
* LatAnalyze is free software: you can redistribute it and/or modify
* it under the terms of the GNU General Public License as published by
* the Free Software Foundation, either version 3 of the License, or
* (at your option) any later version.
*
* LatAnalyze is distributed in the hope that it will be useful,
* but WITHOUT ANY WARRANTY; without even the implied warranty of
* MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
* GNU General Public License for more details.
*
* You should have received a copy of the GNU General Public License
* along with LatAnalyze. If not, see <http://www.gnu.org/licenses/>.
*/
#ifndef LatAnalyze_Utilities_hpp_
#define LatAnalyze_Utilities_hpp_
#ifndef LATAN_GLOBAL_HPP_
#include <LatAnalyze/Global.hpp>
#endif
BEGIN_LATAN_NAMESPACE
// Random seed type ////////////////////////////////////////////////////////////
typedef std::random_device::result_type SeedType;
// Type utilities //////////////////////////////////////////////////////////////
// pointer type test
template <typename Derived, typename Base>
inline bool isDerivedFrom(const Base *pt)
{
return (dynamic_cast<const Derived *>(pt) != nullptr);
}
// static logical or
template <bool... b>
struct static_or;
template <bool... tail>
struct static_or<true, tail...> : static_or<tail...> {};
template <bool... tail>
struct static_or<false, tail...> : std::false_type {};
template <>
struct static_or<> : std::true_type {};
// Environment /////////////////////////////////////////////////////////////////
void testFunction(void);
// String utilities ////////////////////////////////////////////////////////////
inline std::string extension(const std::string fileName)
{
return fileName.substr(fileName.find_last_of(".") + 1);
}
template <typename T>
inline T strTo(const std::string &str)
{
T buf;
std::istringstream stream(str);
stream >> buf;
return buf;
}
// optimized specializations
template <>
inline float strTo<float>(const std::string &str)
{
return strtof(str.c_str(), (char **)NULL);
}
template <>
inline double strTo<double>(const std::string &str)
{
return strtod(str.c_str(), (char **)NULL);
}
template <>
inline int strTo<int>(const std::string &str)
{
return (int)(strtol(str.c_str(), (char **)NULL, 10));
}
template <>
inline long strTo<long>(const std::string &str)
{
return strtol(str.c_str(), (char **)NULL, 10);
}
template <>
inline std::string strTo<std::string>(const std::string &str)
{
return str;
}
template <typename T>
inline std::string strFrom(const T x)
{
std::ostringstream stream;
stream << x;
return stream.str();
}
// specialization for vectors
template<>
inline DVec strTo<DVec>(const std::string &str)
{
DVec res;
std::vector<double> vbuf;
double buf;
std::istringstream stream(str);
while (!stream.eof())
{
stream >> buf;
vbuf.push_back(buf);
}
res = Map<DVec>(vbuf.data(), static_cast<Index>(vbuf.size()));
return res;
}
template<>
inline IVec strTo<IVec>(const std::string &str)
{
IVec res;
std::vector<int> vbuf;
int buf;
std::istringstream stream(str);
while (!stream.eof())
{
stream >> buf;
vbuf.push_back(buf);
}
res = Map<IVec>(vbuf.data(), static_cast<Index>(vbuf.size()));
return res;
}
template<>
inline UVec strTo<UVec>(const std::string &str)
{
UVec res;
std::vector<unsigned int> vbuf;
unsigned int buf;
std::istringstream stream(str);
while (!stream.eof())
{
stream >> buf;
vbuf.push_back(buf);
}
res = Map<UVec>(vbuf.data(), static_cast<Index>(vbuf.size()));
return res;
}
template <typename T>
void tokenReplace(std::string &str, const std::string token,
const T &x, const std::string mark = "@")
{
std::string fullToken = mark + token + mark;
auto pos = str.find(fullToken);
if (pos != std::string::npos)
{
str.replace(pos, fullToken.size(), strFrom(x));
}
}
// Manifest file reader ////////////////////////////////////////////////////////
inline std::vector<std::string> readManifest(const std::string manFileName)
{
std::vector<std::string> list;
std::ifstream manFile;
char buf[MAX_PATH_LENGTH];
manFile.open(manFileName);
while (!manFile.eof())
{
manFile.getline(buf, MAX_PATH_LENGTH);
if (!std::string(buf).empty())
{
list.push_back(buf);
}
}
manFile.close();
return list;
}
// Recursive directory creation ////////////////////////////////////////////////
int mkdir(const std::string dirName);
// C++ version of basename/dirname /////////////////////////////////////////////
std::string basename(const std::string& s);
std::string dirname(const std::string& s);
// Progress bar class //////////////////////////////////////////////////////////
class ProgressBar
{
public:
// constructor
template <typename A, typename B>
ProgressBar(const A current, const B total, const Index nCol = 60);
// IO
friend std::ostream & operator<<(std::ostream &out,
const ProgressBar &&bar);
private:
Index current_, total_, nCol_;
};
template <typename A, typename B>
ProgressBar::ProgressBar(const A current, const B total, const Index nCol)
: current_(static_cast<Index>(current))
, total_(static_cast<Index>(total))
, nCol_(nCol)
{}
std::ostream & operator<<(std::ostream &out, const ProgressBar &&bar);
// named variable interface ////////////////////////////////////////////////////
// FIXME: check redundant names and variable number limit
class VarName
{
public:
// constructor
VarName(const std::string defName);
// destructor
virtual ~VarName(void) = default;
// access
std::string getName(const Index i) const;
void setName(const Index i, const std::string name);
// test
bool hasName(const Index i) const;
private:
std::string defName_;
std::unordered_map<Index, std::string> name_;
};
END_LATAN_NAMESPACE
#endif // LatAnalyze_Utilities_hpp_

View File

@ -0,0 +1,58 @@
/*
* stdincludes.hpp, part of LatAnalyze 3
*
* Copyright (C) 2013 - 2020 Antonin Portelli
*
* LatAnalyze 3 is free software: you can redistribute it and/or modify
* it under the terms of the GNU General Public License as published by
* the Free Software Foundation, either version 3 of the License, or
* (at your option) any later version.
*
* LatAnalyze 3 is distributed in the hope that it will be useful,
* but WITHOUT ANY WARRANTY; without even the implied warranty of
* MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
* GNU General Public License for more details.
*
* You should have received a copy of the GNU General Public License
* along with LatAnalyze 3. If not, see <http://www.gnu.org/licenses/>.
*/
#ifndef Latan_stdincludes_hpp_
#define Latan_stdincludes_hpp_
#include <algorithm>
#include <array>
#include <chrono>
#include <complex>
#include <condition_variable>
#include <fstream>
#include <functional>
#include <iostream>
#include <iomanip>
#include <iterator>
#include <limits>
#include <list>
#include <map>
#include <memory>
#include <queue>
#include <random>
#include <regex>
#include <set>
#include <stack>
#include <string>
#include <sstream>
#include <thread>
#include <type_traits>
#include <unordered_map>
#include <utility>
#include <vector>
#include <cfloat>
#include <climits>
#include <cmath>
#include <cstdio>
#include <cstdlib>
#include <cstring>
#include <sys/stat.h>
#include <unistd.h>
#endif // Latan_stdincludes_hpp_

View File

@ -0,0 +1,131 @@
/*
* CompiledFunction.cpp, part of LatAnalyze 3
*
* Copyright (C) 2013 - 2020 Antonin Portelli
*
* LatAnalyze 3 is free software: you can redistribute it and/or modify
* it under the terms of the GNU General Public License as published by
* the Free Software Foundation, either version 3 of the License, or
* (at your option) any later version.
*
* LatAnalyze 3 is distributed in the hope that it will be useful,
* but WITHOUT ANY WARRANTY; without even the implied warranty of
* MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
* GNU General Public License for more details.
*
* You should have received a copy of the GNU General Public License
* along with LatAnalyze 3. If not, see <http://www.gnu.org/licenses/>.
*/
#include <LatAnalyze/Functional/CompiledFunction.hpp>
#include <LatAnalyze/Core/Math.hpp>
#include <LatAnalyze/includes.hpp>
using namespace std;
using namespace Latan;
/******************************************************************************
* Compiled double function implementation *
******************************************************************************/
// constructors ////////////////////////////////////////////////////////////////
CompiledDoubleFunction::CompiledDoubleFunction(const Index nArg)
: nArg_(nArg)
{}
CompiledDoubleFunction::CompiledDoubleFunction(const string &code,
const Index nArg)
: CompiledDoubleFunction(nArg)
{
setCode(code);
}
// access //////////////////////////////////////////////////////////////////////
string CompiledDoubleFunction::getCode(void)
{
return code_;
}
void CompiledDoubleFunction::setCode(const string &code)
{
code_ = code;
interpreter_.reset(new MathInterpreter(code));
context_.reset(new RunContext);
varAddress_.reset(new std::vector<unsigned int>);
isCompiled_.reset(new bool(false));
}
// compile /////////////////////////////////////////////////////////////////////
void CompiledDoubleFunction::compile(void) const
{
if (!*isCompiled_)
{
varAddress_->clear();
for (Index i = 0; i < nArg_; ++i)
{
varAddress_->push_back(context_->addVariable("x_" + strFrom(i)));
}
interpreter_->compile(*(context_));
*isCompiled_ = true;
}
}
// function call ///////////////////////////////////////////////////////////////
double CompiledDoubleFunction::operator()(const double *arg) const
{
double result;
compile();
for (unsigned int i = 0; i < nArg_; ++i)
{
context_->setVariable((*varAddress_)[i], arg[i]);
}
(*interpreter_)(*context_);
if (!context_->stack().empty())
{
result = context_->stack().top();
context_->stack().pop();
}
else
{
result = 0.0;
LATAN_ERROR(Program, "program execution resulted in an empty stack");
}
return result;
}
// IO //////////////////////////////////////////////////////////////////////////
ostream & Latan::operator<<(ostream &out, CompiledDoubleFunction &f)
{
f.compile();
out << *(f.interpreter_);
return out;
}
// DoubleFunction factory //////////////////////////////////////////////////////
DoubleFunction CompiledDoubleFunction::makeFunction(const bool makeHardCopy)
const
{
DoubleFunction res;
if (makeHardCopy)
{
CompiledDoubleFunction copy(*this);
res.setFunction([copy](const double *p){return copy(p);}, nArg_);
}
else
{
res.setFunction([this](const double *p){return (*this)(p);}, nArg_);
}
return res;
}
DoubleFunction Latan::compile(const string code, const Index nArg)
{
CompiledDoubleFunction compiledFunc(code, nArg);
return compiledFunc.makeFunction();
}

View File

@ -0,0 +1,70 @@
/*
* CompiledFunction.hpp, part of LatAnalyze 3
*
* Copyright (C) 2013 - 2020 Antonin Portelli
*
* LatAnalyze 3 is free software: you can redistribute it and/or modify
* it under the terms of the GNU General Public License as published by
* the Free Software Foundation, either version 3 of the License, or
* (at your option) any later version.
*
* LatAnalyze 3 is distributed in the hope that it will be useful,
* but WITHOUT ANY WARRANTY; without even the implied warranty of
* MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
* GNU General Public License for more details.
*
* You should have received a copy of the GNU General Public License
* along with LatAnalyze 3. If not, see <http://www.gnu.org/licenses/>.
*/
#ifndef Latan_CompiledFunction_hpp_
#define Latan_CompiledFunction_hpp_
#include <LatAnalyze/Global.hpp>
#include <LatAnalyze/Functional/Function.hpp>
#include <LatAnalyze/Core/MathInterpreter.hpp>
BEGIN_LATAN_NAMESPACE
/******************************************************************************
* compiled double function class *
******************************************************************************/
class CompiledDoubleFunction: public DoubleFunctionFactory
{
public:
// constructors
explicit CompiledDoubleFunction(const Index nArg);
CompiledDoubleFunction(const std::string &code, const Index nArg);
// destructor
virtual ~CompiledDoubleFunction(void) = default;
// access
std::string getCode(void);
void setCode(const std::string &code);
// function call
double operator()(const double *arg) const;
// IO
friend std::ostream & operator<<(std::ostream &out,
CompiledDoubleFunction &f);
// factory
virtual DoubleFunction makeFunction(const bool makeHardCopy = true) const;
private:
// compile
void compile(void) const;
private:
Index nArg_;
std::string code_;
std::shared_ptr<MathInterpreter> interpreter_;
std::shared_ptr<RunContext> context_;
std::shared_ptr<std::vector<unsigned int>> varAddress_;
std::shared_ptr<bool> isCompiled_;
};
std::ostream & operator<<(std::ostream &out, CompiledDoubleFunction &f);
// DoubleFunction factory
DoubleFunction compile(const std::string code, const Index nArg);
END_LATAN_NAMESPACE
#endif // Latan_CompiledFunction_hpp_

View File

@ -0,0 +1,145 @@
/*
* CompiledModel.cpp, part of LatAnalyze 3
*
* Copyright (C) 2013 - 2020 Antonin Portelli
*
* LatAnalyze 3 is free software: you can redistribute it and/or modify
* it under the terms of the GNU General Public License as published by
* the Free Software Foundation, either version 3 of the License, or
* (at your option) any later version.
*
* LatAnalyze 3 is distributed in the hope that it will be useful,
* but WITHOUT ANY WARRANTY; without even the implied warranty of
* MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
* GNU General Public License for more details.
*
* You should have received a copy of the GNU General Public License
* along with LatAnalyze 3. If not, see <http://www.gnu.org/licenses/>.
*/
#include <LatAnalyze/Functional/CompiledModel.hpp>
#include <LatAnalyze/Core/Math.hpp>
#include <LatAnalyze/includes.hpp>
using namespace std;
using namespace Latan;
/******************************************************************************
* CompiledDoubleModel implementation *
******************************************************************************/
// constructor /////////////////////////////////////////////////////////////////
CompiledDoubleModel::CompiledDoubleModel(const Index nArg, const Index nPar)
: nArg_(nArg)
, nPar_(nPar)
{}
CompiledDoubleModel::CompiledDoubleModel(const string &code, const Index nArg,
const Index nPar)
: CompiledDoubleModel(nArg, nPar)
{
setCode(code);
}
// access //////////////////////////////////////////////////////////////////////
string CompiledDoubleModel::getCode(void)
{
return code_;
}
void CompiledDoubleModel::setCode(const std::string &code)
{
code_ = code;
interpreter_.reset(new MathInterpreter(code_));
context_.reset(new RunContext);
varAddress_.reset(new std::vector<unsigned int>);
parAddress_.reset(new std::vector<unsigned int>);
isCompiled_.reset(new bool(false));
}
// compile /////////////////////////////////////////////////////////////////////
void CompiledDoubleModel::compile(void) const
{
if (!*isCompiled_)
{
varAddress_->clear();
parAddress_->clear();
for (Index i = 0; i < nArg_; ++i)
{
varAddress_->push_back(context_->addVariable("x_" + strFrom(i)));
}
for (Index j = 0; j < nPar_; ++j)
{
parAddress_->push_back(context_->addVariable("p_" + strFrom(j)));
}
interpreter_->compile(*(context_));
*isCompiled_ = true;
}
}
// function call ///////////////////////////////////////////////////////////////
double CompiledDoubleModel::operator()(const double *arg,
const double *par) const
{
double result;
compile();
for (unsigned int i = 0; i < nArg_; ++i)
{
context_->setVariable((*varAddress_)[i], arg[i]);
}
for (unsigned int j = 0; j < nPar_; ++j)
{
context_->setVariable((*parAddress_)[j], par[j]);
}
(*interpreter_)(*context_);
if (!context_->stack().empty())
{
result = context_->stack().top();
context_->stack().pop();
}
else
{
result = 0.0;
LATAN_ERROR(Program, "program execution resulted in an empty stack");
}
return result;
}
// IO //////////////////////////////////////////////////////////////////////////
ostream & Latan::operator<<(std::ostream &out, CompiledDoubleModel &m)
{
m.compile();
out << *(m.interpreter_);
return out;
}
// DoubleModel factory /////////////////////////////////////////////////////////
DoubleModel CompiledDoubleModel::makeModel(const bool makeHardCopy) const
{
DoubleModel res;
if (makeHardCopy)
{
CompiledDoubleModel copy(*this);
res.setFunction([copy](const double *x, const double *p)
{return copy(x, p);}, nArg_, nPar_);
}
else
{
res.setFunction([this](const double *x, const double *p)
{return (*this)(x, p);}, nArg_, nPar_);
}
return res;
}
DoubleModel Latan::compile(const std::string &code, const Index nArg,
const Index nPar)
{
CompiledDoubleModel compiledModel(code, nArg, nPar);
return compiledModel.makeModel();
}

View File

@ -0,0 +1,71 @@
/*
* CompiledModel.hpp, part of LatAnalyze 3
*
* Copyright (C) 2013 - 2020 Antonin Portelli
*
* LatAnalyze 3 is free software: you can redistribute it and/or modify
* it under the terms of the GNU General Public License as published by
* the Free Software Foundation, either version 3 of the License, or
* (at your option) any later version.
*
* LatAnalyze 3 is distributed in the hope that it will be useful,
* but WITHOUT ANY WARRANTY; without even the implied warranty of
* MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
* GNU General Public License for more details.
*
* You should have received a copy of the GNU General Public License
* along with LatAnalyze 3. If not, see <http://www.gnu.org/licenses/>.
*/
#ifndef Latan_CompiledModel_hpp_
#define Latan_CompiledModel_hpp_
#include <LatAnalyze/Global.hpp>
#include <LatAnalyze/Functional/Model.hpp>
#include <LatAnalyze/Core/MathInterpreter.hpp>
BEGIN_LATAN_NAMESPACE
/******************************************************************************
* compiled double model class *
******************************************************************************/
class CompiledDoubleModel: public DoubleModelFactory
{
public:
// constructor
CompiledDoubleModel(const Index nArg, const Index nPar);
CompiledDoubleModel(const std::string &code, const Index nArg,
const Index nPar);
// destructor
virtual ~CompiledDoubleModel(void) = default;
// access
std::string getCode(void);
void setCode(const std::string &code);
// function call
double operator()(const double *arg, const double *par) const;
// IO
friend std::ostream & operator<<(std::ostream &out,
CompiledDoubleModel &f);
// factory
DoubleModel makeModel(const bool makeHardCopy = true) const;
private:
// compile
void compile(void) const;
private:
Index nArg_, nPar_;
std::string code_;
std::shared_ptr<MathInterpreter> interpreter_;
std::shared_ptr<RunContext> context_;
std::shared_ptr<std::vector<unsigned int>> varAddress_, parAddress_;
std::shared_ptr<bool> isCompiled_;
};
std::ostream & operator<<(std::ostream &out, CompiledDoubleModel &f);
// DoubleModel factory
DoubleModel compile(const std::string &code, const Index nArg,
const Index nPar);
END_LATAN_NAMESPACE
#endif // Latan_CompiledModel_hpp_

View File

@ -0,0 +1,316 @@
/*
* Function.cpp, part of LatAnalyze 3
*
* Copyright (C) 2013 - 2020 Antonin Portelli
*
* LatAnalyze 3 is free software: you can redistribute it and/or modify
* it under the terms of the GNU General Public License as published by
* the Free Software Foundation, either version 3 of the License, or
* (at your option) any later version.
*
* LatAnalyze 3 is distributed in the hope that it will be useful,
* but WITHOUT ANY WARRANTY; without even the implied warranty of
* MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
* GNU General Public License for more details.
*
* You should have received a copy of the GNU General Public License
* along with LatAnalyze 3. If not, see <http://www.gnu.org/licenses/>.
*/
#include <LatAnalyze/Functional/Function.hpp>
#include <LatAnalyze/includes.hpp>
using namespace std;
using namespace Latan;
/******************************************************************************
* DoubleFunction implementation *
******************************************************************************/
// constructor /////////////////////////////////////////////////////////////////
DoubleFunction::DoubleFunction(const vecFunc &f, const Index nArg)
: buffer_(new DVec)
, varName_("x")
{
setFunction(f, nArg);
}
// access //////////////////////////////////////////////////////////////////////
Index DoubleFunction::getNArg(void) const
{
return buffer_->size();
}
void DoubleFunction::setFunction(const vecFunc &f, const Index nArg)
{
buffer_->resize(nArg);
f_ = f;
}
VarName & DoubleFunction::varName(void)
{
return varName_;
}
const VarName & DoubleFunction::varName(void) const
{
return varName_;
}
// error checking //////////////////////////////////////////////////////////////
void DoubleFunction::checkSize(const Index nPar) const
{
if (nPar != getNArg())
{
LATAN_ERROR(Size, "function argument vector has a wrong size (expected "
+ strFrom(getNArg()) + ", got " + strFrom(nPar)
+ ")");
}
}
// function call ///////////////////////////////////////////////////////////////
double DoubleFunction::operator()(const double *arg) const
{
return f_(arg);
}
double DoubleFunction::operator()(const DVec &arg) const
{
checkSize(arg.size());
return (*this)(arg.data());
}
double DoubleFunction::operator()(const std::vector<double> &arg) const
{
checkSize(static_cast<Index>(arg.size()));
return (*this)(arg.data());
}
double DoubleFunction::operator()(std::stack<double> &arg) const
{
for (Index i = 0; i < getNArg(); ++i)
{
if (arg.empty())
{
LATAN_ERROR(Size, "function argument stack is empty (expected "
+ strFrom(getNArg()) + "arguments, got " + strFrom(i)
+ ")");
}
(*buffer_)(getNArg() - i - 1) = arg.top();
arg.pop();
}
return (*this)(*buffer_);
}
double DoubleFunction::operator()(void) const
{
checkSize(0);
return (*this)(nullptr);
}
std::map<double, double> DoubleFunction::operator()(const std::map<double, double> &m) const
{
checkSize(1);
std::map<double, double> res;
for (auto &val: m)
{
res[val.first] = (*this)(val.second);
}
return res;
}
// bind ////////////////////////////////////////////////////////////////////////
DoubleFunction DoubleFunction::bind(const Index argIndex,
const double val) const
{
Index nArg = getNArg();
shared_ptr<DVec> buf(new DVec(nArg));
DoubleFunction copy(*this), bindFunc;
auto func = [copy, buf, argIndex, val](const double *arg)
{
FOR_VEC(*buf, i)
{
if (i < argIndex)
{
(*buf)(i) = arg[i];
}
else if (i == argIndex)
{
(*buf)(i) = val;
}
else
{
(*buf)(i) = arg[i - 1];
}
}
return copy(*buf);
};
bindFunc.setFunction(func, nArg - 1);
return bindFunc;
}
DoubleFunction DoubleFunction::bind(const Index argIndex,
const DVec &x) const
{
Index nArg = getNArg();
shared_ptr<DVec> buf(new DVec(nArg));
DoubleFunction copy(*this), bindFunc;
auto func = [copy, buf, argIndex, x](const double *arg)
{
*buf = x;
(*buf)(argIndex) = arg[0];
return copy(*buf);
};
bindFunc.setFunction(func, 1);
return bindFunc;
}
// sample //////////////////////////////////////////////////////////////////////
DVec DoubleFunction::sample(const DMat &x) const
{
if (x.cols() != getNArg())
{
LATAN_ERROR(Size, "sampling point matrix and number of arguments "
"mismatch (matrix has " + strFrom(x.cols())
+ ", number of arguments is " + strFrom(getNArg()) + ")");
}
DVec res(x.rows());
for (Index i = 0; i < res.size(); ++i)
{
res(i) = (*this)(x.row(i));
}
return res;
}
// arithmetic operators ////////////////////////////////////////////////////////
DoubleFunction DoubleFunction::operator-(void) const
{
DoubleFunction copy(*this), resFunc;
return DoubleFunction([copy](const double *arg){return -copy(arg);},
getNArg());
}
#define MAKE_SELF_FUNC_OP(op)\
DoubleFunction & DoubleFunction::operator op##=(const DoubleFunction &f)\
{\
DoubleFunction copy(*this);\
checkSize(f.getNArg());\
auto res = [f, copy](const double *arg){return copy(arg) op f(arg);};\
setFunction(res, getNArg());\
return *this;\
}\
DoubleFunction & DoubleFunction::operator op##=(const DoubleFunction and f)\
{\
*this op##= f;\
return *this;\
}
#define MAKE_SELF_SCALAR_OP(op)\
DoubleFunction & DoubleFunction::operator op##=(const double x)\
{\
DoubleFunction copy(*this);\
auto res = [x, copy](const double *arg){return copy(arg) op x;};\
setFunction(res, getNArg());\
return *this;\
}\
MAKE_SELF_FUNC_OP(+)
MAKE_SELF_FUNC_OP(-)
MAKE_SELF_FUNC_OP(*)
MAKE_SELF_FUNC_OP(/)
MAKE_SELF_SCALAR_OP(+)
MAKE_SELF_SCALAR_OP(-)
MAKE_SELF_SCALAR_OP(*)
MAKE_SELF_SCALAR_OP(/)
/******************************************************************************
* DoubleFunctionSample implementation *
******************************************************************************/
// constructors ////////////////////////////////////////////////////////////////
DoubleFunctionSample::DoubleFunctionSample(void)
: Sample<DoubleFunction>()
{}
DoubleFunctionSample::DoubleFunctionSample(const Index nSample)
: Sample<DoubleFunction>(nSample)
{}
// function call ///////////////////////////////////////////////////////////////
DSample DoubleFunctionSample::operator()(const DMatSample &arg) const
{
DSample result(size());
FOR_STAT_ARRAY((*this), s)
{
result[s] = (*this)[s](arg[s]);
}
return result;
}
DSample DoubleFunctionSample::operator()(const double *arg) const
{
DSample result(size());
FOR_STAT_ARRAY((*this), s)
{
result[s] = (*this)[s](arg);
}
return result;
}
DSample DoubleFunctionSample::operator()(const DVec &arg) const
{
return (*this)(arg.data());
}
DSample DoubleFunctionSample::operator()(const vector<double> &arg) const
{
return (*this)(arg.data());
}
// bind ////////////////////////////////////////////////////////////////////////
DoubleFunctionSample DoubleFunctionSample::bind(const Index argIndex,
const double val) const
{
DoubleFunctionSample bindFunc(size());
FOR_STAT_ARRAY(bindFunc, s)
{
bindFunc[s] = (*this)[s].bind(argIndex, val);
}
return bindFunc;
}
DoubleFunctionSample DoubleFunctionSample::bind(const Index argIndex,
const DVec &x) const
{
DoubleFunctionSample bindFunc(size());
FOR_STAT_ARRAY(bindFunc, s)
{
bindFunc[s] = (*this)[s].bind(argIndex, x);
}
return bindFunc;
}

View File

@ -0,0 +1,207 @@
/*
* Function.hpp, part of LatAnalyze 3
*
* Copyright (C) 2013 - 2020 Antonin Portelli
*
* LatAnalyze 3 is free software: you can redistribute it and/or modify
* it under the terms of the GNU General Public License as published by
* the Free Software Foundation, either version 3 of the License, or
* (at your option) any later version.
*
* LatAnalyze 3 is distributed in the hope that it will be useful,
* but WITHOUT ANY WARRANTY; without even the implied warranty of
* MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
* GNU General Public License for more details.
*
* You should have received a copy of the GNU General Public License
* along with LatAnalyze 3. If not, see <http://www.gnu.org/licenses/>.
*/
#ifndef Latan_Function_hpp_
#define Latan_Function_hpp_
#include <LatAnalyze/Global.hpp>
#include <LatAnalyze/Core/Mat.hpp>
#include <LatAnalyze/Statistics/MatSample.hpp>
BEGIN_LATAN_NAMESPACE
/******************************************************************************
* Double function class *
******************************************************************************/
class DoubleFunction
{
private:
// function type
typedef std::function<double(const double *)> vecFunc;
public:
// constructor
explicit DoubleFunction(const vecFunc &f = nullptr, const Index nArg = 0);
// destructor
virtual ~DoubleFunction(void) = default;
// access
virtual Index getNArg(void) const;
void setFunction(const vecFunc &f, const Index nArg);
VarName & varName(void);
const VarName & varName(void) const;
// function call
double operator()(const double *arg) const;
double operator()(const DVec &arg) const;
double operator()(const std::vector<double> &arg) const;
double operator()(std::stack<double> &arg) const;
double operator()(void) const;
template <typename... Ts>
double operator()(const double arg0, const Ts... args) const;
std::map<double, double> operator()(const std::map<double, double> &m) const;
// bind
DoubleFunction bind(const Index argIndex, const double val) const;
DoubleFunction bind(const Index argIndex, const DVec &x) const;
// sample
DVec sample(const DMat &x) const;
// arithmetic operators
DoubleFunction operator-(void) const;
DoubleFunction & operator+=(const DoubleFunction &f);
DoubleFunction & operator+=(const DoubleFunction &&f);
DoubleFunction & operator-=(const DoubleFunction &f);
DoubleFunction & operator-=(const DoubleFunction &&f);
DoubleFunction & operator*=(const DoubleFunction &f);
DoubleFunction & operator*=(const DoubleFunction &&f);
DoubleFunction & operator/=(const DoubleFunction &f);
DoubleFunction & operator/=(const DoubleFunction &&f);
DoubleFunction & operator+=(const double x);
DoubleFunction & operator-=(const double x);
DoubleFunction & operator*=(const double x);
DoubleFunction & operator/=(const double x);
private:
// error checking
void checkSize(const Index nPar) const;
private:
std::shared_ptr<DVec> buffer_{nullptr};
VarName varName_;
vecFunc f_;
};
/******************************************************************************
* DoubleFunction template implementation *
******************************************************************************/
template <typename... Ts>
double DoubleFunction::operator()(const double arg0, const Ts... args) const
{
static_assert(static_or<std::is_convertible<double, Ts>::value...>::value,
"DoubleFunction arguments are not compatible with double");
const double arg[] = {arg0, static_cast<double>(args)...};
checkSize(sizeof...(args) + 1);
return (*this)(arg);
}
/******************************************************************************
* DoubleFunction inline arithmetic operators *
******************************************************************************/
#define MAKE_INLINE_FUNC_OP(op)\
inline DoubleFunction operator op(DoubleFunction lhs,\
const DoubleFunction &rhs)\
{\
lhs op##= rhs;\
return lhs;\
}\
inline DoubleFunction operator op(DoubleFunction lhs,\
const DoubleFunction &&rhs)\
{\
return lhs op rhs;\
}
#define MAKE_INLINE_RSCALAR_OP(op)\
inline DoubleFunction operator op(DoubleFunction lhs, const double rhs)\
{\
lhs op##= rhs;\
return lhs;\
}\
#define MAKE_INLINE_LSCALAR_OP(op)\
inline DoubleFunction operator op(const double lhs, DoubleFunction rhs)\
{\
rhs op##= lhs;\
return rhs;\
}
MAKE_INLINE_FUNC_OP(+)
MAKE_INLINE_FUNC_OP(-)
MAKE_INLINE_FUNC_OP(*)
MAKE_INLINE_FUNC_OP(/)
MAKE_INLINE_RSCALAR_OP(+)
MAKE_INLINE_RSCALAR_OP(-)
MAKE_INLINE_RSCALAR_OP(*)
MAKE_INLINE_RSCALAR_OP(/)
MAKE_INLINE_LSCALAR_OP(+)
MAKE_INLINE_LSCALAR_OP(*)
// special case for scalar - function
inline DoubleFunction operator-(const double lhs, DoubleFunction rhs)
{
return (-rhs) + lhs;
}
// special case for scalar/function
inline DoubleFunction operator/(const double lhs, DoubleFunction rhs)
{
auto res = [lhs, rhs](const double *arg){return lhs/rhs(arg);};
rhs.setFunction(res, rhs.getNArg());
return rhs;
}
/******************************************************************************
* DoubleFunctionSample class *
******************************************************************************/
class DoubleFunctionSample: public Sample<DoubleFunction>
{
public:
// constructors
DoubleFunctionSample(void);
DoubleFunctionSample(const Index nSample);
EIGEN_EXPR_CTOR(DoubleFunctionSample, DoubleFunctionSample,
Sample<DoubleFunction>, ArrayExpr)
// destructor
virtual ~DoubleFunctionSample(void) = default;
// function call
DSample operator()(const DMatSample &arg) const;
DSample operator()(const double *arg) const;
DSample operator()(const DVec &arg) const;
DSample operator()(const std::vector<double> &arg) const;
template <typename... Ts>
DSample operator()(const double arg0, const Ts... args) const;
// bind
DoubleFunctionSample bind(const Index argIndex, const double val) const;
DoubleFunctionSample bind(const Index argIndex, const DVec &x) const ;
};
template <typename... Ts>
DSample DoubleFunctionSample::operator()(const double arg0,
const Ts... args) const
{
const double arg[] = {arg0, static_cast<double>(args)...};
return (*this)(arg);
}
/******************************************************************************
* DoubleFunctionFactory class *
******************************************************************************/
class DoubleFunctionFactory
{
public:
// constructor
DoubleFunctionFactory(void) = default;
// destructor
virtual ~DoubleFunctionFactory(void) = default;
// factory
virtual DoubleFunction makeFunction(const bool makeHardCopy) const = 0;
};
END_LATAN_NAMESPACE
#endif // Latan_Function_hpp_

View File

@ -0,0 +1,150 @@
/*
* Model.cpp, part of LatAnalyze 3
*
* Copyright (C) 2013 - 2020 Antonin Portelli
*
* LatAnalyze 3 is free software: you can redistribute it and/or modify
* it under the terms of the GNU General Public License as published by
* the Free Software Foundation, either version 3 of the License, or
* (at your option) any later version.
*
* LatAnalyze 3 is distributed in the hope that it will be useful,
* but WITHOUT ANY WARRANTY; without even the implied warranty of
* MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
* GNU General Public License for more details.
*
* You should have received a copy of the GNU General Public License
* along with LatAnalyze 3. If not, see <http://www.gnu.org/licenses/>.
*/
#include <LatAnalyze/Functional/Model.hpp>
#include <LatAnalyze/includes.hpp>
using namespace std;
using namespace std::placeholders;
using namespace Latan;
/******************************************************************************
* Model implementation *
******************************************************************************/
// constructor /////////////////////////////////////////////////////////////////
DoubleModel::DoubleModel(const vecFunc &f, const Index nArg, const Index nPar)
: size_(new ModelSize)
, varName_("x")
, parName_("p")
{
setFunction(f, nArg, nPar);
}
// access //////////////////////////////////////////////////////////////////////
Index DoubleModel::getNArg(void) const
{
return size_->nArg;
}
Index DoubleModel::getNPar(void) const
{
return size_->nPar;
}
void DoubleModel::setFunction(const vecFunc &f, const Index nArg,
const Index nPar)
{
size_->nArg = nArg;
size_->nPar = nPar;
f_ = f;
}
VarName & DoubleModel::varName(void)
{
return varName_;
}
const VarName & DoubleModel::varName(void) const
{
return varName_;
}
VarName & DoubleModel::parName(void)
{
return parName_;
}
const VarName & DoubleModel::parName(void) const
{
return parName_;
}
// error checking //////////////////////////////////////////////////////////////
void DoubleModel::checkSize(const Index nArg, const Index nPar) const
{
if (nArg != getNArg())
{
LATAN_ERROR(Size, "model argument vector has a wrong size (expected "
+ strFrom(getNArg()) + ", got " + strFrom(nArg)
+ ")");
}
if (nPar != getNPar())
{
LATAN_ERROR(Size, "model parameter vector has a wrong size (expected "
+ strFrom(getNPar()) + ", got " + strFrom(nPar)
+ ")");
}
}
// function call ///////////////////////////////////////////////////////////////
double DoubleModel::operator()(const DVec &arg, const DVec &par) const
{
checkSize(arg.size(), par.size());
return (*this)(arg.data(), par.data());
}
double DoubleModel::operator()(const vector<double> &arg,
const vector<double> &par) const
{
checkSize(static_cast<Index>(arg.size()), static_cast<Index>(par.size()));
return (*this)(arg.data(), par.data());
}
double DoubleModel::operator()(const double *data, const double *par) const
{
return f_(data, par);
}
// model bind //////////////////////////////////////////////////////////////////
DoubleFunction DoubleModel::fixArg(const DVec &arg) const
{
DoubleModel copy(*this);
auto modelWithVec = [copy](const DVec &x, const double *p)
{
return copy(x.data(), p);
};
auto modelBind = bind(modelWithVec, arg, _1);
return DoubleFunction(modelBind, getNPar());
}
DoubleFunction DoubleModel::fixPar(const DVec &par) const
{
DoubleModel copy(*this);
auto modelWithVec = [copy](const double *x, const DVec &p)
{
return copy(x, p.data());
};
auto modelBind = bind(modelWithVec, _1, par);
return DoubleFunction(modelBind, getNArg());
}
DoubleFunction DoubleModel::toFunction(void) const
{
DoubleModel copy(*this);
auto func = [copy](const double *x){return copy(x, x + copy.getNArg());};
return DoubleFunction(func, getNArg() + getNPar());
}

View File

@ -0,0 +1,87 @@
/*
* Model.hpp, part of LatAnalyze 3
*
* Copyright (C) 2013 - 2020 Antonin Portelli
*
* LatAnalyze 3 is free software: you can redistribute it and/or modify
* it under the terms of the GNU General Public License as published by
* the Free Software Foundation, either version 3 of the License, or
* (at your option) any later version.
*
* LatAnalyze 3 is distributed in the hope that it will be useful,
* but WITHOUT ANY WARRANTY; without even the implied warranty of
* MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
* GNU General Public License for more details.
*
* You should have received a copy of the GNU General Public License
* along with LatAnalyze 3. If not, see <http://www.gnu.org/licenses/>.
*/
#ifndef Latan_Model_hpp_
#define Latan_Model_hpp_
#include <LatAnalyze/Global.hpp>
#include <LatAnalyze/Functional/Function.hpp>
#include <LatAnalyze/Core/Mat.hpp>
BEGIN_LATAN_NAMESPACE
/******************************************************************************
* Double model class *
******************************************************************************/
class DoubleModel
{
public:
typedef std::function<double(const double *, const double *)> vecFunc;
private:
struct ModelSize{Index nArg, nPar;};
public:
// constructor
DoubleModel(const vecFunc &f = nullptr, const Index nArg = 0,
const Index nPar = 0);
// destructor
virtual ~DoubleModel(void) = default;
// access
virtual Index getNArg(void) const;
virtual Index getNPar(void) const;
void setFunction(const vecFunc &f, const Index nArg,
const Index nPar);
VarName & varName(void);
const VarName & varName(void) const;
VarName & parName(void);
const VarName & parName(void) const;
// function call
double operator()(const DVec &data, const DVec &par) const;
double operator()(const std::vector<double> &data,
const std::vector<double> &par) const;
double operator()(const double *data, const double *par) const;
// bind
DoubleFunction fixArg(const DVec &arg) const;
DoubleFunction fixPar(const DVec &par) const;
DoubleFunction toFunction(void) const;
private:
// error checking
void checkSize(const Index nArg, const Index nPar) const;
private:
std::shared_ptr<ModelSize> size_;
VarName varName_, parName_;
vecFunc f_;
};
/******************************************************************************
* base class for model factories *
******************************************************************************/
class DoubleModelFactory
{
public:
// constructor
DoubleModelFactory(void) = default;
// destructor
virtual ~DoubleModelFactory(void) = default;
// factory
virtual DoubleModel makeModel(const bool makeHardCopy) const = 0;
};
END_LATAN_NAMESPACE
#endif // Latan_Model_hpp_

View File

@ -0,0 +1,164 @@
/*
* TabFunction.cpp, part of LatAnalyze 3
*
* Copyright (C) 2013 - 2020 Antonin Portelli
*
* LatAnalyze 3 is free software: you can redistribute it and/or modify
* it under the terms of the GNU General Public License as published by
* the Free Software Foundation, either version 3 of the License, or
* (at your option) any later version.
*
* LatAnalyze 3 is distributed in the hope that it will be useful,
* but WITHOUT ANY WARRANTY; without even the implied warranty of
* MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
* GNU General Public License for more details.
*
* You should have received a copy of the GNU General Public License
* along with LatAnalyze 3. If not, see <http://www.gnu.org/licenses/>.
*/
#include <LatAnalyze/Functional/TabFunction.hpp>
#include <LatAnalyze/includes.hpp>
using namespace std;
using namespace Latan;
/******************************************************************************
* TabFunction implementation *
******************************************************************************/
// constructors ////////////////////////////////////////////////////////////////
TabFunction::TabFunction(const DVec &x, const DVec &y,
const InterpType interpType)
{
setData(x, y);
setInterpolationType(interpType);
}
// access //////////////////////////////////////////////////////////////////////
void TabFunction::setData(const DVec &x, const DVec &y)
{
if (x.size() != y.size())
{
LATAN_ERROR(Size, "tabulated function x/y data size mismatch");
}
FOR_VEC(x, i)
{
value_[x(i)] = y(i);
}
}
void TabFunction::setInterpolationType(const InterpType interpType)
{
interpType_ = interpType;
}
// function call ///////////////////////////////////////////////////////////////
double TabFunction::operator()(const double *arg) const
{
double result = 0.0, x = arg[0];
if ((x < value_.begin()->first) or (x >= value_.rbegin()->first)) {
LATAN_ERROR(Range, "tabulated function variable out of range "
"(x= " + strFrom(x) + " not in ["
+ strFrom(value_.begin()->first) + ", "
+ strFrom(value_.rbegin()->first) + "])");
}
auto i = value_.equal_range(x);
auto low = (x == i.first->first) ? i.first : prev(i.first);
auto high = i.second;
switch (interpType_) {
case InterpType::LINEAR: {
double x_a, x_b, y_a, y_b;
x_a = low->first;
x_b = high->first;
y_a = low->second;
y_b = high->second;
result = y_a + (x - x_a) * (y_b - y_a) / (x_b - x_a);
break;
}
case InterpType::NEAREST: {
result = nearest(x)->second;
break;
}
case InterpType::QUADRATIC: {
double xs[3], ys[3], ds[3], d01, d02, d12;
auto it = nearest(x);
if (it == value_.begin()) {
it = next(it);
}
else if (it == prev(value_.end())) {
it = prev(it);
}
xs[0] = prev(it)->first;
ys[0] = prev(it)->second;
xs[1] = it->first;
ys[1] = it->second;
xs[2] = next(it)->first;
ys[2] = next(it)->second;
ds[0] = x - xs[0];
ds[1] = x - xs[1];
ds[2] = x - xs[2];
d01 = xs[0] - xs[1];
d02 = xs[0] - xs[2];
d12 = xs[1] - xs[2];
// Lagrange polynomial coefficient computation
result = ds[1]/d01*ds[2]/d02*ys[0]
-ds[0]/d01*ds[2]/d12*ys[1]
+ds[0]/d02*ds[1]/d12*ys[2];
break;
}
default:
int intType = static_cast<int>(interpType_);
LATAN_ERROR(Implementation, "unsupported interpolation type in "
"tabulated function: "
+ strFrom(intType));
}
return result;
}
// DoubleFunction factory //////////////////////////////////////////////////////
DoubleFunction TabFunction::makeFunction(const bool makeHardCopy) const
{
DoubleFunction res;
if (makeHardCopy)
{
TabFunction copy(*this);
res.setFunction([copy](const double *x){return copy(x);}, 1);
}
else
{
res.setFunction([this](const double *x){return (*this)(x);}, 1);
}
return res;
}
DoubleFunction Latan::interpolate(const DVec &x, const DVec &y,
const InterpType interpType)
{
return TabFunction(x, y, interpType).makeFunction();
}
map<double, double>::const_iterator TabFunction::nearest(const double x) const
{
map<double, double>::const_iterator ret;
auto i = value_.equal_range(x);
auto low = (x == i.first->first) ? i.first : prev(i.first);
auto high = i.second;
if (fabs(high->first - x) < fabs(low->first - x)) {
ret = high;
}
else {
ret = low;
}
return ret;
}

View File

@ -0,0 +1,69 @@
/*
* TabFunction.hpp, part of LatAnalyze 3
*
* Copyright (C) 2013 - 2020 Antonin Portelli
*
* LatAnalyze 3 is free software: you can redistribute it and/or modify
* it under the terms of the GNU General Public License as published by
* the Free Software Foundation, either version 3 of the License, or
* (at your option) any later version.
*
* LatAnalyze 3 is distributed in the hope that it will be useful,
* but WITHOUT ANY WARRANTY; without even the implied warranty of
* MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
* GNU General Public License for more details.
*
* You should have received a copy of the GNU General Public License
* along with LatAnalyze 3. If not, see <http://www.gnu.org/licenses/>.
*/
#ifndef Latan_TabFunction_hpp_
#define Latan_TabFunction_hpp_
#include <LatAnalyze/Global.hpp>
#include <LatAnalyze/Functional/Function.hpp>
#include <LatAnalyze/Core/Math.hpp>
#include <LatAnalyze/Statistics/XYStatData.hpp>
BEGIN_LATAN_NAMESPACE
/******************************************************************************
* tabulated function: 1D only *
******************************************************************************/
enum class InterpType
{
NEAREST,
LINEAR,
QUADRATIC
};
class TabFunction: public DoubleFunctionFactory
{
public:
// constructors
TabFunction(void) = default;
TabFunction(const DVec &x, const DVec &y,
const InterpType interpType = InterpType::LINEAR);
// destructor
virtual ~TabFunction(void) = default;
// access
void setData(const DVec &x, const DVec &y);
void setInterpolationType(const InterpType interpType);
// function call
double operator()(const double *arg) const;
// factory
virtual DoubleFunction makeFunction(const bool makeHardCopy = true) const;
private:
std::map<double, double>::const_iterator nearest(const double x) const;
std::map<double, double> value_;
InterpType interpType_;
};
DoubleFunction interpolate(const DVec &x, const DVec &y,
const InterpType interpType = InterpType::LINEAR);
END_LATAN_NAMESPACE
#endif // Latan_TabFunction_hpp_

35
lib/LatAnalyze/Global.cpp Normal file
View File

@ -0,0 +1,35 @@
/*
* Global.cpp, part of LatAnalyze 3
*
* Copyright (C) 2013 - 2020 Antonin Portelli
*
* LatAnalyze 3 is free software: you can redistribute it and/or modify
* it under the terms of the GNU General Public License as published by
* the Free Software Foundation, either version 3 of the License, or
* (at your option) any later version.
*
* LatAnalyze 3 is distributed in the hope that it will be useful,
* but WITHOUT ANY WARRANTY; without even the implied warranty of
* MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
* GNU General Public License for more details.
*
* You should have received a copy of the GNU General Public License
* along with LatAnalyze 3. If not, see <http://www.gnu.org/licenses/>.
*/
#include <LatAnalyze/Global.hpp>
#include <LatAnalyze/includes.hpp>
using namespace std;
using namespace Latan;
PlaceHolder Latan::_;
const string Env::fullName = PACKAGE_STRING;
const string Env::name = PACKAGE_NAME;
const string Env::version = PACKAGE_VERSION;
const string Env::msgPrefix = "[" + strFrom(PACKAGE_NAME) + " v"
+ strFrom(PACKAGE_VERSION) + "] ";
void Env::function(void)
{}

68
lib/LatAnalyze/Global.hpp Normal file
View File

@ -0,0 +1,68 @@
/*
* Global.hpp, part of LatAnalyze 3
*
* Copyright (C) 2013 - 2020 Antonin Portelli
*
* LatAnalyze 3 is free software: you can redistribute it and/or modify
* it under the terms of the GNU General Public License as published by
* the Free Software Foundation, either version 3 of the License, or
* (at your option) any later version.
*
* LatAnalyze 3 is distributed in the hope that it will be useful,
* but WITHOUT ANY WARRANTY; without even the implied warranty of
* MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
* GNU General Public License for more details.
*
* You should have received a copy of the GNU General Public License
* along with LatAnalyze 3. If not, see <http://www.gnu.org/licenses/>.
*/
#ifndef Latan_Global_hpp_
#define Latan_Global_hpp_
#include <LatAnalyze/Core/stdincludes.hpp>
#define BEGIN_LATAN_NAMESPACE \
namespace Latan {
#define END_LATAN_NAMESPACE }
// macro utilities
#define unique_arg(...) __VA_ARGS__
#define DEBUG_VAR(x) std::cout << #x << "= " << x << std::endl
#define DEBUG_MAT(m) std::cout << #m << "=\n" << m << std::endl
// attribute to switch off unused warnings with gcc
#ifdef __GNUC__
#define __dumb __attribute__((unused))
#else
#define __dumb
#endif
// max length for paths
#define MAX_PATH_LENGTH 512u
BEGIN_LATAN_NAMESPACE
// Placeholder type ////////////////////////////////////////////////////////////
struct PlaceHolder {};
extern PlaceHolder _;
// Environment /////////////////////////////////////////////////////////////////
namespace Env
{
extern const std::string fullName;
extern const std::string name;
extern const std::string version;
extern const std::string msgPrefix;
// empty function for library test
void function(void);
}
END_LATAN_NAMESPACE
#include <LatAnalyze/Core/Eigen.hpp>
#include <LatAnalyze/Core/Exceptions.hpp>
#include <LatAnalyze/Core/Utilities.hpp>
#endif // Latan_Global_hpp_

View File

@ -0,0 +1,212 @@
/*
* AsciiFile.cpp, part of LatAnalyze 3
*
* Copyright (C) 2013 - 2020 Antonin Portelli
*
* LatAnalyze 3 is free software: you can redistribute it and/or modify
* it under the terms of the GNU General Public License as published by
* the Free Software Foundation, either version 3 of the License, or
* (at your option) any later version.
*
* LatAnalyze 3 is distributed in the hope that it will be useful,
* but WITHOUT ANY WARRANTY; without even the implied warranty of
* MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
* GNU General Public License for more details.
*
* You should have received a copy of the GNU General Public License
* along with LatAnalyze 3. If not, see <http://www.gnu.org/licenses/>.
*/
#include <LatAnalyze/Io/AsciiFile.hpp>
#include <LatAnalyze/includes.hpp>
using namespace std;
using namespace Latan;
/******************************************************************************
* AsciiFile implementation *
******************************************************************************/
// AsciiParserState constructor ////////////////////////////////////////////////
AsciiFile::AsciiParserState::AsciiParserState(istream *stream, string *name,
IoDataTable *data)
: ParserState<IoDataTable>(stream, name, data)
{
initScanner();
}
// AsciiParserState destructor /////////////////////////////////////////////////
AsciiFile::AsciiParserState::~AsciiParserState(void)
{
destroyScanner();
}
// constructor /////////////////////////////////////////////////////////////////
AsciiFile::AsciiFile(const string &name, const unsigned int mode)
{
open(name, mode);
}
// destructor //////////////////////////////////////////////////////////////////
AsciiFile::~AsciiFile(void)
{
close();
}
// access //////////////////////////////////////////////////////////////////////
void AsciiFile::save(const DMat &m, const std::string &name)
{
if (name.empty())
{
LATAN_ERROR(Io, "trying to save data with an empty name");
}
const auto defaultPrec = fileStream_.precision(defaultDoublePrec);
checkWritability();
isParsed_ = false;
fileStream_ << "#L latan_begin mat " << name << endl;
fileStream_ << m.cols() << endl;
fileStream_ << scientific << m << endl;
fileStream_ << "#L latan_end mat " << endl;
fileStream_.precision(defaultPrec);
}
void AsciiFile::save(const DSample &ds, const std::string &name)
{
if (name.empty())
{
LATAN_ERROR(Io, "trying to save data with an empty name");
}
checkWritability();
isParsed_ = false;
fileStream_ << "#L latan_begin rs_sample " << name << endl;
fileStream_ << ds.size() << endl;
save(ds.matrix(), name + "_data");
fileStream_ << "#L latan_end rs_sample " << endl;
}
void AsciiFile::save(const DMatSample &ms, const std::string &name)
{
if (name.empty())
{
LATAN_ERROR(Io, "trying to save data with an empty name");
}
checkWritability();
isParsed_ = false;
fileStream_ << "#L latan_begin rs_sample " << name << endl;
fileStream_ << ms.size() << endl;
save(ms[central], name + "_C");
for (Index i = 0; i < ms.size(); ++i)
{
save(ms[i], name + "_S_" + strFrom(i));
}
fileStream_ << "#L latan_end rs_sample " << endl;
}
// read first name ////////////////////////////////////////////////////////////
string AsciiFile::getFirstName(void)
{
return load();
}
// tests ///////////////////////////////////////////////////////////////////////
bool AsciiFile::isOpen() const
{
return fileStream_.is_open();
}
// IO //////////////////////////////////////////////////////////////////////////
void AsciiFile::close(void)
{
state_.reset();
if (isOpen())
{
fileStream_.close();
}
name_ = "";
mode_ = Mode::null;
isParsed_ = false;
deleteData();
}
void AsciiFile::open(const string &name, const unsigned int mode)
{
if (isOpen())
{
LATAN_ERROR(Io, "file already opened with name '" + name_ + "'");
}
else
{
ios_base::openmode stdMode = static_cast<ios_base::openmode>(0);
name_ = name;
mode_ = mode;
if (mode & Mode::write)
{
stdMode |= ios::out|ios::trunc;
}
if (mode & Mode::read)
{
stdMode |= ios::in;
}
if (mode & Mode::append)
{
stdMode |= ios::out|ios::app;
}
isParsed_ = false;
fileStream_.open(name_.c_str(), stdMode);
if (mode_ & Mode::read)
{
state_.reset(new AsciiParserState(&fileStream_, &name_, &data_));
}
else
{
state_.reset();
}
}
}
std::string AsciiFile::load(const string &name)
{
if ((mode_ & Mode::read) and (isOpen()))
{
if (!isParsed_)
{
state_->isFirst = true;
parse();
}
if (name.empty())
{
return state_->first;
}
else
{
return name;
}
}
else
{
if (isOpen())
{
LATAN_ERROR(Io, "file '" + name_ + "' is not opened in read mode");
}
else
{
LATAN_ERROR(Io, "file '" + name_ + "' is not opened");
}
}
}
// parser //////////////////////////////////////////////////////////////////////
//// Bison/Flex parser declaration
int _Ascii_parse(AsciiFile::AsciiParserState *state);
void AsciiFile::parse()
{
fileStream_.seekg(0);
_Ascii_parse(state_.get());
isParsed_ = true;
}

View File

@ -0,0 +1,92 @@
/*
* AsciiFile.hpp, part of LatAnalyze 3
*
* Copyright (C) 2013 - 2020 Antonin Portelli
*
* LatAnalyze 3 is free software: you can redistribute it and/or modify
* it under the terms of the GNU General Public License as published by
* the Free Software Foundation, either version 3 of the License, or
* (at your option) any later version.
*
* LatAnalyze 3 is distributed in the hope that it will be useful,
* but WITHOUT ANY WARRANTY; without even the implied warranty of
* MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
* GNU General Public License for more details.
*
* You should have received a copy of the GNU General Public License
* along with LatAnalyze 3. If not, see <http://www.gnu.org/licenses/>.
*/
#ifndef Latan_AsciiFile_hpp_
#define Latan_AsciiFile_hpp_
#include <LatAnalyze/Global.hpp>
#include <LatAnalyze/Io/File.hpp>
#include <LatAnalyze/Core/Mat.hpp>
#include <LatAnalyze/Statistics/MatSample.hpp>
#include <LatAnalyze/Core/ParserState.hpp>
BEGIN_LATAN_NAMESPACE
/******************************************************************************
* ASCII datafile class *
******************************************************************************/
class AsciiFile: public File
{
public:
class AsciiParserState: public ParserState<IoDataTable>
{
public:
// constructor
AsciiParserState(std::istream *stream, std::string *name,
IoDataTable *data);
// destructor
virtual ~AsciiParserState(void);
// first element reference
bool isFirst;
std::string first;
// parsing buffers
int intBuf;
DSample dSampleBuf;
DMatSample dMatSampleBuf;
std::queue<DMat> dMatQueue;
std::queue<double> doubleQueue;
private:
// allocation/deallocation functions defined in IoAsciiLexer.lpp
virtual void initScanner(void);
virtual void destroyScanner(void);
};
public:
// constructors
AsciiFile(void) = default;
AsciiFile(const std::string &name, const unsigned int mode);
// destructor
virtual ~AsciiFile(void);
// access
virtual void save(const DMat &m, const std::string &name);
virtual void save(const DSample &ds, const std::string &name);
virtual void save(const DMatSample &ms, const std::string &name);
// read first name
virtual std::string getFirstName(void);
// tests
virtual bool isOpen(void) const;
// IO
virtual void close(void);
virtual void open(const std::string &name, const unsigned int mode);
public:
// default ASCII precision
static const unsigned int defaultDoublePrec = 15;
private:
// IO
virtual std::string load(const std::string &name = "");
// parser
void parse(void);
private:
std::fstream fileStream_;
bool isParsed_{false};
std::unique_ptr<AsciiParserState> state_{nullptr};
};
END_LATAN_NAMESPACE
#endif // Latan_AsciiFile_hpp_

View File

@ -0,0 +1,91 @@
/*
* AsciiLexer.lpp, part of LatAnalyze 3
*
* Copyright (C) 2013 - 2020 Antonin Portelli
*
* LatAnalyze 3 is free software: you can redistribute it and/or modify
* it under the terms of the GNU General Public License as published by
* the Free Software Foundation, either version 3 of the License, or
* (at your option) any later version.
*
* LatAnalyze 3 is distributed in the hope that it will be useful,
* but WITHOUT ANY WARRANTY; without even the implied warranty of
* MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
* GNU General Public License for more details.
*
* You should have received a copy of the GNU General Public License
* along with LatAnalyze 3. If not, see <http://www.gnu.org/licenses/>.
*/
%option reentrant
%option prefix="_Ascii_"
%option bison-bridge
%option bison-locations
%option noyywrap
%option yylineno
%{
#include <LatAnalyze/Io/AsciiFile.hpp>
#include "AsciiParser.hpp"
using namespace std;
using namespace Latan;
#define YY_EXTRA_TYPE AsciiFile::AsciiParserState*
#define YY_USER_ACTION \
yylloc->first_line = yylloc->last_line = yylineno;\
yylloc->first_column = yylloc->last_column + 1;\
yylloc->last_column = yylloc->first_column + yyleng - 1;
#define YY_INPUT(buf, result, max_size) \
{ \
(*yyextra->stream).read(buf,max_size);\
result = (*yyextra->stream).gcount();\
}
#define YY_DEBUG 0
#if (YY_DEBUG == 1)
#define RETTOK(tok) cout << #tok << "(" << yytext << ")" << endl; return tok
#else
#define RETTOK(tok) return tok
#endif
%}
DIGIT [0-9]
ALPHA [a-zA-Z_+./-]
SIGN \+|-
EXP e|E
INT {SIGN}?{DIGIT}+
FLOAT {SIGN}?(({DIGIT}+\.{DIGIT}*)|({DIGIT}*\.{DIGIT}+))({EXP}{SIGN}?{INT}+)?
LMARK #L
BLANK [ \t]
%x MARK TYPE
%%
{LMARK} {BEGIN(MARK);}
{INT} {yylval->val_int = strTo<long int>(yytext); RETTOK(INT);}
{FLOAT} {yylval->val_double = strTo<double>(yytext); RETTOK(FLOAT);}
({ALPHA}|{DIGIT})+ {strcpy(yylval->val_str,yytext); RETTOK(ID);}
<MARK>latan_begin {BEGIN(TYPE); RETTOK(OPEN);}
<MARK>latan_end {BEGIN(TYPE); RETTOK(CLOSE);}
<TYPE>mat {BEGIN(INITIAL); RETTOK(MAT);}
<TYPE>rs_sample {BEGIN(INITIAL); RETTOK(SAMPLE);}
<TYPE>rg_state {BEGIN(INITIAL); RETTOK(RG_STATE);}
<*>\r*\n {yylloc->last_column = 0;}
<*>[ \t]
<*>. {yylval->val_char = yytext[0]; RETTOK(ERR);}
%%
void AsciiFile::AsciiParserState::initScanner()
{
yylex_init(&scanner);
yyset_extra(this, scanner);
}
void AsciiFile::AsciiParserState::destroyScanner()
{
yylex_destroy(scanner);
}

View File

@ -0,0 +1,183 @@
/*
* AsciiParser.ypp, part of LatAnalyze 3
*
* Copyright (C) 2013 - 2020 Antonin Portelli
*
* LatAnalyze 3 is free software: you can redistribute it and/or modify
* it under the terms of the GNU General Public License as published by
* the Free Software Foundation, either version 3 of the License, or
* (at your option) any later version.
*
* LatAnalyze 3 is distributed in the hope that it will be useful,
* but WITHOUT ANY WARRANTY; without even the implied warranty of
* MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
* GNU General Public License for more details.
*
* You should have received a copy of the GNU General Public License
* along with LatAnalyze 3. If not, see <http://www.gnu.org/licenses/>.
*/
%{
#include <LatAnalyze/Global.hpp>
#include <LatAnalyze/Io/AsciiFile.hpp>
#include <LatAnalyze/Core/Mat.hpp>
#include <LatAnalyze/Statistics/MatSample.hpp>
using namespace std;
using namespace Latan;
#define TEST_FIRST(name) \
if (state->isFirst)\
{\
state->first = (name);\
state->isFirst = false;\
}
%}
%pure-parser
%name-prefix "_Ascii_"
%locations
%defines
%error-verbose
%parse-param { Latan::AsciiFile::AsciiParserState* state }
%initial-action {yylloc.last_column = 0;}
%lex-param { void* scanner }
%union
{
long int val_int;
double val_double;
char val_char;
char val_str[256];
}
%token <val_char> ERR
%token <val_double> FLOAT
%token <val_int> INT
%token <val_str> ID
%token OPEN CLOSE MAT SAMPLE RG_STATE
%type <val_str> mat matsample dsample
%{
int _Ascii_lex(YYSTYPE* lvalp, YYLTYPE* llocp, void* scanner);
void _Ascii_error(YYLTYPE* locp, AsciiFile::AsciiParserState* state,
const char* err)
{
stringstream buf;
buf << *state->streamName << ":" << locp->first_line << ":"\
<< locp->first_column << ": " << err;
LATAN_ERROR(Parsing, buf.str());
}
#define scanner state->scanner
%}
%%
datas:
/* empty string */
| datas data
;
data:
mat
{
TEST_FIRST($1);
(*state->data)[$1].reset(new DMat(state->dMatQueue.front()));
state->dMatQueue.pop();
}
| dsample
{
TEST_FIRST($1);
(*state->data)[$1].reset(new DSample(state->dSampleBuf));
}
| matsample
{
TEST_FIRST($1);
(*state->data)[$1].reset(new DMatSample(state->dMatSampleBuf));
}
;
mat:
OPEN MAT ID INT floats CLOSE MAT
{
const unsigned int nRow = state->doubleQueue.size()/$4, nCol = $4;
Index i, j, r = 0;
if (state->doubleQueue.size() != nRow*nCol)
{
LATAN_ERROR(Size, "matrix '" + *state->streamName + ":" + $3 +
"' has a wrong size");
}
state->dMatQueue.push(DMat(nRow, nCol));
while (!state->doubleQueue.empty())
{
j = r % nCol;
i = (r - j)/nCol;
state->dMatQueue.back()(i, j) = state->doubleQueue.front();
state->doubleQueue.pop();
r++;
}
strcpy($$, $3);
}
;
dsample:
OPEN SAMPLE ID INT mat CLOSE SAMPLE
{
const unsigned int nSample = $4, os = DMatSample::offset;
DMat &m = state->dMatQueue.front();
if (m.rows() != nSample + os)
{
LATAN_ERROR(Size, "double sample '" + *state->streamName + ":"
+ $3 + "' has a wrong size");
}
if (m.cols() != 1)
{
LATAN_ERROR(Size, "double sample '" + *state->streamName + ":"
+ $3 + "' is not stored as a column vector");
}
state->dSampleBuf = m.array();
state->dMatQueue.pop();
strcpy($$, $3);
}
;
matsample:
OPEN SAMPLE ID INT mat mats CLOSE SAMPLE
{
const unsigned int nSample = $4, os = DMatSample::offset;
if (state->dMatQueue.size() != nSample + os)
{
LATAN_ERROR(Size, "matrix sample '" + *state->streamName + ":"
+ $3 + "' has a wrong size");
}
state->dMatSampleBuf.resize(nSample);
state->dMatSampleBuf[central] = state->dMatQueue.front();
state->dMatQueue.pop();
for (unsigned int i = 0; i < nSample; ++i)
{
state->dMatSampleBuf[i] = state->dMatQueue.front();
state->dMatQueue.pop();
}
strcpy($$, $3);
}
;
mats:
mats mat
| mat
;
floats:
floats FLOAT {state->doubleQueue.push($2);}
| floats INT {state->doubleQueue.push(static_cast<double>($2));}
| FLOAT {state->doubleQueue.push($1);}
| INT {state->doubleQueue.push(static_cast<double>($1));}
;

View File

@ -0,0 +1,81 @@
/*
* BinReader.cpp, part of LatAnalyze
*
* Copyright (C) 2015 Antonin Portelli
*
* LatAnalyze is free software: you can redistribute it and/or modify
* it under the terms of the GNU General Public License as published by
* the Free Software Foundation, either version 3 of the License, or
* (at your option) any later version.
*
* LatAnalyze is distributed in the hope that it will be useful,
* but WITHOUT ANY WARRANTY; without even the implied warranty of
* MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
* GNU General Public License for more details.
*
* You should have received a copy of the GNU General Public License
* along with LatAnalyze. If not, see <http://www.gnu.org/licenses/>.
*/
#include <LatAnalyze/Io/BinReader.hpp>
#include <LatAnalyze/includes.hpp>
#if (defined __GNUC__)||(defined __clang__)
#pragma GCC diagnostic ignored "-Wunreachable-code"
#endif
using namespace std;
using namespace Latan;
BinIO::BinIO(string msg, string loc)
: runtime_error("Binary reader error: " + msg + " (" + loc + ")")
{}
/******************************************************************************
* BinReader implementation *
******************************************************************************/
// constructor /////////////////////////////////////////////////////////////////
BinReader::BinReader(const string fileName, const uint32_t endianness,
const bool isColMaj)
{
open(fileName, endianness, isColMaj);
}
// I/O /////////////////////////////////////////////////////////////////////////
void BinReader::open(const string fileName, const uint32_t endianness,
const bool isColMaj)
{
fileName_ = fileName;
endianness_ = endianness;
isColMaj_ = isColMaj;
file_.reset(new ifstream(fileName_, ios::in|ios::binary|ios::ate));
if (!file_->is_open())
{
LATAN_ERROR(Io, "impossible to open file '" + fileName_ + "'");
}
size_ = static_cast<size_t>(file_->tellg());
file_->seekg(0, ios::beg);
}
void BinReader::close(void)
{
file_.reset(nullptr);
}
template <>
std::string BinReader::read(void)
{
std::string s;
char c = 'a';
while (c != '\n')
{
c = read<char>();
if (c != '\n')
{
s.push_back(c);
}
}
return s;
}

View File

@ -0,0 +1,196 @@
/*
* BinReader.hpp, part of LatAnalyze
*
* Copyright (C) 2015 Antonin Portelli
*
* LatAnalyze is free software: you can redistribute it and/or modify
* it under the terms of the GNU General Public License as published by
* the Free Software Foundation, either version 3 of the License, or
* (at your option) any later version.
*
* LatAnalyze is distributed in the hope that it will be useful,
* but WITHOUT ANY WARRANTY; without even the implied warranty of
* MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
* GNU General Public License for more details.
*
* You should have received a copy of the GNU General Public License
* along with LatAnalyze. If not, see <http://www.gnu.org/licenses/>.
*/
#ifndef LatAnalyze_BinReader_hpp_
#define LatAnalyze_BinReader_hpp_
#include <LatAnalyze/Global.hpp>
BEGIN_LATAN_NAMESPACE
// I/O exception
class BinIO: public std::runtime_error
{
public:
BinIO(std::string msg, std::string loc);
};
/******************************************************************************
* Byte manipulation utilities *
******************************************************************************/
class Endianness
{
public:
enum: uint32_t
{
little = 0x00000001,
big = 0x01000000,
unknown = 0xffffffff
};
};
class ByteManip
{
public:
static constexpr uint32_t getHostEndianness(void)
{
return ((0xffffffff & 1) == Endianness::little) ? Endianness::little
: (((0xffffffff & 1) == Endianness::big) ? Endianness::big
: Endianness::unknown);
}
template <typename T>
static T swapBytes(const T);
};
/******************************************************************************
* template implementation *
******************************************************************************/
template <typename T>
T ByteManip::swapBytes(const T u)
{
static_assert (CHAR_BIT == 8, "CHAR_BIT != 8");
union
{
T u;
unsigned char u8[sizeof(T)];
} source, dest;
source.u = u;
for (size_t k = 0; k < sizeof(T); ++k)
{
dest.u8[k] = source.u8[sizeof(T) - k - 1];
}
return dest.u;
}
/******************************************************************************
* Utility to read binary files *
******************************************************************************/
class BinReader
{
public:
// constructor
BinReader(void) = default;
BinReader(const std::string fileName,
const uint32_t endianness = ByteManip::getHostEndianness(),
const bool isColMaj = false);
// destructor
virtual ~BinReader(void) = default;
// I/O
void open(const std::string fileName,
const uint32_t endianness = ByteManip::getHostEndianness(),
const bool isColMaj = false);
void close(void);
template <typename T>
void read(T *pt, Index size);
template <typename T>
T read(void);
template <typename T>
MatBase<T> read(const Index nRow, const Index nCol);
private:
std::unique_ptr<std::ifstream> file_{nullptr};
std::string fileName_;
size_t size_;
uint32_t endianness_;
bool isColMaj_;
};
/******************************************************************************
* template implementation *
******************************************************************************/
template <typename T>
void BinReader::read(T *pt, Index n)
{
if (file_ != nullptr)
{
file_->read(reinterpret_cast<char *>(pt),
static_cast<long>(sizeof(T))*n);
if (endianness_ != ByteManip::getHostEndianness())
{
for (Index i = 0; i < n; ++i)
{
pt[i] = ByteManip::swapBytes(pt[i]);
}
}
}
else
{
LATAN_ERROR(Io, "file is not opened");
}
}
template <typename T>
T BinReader::read(void)
{
T x;
if (file_ != nullptr)
{
read(&x, 1);
}
else
{
LATAN_ERROR(Io, "file is not opened");
}
return x;
}
template <>
std::string BinReader::read(void);
template <typename T>
MatBase<T> BinReader::read(const Index nRow, const Index nCol)
{
MatBase<T> m;
// Eigen matrices use column-major ordering
if (isColMaj_)
{
m.resize(nRow, nCol);
}
else
{
m.resize(nCol, nRow);
}
if (file_ != nullptr)
{
read(m.data(), nRow*nCol);
}
else
{
LATAN_ERROR(Io, "file is not opened");
}
if (isColMaj_)
{
return m;
}
else
{
return m.transpose();
}
}
END_LATAN_NAMESPACE
#endif // LatAnalyze_BinReader_hpp_

View File

@ -0,0 +1,64 @@
/*
* File.cpp, part of LatAnalyze 3
*
* Copyright (C) 2013 - 2020 Antonin Portelli
*
* LatAnalyze 3 is free software: you can redistribute it and/or modify
* it under the terms of the GNU General Public License as published by
* the Free Software Foundation, either version 3 of the License, or
* (at your option) any later version.
*
* LatAnalyze 3 is distributed in the hope that it will be useful,
* but WITHOUT ANY WARRANTY; without even the implied warranty of
* MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
* GNU General Public License for more details.
*
* You should have received a copy of the GNU General Public License
* along with LatAnalyze 3. If not, see <http://www.gnu.org/licenses/>.
*/
#include <LatAnalyze/Io/File.hpp>
#include <LatAnalyze/includes.hpp>
using namespace std;
using namespace Latan;
/******************************************************************************
* File implementation *
******************************************************************************/
// constructors ////////////////////////////////////////////////////////////////
File::File(const string &name, const unsigned int mode)
: name_(name)
, mode_(mode)
{}
// destructor //////////////////////////////////////////////////////////////////
File::~File(void)
{
deleteData();
}
// access //////////////////////////////////////////////////////////////////////
const string & File::getName(void) const
{
return name_;
}
unsigned int File::getMode(void) const
{
return mode_;
}
// internal functions //////////////////////////////////////////////////////////
void File::deleteData(void)
{
data_.clear();
}
void File::checkWritability(void)
{
if (!((mode_ & Mode::write) or (mode_ & Mode::append)) or !isOpen())
{
LATAN_ERROR(Io, "file '" + name_ + "' is not writable");
}
}

123
lib/LatAnalyze/Io/File.hpp Normal file
View File

@ -0,0 +1,123 @@
/*
* File.hpp, part of LatAnalyze 3
*
* Copyright (C) 2013 - 2020 Antonin Portelli
*
* LatAnalyze 3 is free software: you can redistribute it and/or modify
* it under the terms of the GNU General Public License as published by
* the Free Software Foundation, either version 3 of the License, or
* (at your option) any later version.
*
* LatAnalyze 3 is distributed in the hope that it will be useful,
* but WITHOUT ANY WARRANTY; without even the implied warranty of
* MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
* GNU General Public License for more details.
*
* You should have received a copy of the GNU General Public License
* along with LatAnalyze 3. If not, see <http://www.gnu.org/licenses/>.
*/
#ifndef Latan_File_hpp_
#define Latan_File_hpp_
#include <LatAnalyze/Global.hpp>
#include <LatAnalyze/Io/IoObject.hpp>
#include <LatAnalyze/Core/Mat.hpp>
#include <LatAnalyze/Statistics/MatSample.hpp>
BEGIN_LATAN_NAMESPACE
/******************************************************************************
* Abstract datafile class *
******************************************************************************/
typedef std::unordered_map<std::string, std::unique_ptr<IoObject>> IoDataTable;
class File
{
public:
class Mode
{
public:
enum
{
null = 0,
write = 1 << 0,
read = 1 << 1,
append = 1 << 2
};
};
public:
// constructors
File(void) = default;
File(const std::string &name, const unsigned int mode);
// destructor
virtual ~File(void);
// access
const std::string & getName(void) const;
unsigned int getMode(void) const;
template <typename IoT>
const IoT & read(const std::string &name = "");
virtual void save(const DMat &m, const std::string &name) = 0;
virtual void save(const DSample &ds, const std::string &name) = 0;
virtual void save(const DMatSample &ms, const std::string &name) = 0;
// read first name
virtual std::string getFirstName(void) = 0;
// tests
virtual bool isOpen(void) const = 0;
// IO
virtual void close(void) = 0;
virtual void open(const std::string &name, const unsigned int mode) = 0;
protected:
// access
void setName(const std::string &name);
void setMode(const unsigned int mode);
// data access
void deleteData(void);
// error checking
void checkWritability(void);
private:
// data access
template <typename IoT>
const IoT& getData(const std::string &name = "") const;
// IO
virtual std::string load(const std::string &name = "") = 0;
protected:
std::string name_{""};
unsigned int mode_{Mode::null};
IoDataTable data_;
};
// Template implementations
template <typename IoT>
const IoT& File::read(const std::string &name)
{
std::string dataName;
dataName = load(name);
return getData<IoT>(dataName);
}
template <typename IoT>
const IoT& File::getData(const std::string &name) const
{
try
{
return dynamic_cast<const IoT &>(*(data_.at(name)));
}
catch(std::out_of_range)
{
LATAN_ERROR(Definition, "no data with name '" + name + "' in file "
+ name_);
}
catch(std::bad_cast)
{
LATAN_ERROR(Definition, "data with name '" + name + "' in file "
+ name_ + " does not have type '" + typeid(IoT).name()
+ "'");
}
}
END_LATAN_NAMESPACE
#endif

View File

@ -0,0 +1,355 @@
/*
* Hdf5File.cpp, part of LatAnalyze 3
*
* Copyright (C) 2013 - 2020 Antonin Portelli, Matt Spraggs
*
* LatAnalyze 3 is free software: you can redistribute it and/or modify
* it under the terms of the GNU General Public License as published by
* the Free Software Foundation, either version 3 of the License, or
* (at your option) any later version.
*
* LatAnalyze 3 is distributed in the hope that it will be useful,
* but WITHOUT ANY WARRANTY; without even the implied warranty of
* MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
* GNU General Public License for more details.
*
* You should have received a copy of the GNU General Public License
* along with LatAnalyze 3. If not, see <http://www.gnu.org/licenses/>.
*/
#include <LatAnalyze/Io/Hdf5File.hpp>
#include <LatAnalyze/Io/IoObject.hpp>
#include <LatAnalyze/includes.hpp>
using namespace std;
using namespace Latan;
#ifndef H5_NO_NAMESPACE
using namespace H5NS;
#endif
constexpr unsigned int maxGroupNameSize = 1024u;
const short dMatType = static_cast<short>(IoObject::IoType::dMat);
const short dSampleType = static_cast<short>(IoObject::IoType::dSample);
const short dMatSampleType = static_cast<short>(IoObject::IoType::dMatSample);
/******************************************************************************
* Hdf5File implementation *
******************************************************************************/
// constructors ////////////////////////////////////////////////////////////////
Hdf5File::Hdf5File(void)
{}
Hdf5File::Hdf5File(const std::string &name, const unsigned int mode)
{
open(name, mode);
}
// destructor //////////////////////////////////////////////////////////////////
Hdf5File::~Hdf5File(void)
{
close();
}
// access //////////////////////////////////////////////////////////////////////
void Hdf5File::save(const DMat &m, const string &name)
{
if (name.empty())
{
LATAN_ERROR(Io, "trying to save data with an empty name");
}
Group group;
Attribute attr;
DataSet dataset;
hsize_t dim[2] = {static_cast<hsize_t>(m.rows()),
static_cast<hsize_t>(m.cols())};
hsize_t attrDim = 1;
DataSpace dataSpace(2, dim), attrSpace(1, &attrDim);
group = h5File_->createGroup(name.c_str() + nameOffset(name));
attr = group.createAttribute("type", PredType::NATIVE_SHORT, attrSpace);
attr.write(PredType::NATIVE_SHORT, &dMatType);
dataset = group.createDataSet("data", PredType::NATIVE_DOUBLE, dataSpace);
dataset.write(m.data(), PredType::NATIVE_DOUBLE);
}
void Hdf5File::save(const DSample &ds, const string &name)
{
if (name.empty())
{
LATAN_ERROR(Io, "trying to save data with an empty name");
}
Group group;
Attribute attr;
DataSet dataset;
hsize_t dim = static_cast<hsize_t>(ds.size() + 1);
hsize_t attrDim = 1;
DataSpace dataSpace(1, &dim), attrSpace(1, &attrDim);
const long int nSample = ds.size();
group = h5File_->createGroup(name.c_str() + nameOffset(name));
attr = group.createAttribute("type", PredType::NATIVE_SHORT, attrSpace);
attr.write(PredType::NATIVE_SHORT, &dSampleType);
attr = group.createAttribute("nSample", PredType::NATIVE_LONG, attrSpace);
attr.write(PredType::NATIVE_LONG, &nSample);
dataset = group.createDataSet("data", PredType::NATIVE_DOUBLE, dataSpace);
dataset.write(ds.data(), PredType::NATIVE_DOUBLE);
}
void Hdf5File::save(const DMatSample &ms, const string &name)
{
if (name.empty())
{
LATAN_ERROR(Io, "trying to save data with an empty name");
}
Group group;
Attribute attr;
DataSet dataset;
hsize_t dim[2] = {static_cast<hsize_t>(ms[central].rows()),
static_cast<hsize_t>(ms[central].cols())};
hsize_t attrDim = 1;
DataSpace dataSpace(2, dim), attrSpace(1, &attrDim);
const long int nSample = ms.size();
string datasetName;
group = h5File_->createGroup(name.c_str() + nameOffset(name));
attr = group.createAttribute("type", PredType::NATIVE_SHORT, attrSpace);
attr.write(PredType::NATIVE_SHORT, &dMatSampleType);
attr = group.createAttribute("nSample", PredType::NATIVE_LONG, attrSpace);
attr.write(PredType::NATIVE_LONG, &nSample);
FOR_STAT_ARRAY(ms, s)
{
datasetName = (s == central) ? "data_C" : ("data_S_" + strFrom(s));
dataset = group.createDataSet(datasetName.c_str(),
PredType::NATIVE_DOUBLE,
dataSpace);
dataset.write(ms[s].data(), PredType::NATIVE_DOUBLE);
}
}
// read first name ////////////////////////////////////////////////////////////
string Hdf5File::getFirstName(void)
{
return getFirstGroupName();
}
// tests ///////////////////////////////////////////////////////////////////////
bool Hdf5File::isOpen(void) const
{
return (h5File_ != nullptr);
}
// check names for forbidden characters ////////////////////////////////////////
size_t Hdf5File::nameOffset(const string &name)
{
size_t ret = 0;
string badChars = "/";
for (auto c : badChars)
{
size_t pos = name.rfind(c);
if (pos != string::npos and pos > ret)
{
ret = pos;
}
}
return ret;
}
// IO //////////////////////////////////////////////////////////////////////////
void Hdf5File::close(void)
{
if (isOpen())
{
h5File_->close();
}
h5File_.reset(nullptr);
name_ = "";
mode_ = Mode::null;
deleteData();
}
void Hdf5File::open(const string &name, const unsigned int mode)
{
if (isOpen())
{
LATAN_ERROR(Io, "file already opened with name '" + name_ + "'");
}
else
{
unsigned int h5Mode = 0;
name_ = name;
mode_ = mode;
if (mode & Mode::write)
{
h5Mode |= H5F_ACC_TRUNC;
}
if (mode & Mode::read)
{
h5Mode |= H5F_ACC_RDONLY;
}
if (mode & Mode::append)
{
h5Mode |= H5F_ACC_RDWR|H5F_ACC_CREAT;
}
h5File_.reset(new H5File(name_.c_str(), h5Mode));
}
}
string Hdf5File::getFirstGroupName(void)
{
string res;
if ((mode_ & Mode::read) and (isOpen()))
{
auto firstGroupName = [](hid_t loc_id, const char *name, void *fname)
{
H5G_stat_t statbuf;
H5Gget_objinfo(loc_id, name, 0, &statbuf);
if ((statbuf.type == H5G_GROUP) and (strlen((char *)fname) == 0))
{
strncpy((char *)fname, name, maxGroupNameSize);
}
return 0;
};
char groupName[maxGroupNameSize] = "";
h5File_->iterateElems("/", nullptr, firstGroupName, groupName);
res = groupName;
}
else
{
if (isOpen())
{
LATAN_ERROR(Io, "file '" + name_ + "' is not opened in read mode");
}
else
{
LATAN_ERROR(Io, "file '" + name_ + "' is not opened");
}
return "";
}
return res;
}
void Hdf5File::load(DMat &m, const DataSet &d)
{
DataSpace dataspace;
hsize_t dim[2];
dataspace = d.getSpace();
dataspace.getSimpleExtentDims(dim);
m.resize(dim[0], dim[1]);
d.read(m.data(), PredType::NATIVE_DOUBLE);
}
void Hdf5File::load(DSample &ds, const DataSet &d)
{
DataSpace dataspace;
hsize_t dim[1];
dataspace = d.getSpace();
dataspace.getSimpleExtentDims(dim);
ds.resize(dim[0] - 1);
d.read(ds.data(), PredType::NATIVE_DOUBLE);
}
string Hdf5File::load(const string &name)
{
if ((mode_ & Mode::read) and (isOpen()))
{
string groupName;
Group group;
Attribute attribute;
DataSet dataset;
IoObject::IoType type;
groupName = (name.empty()) ? getFirstGroupName() : name;
if (groupName.empty())
{
LATAN_ERROR(Io, "file '" + name_ + "' is empty");
}
group = h5File_->openGroup(groupName.c_str());
attribute = group.openAttribute("type");
attribute.read(PredType::NATIVE_SHORT, &type);
switch (type)
{
case IoObject::IoType::dMat:
{
DMat *pt = new DMat;
data_[groupName].reset(pt);
dataset = group.openDataSet("data");
load(*pt, dataset);
break;
}
case IoObject::IoType::dSample:
{
DSample *pt = new DSample;
data_[groupName].reset(pt);
dataset = group.openDataSet("data");
load(*pt, dataset);
break;
}
case IoObject::IoType::dMatSample:
{
DMatSample *pt = new DMatSample;
long int nSample;
data_[groupName].reset(pt);
attribute = group.openAttribute("nSample");
attribute.read(PredType::NATIVE_LONG, &nSample);
pt->resize(nSample);
FOR_STAT_ARRAY(*pt, s)
{
if (s == central)
{
dataset = group.openDataSet("data_C");
}
else
{
dataset =
group.openDataSet(("data_S_" + strFrom(s)).c_str());
}
load((*pt)[s], dataset);
}
break;
}
default:
{
LATAN_ERROR(Io, "unknown data type ("
+ strFrom(static_cast<int>(type)) + ") "
" (" + name_ + ":" + groupName + ")");
break;
}
}
return groupName;
}
else
{
if (isOpen())
{
LATAN_ERROR(Io, "file '" + name_ + "' is not opened in read mode");
}
else
{
LATAN_ERROR(Io, "file '" + name_ + "' is not opened");
}
return "";
}
}

View File

@ -0,0 +1,73 @@
/*
* Hdf5File.hpp, part of LatAnalyze 3
*
* Copyright (C) 2013 - 2020 Antonin Portelli, Matt Spraggs
*
* LatAnalyze 3 is free software: you can redistribute it and/or modify
* it under the terms of the GNU General Public License as published by
* the Free Software Foundation, either version 3 of the License, or
* (at your option) any later version.
*
* LatAnalyze 3 is distributed in the hope that it will be useful,
* but WITHOUT ANY WARRANTY; without even the implied warranty of
* MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
* GNU General Public License for more details.
*
* You should have received a copy of the GNU General Public License
* along with LatAnalyze 3. If not, see <http://www.gnu.org/licenses/>.
*/
#ifndef Latan_Hdf5File_hpp_
#define Latan_Hdf5File_hpp_
#include <LatAnalyze/Global.hpp>
#include <LatAnalyze/Io/File.hpp>
#include <LatAnalyze/Core/Mat.hpp>
#include <LatAnalyze/Statistics/MatSample.hpp>
#include <H5Cpp.h>
BEGIN_LATAN_NAMESPACE
#ifndef H5_NO_NAMESPACE
#define H5NS H5
#endif
/******************************************************************************
* HDF5 datafile class *
******************************************************************************/
class Hdf5File: public File
{
public:
// constructors
Hdf5File(void);
Hdf5File(const std::string &name, const unsigned int mode);
// destructor
virtual ~Hdf5File(void);
// access
virtual void save(const DMat &m, const std::string &name);
virtual void save(const DSample &ds, const std::string &name);
virtual void save(const DMatSample &ms, const std::string &name);
// read first name
virtual std::string getFirstName(void);
// tests
virtual bool isOpen(void) const;
// IO
virtual void close(void);
virtual void open(const std::string &name, const unsigned int mode);
private:
// IO
std::string getFirstGroupName(void);
virtual std::string load(const std::string &name = "");
void load(DMat &m, const H5NS::DataSet &d);
void load(DSample &ds, const H5NS::DataSet &d);
void load(DMatSample &s, const H5NS::DataSet &d);
// check name for forbidden characters
static size_t nameOffset(const std::string &name);
private:
// file name
std::unique_ptr<H5NS::H5File> h5File_{nullptr};
};
END_LATAN_NAMESPACE
#endif // Latan_Hdf5File_hpp_

51
lib/LatAnalyze/Io/Io.cpp Normal file
View File

@ -0,0 +1,51 @@
/*
* Io.cpp, part of LatAnalyze 3
*
* Copyright (C) 2013 - 2020 Antonin Portelli
*
* LatAnalyze 3 is free software: you can redistribute it and/or modify
* it under the terms of the GNU General Public License as published by
* the Free Software Foundation, either version 3 of the License, or
* (at your option) any later version.
*
* LatAnalyze 3 is distributed in the hope that it will be useful,
* but WITHOUT ANY WARRANTY; without even the implied warranty of
* MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
* GNU General Public License for more details.
*
* You should have received a copy of the GNU General Public License
* along with LatAnalyze 3. If not, see <http://www.gnu.org/licenses/>.
*/
#include <LatAnalyze/Io/Io.hpp>
#include <LatAnalyze/includes.hpp>
#include <LatAnalyze/Io/AsciiFile.hpp>
#include <LatAnalyze/Io/Hdf5File.hpp>
using namespace std;
using namespace Latan;
string Io::getFirstName(const string &fileName)
{
std::unique_ptr<File> file = open(fileName);
return file->getFirstName();
}
unique_ptr<File> Io::open(const std::string &fileName, const unsigned int mode)
{
string ext = extension(fileName);
if ((ext == "dat") or (ext == "sample") or (ext == "seed"))
{
return unique_ptr<File>(new AsciiFile(fileName, mode));
}
else if (ext == "h5")
{
return unique_ptr<File>(new Hdf5File(fileName, mode));
}
else
{
LATAN_ERROR(Io, "unknown file extension '" + ext + "'");
}
}

100
lib/LatAnalyze/Io/Io.hpp Normal file
View File

@ -0,0 +1,100 @@
/*
* Io.hpp, part of LatAnalyze 3
*
* Copyright (C) 2013 - 2020 Antonin Portelli
*
* LatAnalyze 3 is free software: you can redistribute it and/or modify
* it under the terms of the GNU General Public License as published by
* the Free Software Foundation, either version 3 of the License, or
* (at your option) any later version.
*
* LatAnalyze 3 is distributed in the hope that it will be useful,
* but WITHOUT ANY WARRANTY; without even the implied warranty of
* MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
* GNU General Public License for more details.
*
* You should have received a copy of the GNU General Public License
* along with LatAnalyze 3. If not, see <http://www.gnu.org/licenses/>.
*/
#ifndef Latan_Io_hpp_
#define Latan_Io_hpp_
#include <LatAnalyze/Global.hpp>
#include <LatAnalyze/Io/File.hpp>
BEGIN_LATAN_NAMESPACE
/******************************************************************************
* Static IO functions *
******************************************************************************/
class Io
{
public:
template <typename IoT, typename FileType>
static IoT load(const std::string &fileName, const std::string &name = "");
template <typename IoT>
static IoT load(const std::string &fileName, const std::string &name = "");
template <typename IoT, typename FileType>
static void save(const IoT &data, const std::string &fileName,
const unsigned int mode = File::Mode::write,
const std::string &name = "");
template <typename IoT>
static void save(const IoT &data, const std::string &fileName,
const unsigned int mode = File::Mode::write,
const std::string &name = "");
template <typename FileType>
static std::string getFirstName(const std::string &fileName);
static std::string getFirstName(const std::string &fileName);
static std::unique_ptr<File> open(const std::string &fileName,
const unsigned int mode = File::Mode::read);
};
// template implementation /////////////////////////////////////////////////////
template <typename IoT, typename FileType>
IoT Io::load(const std::string &fileName, const std::string &name)
{
FileType file(fileName, File::Mode::read);
return file.template read<IoT>(name);
}
template <typename IoT>
IoT Io::load(const std::string &fileName, const std::string &name)
{
std::unique_ptr<File> file = open(fileName);
return file->read<IoT>(name);
}
template <typename IoT, typename FileType>
void Io::save(const IoT &data, const std::string &fileName,
const unsigned int mode, const std::string &name)
{
FileType file(fileName, mode);
std::string realName = (name.empty()) ? fileName : name;
file.save(data, realName);
}
template <typename IoT>
void Io::save(const IoT &data, const std::string &fileName,
const unsigned int mode, const std::string &name)
{
std::unique_ptr<File> file = open(fileName, mode);
std::string realName = (name.empty()) ? fileName : name;
file->save(data, realName);
}
template <typename FileType>
std::string Io::getFirstName(const std::string &fileName)
{
FileType file(fileName, File::Mode::read);
return file.getFirstName();
}
END_LATAN_NAMESPACE
#endif // Latan_Io_hpp_

View File

@ -0,0 +1,50 @@
/*
* IoObject.hpp, part of LatAnalyze 3
*
* Copyright (C) 2013 - 2020 Antonin Portelli
*
* LatAnalyze 3 is free software: you can redistribute it and/or modify
* it under the terms of the GNU General Public License as published by
* the Free Software Foundation, either version 3 of the License, or
* (at your option) any later version.
*
* LatAnalyze 3 is distributed in the hope that it will be useful,
* but WITHOUT ANY WARRANTY; without even the implied warranty of
* MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
* GNU General Public License for more details.
*
* You should have received a copy of the GNU General Public License
* along with LatAnalyze 3. If not, see <http://www.gnu.org/licenses/>.
*/
#ifndef Latan_IoObject_hpp_
#define Latan_IoObject_hpp_
#include <LatAnalyze/Global.hpp>
BEGIN_LATAN_NAMESPACE
/******************************************************************************
* Abstract class for IO objects *
******************************************************************************/
class IoObject
{
public:
// conserve order for datafile retro-compatibility!
enum class IoType: short int
{
noType = 0,
dMat = 1,
dMatSample = 2,
dSample = 3
};
public:
// destructor
virtual ~IoObject(void) = default;
// access
virtual IoType getType(void) const = 0;
};
END_LATAN_NAMESPACE
#endif // Latan_IoObject_hpp_

File diff suppressed because it is too large Load Diff

File diff suppressed because it is too large Load Diff

View File

@ -0,0 +1,84 @@
/*
* XmlReader.cpp, part of LatAnalyze
*
* Copyright (C) 2013 - 2020 Antonin Portelli
*
* LatAnalyze is free software: you can redistribute it and/or modify
* it under the terms of the GNU General Public License as published by
* the Free Software Foundation, either version 3 of the License, or
* (at your option) any later version.
*
* LatAnalyze is distributed in the hope that it will be useful,
* but WITHOUT ANY WARRANTY; without even the implied warranty of
* MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
* GNU General Public License for more details.
*
* You should have received a copy of the GNU General Public License
* along with LatAnalyze. If not, see <http://www.gnu.org/licenses/>.
*/
#include <LatAnalyze/Io/XmlReader.hpp>
#include <LatAnalyze/includes.hpp>
using namespace std;
using namespace Latan;
XmlParsing::XmlParsing(string msg, string loc)
: runtime_error("XML reader error: " + msg + " (" + loc + ")")
{}
/******************************************************************************
* XmlReader implementation *
******************************************************************************/
// constructor /////////////////////////////////////////////////////////////////
XmlReader::XmlReader(const string &fileName)
{
open(fileName);
}
// IO //////////////////////////////////////////////////////////////////////////
void XmlReader::open(const string &fileName)
{
name_ = fileName;
doc_.LoadFile(name_.c_str());
if (doc_.Error())
{
string errMsg;
if (doc_.ErrorStr())
{
errMsg = doc_.ErrorStr();
}
LATAN_ERROR(Io, "cannot open file " + fileName + " [tinyxml2 code "
+ strFrom(doc_.ErrorID()) + ": " + errMsg + "]");
}
root_ = doc_.RootElement();
}
// XML structure access ////////////////////////////////////////////////////////
const XmlNode * XmlReader::getNextNode(const XmlNode *node,
const string &nodeName)
{
const char *pt = (nodeName.empty()) ? nullptr : nodeName.c_str();
if (node)
{
return node->NextSiblingElement(pt);
}
else
{
return nullptr;
}
}
const XmlNode * XmlReader::getNextSameNode(const XmlNode *node)
{
if (node)
{
return getNextNode(node, node->Name());
}
else
{
return nullptr;
}
}

View File

@ -0,0 +1,235 @@
/*
* XmlReader.hpp, part of LatAnalyze
*
* Copyright (C) 2013 - 2020 Antonin Portelli
*
* LatAnalyze is free software: you can redistribute it and/or modify
* it under the terms of the GNU General Public License as published by
* the Free Software Foundation, either version 3 of the License, or
* (at your option) any later version.
*
* LatAnalyze is distributed in the hope that it will be useful,
* but WITHOUT ANY WARRANTY; without even the implied warranty of
* MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
* GNU General Public License for more details.
*
* You should have received a copy of the GNU General Public License
* along with LatAnalyze. If not, see <http://www.gnu.org/licenses/>.
*/
#ifndef LatAnalyze_XmlReader_hpp_
#define LatAnalyze_XmlReader_hpp_
#include <LatAnalyze/Global.hpp>
#pragma GCC diagnostic push
#pragma GCC diagnostic ignored "-Wsign-conversion"
#include <LatAnalyze/Io/Xml/tinyxml2.hpp>
#pragma GCC diagnostic pop
BEGIN_LATAN_NAMESPACE
// parsing exception
class XmlParsing: public std::runtime_error
{
public:
XmlParsing(std::string msg, std::string loc);
};
/******************************************************************************
* XML parameter file reader *
******************************************************************************/
typedef tinyxml2::XMLElement XmlNode;
class XmlReader
{
public:
// constructor
XmlReader(void) = default;
explicit XmlReader(const std::string &fileName);
// destructor
virtual ~XmlReader(void) = default;
// IO
void open(const std::string &fileName);
// XML structure access
template <typename... Strs>
static const XmlNode * getFirstNode(const XmlNode *startNode,
const std::string &nodeName,
Strs... nodeNames);
template <typename... Strs>
const XmlNode * getFirstNode(const std::string &nodeName,
Strs... nodeNames) const;
static const XmlNode * getNextNode(const XmlNode *node,
const std::string &nodeName = "");
static const XmlNode * getNextSameNode(const XmlNode *node);
template <typename T>
static T getValue(const XmlNode *node);
template <typename T, typename... Strs>
static T getFirstValue(const XmlNode *startNode,
const std::string &nodeName, Strs... nodeNames);
template <typename T, typename... Strs>
T getFirstValue(const std::string &nodeName, Strs... nodeNames) const;
template <typename T, typename... Strs>
static std::vector<T> getAllValues(const XmlNode *startNode,
const std::string &nodeName,
Strs... nodeNames);
template <typename T, typename... Strs>
std::vector<T> getAllValues(const std::string &nodeName,
Strs... nodeNames) const;
// XML structure test
template <typename... Strs>
static bool hasNode(const XmlNode *startNode, const std::string &nodeName,
Strs... nodeNames);
template <typename... Strs>
bool hasNode(const std::string &nodeName, Strs... nodeNames) const;
private:
std::string name_;
tinyxml2::XMLDocument doc_;
XmlNode *root_{nullptr};
};
/******************************************************************************
* XmlReader template implementation *
******************************************************************************/
// XML structure access ////////////////////////////////////////////////////////
template <typename... Strs>
const XmlNode * XmlReader::getFirstNode(const XmlNode *startNode,
const std::string &nodeName,
Strs... nodeNames)
{
static_assert(static_or<std::is_assignable<std::string, Strs>::value...>::value,
"getFirstNode arguments are not compatible with std::string");
const unsigned int nName = sizeof...(nodeNames) + 1;
const std::string name[] = {nodeName, nodeNames...};
const XmlNode *node = startNode;
if (!node)
{
LATAN_ERROR(Io, "root node is null, no XML file opened");
}
for (unsigned int i = 0; i < nName; ++i)
{
node = node->FirstChildElement(name[i].c_str());
if (!node)
{
LATAN_ERROR(Io, "XML node " + name[i] + " not found");
}
}
return node;
}
template <typename... Strs>
const XmlNode * XmlReader::getFirstNode(const std::string &nodeName,
Strs... nodeNames) const
{
if (!root_)
{
LATAN_ERROR(Io, "root node is null, no XML file opened");
}
return getFirstNode(root_, nodeName, nodeNames...);
}
template <typename T>
T XmlReader::getValue(const XmlNode *node)
{
if (node)
{
if (node->GetText())
{
return Latan::strTo<T>(node->GetText());
}
else
{
return T();
}
}
else
{
return T();
}
}
template <typename T, typename... Strs>
T XmlReader::getFirstValue(const XmlNode *startNode,
const std::string &nodeName, Strs... nodeNames)
{
const XmlNode *node = getFirstNode(startNode, nodeName, nodeNames...);
return getValue<T>(node);
}
template <typename T, typename... Strs>
T XmlReader::getFirstValue(const std::string &nodeName, Strs... nodeNames) const
{
return getFirstValue<T>(root_, nodeName, nodeNames...);
}
template <typename T, typename... Strs>
std::vector<T> XmlReader::getAllValues(const XmlNode *startNode,
const std::string &nodeName,
Strs... nodeNames)
{
const XmlNode *node = getFirstNode(startNode, nodeName, nodeNames...);
std::vector<T> value;
while (node)
{
value.push_back(getValue<T>(node));
node = getNextSameNode(node);
}
return value;
}
template <typename T, typename... Strs>
std::vector<T> XmlReader::getAllValues(const std::string &nodeName,
Strs... nodeNames) const
{
return getAllValues<T>(root_, nodeName, nodeNames...);
}
// XML structure test //////////////////////////////////////////////////////////
template <typename... Strs>
bool XmlReader::hasNode(const XmlNode *startNode, const std::string &nodeName,
Strs... nodeNames)
{
static_assert(static_or<std::is_assignable<std::string, Strs>::value...>::value,
"hasNode arguments are not compatible with std::string");
const unsigned int nName = sizeof...(nodeNames) + 1;
const std::string name[] = {nodeName, nodeNames...};
const XmlNode *node = startNode;
if (!node)
{
LATAN_ERROR(Io, "root node is null, no XML file opened");
}
for (unsigned int i = 0; i < nName; ++i)
{
node = node->FirstChildElement(name[i].c_str());
if (!node)
{
return false;
}
}
return true;
}
template <typename... Strs>
bool XmlReader::hasNode(const std::string &nodeName, Strs... nodeNames) const
{
if (!root_)
{
LATAN_ERROR(Io, "root node is null, no XML file opened");
}
return hasNode(root_, nodeName, nodeNames...);
}
END_LATAN_NAMESPACE
#endif // LatAnalyze_XmlReader_hpp_

151
lib/LatAnalyze/Makefile.am Normal file
View File

@ -0,0 +1,151 @@
COM_CXXFLAGS = -Wall
if CXX_GNU
COM_CXXFLAGS += -W -pedantic -Wno-deprecated-declarations
else
if CXX_INTEL
COM_CXXFLAGS += -wd1682
endif
endif
include eigen_files.mk
AM_LFLAGS = -olex.yy.c
AM_YFLAGS = -y -d -Wno-yacc -Wno-deprecated
lib_LTLIBRARIES = libLatAnalyze.la
noinst_LTLIBRARIES = libLexers.la
libLexers_la_SOURCES = Io/AsciiLexer.lpp Core/MathLexer.lpp
if CXX_GNU
libLexers_la_CXXFLAGS = $(COM_CXXFLAGS) -Wno-unused-parameter -Wno-unused-function -Wno-deprecated-register
else
libLexers_la_CXXFLAGS = $(COM_CXXFLAGS)
endif
libLatAnalyze_la_SOURCES = \
includes.hpp \
Global.cpp \
Core/Exceptions.cpp \
Core/Mat.cpp \
Core/Math.cpp \
Core/MathInterpreter.cpp \
Core/MathParser.ypp \
Core/OptParser.cpp \
Core/Plot.cpp \
Core/ThreadPool.cpp \
Core/Utilities.cpp \
Functional/CompiledFunction.cpp \
Functional/CompiledModel.cpp \
Functional/Function.cpp \
Functional/Model.cpp \
Functional/TabFunction.cpp \
Io/AsciiFile.cpp \
Io/AsciiParser.ypp \
Io/BinReader.cpp \
Io/File.cpp \
Io/Hdf5File.cpp \
Io/Io.cpp \
Io/XmlReader.cpp \
Io/Xml/tinyxml2.cpp \
Numerical/Derivative.cpp \
Numerical/DWT.cpp \
Numerical/DWTFilters.cpp \
Numerical/GslFFT.cpp \
Numerical/GslHybridRootFinder.cpp\
Numerical/GslMinimizer.cpp \
Numerical/GslQagsIntegrator.cpp \
Numerical/Minimizer.cpp \
Numerical/RootFinder.cpp \
Numerical/Solver.cpp \
Physics/CorrelatorFitter.cpp \
Physics/DataFilter.cpp \
Physics/EffectiveMass.cpp \
Statistics/FitInterface.cpp \
Statistics/Histogram.cpp \
Statistics/Random.cpp \
Statistics/StatArray.cpp \
Statistics/XYSampleData.cpp \
Statistics/XYStatData.cpp \
../config.h
libLatAnalyze_ladir = $(pkgincludedir)
HPPFILES = \
Global.hpp \
Core/Eigen.hpp \
Core/EigenPlugin.hpp \
Core/Exceptions.hpp \
Core/Mat.hpp \
Core/Math.hpp \
Core/MathInterpreter.hpp \
Core/OptParser.hpp \
Core/ParserState.hpp \
Core/Plot.hpp \
Core/ThreadPool.hpp \
Core/stdincludes.hpp \
Core/Utilities.hpp \
Functional/CompiledFunction.hpp \
Functional/CompiledModel.hpp \
Functional/Function.hpp \
Functional/Model.hpp \
Functional/TabFunction.hpp \
Io/AsciiFile.hpp \
Io/BinReader.hpp \
Io/File.hpp \
Io/Hdf5File.hpp \
Io/Io.hpp \
Io/IoObject.hpp \
Io/XmlReader.hpp \
Numerical/Derivative.hpp \
Numerical/DWT.hpp \
Numerical/DWTFilters.hpp \
Numerical/FFT.hpp \
Numerical/GslFFT.hpp \
Numerical/GslHybridRootFinder.hpp\
Numerical/GslMinimizer.hpp \
Numerical/GslQagsIntegrator.hpp \
Numerical/Integrator.hpp \
Numerical/Minimizer.hpp \
Numerical/RootFinder.hpp \
Numerical/Solver.hpp \
Physics/CorrelatorFitter.hpp \
Physics/DataFilter.hpp \
Physics/EffectiveMass.hpp \
Statistics/Dataset.hpp \
Statistics/FitInterface.hpp \
Statistics/Histogram.hpp \
Statistics/MatSample.hpp \
Statistics/Random.hpp \
Statistics/StatArray.hpp \
Statistics/XYSampleData.hpp \
Statistics/XYStatData.hpp
if HAVE_MINUIT
libLatAnalyze_la_SOURCES += Numerical/MinuitMinimizer.cpp
HPPFILES += Numerical/MinuitMinimizer.hpp
endif
if HAVE_NLOPT
libLatAnalyze_la_SOURCES += Numerical/NloptMinimizer.cpp
HPPFILES += Numerical/NloptMinimizer.hpp
endif
libLatAnalyze_la_CXXFLAGS = $(COM_CXXFLAGS)
libLatAnalyze_la_LIBADD = libLexers.la
if HAVE_AM_MINOR_LE_11
Io/AsciiParser.hpp: Io/AsciiParser.ypp
$(AM_V_YACC) $(YACC) -o Io/AsciiParser.cpp --defines=Io/AsciiParser.hpp $<
Core/MathParser.hpp: Core/MathParser.ypp
$(AM_V_YACC) $(YACC) -o Core/MathParser.cpp --defines=Core/MathParser.hpp $<
endif
BUILT_SOURCES = Io/AsciiParser.hpp Core/MathParser.hpp
CLEANFILES = \
Core/MathLexer.cpp \
Core/MathParser.cpp\
Core/MathParser.hpp\
Io/AsciiLexer.cpp \
Io/AsciiParser.cpp \
Io/AsciiParser.hpp
nobase_dist_pkginclude_HEADERS = $(HPPFILES) $(eigen_files) Io/Xml/tinyxml2.hpp
ACLOCAL_AMFLAGS = -I .buildutils/m4

View File

@ -0,0 +1,205 @@
/*
* DWT.cpp, part of LatAnalyze
*
* Copyright (C) 2013 - 2020 Antonin Portelli
*
* LatAnalyze is free software: you can redistribute it and/or modify
* it under the terms of the GNU General Public License as published by
* the Free Software Foundation, either version 3 of the License, or
* (at your option) any later version.
*
* LatAnalyze is distributed in the hope that it will be useful,
* but WITHOUT ANY WARRANTY; without even the implied warranty of
* MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
* GNU General Public License for more details.
*
* You should have received a copy of the GNU General Public License
* along with LatAnalyze. If not, see <http://www.gnu.org/licenses/>.
*/
#include <LatAnalyze/Numerical/DWT.hpp>
#include <LatAnalyze/includes.hpp>
using namespace std;
using namespace Latan;
/******************************************************************************
* DWT implementation *
******************************************************************************/
// constructor /////////////////////////////////////////////////////////////////
DWT::DWT(const DWTFilter &filter)
: filter_(filter)
{}
// convolution primitive ///////////////////////////////////////////////////////
template <typename MatType>
void filterConvolution(MatType &out, const MatType &data,
const std::vector<double> &filter, const Index offset)
{
Index n = data.rows(), nf = n*filter.size();
out.resizeLike(data);
out.fill(0.);
for (unsigned int i = 0; i < filter.size(); ++i)
{
FOR_MAT(out, j, k)
{
out(j, k) += filter[i]*data((j + i + nf - offset) % n, k);
}
}
}
void DWT::filterConvolution(DVec &out, const DVec &data,
const std::vector<double> &filter, const Index offset)
{
::filterConvolution(out, data, filter, offset);
}
void DWT::filterConvolution(DMat &out, const DMat &data,
const std::vector<double> &filter, const Index offset)
{
::filterConvolution(out, data, filter, offset);
}
// downsampling/upsampling primitives //////////////////////////////////////////
template <typename MatType>
void downsample(MatType &out, const MatType &in)
{
if (out.rows() < in.rows()/2)
{
LATAN_ERROR(Size, "output rows smaller than half the input vector rows");
}
if (out.cols() != in.cols())
{
LATAN_ERROR(Size, "output and input number of columns mismatch");
}
for (Index j = 0; j < in.cols(); j++)
for (Index i = 0; i < in.rows(); i += 2)
{
out(i/2, j) = in(i, j);
}
}
void DWT::downsample(DVec &out, const DVec &in)
{
::downsample(out, in);
}
void DWT::downsample(DMat &out, const DMat &in)
{
::downsample(out, in);
}
template <typename MatType>
void upsample(MatType &out, const MatType &in)
{
if (out.size() < 2*in.size())
{
LATAN_ERROR(Size, "output rows smaller than twice the input rows");
}
if (out.cols() != in.cols())
{
LATAN_ERROR(Size, "output and input number of columns mismatch");
}
out.block(0, 0, 2*in.size(), out.cols()).fill(0.);
for (Index j = 0; j < in.cols(); j++)
for (Index i = 0; i < in.size(); i ++)
{
out(2*i, j) = in(i, j);
}
}
void DWT::upsample(DVec &out, const DVec &in)
{
::upsample(out, in);
}
void DWT::upsample(DMat &out, const DMat &in)
{
::upsample(out, in);
}
// DWT /////////////////////////////////////////////////////////////////////////
std::vector<DWT::DWTLevel>
DWT::forward(const DVec &data, const unsigned int level) const
{
std::vector<DWTLevel> dwt(level);
DVec *finePt = const_cast<DVec *>(&data);
DVec tmp;
Index n = data.size(), o = filter_.fwdL.size()/2, minSize;
minSize = 1;
for (unsigned int l = 0; l < level; ++l) minSize *= 2;
if (n < minSize)
{
LATAN_ERROR(Size, "data vector too small for a " + strFrom(level)
+ "-level DWT (data size is " + strFrom(n) + ")");
}
for (unsigned int l = 0; l < level; ++l)
{
n /= 2;
dwt[l].first.resize(n);
dwt[l].second.resize(n);
filterConvolution(tmp, *finePt, filter_.fwdL, o);
downsample(dwt[l].first, tmp);
filterConvolution(tmp, *finePt, filter_.fwdH, o);
downsample(dwt[l].second, tmp);
finePt = &dwt[l].first;
}
return dwt;
}
DVec DWT::backward(const std::vector<DWTLevel>& dwt) const
{
unsigned int level = dwt.size();
Index n = dwt.back().second.size(), o = filter_.bwdL.size()/2 - 1;
DVec res, tmp, conv;
res = dwt.back().first;
for (int l = level - 2; l >= 0; --l)
{
n *= 2;
if (dwt[l].second.size() != n)
{
LATAN_ERROR(Size, "DWT result size mismatch");
}
}
n = dwt.back().second.size();
for (int l = level - 1; l >= 0; --l)
{
n *= 2;
tmp.resize(n);
upsample(tmp, res);
filterConvolution(conv, tmp, filter_.bwdL, o);
res = conv;
upsample(tmp, dwt[l].second);
filterConvolution(conv, tmp, filter_.bwdH, o);
res += conv;
}
return res;
}
// concatenate levels //////////////////////////////////////////////////////////
DVec DWT::concat(const std::vector<DWTLevel> &dwt, const int maxLevel, const bool dropLow)
{
unsigned int level = ((maxLevel >= 0) ? (maxLevel + 1) : dwt.size());
Index nlast = dwt[level - 1].first.size();
Index n = 2*dwt.front().first.size() - ((dropLow) ? nlast : 0);
Index pt = n, nl;
DVec res(n);
for (unsigned int l = 0; l < level; ++l)
{
nl = dwt[l].second.size();
pt -= nl;
res.segment(pt, nl) = dwt[l].second;
}
if (!dropLow)
{
res.segment(0, nl) = dwt[level-1].first;
}
return res;
}

View File

@ -0,0 +1,62 @@
/*
* DWT.hpp, part of LatAnalyze
*
* Copyright (C) 2013 - 2020 Antonin Portelli
*
* LatAnalyze is free software: you can redistribute it and/or modify
* it under the terms of the GNU General Public License as published by
* the Free Software Foundation, either version 3 of the License, or
* (at your option) any later version.
*
* LatAnalyze is distributed in the hope that it will be useful,
* but WITHOUT ANY WARRANTY; without even the implied warranty of
* MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
* GNU General Public License for more details.
*
* You should have received a copy of the GNU General Public License
* along with LatAnalyze. If not, see <http://www.gnu.org/licenses/>.
*/
#ifndef Latan_DWT_hpp_
#define Latan_DWT_hpp_
#include <LatAnalyze/Global.hpp>
#include <LatAnalyze/Numerical/DWTFilters.hpp>
#include <LatAnalyze/Core/Mat.hpp>
BEGIN_LATAN_NAMESPACE
/******************************************************************************
* Discrete wavelet transform class *
******************************************************************************/
class DWT
{
public:
typedef std::pair<DVec, DVec> DWTLevel;
public:
// constructor
DWT(const DWTFilter &filter);
// destructor
virtual ~DWT(void) = default;
// convolution primitive
static void filterConvolution(DVec &out, const DVec &data,
const std::vector<double> &filter, const Index offset);
static void filterConvolution(DMat &out, const DMat &data,
const std::vector<double> &filter, const Index offset);
// downsampling/upsampling primitives
static void downsample(DVec &out, const DVec &in);
static void downsample(DMat &out, const DMat &in);
static void upsample(DVec &out, const DVec &in);
static void upsample(DMat &out, const DMat &in);
// DWT
std::vector<DWTLevel> forward(const DVec &data, const unsigned int level) const;
DVec backward(const std::vector<DWTLevel>& dwt) const;
// concatenate levels
static DVec concat(const std::vector<DWTLevel>& dwt, const int maxLevel = -1, const bool dropLow = false);
private:
DWTFilter filter_;
};
END_LATAN_NAMESPACE
#endif // Latan_DWT_hpp_

View File

@ -0,0 +1,528 @@
/*
* DWTFilters.cpp, part of LatAnalyze
*
* Copyright (C) 2013 - 2020 Antonin Portelli
*
* LatAnalyze is free software: you can redistribute it and/or modify
* it under the terms of the GNU General Public License as published by
* the Free Software Foundation, either version 3 of the License, or
* (at your option) any later version.
*
* LatAnalyze is distributed in the hope that it will be useful,
* but WITHOUT ANY WARRANTY; without even the implied warranty of
* MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
* GNU General Public License for more details.
*
* You should have received a copy of the GNU General Public License
* along with LatAnalyze. If not, see <http://www.gnu.org/licenses/>.
*/
#include <LatAnalyze/Numerical/DWTFilters.hpp>
#include <LatAnalyze/includes.hpp>
// cf. http://wavelets.pybytes.com
// *here we implement the reverse filters more convenient for convolutions*
using namespace std;
using namespace Latan;
#define FILTDICT(x) {#x, &DWTFilters::x}
std::map<std::string, const DWTFilter *> DWTFilters::fromName = {
FILTDICT(haar),
FILTDICT(db2),
FILTDICT(db3),
FILTDICT(db4),
FILTDICT(db5),
FILTDICT(db6),
FILTDICT(bior13),
FILTDICT(bior15),
FILTDICT(bior22),
FILTDICT(bior24),
FILTDICT(bior31),
FILTDICT(bior33),
FILTDICT(bior35)
};
DWTFilter DWTFilters::haar = {
// fwdL
{0.7071067811865476,
0.7071067811865476},
// fwdH
{0.7071067811865476,
-0.7071067811865476},
// bwdL
{0.7071067811865476,
0.7071067811865476},
// bwdH
{-0.7071067811865476,
0.7071067811865476}
};
DWTFilter DWTFilters::db2 = {
// fwdL
{0.48296291314469025,
0.836516303737469,
0.22414386804185735,
-0.12940952255092145},
// fwdH
{-0.12940952255092145,
-0.22414386804185735,
0.836516303737469,
-0.48296291314469025},
// bwdL
{-0.12940952255092145,
0.22414386804185735,
0.836516303737469,
0.48296291314469025},
// bwdH
{-0.48296291314469025,
0.836516303737469,
-0.22414386804185735,
-0.12940952255092145}
};
DWTFilter DWTFilters::db3 = {
// fwdL
{0.3326705529509569,
0.8068915093133388,
0.4598775021193313,
-0.13501102001039084,
-0.08544127388224149,
0.035226291882100656},
// fwdH
{0.035226291882100656,
0.08544127388224149,
-0.13501102001039084,
-0.4598775021193313,
0.8068915093133388,
-0.3326705529509569},
// bwdL
{0.035226291882100656,
-0.08544127388224149,
-0.13501102001039084,
0.4598775021193313,
0.8068915093133388,
0.3326705529509569},
// bwdH
{-0.3326705529509569,
0.8068915093133388,
-0.4598775021193313,
-0.13501102001039084,
0.08544127388224149,
0.035226291882100656}
};
DWTFilter DWTFilters::db4 = {
// fwdL
{0.23037781330885523,
0.7148465705525415,
0.6308807679295904,
-0.02798376941698385,
-0.18703481171888114,
0.030841381835986965,
0.032883011666982945,
-0.010597401784997278},
// fwdH
{-0.010597401784997278,
-0.032883011666982945,
0.030841381835986965,
0.18703481171888114,
-0.02798376941698385,
-0.6308807679295904,
0.7148465705525415,
-0.23037781330885523},
// bwdL
{-0.010597401784997278,
0.032883011666982945,
0.030841381835986965,
-0.18703481171888114,
-0.02798376941698385,
0.6308807679295904,
0.7148465705525415,
0.23037781330885523},
// bwdH
{-0.23037781330885523,
0.7148465705525415,
-0.6308807679295904,
-0.02798376941698385,
0.18703481171888114,
0.030841381835986965,
-0.032883011666982945,
-0.010597401784997278}
};
DWTFilter DWTFilters::db5 = {
// fwdL
{0.160102397974125,
0.6038292697974729,
0.7243085284385744,
0.13842814590110342,
-0.24229488706619015,
-0.03224486958502952,
0.07757149384006515,
-0.006241490213011705,
-0.012580751999015526,
0.003335725285001549},
// fwdH
{0.003335725285001549,
0.012580751999015526,
-0.006241490213011705,
-0.07757149384006515,
-0.03224486958502952,
0.24229488706619015,
0.13842814590110342,
-0.7243085284385744,
0.6038292697974729,
-0.160102397974125},
// bwdL
{0.003335725285001549,
-0.012580751999015526,
-0.006241490213011705,
0.07757149384006515,
-0.03224486958502952,
-0.24229488706619015,
0.13842814590110342,
0.7243085284385744,
0.6038292697974729,
0.160102397974125},
// bwdH
{-0.160102397974125,
0.6038292697974729,
-0.7243085284385744,
0.13842814590110342,
0.24229488706619015,
-0.03224486958502952,
-0.07757149384006515,
-0.006241490213011705,
0.012580751999015526,
0.003335725285001549}
};
DWTFilter DWTFilters::db6 = {
// fwdL
{0.11154074335008017,
0.4946238903983854,
0.7511339080215775,
0.3152503517092432,
-0.22626469396516913,
-0.12976686756709563,
0.09750160558707936,
0.02752286553001629,
-0.031582039318031156,
0.0005538422009938016,
0.004777257511010651,
-0.00107730108499558},
// fwdH
{-0.00107730108499558,
-0.004777257511010651,
0.0005538422009938016,
0.031582039318031156,
0.02752286553001629,
-0.09750160558707936,
-0.12976686756709563,
0.22626469396516913,
0.3152503517092432,
-0.7511339080215775,
0.4946238903983854,
-0.11154074335008017},
// bwdL
{-0.00107730108499558,
0.004777257511010651,
0.0005538422009938016,
-0.031582039318031156,
0.02752286553001629,
0.09750160558707936,
-0.12976686756709563,
-0.22626469396516913,
0.3152503517092432,
0.7511339080215775,
0.4946238903983854,
0.11154074335008017},
// bwdH
{-0.11154074335008017,
0.4946238903983854,
-0.7511339080215775,
0.3152503517092432,
0.22626469396516913,
-0.12976686756709563,
-0.09750160558707936,
0.02752286553001629,
0.031582039318031156,
0.0005538422009938016,
-0.004777257511010651,
-0.00107730108499558}
};
DWTFilter DWTFilters::bior13 = {
// fwdL
{-0.08838834764831845,
0.08838834764831845,
0.7071067811865476,
0.7071067811865476,
0.08838834764831845,
-0.08838834764831845},
// fwdH
{0.0,
0.0,
0.7071067811865476,
-0.7071067811865476,
0.0,
0.0},
// bwdL
{0.0,
0.0,
0.7071067811865476,
0.7071067811865476,
0.0,
0.0},
// bwdH
{0.08838834764831845,
0.08838834764831845,
-0.7071067811865476,
0.7071067811865476,
-0.08838834764831845,
-0.08838834764831845}
};
DWTFilter DWTFilters::bior15 = {
// fwdL
{0.01657281518405971,
-0.01657281518405971,
-0.12153397801643787,
0.12153397801643787,
0.7071067811865476,
0.7071067811865476,
0.12153397801643787,
-0.12153397801643787,
-0.01657281518405971,
0.01657281518405971},
// fwdH
{0.0,
0.0,
0.0,
0.0,
0.7071067811865476,
-0.7071067811865476,
0.0,
0.0,
0.0,
0.0},
// bwdL
{0.0,
0.0,
0.0,
0.0,
0.7071067811865476,
0.7071067811865476,
0.0,
0.0,
0.0,
0.0},
// bwdH
{-0.01657281518405971,
-0.01657281518405971,
0.12153397801643787,
0.12153397801643787,
-0.7071067811865476,
0.7071067811865476,
-0.12153397801643787,
-0.12153397801643787,
0.01657281518405971,
0.01657281518405971}
};
DWTFilter DWTFilters::bior22 = {
// fwdL
{-0.1767766952966369,
0.3535533905932738,
1.0606601717798214,
0.3535533905932738,
-0.1767766952966369,
0.0},
// fwdH
{0.0,
0.0,
0.3535533905932738,
-0.7071067811865476,
0.3535533905932738,
0.0},
// bwdL
{0.0,
0.0,
0.3535533905932738,
0.7071067811865476,
0.3535533905932738,
0.0},
// bwdH
{0.1767766952966369,
0.3535533905932738,
-1.0606601717798214,
0.3535533905932738,
0.1767766952966369,
0.0}
};
DWTFilter DWTFilters::bior24 = {
// fwdL
{0.03314563036811942,
-0.06629126073623884,
-0.1767766952966369,
0.4198446513295126,
0.9943689110435825,
0.4198446513295126,
-0.1767766952966369,
-0.06629126073623884,
0.03314563036811942,
0.0},
// fwdH
{0.0,
0.0,
0.0,
0.0,
0.3535533905932738,
-0.7071067811865476,
0.3535533905932738,
0.0,
0.0,
0.0},
// bwdL
{0.0,
0.0,
0.0,
0.0,
0.3535533905932738,
0.7071067811865476,
0.3535533905932738,
0.0,
0.0,
0.0},
// bwdH
{-0.03314563036811942,
-0.06629126073623884,
0.1767766952966369,
0.4198446513295126,
-0.9943689110435825,
0.4198446513295126,
0.1767766952966369,
-0.06629126073623884,
-0.03314563036811942,
0.0}
};
DWTFilter DWTFilters::bior31 = {
// fwdL
{-0.3535533905932738,
1.0606601717798214,
1.0606601717798214,
-0.3535533905932738},
// fwdH
{0.1767766952966369,
-0.5303300858899107,
0.5303300858899107,
-0.1767766952966369},
// bwdL
{0.1767766952966369,
0.5303300858899107,
0.5303300858899107,
0.1767766952966369},
// bwdH
{0.3535533905932738,
1.0606601717798214,
-1.0606601717798214,
-0.3535533905932738}
};
DWTFilter DWTFilters::bior33 = {
// fwdL
{0.06629126073623884,
-0.19887378220871652,
-0.15467960838455727,
0.9943689110435825,
0.9943689110435825,
-0.15467960838455727,
-0.19887378220871652,
0.06629126073623884},
// fwdH
{0.0,
0.0,
0.1767766952966369,
-0.5303300858899107,
0.5303300858899107,
-0.1767766952966369,
0.0,
0.0},
// bwdL
{0.0,
0.0,
0.1767766952966369,
0.5303300858899107,
0.5303300858899107,
0.1767766952966369,
0.0,
0.0},
// bwdH
{-0.06629126073623884,
-0.19887378220871652,
0.15467960838455727,
0.9943689110435825,
-0.9943689110435825,
-0.15467960838455727,
0.19887378220871652,
0.06629126073623884}
};
DWTFilter DWTFilters::bior35 = {
// fwdL
{-0.013810679320049757,
0.04143203796014927,
0.052480581416189075,
-0.26792717880896527,
-0.07181553246425874,
0.966747552403483,
0.966747552403483,
-0.07181553246425874,
-0.26792717880896527,
0.052480581416189075,
0.04143203796014927,
-0.013810679320049757},
// fwdH
{0.0,
0.0,
0.0,
0.0,
0.1767766952966369,
-0.5303300858899107,
0.5303300858899107,
-0.1767766952966369,
0.0,
0.0,
0.0,
0.0},
// bwdL
{0.0,
0.0,
0.0,
0.0,
0.1767766952966369,
0.5303300858899107,
0.5303300858899107,
0.1767766952966369,
0.0,
0.0,
0.0,
0.0},
// bwdH
{0.013810679320049757,
0.04143203796014927,
-0.052480581416189075,
-0.26792717880896527,
0.07181553246425874,
0.966747552403483,
-0.966747552403483,
-0.07181553246425874,
0.26792717880896527,
0.052480581416189075,
-0.04143203796014927,
-0.013810679320049757}
};

View File

@ -0,0 +1,53 @@
/*
* DWTFilters.hpp, part of LatAnalyze
*
* Copyright (C) 2013 - 2020 Antonin Portelli
*
* LatAnalyze is free software: you can redistribute it and/or modify
* it under the terms of the GNU General Public License as published by
* the Free Software Foundation, either version 3 of the License, or
* (at your option) any later version.
*
* LatAnalyze is distributed in the hope that it will be useful,
* but WITHOUT ANY WARRANTY; without even the implied warranty of
* MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
* GNU General Public License for more details.
*
* You should have received a copy of the GNU General Public License
* along with LatAnalyze. If not, see <http://www.gnu.org/licenses/>.
*/
#ifndef Latan_DWTFilters_hpp_
#define Latan_DWTFilters_hpp_
#include <LatAnalyze/Global.hpp>
BEGIN_LATAN_NAMESPACE
struct DWTFilter
{
const std::vector<double> fwdL, fwdH, bwdL, bwdH;
};
namespace DWTFilters
{
extern DWTFilter haar;
extern DWTFilter db2;
extern DWTFilter db3;
extern DWTFilter db4;
extern DWTFilter db5;
extern DWTFilter db6;
extern DWTFilter bior13;
extern DWTFilter bior15;
extern DWTFilter bior22;
extern DWTFilter bior24;
extern DWTFilter bior31;
extern DWTFilter bior33;
extern DWTFilter bior35;
extern std::map<std::string, const DWTFilter *> fromName;
}
END_LATAN_NAMESPACE
#endif // Latan_DWTFilters_hpp_

View File

@ -0,0 +1,234 @@
/*
* Derivative.cpp, part of LatAnalyze 3
*
* Copyright (C) 2013 - 2020 Antonin Portelli
*
* LatAnalyze 3 is free software: you can redistribute it and/or modify
* it under the terms of the GNU General Public License as published by
* the Free Software Foundation, either version 3 of the License, or
* (at your option) any later version.
*
* LatAnalyze 3 is distributed in the hope that it will be useful,
* but WITHOUT ANY WARRANTY; without even the implied warranty of
* MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
* GNU General Public License for more details.
*
* You should have received a copy of the GNU General Public License
* along with LatAnalyze 3. If not, see <http://www.gnu.org/licenses/>.
*/
#include <LatAnalyze/Numerical/Derivative.hpp>
#include <LatAnalyze/includes.hpp>
#include <LatAnalyze/Core/Math.hpp>
using namespace std;
using namespace Latan;
using namespace Math;
/******************************************************************************
* Derivative implementation *
******************************************************************************/
// constructor /////////////////////////////////////////////////////////////////
Derivative::Derivative(const DoubleFunction &f, const Index dir,
const double step)
: buffer_(new DVec(f.getNArg()))
{
setFunction(f);
setDir(dir);
setStep(step);
}
Derivative::Derivative(const DoubleFunction &f, const Index dir,
const Index order, const DVec &point, const double step)
: Derivative(f, dir, step)
{
setOrderAndPoint(order, point);
}
// access //////////////////////////////////////////////////////////////////////
Index Derivative::getDir(void) const
{
return dir_;
}
Index Derivative::getOrder(void) const
{
return order_;
}
Index Derivative::getNPoint(void) const
{
return point_.size();
}
double Derivative::getStep(void) const
{
return step_;
}
void Derivative::setDir(const Index dir)
{
dir_ = dir;
}
void Derivative::setFunction(const DoubleFunction &f)
{
f_ = f;
}
void Derivative::setOrderAndPoint(const Index order, const DVec &point)
{
if (order >= point.size())
{
LATAN_ERROR(Size, "derivative order is superior or equal to the number of point");
}
order_ = order;
point_ = point;
coefficient_.resize(point.size());
makeCoefficients();
}
void Derivative::setStep(const double step)
{
step_ = step;
}
// coefficient generation //////////////////////////////////////////////////////
// from B. Fornberg, “Generation of finite difference formulas on arbitrarily
// spaced grids,” Math. Comp., vol. 51, no. 184, pp. 699706, 1988.
// http://dx.doi.org/10.1090/S0025-5718-1988-0935077-0
void Derivative::makeCoefficients(void)
{
double c[3];
const Index N = point_.size() - 1, M = order_;
DMat curr(M + 1, N + 1), prev(M + 1, N + 1);
curr.fill(0.);
prev.fill(0.);
prev(0, 0) = 1.;
c[0] = 1.;
for (Index n = 1; n <= N; ++n)
{
c[1] = 1.;
for (Index nu = 0; nu <= n - 1; ++nu)
{
c[2] = point_(n) - point_(nu);
c[1] *= c[2];
for (Index m = 0; m <= min(n, M); ++m)
{
curr(m, nu) = point_(n)*prev(m, nu);
if (m)
{
curr(m, nu) -= m*prev(m-1, nu);
}
curr(m, nu) /= c[2];
}
}
for (Index m = 0; m <= min(n, M); ++m)
{
curr(m, n) = -point_(n-1)*prev(m, n-1);
if (m)
{
curr(m, n) += m*prev(m-1, n-1);
}
curr(m, n) *= c[0]/c[1];
}
c[0] = c[1];
prev = curr;
}
coefficient_ = curr.row(M);
}
// function call ///////////////////////////////////////////////////////////////
double Derivative::operator()(const double *x) const
{
ConstMap<DVec> xMap(x, f_.getNArg());
double res = 0.;
*buffer_ = xMap;
FOR_VEC(point_, i)
{
(*buffer_)(dir_) = x[dir_] + point_(i)*step_;
res += coefficient_[i]*f_(*buffer_);
}
res /= pow(step_, order_);
return res;
}
// function factory ////////////////////////////////////////////////////////////
DoubleFunction Derivative::makeFunction(const bool makeHardCopy) const
{
DoubleFunction res;
if (makeHardCopy)
{
Derivative copy(*this);
res.setFunction([copy](const double *x){return copy(x);}, f_.getNArg());
}
else
{
res.setFunction([this](const double *x){return (*this)(x);},
f_.getNArg());
}
return res;
}
DoubleFunction Latan::derivative(const DoubleFunction &f, const Index dir,
const Index order, const DVec point,
const double step)
{
return Derivative(f, dir, order, point, step).makeFunction();
}
/******************************************************************************
* CentralDerivative implementation *
******************************************************************************/
// constructor /////////////////////////////////////////////////////////////////
CentralDerivative::CentralDerivative(const DoubleFunction &f, const Index dir,
const Index order, const Index precOrder)
: Derivative(f, dir)
{
setOrder(order, precOrder);
}
// access //////////////////////////////////////////////////////////////////////
Index CentralDerivative::getPrecOrder(void) const
{
return precOrder_;
}
void CentralDerivative::setOrder(const Index order, const Index precOrder)
{
const Index nPoint = 2*(precOrder + (order - 1)/2) + 1;
DVec point(nPoint);
precOrder_ = precOrder;
FOR_VEC(point, i)
{
point(i) = static_cast<double>(i - (nPoint - 1)/2);
}
setOrderAndPoint(order, point);
tuneStep();
}
// step tuning /////////////////////////////////////////////////////////////////
// the rounding error should be O(N*epsilon/h^order)
//
void CentralDerivative::tuneStep(void)
{
const Index nPoint = getNPoint();
const double epsilon = numeric_limits<double>::epsilon();
const double step = pow(epsilon*nPoint, 1./(2.*precOrder_+getOrder()));
setStep(step);
}
// function factory ////////////////////////////////////////////////////////////
DoubleFunction Latan::centralDerivative(const DoubleFunction &f,
const Index dir, const Index order,
const Index precOrder)
{
return CentralDerivative(f, dir, order, precOrder).makeFunction();
}

View File

@ -0,0 +1,103 @@
/*
* Derivative.hpp, part of LatAnalyze 3
*
* Copyright (C) 2013 - 2020 Antonin Portelli
*
* LatAnalyze 3 is free software: you can redistribute it and/or modify
* it under the terms of the GNU General Public License as published by
* the Free Software Foundation, either version 3 of the License, or
* (at your option) any later version.
*
* LatAnalyze 3 is distributed in the hope that it will be useful,
* but WITHOUT ANY WARRANTY; without even the implied warranty of
* MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
* GNU General Public License for more details.
*
* You should have received a copy of the GNU General Public License
* along with LatAnalyze 3. If not, see <http://www.gnu.org/licenses/>.
*/
#ifndef Latan_Derivative_hpp_
#define Latan_Derivative_hpp_
#include <LatAnalyze/Global.hpp>
#include <LatAnalyze/Functional/Function.hpp>
BEGIN_LATAN_NAMESPACE
/******************************************************************************
* Derivative *
******************************************************************************/
class Derivative: public DoubleFunctionFactory
{
public:
static constexpr double defaultStep = 1.0e-2;
public:
// constructor
Derivative(const DoubleFunction &f, const Index dir, const Index order,
const DVec &point, const double step = defaultStep);
// destructor
virtual ~Derivative(void) = default;
// access
Index getDir(void) const;
Index getNPoint(void) const;
Index getOrder(void) const;
double getStep(void) const;
void setDir(const Index dir);
void setFunction(const DoubleFunction &f);
void setOrderAndPoint(const Index order, const DVec &point);
void setStep(const double step);
// function call
double operator()(const double *x) const;
// function factory
virtual DoubleFunction makeFunction(const bool makeHardCopy = true) const;
protected:
// constructor
Derivative(const DoubleFunction &f, const Index dir,
const double step = defaultStep);
private:
void makeCoefficients(void);
private:
DoubleFunction f_;
Index dir_, order_;
double step_;
DVec point_, coefficient_;
std::shared_ptr<DVec> buffer_;
};
DoubleFunction derivative(const DoubleFunction &f, const Index dir,
const Index order, const DVec point,
const double step = Derivative::defaultStep);
class CentralDerivative: public Derivative
{
public:
static const Index defaultPrecOrder = 2;
public:
// constructor
CentralDerivative(const DoubleFunction &f = DoubleFunction(),
const Index dir = 0,
const Index order = 1,
const Index precOrder = defaultPrecOrder);
// destructor
virtual ~CentralDerivative(void) = default;
// access
Index getPrecOrder(void) const;
void setOrder(const Index order, const Index precOrder = defaultPrecOrder);
// function call
using Derivative::operator();
private:
// step tuning
void tuneStep(void);
private:
Index precOrder_;
};
DoubleFunction centralDerivative(const DoubleFunction &f, const Index dir = 0,
const Index order = 1,
const Index precOrder =
CentralDerivative::defaultPrecOrder);
END_LATAN_NAMESPACE
#endif // Latan_Derivative_hpp_

View File

@ -0,0 +1,53 @@
/*
* FFT.hpp, part of LatAnalyze
*
* Copyright (C) 2013 - 2020 Antonin Portelli
*
* LatAnalyze is free software: you can redistribute it and/or modify
* it under the terms of the GNU General Public License as published by
* the Free Software Foundation, either version 3 of the License, or
* (at your option) any later version.
*
* LatAnalyze is distributed in the hope that it will be useful,
* but WITHOUT ANY WARRANTY; without even the implied warranty of
* MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
* GNU General Public License for more details.
*
* You should have received a copy of the GNU General Public License
* along with LatAnalyze. If not, see <http://www.gnu.org/licenses/>.
*/
#ifndef Latan_FFT_hpp_
#define Latan_FFT_hpp_
#include <LatAnalyze/Global.hpp>
BEGIN_LATAN_NAMESPACE
/******************************************************************************
* FFT abstract class *
******************************************************************************/
class FFT
{
public:
enum
{
Forward = 0,
Backward = 1
};
public:
// constructor
FFT(void) = default;
FFT(const Index size);
// destructor
virtual ~FFT(void) = default;
// size
virtual void resize(const Index size) = 0;
// FFT
virtual void operator()(CMat &x, const unsigned int dir = FFT::Forward) = 0;
};
END_LATAN_NAMESPACE
#endif // Latan_FFT_hpp_

View File

@ -0,0 +1,90 @@
/*
* GslFFT.cpp, part of LatAnalyze
*
* Copyright (C) 2013 - 2020 Antonin Portelli
*
* LatAnalyze is free software: you can redistribute it and/or modify
* it under the terms of the GNU General Public License as published by
* the Free Software Foundation, either version 3 of the License, or
* (at your option) any later version.
*
* LatAnalyze is distributed in the hope that it will be useful,
* but WITHOUT ANY WARRANTY; without even the implied warranty of
* MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
* GNU General Public License for more details.
*
* You should have received a copy of the GNU General Public License
* along with LatAnalyze. If not, see <http://www.gnu.org/licenses/>.
*/
#include <LatAnalyze/Numerical/GslFFT.hpp>
#include <LatAnalyze/includes.hpp>
using namespace std;
using namespace Latan;
/******************************************************************************
* GslFFT implementation *
******************************************************************************/
// constructor /////////////////////////////////////////////////////////////////
GslFFT::GslFFT(const Index size)
{
resize(size);
}
// destructor //////////////////////////////////////////////////////////////////
GslFFT::~GslFFT(void)
{
clear();
}
// size ////////////////////////////////////////////////////////////////////////
void GslFFT::resize(const Index size)
{
if (size_ != size)
{
clear();
size_ = size;
wavetable_ = gsl_fft_complex_wavetable_alloc(size_);
workspace_ = gsl_fft_complex_workspace_alloc(size_);
}
}
// fft /////////////////////////////////////////////////////////////////////////
void GslFFT::operator()(CMat &x, const unsigned int dir)
{
if (x.size() != size_)
{
LATAN_ERROR(Size, "wrong input vector size");
}
else
{
switch (dir)
{
case FFT::Forward:
gsl_fft_complex_forward((double *)x.data(), 1, size_,
wavetable_, workspace_);
break;
case FFT::Backward:
gsl_fft_complex_backward((double *)x.data(), 1, size_,
wavetable_, workspace_);
break;
default:
LATAN_ERROR(Argument, "invalid FT direction");
break;
}
}
}
// destroy GSL objects /////////////////////////////////////////////////////////
void GslFFT::clear(void)
{
if (!wavetable_)
{
gsl_fft_complex_wavetable_free(wavetable_);
}
if (!workspace_)
{
gsl_fft_complex_workspace_free(workspace_);
}
}

View File

@ -0,0 +1,57 @@
/*
* GslFFT.hpp, part of LatAnalyze
*
* Copyright (C) 2013 - 2020 Antonin Portelli
*
* LatAnalyze is free software: you can redistribute it and/or modify
* it under the terms of the GNU General Public License as published by
* the Free Software Foundation, either version 3 of the License, or
* (at your option) any later version.
*
* LatAnalyze is distributed in the hope that it will be useful,
* but WITHOUT ANY WARRANTY; without even the implied warranty of
* MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
* GNU General Public License for more details.
*
* You should have received a copy of the GNU General Public License
* along with LatAnalyze. If not, see <http://www.gnu.org/licenses/>.
*/
#ifndef Latan_GslFFT_hpp_
#define Latan_GslFFT_hpp_
#include <LatAnalyze/Global.hpp>
#include <LatAnalyze/Core/Mat.hpp>
#include <LatAnalyze/Numerical/FFT.hpp>
#include <gsl/gsl_fft_complex.h>
BEGIN_LATAN_NAMESPACE
/******************************************************************************
* GSL FFT *
******************************************************************************/
class GslFFT: public FFT
{
public:
// constructors
GslFFT(void) = default;
GslFFT(const Index size);
// destructor
virtual ~GslFFT(void);
// size
void resize(const Index size);
// fft
virtual void operator()(CMat &x, const unsigned int dir = FFT::Forward);
private:
// destroy GSL objects
void clear(void);
private:
Index size_{0};
gsl_fft_complex_wavetable *wavetable_{nullptr};
gsl_fft_complex_workspace *workspace_{nullptr};
};
END_LATAN_NAMESPACE
#endif // Latan_GslFFT_hpp_

View File

@ -0,0 +1,145 @@
/*
* GslHybridRootFinder.cpp, part of LatAnalyze 3
*
* Copyright (C) 2013 - 2020 Antonin Portelli
*
* LatAnalyze 3 is free software: you can redistribute it and/or modify
* it under the terms of the GNU General Public License as published by
* the Free Software Foundation, either version 3 of the License, or
* (at your option) any later version.
*
* LatAnalyze 3 is distributed in the hope that it will be useful,
* but WITHOUT ANY WARRANTY; without even the implied warranty of
* MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
* GNU General Public License for more details.
*
* You should have received a copy of the GNU General Public License
* along with LatAnalyze 3. If not, see <http://www.gnu.org/licenses/>.
*/
#include <LatAnalyze/Numerical/GslHybridRootFinder.hpp>
#include <LatAnalyze/includes.hpp>
#include <gsl/gsl_vector.h>
#include <gsl/gsl_multiroots.h>
using namespace std;
using namespace Latan;
/******************************************************************************
* GslHybridRootFinder implementation *
******************************************************************************/
// constructor /////////////////////////////////////////////////////////////////
GslHybridRootFinder::GslHybridRootFinder(const Index dim)
: RootFinder(dim)
{}
// output //////////////////////////////////////////////////////////////////////
void GslHybridRootFinder::printState(void)
{
if (solver_)
{
cout << "x=";
for (size_t i = 0; i < solver_->x->size; ++i)
{
cout << " " << scientific << gsl_vector_get(solver_->x, i);
}
cout << endl;
cout << "f=";
for (size_t i = 0; i < solver_->f->size; ++i)
{
cout << " " << scientific << gsl_vector_get(solver_->f, i);
}
cout << endl;
}
}
// solver //////////////////////////////////////////////////////////////////////
const DVec &
GslHybridRootFinder::operator()(const vector<DoubleFunction *> &func)
{
DVec &res = getState();
Verbosity verbosity = getVerbosity();
int status;
unsigned int iter = 0;
const size_t nFunc = func.size();
Index nArg;
gsl_vector *x;
gsl_multiroot_function fStruct;
int (*fWrap)(const gsl_vector *, void *, gsl_vector *) =
[](const gsl_vector *var, void *vFunc, gsl_vector *f)->int
{
vector<DoubleFunction *> &fPt =
*static_cast<vector<DoubleFunction *> *>(vFunc);
for (unsigned int i = 0; i < fPt.size(); ++i)
{
gsl_vector_set(f, i, (*fPt[i])(var->data));
}
return GSL_SUCCESS;
};
nArg = func[0]->getNArg();
for (auto f: func)
{
if (f->getNArg() != nArg)
{
LATAN_ERROR(Size,
"equations do not have the same number of unknown");
}
}
if (nArg != static_cast<Index>(nFunc))
{
LATAN_ERROR(Size, "equation and unknown number mismatch");
}
if (res.size() != nArg)
{
resize(nArg);
}
solver_ = gsl_multiroot_fsolver_alloc(gsl_multiroot_fsolver_hybrids, nFunc);
x = gsl_vector_alloc(nFunc);
FOR_VEC(res, i)
{
gsl_vector_set(x, static_cast<size_t>(i), res(i));
}
fStruct.n = nFunc;
fStruct.params = reinterpret_cast<void *>(
const_cast<vector<DoubleFunction *> *>(&func));
fStruct.f = fWrap;
gsl_multiroot_fsolver_set(solver_, &fStruct, x);
do
{
iter++;
status = gsl_multiroot_fsolver_iterate(solver_);
if (verbosity >= Verbosity::Debug)
{
cout << "--- iteration " << iter << endl;
printState();
}
if (status)
{
break;
}
status = gsl_multiroot_test_residual(solver_->f, getPrecision());
} while ((status == GSL_CONTINUE) and (iter < getMaxIteration()));
if (verbosity >= Verbosity::Debug)
{
cout << "--- done" << endl;
cout << "end status: " << gsl_strerror(status) << endl;
}
if (status)
{
LATAN_WARNING("GSL hybrid root finder ended with status '" +
strFrom(gsl_strerror(status)) + "'");
}
FOR_VEC(res, i)
{
res(i) = gsl_vector_get(solver_->x, static_cast<size_t>(i));
}
gsl_vector_free(x);
gsl_multiroot_fsolver_free(solver_);
solver_ = nullptr;
return res;
}

View File

@ -0,0 +1,52 @@
/*
* GslHybridRootFinder.hpp, part of LatAnalyze 3
*
* Copyright (C) 2013 - 2020 Antonin Portelli
*
* LatAnalyze 3 is free software: you can redistribute it and/or modify
* it under the terms of the GNU General Public License as published by
* the Free Software Foundation, either version 3 of the License, or
* (at your option) any later version.
*
* LatAnalyze 3 is distributed in the hope that it will be useful,
* but WITHOUT ANY WARRANTY; without even the implied warranty of
* MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
* GNU General Public License for more details.
*
* You should have received a copy of the GNU General Public License
* along with LatAnalyze 3. If not, see <http://www.gnu.org/licenses/>.
*/
#ifndef Latan_GslHybridRootFinder_hpp_
#define Latan_GslHybridRootFinder_hpp_
#include <LatAnalyze/Global.hpp>
#include <LatAnalyze/Numerical/RootFinder.hpp>
#include <gsl/gsl_multiroots.h>
BEGIN_LATAN_NAMESPACE
/******************************************************************************
* GslHybridRootFinder *
******************************************************************************/
class GslHybridRootFinder: public RootFinder
{
public:
// constructors
GslHybridRootFinder(void) = default;
explicit GslHybridRootFinder(const Index dim);
// destructor
virtual ~GslHybridRootFinder(void) = default;
// solver
virtual const DVec & operator()(const std::vector<DoubleFunction *> &func);
private:
// output
void printState(void);
private:
gsl_multiroot_fsolver *solver_{nullptr};
};
END_LATAN_NAMESPACE
#endif // Latan_GslHybridRootFinder_hpp_

View File

@ -0,0 +1,361 @@
/*
* GslMinimizer.cpp, part of LatAnalyze
*
* Copyright (C) 2013 - 2020 Antonin Portelli
*
* LatAnalyze is free software: you can redistribute it and/or modify
* it under the terms of the GNU General Public License as published by
* the Free Software Foundation, either version 3 of the License, or
* (at your option) any later version.
*
* LatAnalyze is distributed in the hope that it will be useful,
* but WITHOUT ANY WARRANTY; without even the implied warranty of
* MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
* GNU General Public License for more details.
*
* You should have received a copy of the GNU General Public License
* along with LatAnalyze. If not, see <http://www.gnu.org/licenses/>.
*/
#include <LatAnalyze/Numerical/GslMinimizer.hpp>
#include <LatAnalyze/includes.hpp>
#include <LatAnalyze/Core/Math.hpp>
#include <gsl/gsl_multimin.h>
#include <gsl/gsl_blas.h>
using namespace std;
using namespace Latan;
/******************************************************************************
* GslMinimizer implementation *
******************************************************************************/
// constructor /////////////////////////////////////////////////////////////////
GslMinimizer::GslMinimizer(const Algorithm algorithm)
{
setAlgorithm(algorithm);
der_.setOrder(1, 1);
}
// access //////////////////////////////////////////////////////////////////////
GslMinimizer::Algorithm GslMinimizer::getAlgorithm(void) const
{
return algorithm_;
}
void GslMinimizer::setAlgorithm(const Algorithm algorithm)
{
algorithm_ = algorithm;
}
bool GslMinimizer::supportLimits(void) const
{
return false;
}
// test ////////////////////////////////////////////////////////////////////////
bool GslMinimizer::isDerAlgorithm(const Algorithm algorithm)
{
return (algorithm <= Algorithm::lastDerAlg);
}
// minimization ////////////////////////////////////////////////////////////////
const DVec & GslMinimizer::operator()(const DoubleFunction &f)
{
DVec &x = getState();
// resize minimizer state to match function number of arguments
if (f.getNArg() != x.size())
{
resize(f.getNArg());
}
// set function data
GslFuncData data;
der_.setFunction(f);
data.f = &f;
data.d = &der_;
// set initial position
gsl_vector *gslX = gsl_vector_alloc(getDim());
for (Index i = 0; i < getDim(); ++i)
{
gsl_vector_set(gslX, i, x(i));
}
// minimization
int status;
if (isDerAlgorithm(getAlgorithm()))
{
// set function
gsl_multimin_function_fdf gslFunc;
gslFunc.n = getDim();
gslFunc.f = &fWrapper;
gslFunc.df = &dfWrapper;
gslFunc.fdf = &fdfWrapper;
gslFunc.params = &data;
// create and set minimizer
const gsl_multimin_fdfminimizer_type *gslAlg;
gsl_multimin_fdfminimizer *gslMin;
switch (getAlgorithm())
{
case Algorithm::cgFR:
gslAlg = gsl_multimin_fdfminimizer_conjugate_fr;
break;
case Algorithm::cgPR:
gslAlg = gsl_multimin_fdfminimizer_conjugate_pr;
break;
case Algorithm::bfgs:
gslAlg = gsl_multimin_fdfminimizer_vector_bfgs;
break;
case Algorithm::bfgs2:
gslAlg = gsl_multimin_fdfminimizer_vector_bfgs2;
break;
case Algorithm::steepDesc:
gslAlg = gsl_multimin_fdfminimizer_vector_bfgs2;
break;
default:
LATAN_ERROR(Argument, "unknow GSL minization algorithm "
+ strFrom(static_cast<int>(getAlgorithm())));
break;
}
gslMin = gsl_multimin_fdfminimizer_alloc(gslAlg, getDim());
// minimize
unsigned int pass = 0, it;
double dxRel;
do
{
pass++;
gsl_multimin_fdfminimizer_set(gslMin, &gslFunc, gslX, 0.01, 0.001);
if (getVerbosity() >= Verbosity::Normal)
{
cout << "========== GSL minimization, pass #" << pass;
cout << " ==========" << endl;
cout << "Algorithm: " << getAlgorithmName(getAlgorithm());
cout << endl;
cout << "Max eval.= " << getMaxIteration();
cout << " -- Precision= " << getPrecision() << endl;
printf("Starting f(x)= %.10e\n", f(x));
}
it = 0;
do
{
it++;
gsl_multimin_fdfminimizer_iterate(gslMin);
dxRel = gsl_blas_dnrm2(gslMin->dx)/gsl_blas_dnrm2(gslMin->x);
status = (dxRel < getPrecision()) ? GSL_SUCCESS : GSL_CONTINUE;
if (getVerbosity() >= Verbosity::Debug)
{
printf("iteration %4d: f= %.10e dxrel= %.10e eval= %d\n",
it, gslMin->f, dxRel, data.evalCount);
}
} while (status == GSL_CONTINUE and
(data.evalCount < getMaxIteration()));
if (getVerbosity() >= Verbosity::Normal)
{
printf("Found minimum %.10e at:\n", gslMin->f);
for (Index i = 0; i < x.size(); ++i)
{
printf("%8s= %.10e\n", f.varName().getName(i).c_str(),
gsl_vector_get(gslMin->x, i));
}
cout << "after " << data.evalCount << " evaluations" << endl;
cout << "Minimization ended with code " << status;
cout << endl;
}
data.evalCount = 0;
for (Index i = 0; i < getDim(); ++i)
{
gsl_vector_set(gslX, i, gsl_vector_get(gslMin->x, i));
}
} while (status != GSL_SUCCESS and (pass < getMaxPass()));
// deallocate GSL minimizer
gsl_multimin_fdfminimizer_free(gslMin);
}
else
{
// set function
gsl_multimin_function gslFunc;
gslFunc.n = getDim();
gslFunc.f = &fWrapper;
gslFunc.params = &data;
// create and set minimizer
const gsl_multimin_fminimizer_type *gslAlg;
gsl_multimin_fminimizer *gslMin;
switch (getAlgorithm())
{
case Algorithm::simplex:
gslAlg = gsl_multimin_fminimizer_nmsimplex;
break;
case Algorithm::simplex2:
gslAlg = gsl_multimin_fminimizer_nmsimplex2;
break;
case Algorithm::simplex2R:
gslAlg = gsl_multimin_fminimizer_nmsimplex2rand;
break;
default:
LATAN_ERROR(Argument, "unknow GSL minization algorithm "
+ strFrom(static_cast<int>(getAlgorithm())));
break;
}
gslMin = gsl_multimin_fminimizer_alloc(gslAlg, getDim());
// minimize
unsigned int pass = 0, it;
gsl_vector *step = gsl_vector_alloc(getDim());
double relSize;
gsl_vector_set_all(step, 0.01);
do
{
pass++;
gsl_multimin_fminimizer_set(gslMin, &gslFunc, gslX, step);
if (getVerbosity() >= Verbosity::Normal)
{
cout << "========== GSL minimization, pass #" << pass;
cout << " ==========" << endl;
cout << "Algorithm: " << getAlgorithmName(getAlgorithm());
cout << endl;
cout << "Max eval.= " << getMaxIteration();
cout << " -- Precision= " << getPrecision() << endl;
printf("Starting f(x)= %.10e\n", f(x));
}
it = 0;
do
{
it++;
gsl_multimin_fminimizer_iterate(gslMin);
relSize = Math::pow<2>(gslMin->size)/gsl_blas_dnrm2(gslMin->x);
status = (relSize < getPrecision()) ? GSL_SUCCESS
: GSL_CONTINUE;
if (getVerbosity() >= Verbosity::Debug)
{
printf("iteration %4d: f= %.10e relSize= %.10e eval= %d\n",
it, gslMin->fval, relSize, data.evalCount);
}
} while (status == GSL_CONTINUE and
(data.evalCount < getMaxIteration()));
if (getVerbosity() >= Verbosity::Normal)
{
printf("Found minimum %.10e at:\n", gslMin->fval);
for (Index i = 0; i < x.size(); ++i)
{
printf("%8s= %.10e\n", f.varName().getName(i).c_str(),
gsl_vector_get(gslMin->x, i));
}
cout << "after " << data.evalCount << " evaluations" << endl;
cout << "Minimization ended with code " << status;
cout << endl;
}
data.evalCount = 0;
for (Index i = 0; i < getDim(); ++i)
{
gsl_vector_set(gslX, i, gsl_vector_get(gslMin->x, i));
}
} while (status != GSL_SUCCESS and (pass < getMaxPass()));
// deallocate GSL minimizer
gsl_multimin_fminimizer_free(gslMin);
gsl_vector_free(step);
}
if (status != GSL_SUCCESS)
{
LATAN_WARNING("invalid minimum: maximum number of call reached");
}
// save final result
for (Index i = 0; i < getDim(); ++i)
{
x(i) = gsl_vector_get(gslX, i);
}
// deallocate GSL state and return
gsl_vector_free(gslX);
return x;
}
// function wrappers ///////////////////////////////////////////////////////////
double GslMinimizer::fWrapper(const gsl_vector *x, void *vdata)
{
GslFuncData &data = *static_cast<GslFuncData *>(vdata);
data.evalCount++;
return (*data.f)(x->data);
}
void GslMinimizer::dfWrapper(const gsl_vector *x, void *vdata, gsl_vector * df)
{
GslFuncData &data = *static_cast<GslFuncData *>(vdata);
const unsigned int n = data.f->getNArg();
for (unsigned int i = 0; i < n; ++i)
{
data.d->setDir(i);
gsl_vector_set(df, i, (*(data.d))(x->data));
}
data.evalCount += data.d->getNPoint()*n;
}
void GslMinimizer::fdfWrapper(const gsl_vector *x, void *vdata, double *f,
gsl_vector * df)
{
GslFuncData &data = *static_cast<GslFuncData *>(vdata);
const unsigned int n = data.f->getNArg();
for (unsigned int i = 0; i < n; ++i)
{
data.d->setDir(i);
gsl_vector_set(df, i, (*(data.d))(x->data));
}
*f = (*data.f)(x->data);
data.evalCount += data.d->getNPoint()*n + 1;
}
// algorithm names /////////////////////////////////////////////////////////////
string GslMinimizer::getAlgorithmName(const Algorithm algorithm)
{
switch (algorithm)
{
case Algorithm::cgFR:
return "Fletcher-Reeves conjugate gradient";
break;
case Algorithm::cgPR:
return "Polak-Ribiere conjugate gradient";
break;
case Algorithm::bfgs:
return "Broyden-Fletcher-Goldfarb-Shanno";
break;
case Algorithm::bfgs2:
return "improved Broyden-Fletcher-Goldfarb-Shanno";
break;
case Algorithm::steepDesc:
return "steepest descent";
break;
case Algorithm::simplex:
return "Nelder-Mead simplex";
break;
case Algorithm::simplex2:
return "improved Nelder-Mead simplex";
break;
case Algorithm::simplex2R:
return "improved Nelder-Mead simplex with random start";
break;
}
return "";
}

View File

@ -0,0 +1,86 @@
/*
* GslMinimizer.hpp, part of LatAnalyze
*
* Copyright (C) 2013 - 2020 Antonin Portelli
*
* LatAnalyze is free software: you can redistribute it and/or modify
* it under the terms of the GNU General Public License as published by
* the Free Software Foundation, either version 3 of the License, or
* (at your option) any later version.
*
* LatAnalyze is distributed in the hope that it will be useful,
* but WITHOUT ANY WARRANTY; without even the implied warranty of
* MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
* GNU General Public License for more details.
*
* You should have received a copy of the GNU General Public License
* along with LatAnalyze. If not, see <http://www.gnu.org/licenses/>.
*/
#ifndef Latan_GslMinimizer_hpp_
#define Latan_GslMinimizer_hpp_
#include <LatAnalyze/Global.hpp>
#include <LatAnalyze/Numerical/Derivative.hpp>
#include <LatAnalyze/Functional/Function.hpp>
#include <LatAnalyze/Numerical/Minimizer.hpp>
#include <gsl/gsl_vector.h>
BEGIN_LATAN_NAMESPACE
/******************************************************************************
* interface to the GSL minimizers *
******************************************************************************/
class GslMinimizer: public Minimizer
{
public:
enum class Algorithm
{
cgFR = 1,
cgPR = 2,
bfgs = 3,
bfgs2 = 4,
steepDesc = 5,
lastDerAlg = 5,
simplex = 6,
simplex2 = 7,
simplex2R = 8
};
private:
struct GslFuncData
{
const DoubleFunction *f{nullptr};
Derivative *d{nullptr};
unsigned int evalCount{0};
};
public:
// constructor
explicit GslMinimizer(const Algorithm algorithm = defaultAlg_);
// destructor
virtual ~GslMinimizer(void) = default;
// access
Algorithm getAlgorithm(void) const;
void setAlgorithm(const Algorithm algorithm);
virtual bool supportLimits(void) const;
// minimization
virtual const DVec & operator()(const DoubleFunction &f);
private:
// test
static bool isDerAlgorithm(const Algorithm algorithm);
// function wrappers
static double fWrapper(const gsl_vector *x, void * params);
static void dfWrapper(const gsl_vector *x, void * params,
gsl_vector * df);
static void fdfWrapper(const gsl_vector *x, void *params, double *f,
gsl_vector * df);
// algorithm names
std::string getAlgorithmName(const Algorithm algorithm);
private:
Algorithm algorithm_;
static constexpr Algorithm defaultAlg_ = Algorithm::simplex2;
CentralDerivative der_;
};
END_LATAN_NAMESPACE
#endif // Latan_GslMinimizer_hpp_

View File

@ -0,0 +1,87 @@
/*
* GslQagsIntegrator.cpp, part of LatAnalyze 3
*
* Copyright (C) 2013 - 2020 Antonin Portelli
*
* LatAnalyze 3 is free software: you can redistribute it and/or modify
* it under the terms of the GNU General Public License as published by
* the Free Software Foundation, either version 3 of the License, or
* (at your option) any later version.
*
* LatAnalyze 3 is distributed in the hope that it will be useful,
* but WITHOUT ANY WARRANTY; without even the implied warranty of
* MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
* GNU General Public License for more details.
*
* You should have received a copy of the GNU General Public License
* along with LatAnalyze 3. If not, see <http://www.gnu.org/licenses/>.
*/
#include <LatAnalyze/Numerical/GslQagsIntegrator.hpp>
#include <LatAnalyze/includes.hpp>
#include <LatAnalyze/Core/Math.hpp>
using namespace std;
using namespace Latan;
/******************************************************************************
* GslQagIntegrator implementation *
******************************************************************************/
// constructor /////////////////////////////////////////////////////////////////
GslQagsIntegrator::GslQagsIntegrator(const unsigned int limit,
const double precision)
: limit_(limit)
, precision_(precision)
{
workspace_ = gsl_integration_workspace_alloc(limit);
}
// destructor //////////////////////////////////////////////////////////////////
GslQagsIntegrator::~GslQagsIntegrator(void)
{
gsl_integration_workspace_free(workspace_);
}
// integral calculation ////////////////////////////////////////////////////////
double GslQagsIntegrator::operator()(const DoubleFunction &f, const double xMin,
const double xMax)
{
double (*fWrap)(double, void *) = [](double x, void *fPt)->double
{
return (*static_cast<DoubleFunction *>(fPt))(&x);
};
gsl_function gslF;
double result;
gslF.function = fWrap;
gslF.params = reinterpret_cast<void *>(&const_cast<DoubleFunction &>(f));
if ((xMin > -Math::inf) and (xMax < Math::inf))
{
gsl_integration_qags(&gslF, xMin, xMax, 0.0, precision_, limit_,
workspace_, &result, &error_);
}
else if (xMax < Math::inf)
{
gsl_integration_qagil(&gslF, xMax, 0.0, precision_, limit_,
workspace_, &result, &error_);
}
else if (xMin > -Math::inf)
{
gsl_integration_qagiu(&gslF, xMin, 0.0, precision_, limit_,
workspace_, &result, &error_);
}
else
{
gsl_integration_qagi(&gslF, 0.0, precision_, limit_,
workspace_, &result, &error_);
}
return result;
}
// get last error //////////////////////////////////////////////////////////////
double GslQagsIntegrator::getLastError(void) const
{
return error_;
}

View File

@ -0,0 +1,58 @@
/*
* GslQagsIntegrator.hpp, part of LatAnalyze 3
*
* Copyright (C) 2013 - 2020 Antonin Portelli
*
* LatAnalyze 3 is free software: you can redistribute it and/or modify
* it under the terms of the GNU General Public License as published by
* the Free Software Foundation, either version 3 of the License, or
* (at your option) any later version.
*
* LatAnalyze 3 is distributed in the hope that it will be useful,
* but WITHOUT ANY WARRANTY; without even the implied warranty of
* MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
* GNU General Public License for more details.
*
* You should have received a copy of the GNU General Public License
* along with LatAnalyze 3. If not, see <http://www.gnu.org/licenses/>.
*/
#ifndef Latan_GslQagsIntegrator_hpp_
#define Latan_GslQagsIntegrator_hpp_
#include <LatAnalyze/Global.hpp>
#include <LatAnalyze/Functional/Function.hpp>
#include <LatAnalyze/Numerical/Integrator.hpp>
#include <gsl/gsl_integration.h>
BEGIN_LATAN_NAMESPACE
/******************************************************************************
* GSL general quadrature adaptive integration with singularities *
******************************************************************************/
class GslQagsIntegrator: public Integrator
{
public:
static const unsigned int defaultLimit = 1000;
static constexpr double defaultPrec = 1.0e-7;
public:
// constructor
GslQagsIntegrator(const unsigned int limit = defaultLimit,
const double precision = defaultPrec);
// destructor
virtual ~GslQagsIntegrator(void);
// integral calculation
virtual double operator()(const DoubleFunction &f, const double xMin,
const double xMax);
// get last error
double getLastError(void) const;
private:
unsigned int limit_;
double precision_, error_;
gsl_integration_workspace *workspace_;
};
END_LATAN_NAMESPACE
#endif // Latan_GslQagsIntegrator_hpp_

View File

@ -0,0 +1,46 @@
/*
* Integrator.hpp, part of LatAnalyze 3
*
* Copyright (C) 2013 - 2020 Antonin Portelli
*
* LatAnalyze 3 is free software: you can redistribute it and/or modify
* it under the terms of the GNU General Public License as published by
* the Free Software Foundation, either version 3 of the License, or
* (at your option) any later version.
*
* LatAnalyze 3 is distributed in the hope that it will be useful,
* but WITHOUT ANY WARRANTY; without even the implied warranty of
* MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
* GNU General Public License for more details.
*
* You should have received a copy of the GNU General Public License
* along with LatAnalyze 3. If not, see <http://www.gnu.org/licenses/>.
*/
#ifndef Latan_Integrator_hpp_
#define Latan_Integrator_hpp_
#include <LatAnalyze/Global.hpp>
#include <LatAnalyze/Functional/Function.hpp>
BEGIN_LATAN_NAMESPACE
/******************************************************************************
* abstract integrator class *
******************************************************************************/
class Integrator
{
public:
// constructor
Integrator(void) = default;
// destructor
virtual ~Integrator(void) = default;
// integral calculation
virtual double operator()(const DoubleFunction &f, const double xMin,
const double xMax) = 0;
};
END_LATAN_NAMESPACE
#endif // Latan_Integrator_hpp_

View File

@ -0,0 +1,194 @@
/*
* Minimizer.cpp, part of LatAnalyze 3
*
* Copyright (C) 2013 - 2020 Antonin Portelli
*
* LatAnalyze 3 is free software: you can redistribute it and/or modify
* it under the terms of the GNU General Public License as published by
* the Free Software Foundation, either version 3 of the License, or
* (at your option) any later version.
*
* LatAnalyze 3 is distributed in the hope that it will be useful,
* but WITHOUT ANY WARRANTY; without even the implied warranty of
* MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
* GNU General Public License for more details.
*
* You should have received a copy of the GNU General Public License
* along with LatAnalyze 3. If not, see <http://www.gnu.org/licenses/>.
*/
#include <LatAnalyze/Numerical/Minimizer.hpp>
#include <LatAnalyze/includes.hpp>
using namespace std;
using namespace Latan;
// access //////////////////////////////////////////////////////////////////////
void Minimizer::resize(const Index dim)
{
const Index oldDim = getDim();
Solver::resize(dim);
highLimit_.conservativeResize(dim);
lowLimit_.conservativeResize(dim);
hasHighLimit_.conservativeResize(dim);
hasLowLimit_.conservativeResize(dim);
if (dim > oldDim)
{
highLimit_.segment(oldDim, dim - oldDim).fill(0.);
highLimit_.segment(oldDim, dim - oldDim).fill(0.);
lowLimit_.segment(oldDim, dim - oldDim).fill(0.);
hasHighLimit_.segment(oldDim, dim - oldDim).fill(false);
hasLowLimit_.segment(oldDim, dim - oldDim).fill(false);
}
}
#define checkSupport \
if (!supportLimits())\
{\
LATAN_ERROR(Implementation, "minimizer does not support limits");\
}
double Minimizer::getHighLimit(const Index i) const
{
checkSupport;
if (i >= getDim())
{
LATAN_ERROR(Size, "invalid variable index");
}
return highLimit_(i);
}
const DVec & Minimizer::getHighLimit(const PlaceHolder ph __dumb) const
{
checkSupport;
return highLimit_;
}
double Minimizer::getLowLimit(const Index i) const
{
checkSupport;
if (i >= getDim())
{
LATAN_ERROR(Size, "invalid variable index");
}
return lowLimit_(i);
}
const DVec & Minimizer::getLowLimit(const PlaceHolder ph __dumb) const
{
checkSupport;
return lowLimit_;
}
bool Minimizer::hasHighLimit(const Index i) const
{
checkSupport;
if (i >= getDim())
{
LATAN_ERROR(Size, "invalid variable index");
}
return hasHighLimit_(i);
}
bool Minimizer::hasLowLimit(const Index i) const
{
checkSupport;
if (i >= getDim())
{
LATAN_ERROR(Size, "invalid variable index");
}
return hasLowLimit_(i);
}
void Minimizer::setHighLimit(const Index i, const double l)
{
checkSupport;
if (i >= getDim())
{
resize(i + 1);
}
highLimit_(i) = l;
useHighLimit(i);
}
void Minimizer::setHighLimit(const PlaceHolder ph __dumb, const DVec &l)
{
checkSupport;
if (l.size() != getDim())
{
resize(l.size());
}
highLimit_ = l;
useHighLimit(_);
}
void Minimizer::setLowLimit(const Index i, const double l)
{
checkSupport;
if (i >= getDim())
{
resize(i + 1);
}
lowLimit_(i) = l;
useLowLimit(i);
}
void Minimizer::setLowLimit(const PlaceHolder ph __dumb, const DVec &l)
{
checkSupport;
if (l.size() != getDim())
{
resize(l.size());
}
lowLimit_ = l;
useLowLimit(_);
}
void Minimizer::useHighLimit(const Index i, const bool use)
{
checkSupport;
if (i >= getDim())
{
resize(i + 1);
}
hasHighLimit_(i) = use;
}
void Minimizer::useHighLimit(const PlaceHolder ph __dumb, const bool use)
{
checkSupport;
hasHighLimit_.fill(use);
}
void Minimizer::useLowLimit(const Index i, const bool use)
{
checkSupport;
if (i >= getDim())
{
resize(i + 1);
}
hasLowLimit_(i) = use;
}
void Minimizer::useLowLimit(const PlaceHolder ph __dumb, const bool use)
{
checkSupport;
hasLowLimit_.fill(use);
}
unsigned int Minimizer::getMaxPass(void) const
{
return maxPass_;
}
void Minimizer::setMaxPass(const unsigned int maxPass)
{
maxPass_ = maxPass;
}

View File

@ -0,0 +1,72 @@
/*
* Minimizer.hpp, part of LatAnalyze 3
*
* Copyright (C) 2013 - 2020 Antonin Portelli
*
* LatAnalyze 3 is free software: you can redistribute it and/or modify
* it under the terms of the GNU General Public License as published by
* the Free Software Foundation, either version 3 of the License, or
* (at your option) any later version.
*
* LatAnalyze 3 is distributed in the hope that it will be useful,
* but WITHOUT ANY WARRANTY; without even the implied warranty of
* MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
* GNU General Public License for more details.
*
* You should have received a copy of the GNU General Public License
* along with LatAnalyze 3. If not, see <http://www.gnu.org/licenses/>.
*/
#ifndef Latan_Minimizer_hpp_
#define Latan_Minimizer_hpp_
#include <LatAnalyze/Global.hpp>
#include <LatAnalyze/Functional/Function.hpp>
#include <LatAnalyze/Core/Mat.hpp>
#include <LatAnalyze/Numerical/Solver.hpp>
BEGIN_LATAN_NAMESPACE
/******************************************************************************
* Abstract minimizer class *
******************************************************************************/
class Minimizer: public Solver
{
public:
// constructor
Minimizer(void) = default;
// destructor
virtual ~Minimizer(void) = default;
// access
virtual void resize(const Index dim);
virtual double getHighLimit(const Index i) const ;
virtual const DVec & getHighLimit(const PlaceHolder ph = _) const;
virtual double getLowLimit(const Index i) const;
virtual const DVec & getLowLimit(const PlaceHolder ph = _) const;
virtual bool hasHighLimit(const Index i) const;
virtual bool hasLowLimit(const Index i) const;
virtual void setHighLimit(const Index i, const double l);
virtual void setHighLimit(const PlaceHolder ph, const DVec &l);
virtual void setLowLimit(const Index i, const double l);
virtual void setLowLimit(const PlaceHolder ph, const DVec &l);
virtual void useHighLimit(const Index i, const bool use = true);
virtual void useHighLimit(const PlaceHolder ph = _,
const bool use = true);
virtual void useLowLimit(const Index i, const bool use = true);
virtual void useLowLimit(const PlaceHolder ph = _,
const bool use = true);
virtual bool supportLimits(void) const = 0;
virtual unsigned int getMaxPass(void) const;
virtual void setMaxPass(const unsigned int maxPass);
// minimization
virtual const DVec & operator()(const DoubleFunction &f) = 0;
private:
DVec highLimit_, lowLimit_;
Vec<bool> hasHighLimit_, hasLowLimit_;
unsigned int maxPass_{5u};
};
END_LATAN_NAMESPACE
#endif // Latan_Minimizer_hpp_

View File

@ -0,0 +1,198 @@
/*
* MinuitMinimizer.cpp, part of LatAnalyze 3
*
* Copyright (C) 2013 - 2020 Antonin Portelli
*
* LatAnalyze 3 is free software: you can redistribute it and/or modify
* it under the terms of the GNU General Public License as published by
* the Free Software Foundation, either version 3 of the License, or
* (at your option) any later version.
*
* LatAnalyze 3 is distributed in the hope that it will be useful,
* but WITHOUT ANY WARRANTY; without even the implied warranty of
* MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
* GNU General Public License for more details.
*
* You should have received a copy of the GNU General Public License
* along with LatAnalyze 3. If not, see <http://www.gnu.org/licenses/>.
*/
#include <LatAnalyze/Numerical/MinuitMinimizer.hpp>
#include <LatAnalyze/includes.hpp>
// forward declaration necessary in the ROOT-based version of Minuit2
namespace ROOT
{
namespace Fit
{
class ParameterSettings;
};
};
// macros necessary in the ROOT-based version of Minuit2
#define ROOT_Math_VecTypes
#define MATHCORE_STANDALONE
#include <Minuit2/Minuit2Minimizer.h>
#include <Math/Functor.h>
using namespace std;
using namespace Latan;
static constexpr double initErr = 0.1;
/******************************************************************************
* MinuitMinimizer implementation *
******************************************************************************/
// constructors ////////////////////////////////////////////////////////////////
MinuitMinimizer::MinuitMinimizer(const Algorithm algorithm)
{
setAlgorithm(algorithm);
}
// access //////////////////////////////////////////////////////////////////////
MinuitMinimizer::Algorithm MinuitMinimizer::getAlgorithm(void) const
{
return algorithm_;
}
void MinuitMinimizer::setAlgorithm(const Algorithm algorithm)
{
algorithm_ = algorithm;
}
bool MinuitMinimizer::supportLimits(void) const
{
return true;
}
// minimization ////////////////////////////////////////////////////////////////
const DVec & MinuitMinimizer::operator()(const DoubleFunction &f)
{
using namespace ROOT;
using namespace Minuit2;
DVec &x = getState();
int printLevel = 0;
EMinimizerType minuitAlg = kCombined;
double prec = getPrecision();
// convert Latan parameters to Minuit parameters
switch (getVerbosity())
{
case Verbosity::Silent:
printLevel = 0;
break;
case Verbosity::Normal:
printLevel = 2;
break;
case Verbosity::Debug:
printLevel = 3;
break;
}
// The factor of 0.002 here is to compensate the dirty hack in Minuit
// source used to match the C++ and F77 versions
// (cf. VariableMetricBuilder.cxx)
switch (getAlgorithm())
{
case Algorithm::migrad:
minuitAlg = kMigrad;
prec /= 0.002;
break;
case Algorithm::simplex:
minuitAlg = kSimplex;
break;
case Algorithm::combined:
minuitAlg = kCombined;
prec /= 0.002;
break;
}
// resize minimizer state to match function number of arguments
if (f.getNArg() != x.size())
{
resize(f.getNArg());
}
// create and set minimizer
Minuit2Minimizer min(minuitAlg);
min.SetStrategy(2);
min.SetMaxFunctionCalls(getMaxIteration());
min.SetTolerance(prec);
min.SetPrintLevel(printLevel);
// set function and variables
Math::Functor minuitF(f, x.size());
string name;
double val, step;
min.SetFunction(minuitF);
for (Index i = 0; i < x.size(); ++i)
{
name = f.varName().getName(i);
val = x(i);
step = (fabs(x(i)) != 0.) ? initErr*fabs(x(i)) : 1.;
if (hasHighLimit(i) and !hasLowLimit(i))
{
min.SetUpperLimitedVariable(i, name, val, step, getHighLimit(i));
}
else if (!hasHighLimit(i) and hasLowLimit(i))
{
min.SetLowerLimitedVariable(i, name, val, step, getLowLimit(i));
}
else if (hasHighLimit(i) and hasLowLimit(i))
{
min.SetLimitedVariable(i, name, val, step, getLowLimit(i),
getHighLimit(i));
}
else
{
min.SetVariable(i, name, val, step);
}
}
// minimize
int status;
unsigned int n = 0;
do
{
if (getVerbosity() >= Verbosity::Normal)
{
cout << "========== Minuit minimization, pass #" << n + 1;
cout << " =========" << endl;
}
min.Minimize();
status = min.Status();
n++;
} while ((status >= 2) and (n < getMaxPass()));
if (getVerbosity() >= Verbosity::Normal)
{
cout << "=================================================" << endl;
}
switch (status)
{
case 1:
// covariance matrix was made positive, the minimum is still good
// it just means that Minuit error analysis is inaccurate
break;
case 2:
LATAN_WARNING("invalid minimum: Hesse analysis is not valid");
break;
case 3:
LATAN_WARNING("invalid minimum: requested precision not reached");
break;
case 4:
LATAN_WARNING("invalid minimum: iteration limit reached");
break;
}
// save and return result
for (Index i = 0; i < x.size(); ++i)
{
x(i) = min.X()[i];
}
return x;
}

View File

@ -0,0 +1,60 @@
/*
* MinuitMinimizer.hpp, part of LatAnalyze 3
*
* Copyright (C) 2013 - 2020 Antonin Portelli
*
* LatAnalyze 3 is free software: you can redistribute it and/or modify
* it under the terms of the GNU General Public License as published by
* the Free Software Foundation, either version 3 of the License, or
* (at your option) any later version.
*
* LatAnalyze 3 is distributed in the hope that it will be useful,
* but WITHOUT ANY WARRANTY; without even the implied warranty of
* MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
* GNU General Public License for more details.
*
* You should have received a copy of the GNU General Public License
* along with LatAnalyze 3. If not, see <http://www.gnu.org/licenses/>.
*/
#ifndef Latan_MinuitMinimizer_hpp_
#define Latan_MinuitMinimizer_hpp_
#include <LatAnalyze/Global.hpp>
#include <LatAnalyze/Functional/Function.hpp>
#include <LatAnalyze/Numerical/Minimizer.hpp>
BEGIN_LATAN_NAMESPACE
/******************************************************************************
* interface to CERN Minuit minimizer *
* ( http://www.cern.ch/minuit ) *
******************************************************************************/
class MinuitMinimizer: public Minimizer
{
public:
enum class Algorithm
{
migrad = 1,
simplex = 2,
combined = 3
};
public:
// constructor
explicit MinuitMinimizer(const Algorithm algorithm = defaultAlg_);
// destructor
virtual ~MinuitMinimizer(void) = default;
// access
Algorithm getAlgorithm(void) const;
void setAlgorithm(const Algorithm algorithm);
virtual bool supportLimits(void) const;
// minimization
virtual const DVec & operator()(const DoubleFunction &f);
private:
Algorithm algorithm_;
static constexpr Algorithm defaultAlg_ = Algorithm::combined;
};
END_LATAN_NAMESPACE
#endif // Latan_MinuitMinimizer_hpp_

View File

@ -0,0 +1,201 @@
/*
* NloptMinimizer.cpp, part of LatAnalyze 3
*
* Copyright (C) 2013 - 2020 Antonin Portelli
*
* LatAnalyze 3 is free software: you can redistribute it and/or modify
* it under the terms of the GNU General Public License as published by
* the Free Software Foundation, either version 3 of the License, or
* (at your option) any later version.
*
* LatAnalyze 3 is distributed in the hope that it will be useful,
* but WITHOUT ANY WARRANTY; without even the implied warranty of
* MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
* GNU General Public License for more details.
*
* You should have received a copy of the GNU General Public License
* along with LatAnalyze 3. If not, see <http://www.gnu.org/licenses/>.
*/
#include <LatAnalyze/Numerical/NloptMinimizer.hpp>
#include <LatAnalyze/includes.hpp>
using namespace std;
using namespace Latan;
/******************************************************************************
* NloptMinimizer implementation *
******************************************************************************/
// constructors ////////////////////////////////////////////////////////////////
NloptMinimizer::NloptMinimizer(const Algorithm algorithm)
{
setAlgorithm(algorithm);
der_.setOrder(1, 1);
}
// access //////////////////////////////////////////////////////////////////////
NloptMinimizer::Algorithm NloptMinimizer::getAlgorithm(void) const
{
return algorithm_;
}
void NloptMinimizer::setAlgorithm(const Algorithm algorithm)
{
algorithm_ = algorithm;
}
bool NloptMinimizer::supportLimits(void) const
{
return true;
}
// minimization ////////////////////////////////////////////////////////////////
const DVec & NloptMinimizer::operator()(const DoubleFunction &f)
{
DVec &x = getState();
// resize minimizer state to match function number of arguments
if (f.getNArg() != x.size())
{
resize(f.getNArg());
}
// create and set minimizer
nlopt::opt min(getAlgorithm(), x.size());
NloptFuncData data;
vector<double> lb(x.size()), hb(x.size());
min.set_maxeval(getMaxIteration());
min.set_xtol_rel(getPrecision());
min.set_ftol_rel(-1.);
der_.setFunction(f);
data.f = &f;
data.d = &der_;
min.set_min_objective(&funcWrapper, &data);
for (Index i = 0; i < x.size(); ++i)
{
lb[i] = hasLowLimit(i) ? getLowLimit(i) : -HUGE_VAL;
hb[i] = hasHighLimit(i) ? getHighLimit(i) : HUGE_VAL;
}
min.set_lower_bounds(lb);
min.set_upper_bounds(hb);
// minimize
double res;
vector<double> vx(x.size());
nlopt::result status;
unsigned int n = 0;
for (Index i = 0; i < x.size(); ++i)
{
vx[i] = x(i);
}
do
{
if (getVerbosity() >= Verbosity::Normal)
{
cout << "========== NLopt minimization, pass #" << n + 1;
cout << " ==========" << endl;
cout << "Algorithm: " << min.get_algorithm_name() << endl;
cout << "Max eval.= " << min.get_maxeval();
cout << " -- Precision= " << min.get_xtol_rel() << endl;
printf("Starting f(x)= %.10e\n", f(x));
}
try
{
status = min.optimize(vx, res);
}
catch (invalid_argument &e)
{
LATAN_ERROR(Runtime, "NLopt has reported receving invalid "
"arguments (if you are using a global minimizer, did "
"you specify limits for all variables?)");
}
if (getVerbosity() >= Verbosity::Normal)
{
printf("Found minimum %.10e at:\n", res);
for (Index i = 0; i < x.size(); ++i)
{
printf("%8s= %.10e\n", f.varName().getName(i).c_str(), vx[i]);
}
cout << "after " << data.evalCount << " evaluations" << endl;
cout << "Minimization ended with code " << status;
cout << " (" << returnMessage(status) << ")";
cout << endl;
}
data.evalCount = 0;
for (Index i = 0; i < x.size(); ++i)
{
x(i) = vx[i];
}
n++;
} while (!minSuccess(status) and (n < getMaxPass()));
if (getVerbosity() >= Verbosity::Normal)
{
cout << "=================================================" << endl;
}
if (!minSuccess(status))
{
LATAN_WARNING("invalid minimum: " + returnMessage(status));
}
return x;
}
// NLopt return code parser ////////////////////////////////////////////////////
string NloptMinimizer::returnMessage(const nlopt::result status)
{
switch (status)
{
case nlopt::SUCCESS:
return "success";
case nlopt::STOPVAL_REACHED:
return "stopping value reached";
case nlopt::FTOL_REACHED:
return "tolerance on function reached";
case nlopt::XTOL_REACHED:
return "tolerance on variable reached";
case nlopt::MAXEVAL_REACHED:
return "maximum function evaluation reached";
case nlopt::MAXTIME_REACHED:
return "maximum time reached";
default:
return "";
}
}
// NLopt function wrapper //////////////////////////////////////////////////////
double NloptMinimizer::funcWrapper(unsigned int n, const double *arg,
double *grad , void *vdata)
{
NloptFuncData &data = *static_cast<NloptFuncData *>(vdata);
if (grad)
{
for (unsigned int i = 0; i < n; ++i)
{
data.d->setDir(i);
grad[i] = (*(data.d))(arg);
}
data.evalCount += data.d->getNPoint()*n;
}
data.evalCount++;
return (*data.f)(arg);
}
// NLopt return status parser //////////////////////////////////////////////////
bool NloptMinimizer::minSuccess(const nlopt::result status)
{
switch (status)
{
case nlopt::SUCCESS:
case nlopt::FTOL_REACHED:
case nlopt::XTOL_REACHED:
return true;
break;
default:
return false;
break;
}
}

View File

@ -0,0 +1,76 @@
/*
* NloptMinimizer.hpp, part of LatAnalyze 3
*
* Copyright (C) 2013 - 2020 Antonin Portelli
*
* LatAnalyze 3 is free software: you can redistribute it and/or modify
* it under the terms of the GNU General Public License as published by
* the Free Software Foundation, either version 3 of the License, or
* (at your option) any later version.
*
* LatAnalyze 3 is distributed in the hope that it will be useful,
* but WITHOUT ANY WARRANTY; without even the implied warranty of
* MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
* GNU General Public License for more details.
*
* You should have received a copy of the GNU General Public License
* along with LatAnalyze 3. If not, see <http://www.gnu.org/licenses/>.
*/
#ifndef Latan_NloptMinimizer_hpp_
#define Latan_NloptMinimizer_hpp_
#include <LatAnalyze/Global.hpp>
#include <LatAnalyze/Numerical/Derivative.hpp>
#include <LatAnalyze/Functional/Function.hpp>
#include <LatAnalyze/Numerical/Minimizer.hpp>
#include <nlopt.hpp>
BEGIN_LATAN_NAMESPACE
/******************************************************************************
* interface to NLOpt minimizers *
* ( http://ab-initio.mit.edu/wiki/index.php/NLopt ) *
* -------------------------------------------------------------------------- *
* cf. http://ab-initio.mit.edu/wiki/index.php/NLopt_Algorithms for algorithm *
* references and naming conventions *
******************************************************************************/
class NloptMinimizer: public Minimizer
{
public:
typedef nlopt::algorithm Algorithm;
private:
struct NloptFuncData
{
const DoubleFunction *f{nullptr};
Derivative *d{nullptr};
unsigned int evalCount{0};
};
public:
// constructor
explicit NloptMinimizer(const Algorithm algorithm = defaultAlg_);
// destructor
virtual ~NloptMinimizer(void) = default;
// access
Algorithm getAlgorithm(void) const;
void setAlgorithm(const Algorithm algorithm);
virtual bool supportLimits(void) const;
// minimization
virtual const DVec & operator()(const DoubleFunction &f);
private:
// NLopt return code parser
static std::string returnMessage(const nlopt::result status);
// NLopt function wrapper
static double funcWrapper(unsigned int n, const double *arg,
double *grad , void *vdata);
// NLopt return status parser
static bool minSuccess(const nlopt::result status);
private:
Algorithm algorithm_;
static constexpr Algorithm defaultAlg_ = Algorithm::LN_NELDERMEAD;
CentralDerivative der_;
};
END_LATAN_NAMESPACE
#endif // Latan_NloptMinimizer_hpp_

View File

@ -0,0 +1,29 @@
/*
* RootFinder.cpp, part of LatAnalyze 3
*
* Copyright (C) 2013 - 2020 Antonin Portelli
*
* LatAnalyze 3 is free software: you can redistribute it and/or modify
* it under the terms of the GNU General Public License as published by
* the Free Software Foundation, either version 3 of the License, or
* (at your option) any later version.
*
* LatAnalyze 3 is distributed in the hope that it will be useful,
* but WITHOUT ANY WARRANTY; without even the implied warranty of
* MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
* GNU General Public License for more details.
*
* You should have received a copy of the GNU General Public License
* along with LatAnalyze 3. If not, see <http://www.gnu.org/licenses/>.
*/
#include <LatAnalyze/Numerical/RootFinder.hpp>
#include <LatAnalyze/includes.hpp>
using namespace std;
using namespace Latan;
// constructor /////////////////////////////////////////////////////////////////
RootFinder::RootFinder(const Index dim)
: Solver(dim)
{}

View File

@ -0,0 +1,48 @@
/*
* RootFinder.hpp, part of LatAnalyze 3
*
* Copyright (C) 2013 - 2020 Antonin Portelli
*
* LatAnalyze 3 is free software: you can redistribute it and/or modify
* it under the terms of the GNU General Public License as published by
* the Free Software Foundation, either version 3 of the License, or
* (at your option) any later version.
*
* LatAnalyze 3 is distributed in the hope that it will be useful,
* but WITHOUT ANY WARRANTY; without even the implied warranty of
* MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
* GNU General Public License for more details.
*
* You should have received a copy of the GNU General Public License
* along with LatAnalyze 3. If not, see <http://www.gnu.org/licenses/>.
*/
#ifndef Latan_RootFinder_hpp_
#define Latan_RootFinder_hpp_
#include <LatAnalyze/Global.hpp>
#include <LatAnalyze/Functional/Function.hpp>
#include <LatAnalyze/Numerical/Solver.hpp>
BEGIN_LATAN_NAMESPACE
/******************************************************************************
* RootFinder *
******************************************************************************/
class RootFinder: public Solver
{
public:
// constructors
RootFinder(void) = default;
explicit RootFinder(const Index dim);
// destructor
virtual ~RootFinder(void) = default;
// solver
virtual const DVec & operator()(const std::vector<DoubleFunction *> &func)
= 0;
};
END_LATAN_NAMESPACE
#endif // Latan_RootFinder_hpp_

View File

@ -0,0 +1,96 @@
/*
* Solver.cpp, part of LatAnalyze 3
*
* Copyright (C) 2013 - 2020 Antonin Portelli
*
* LatAnalyze 3 is free software: you can redistribute it and/or modify
* it under the terms of the GNU General Public License as published by
* the Free Software Foundation, either version 3 of the License, or
* (at your option) any later version.
*
* LatAnalyze 3 is distributed in the hope that it will be useful,
* but WITHOUT ANY WARRANTY; without even the implied warranty of
* MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
* GNU General Public License for more details.
*
* You should have received a copy of the GNU General Public License
* along with LatAnalyze 3. If not, see <http://www.gnu.org/licenses/>.
*/
#include <LatAnalyze/Numerical/Solver.hpp>
#include <LatAnalyze/includes.hpp>
using namespace std;
using namespace Latan;
/******************************************************************************
* Solver implementation *
******************************************************************************/
// constructors ////////////////////////////////////////////////////////////////
Solver::Solver(const double precision, const unsigned int maxIteration)
{
setMaxIteration(maxIteration);
setPrecision(precision);
}
Solver::Solver(const Index dim, const double precision,
const unsigned int maxIteration)
: Solver(precision, maxIteration)
{
resize(dim);
}
// access //////////////////////////////////////////////////////////////////////
Index Solver::getDim(void) const
{
return x_.size();
}
unsigned int Solver::getMaxIteration(void) const
{
return maxIteration_;
}
double Solver::getPrecision(void) const
{
return precision_;
}
DVec & Solver::getState(void)
{
return x_;
}
Solver::Verbosity Solver::getVerbosity(void) const
{
return verbosity_;
}
void Solver::setInit(const DVec &x0)
{
if (x0.size() != x_.size())
{
resize(x0.size());
}
x_ = x0;
}
void Solver::setMaxIteration(const unsigned int maxIteration)
{
maxIteration_ = maxIteration;
}
void Solver::setPrecision(const double precision)
{
precision_ = precision;
}
void Solver::setVerbosity(const Verbosity verbosity)
{
verbosity_ = verbosity;
}
void Solver::resize(const Index dim)
{
x_.resize(dim);
}

View File

@ -0,0 +1,73 @@
/*
* Solver.hpp, part of LatAnalyze 3
*
* Copyright (C) 2013 - 2020 Antonin Portelli
*
* LatAnalyze 3 is free software: you can redistribute it and/or modify
* it under the terms of the GNU General Public License as published by
* the Free Software Foundation, either version 3 of the License, or
* (at your option) any later version.
*
* LatAnalyze 3 is distributed in the hope that it will be useful,
* but WITHOUT ANY WARRANTY; without even the implied warranty of
* MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
* GNU General Public License for more details.
*
* You should have received a copy of the GNU General Public License
* along with LatAnalyze 3. If not, see <http://www.gnu.org/licenses/>.
*/
#ifndef Latan_Solver_hpp_
#define Latan_Solver_hpp_
#include <LatAnalyze/Global.hpp>
BEGIN_LATAN_NAMESPACE
/******************************************************************************
* Solver *
******************************************************************************/
class Solver
{
public:
static const unsigned int defaultMaxIteration = 10000u;
static constexpr double defaultPrec = 1.0e-7;
public:
enum class Verbosity
{
Silent = 0,
Normal = 1,
Debug = 2
};
public:
// constructors
Solver(const double precision = defaultPrec,
const unsigned int maxIteration = defaultMaxIteration);
explicit Solver(const Index dim, const double precision = defaultPrec,
const unsigned int maxIteration = defaultMaxIteration);
// destructor
virtual ~Solver(void) = default;
// access
Index getDim(void) const;
virtual double getPrecision(void) const;
virtual unsigned int getMaxIteration(void) const;
Verbosity getVerbosity(void) const;
virtual void setInit(const DVec &x0);
virtual void setPrecision(const double precision);
virtual void setMaxIteration(const unsigned int maxIteration);
void setVerbosity(const Verbosity verbosity);
virtual void resize(const Index dim);
protected:
// access
DVec & getState(void);
private:
unsigned int maxIteration_;
double precision_;
DVec x_;
Verbosity verbosity_{Verbosity::Silent};
};
END_LATAN_NAMESPACE
#endif // Latan_Solver_hpp_

View File

@ -0,0 +1,417 @@
/*
* CorrelatorFitter.cpp, part of LatAnalyze 3
*
* Copyright (C) 2013 - 2020 Antonin Portelli
*
* LatAnalyze 3 is free software: you can redistribute it and/or modify
* it under the terms of the GNU General Public License as published by
* the Free Software Foundation, either version 3 of the License, or
* (at your option) any later version.
*
* LatAnalyze 3 is distributed in the hope that it will be useful,
* but WITHOUT ANY WARRANTY; without even the implied warranty of
* MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
* GNU General Public License for more details.
*
* You should have received a copy of the GNU General Public License
* along with LatAnalyze 3. If not, see <http://www.gnu.org/licenses/>.
*/
#include <LatAnalyze/Physics/CorrelatorFitter.hpp>
#include <LatAnalyze/includes.hpp>
using namespace std;
using namespace Latan;
/******************************************************************************
* Correlator models *
******************************************************************************/
DoubleModel CorrelatorModels::makeExpModel(const Index nState)
{
DoubleModel mod;
mod.setFunction([nState](const double *x, const double *p)
{
double res = 0.;
for (unsigned int i = 0; i < nState; ++i)
{
res += p[2*i + 1]*exp(-p[2*i]*x[0]);
}
return res;
}, 1, 2*nState);
for (unsigned int i = 0; i < nState; ++i)
{
mod.parName().setName(2*i, "E_" + strFrom(i));
mod.parName().setName(2*i + 1, "Z_" + strFrom(i));
}
return mod;
}
DoubleModel CorrelatorModels::makeCoshModel(const Index nState, const Index nt)
{
DoubleModel mod;
mod.setFunction([nState, nt](const double *x, const double *p)
{
double res = 0.;
for (unsigned int i = 0; i < nState; ++i)
{
res += p[2*i + 1]*(exp(-p[2*i]*x[0]) + exp(-p[2*i]*(nt - x[0])));
}
return res;
}, 1, 2*nState);
for (unsigned int i = 0; i < nState; ++i)
{
mod.parName().setName(2*i, "E_" + strFrom(i));
mod.parName().setName(2*i + 1, "Z_" + strFrom(i));
}
return mod;
}
DoubleModel CorrelatorModels::makeSinhModel(const Index nState, const Index nt)
{
DoubleModel mod;
mod.setFunction([nState, nt](const double *x, const double *p)
{
double res = 0.;
for (unsigned int i = 0; i < nState; ++i)
{
res += p[2*i + 1]*(exp(-p[2*i]*x[0]) - exp(-p[2*i]*(nt - x[0])));
}
return res;
}, 1, 2*nState);
for (unsigned int i = 0; i < nState; ++i)
{
mod.parName().setName(2*i, "E_" + strFrom(i));
mod.parName().setName(2*i + 1, "Z_" + strFrom(i));
}
return mod;
}
DoubleModel CorrelatorModels::makeConstModel(void)
{
DoubleModel mod;
mod.setFunction([](const double *x __dumb, const double *p __dumb)
{
return p[0];
}, 1, 1);
mod.parName().setName(0, "cst");
return mod;
}
DoubleModel CorrelatorModels::makeLinearModel(void)
{
DoubleModel mod;
mod.setFunction([](const double *x, const double *p)
{
return p[1] + p[0]*x[0];
}, 1, 2);
return mod;
}
CorrelatorModels::ModelPar CorrelatorModels::parseModel(const string s)
{
smatch sm;
ModelPar par;
if (regex_match(s, sm, regex("exp([0-9]+)")))
{
par.type = CorrelatorType::exp;
par.nState = strTo<Index>(sm[1].str());
}
else if (regex_match(s, sm, regex("cosh([0-9]+)")))
{
par.type = CorrelatorType::cosh;
par.nState = strTo<Index>(sm[1].str());
}
else if (regex_match(s, sm, regex("sinh([0-9]+)")))
{
par.type = CorrelatorType::sinh;
par.nState = strTo<Index>(sm[1].str());
}
else if (s == "linear")
{
par.type = CorrelatorType::linear;
par.nState = 1;
}
else if (s == "cst")
{
par.type = CorrelatorType::cst;
par.nState = 1;
}
else
{
par.type = CorrelatorType::undefined;
par.nState = 0;
}
return par;
}
DoubleModel CorrelatorModels::makeModel(const CorrelatorModels::ModelPar par,
const Index nt)
{
switch (par.type)
{
case CorrelatorType::undefined:
LATAN_ERROR(Argument, "correlator type is undefined");
break;
case CorrelatorType::exp:
return makeExpModel(par.nState);
break;
case CorrelatorType::cosh:
return makeCoshModel(par.nState, nt);
break;
case CorrelatorType::sinh:
return makeSinhModel(par.nState, nt);
break;
case CorrelatorType::linear:
return makeLinearModel();
break;
case CorrelatorType::cst:
return makeConstModel();
break;
}
}
DVec CorrelatorModels::parameterGuess(const DMatSample &corr,
const ModelPar par)
{
DVec init;
Index nt = corr[central].size();
switch (par.type)
{
case CorrelatorType::undefined:
LATAN_ERROR(Argument, "correlator type is undefined");
break;
case CorrelatorType::exp:
case CorrelatorType::cosh:
case CorrelatorType::sinh:
init.resize(2*par.nState);
init(0) = log(corr[central](nt/4)/corr[central](nt/4 + 1));
init(1) = corr[central](nt/4)/(exp(-init(0)*nt/4));
for (Index p = 2; p < init.size(); p += 2)
{
init(p) = 2*init(p - 2);
init(p + 1) = init(p - 1)/2.;
}
break;
case CorrelatorType::linear:
init.resize(2);
init(0) = corr[central](nt/4) - corr[central](nt/4 + 1, 0);
init(1) = corr[central](nt/4, 0) + nt/4*init(0);
break;
case CorrelatorType::cst:
init.resize(1);
init(0) = corr[central](nt/4);
break;
default:
break;
}
return init;
}
/******************************************************************************
* Correlator utilities *
******************************************************************************/
DMatSample CorrelatorUtils::shift(const DMatSample &c, const Index ts)
{
if (ts != 0)
{
const Index nt = c[central].rows();
DMatSample buf = c;
FOR_STAT_ARRAY(buf, s)
{
for (Index t = 0; t < nt; ++t)
{
buf[s]((t - ts + nt)%nt) = c[s](t);
}
}
return buf;
}
else
{
return c;
}
}
DMatSample CorrelatorUtils::fold(const DMatSample &c)
{
const Index nt = c[central].rows();
DMatSample buf = c;
FOR_STAT_ARRAY(buf, s)
{
for (Index t = 0; t < nt; ++t)
{
buf[s](t) = 0.5*(c[s](t) + c[s]((nt - t) % nt));
}
}
return buf;
}
DMatSample CorrelatorUtils::fourierTransform(const DMatSample &c, FFT &fft,
const unsigned int dir)
{
const Index nSample = c.size();
const Index nt = c[central].rows();
bool isComplex = (c[central].cols() > 1);
CMatSample buf(nSample, nt, 1);
DMatSample out(nSample, nt, 2);
fft.resize(nt);
FOR_STAT_ARRAY(buf, s)
{
buf[s].real() = c[s].col(0);
if (isComplex)
{
buf[s].imag() = c[s].col(1);
}
else
{
buf[s].imag() = DVec::Constant(nt, 0.);
}
fft(buf[s], dir);
out[s].col(0) = buf[s].real();
out[s].col(1) = buf[s].imag();
}
return out;
}
/******************************************************************************
* CorrelatorFitter implementation *
******************************************************************************/
// constructors ////////////////////////////////////////////////////////////////
CorrelatorFitter::CorrelatorFitter(const DMatSample &corr)
{
setCorrelator(corr);
}
CorrelatorFitter::CorrelatorFitter(const std::vector<DMatSample> &corr)
{
setCorrelators(corr);
}
// access //////////////////////////////////////////////////////////////////////
XYSampleData & CorrelatorFitter::data(void)
{
return *data_;
}
void CorrelatorFitter::setCorrelator(const DMatSample &corr)
{
std::vector<DMatSample> vec;
vec.push_back(corr);
setCorrelators(vec);
}
void CorrelatorFitter::setCorrelators(const std::vector<DMatSample> &corr)
{
Index nSample = corr[0].size();
DMatSample tVec(nSample);
std::vector<const DMatSample *> ptVec;
nt_ = corr[0][central].rows();
tVec.fill(DVec::LinSpaced(nt_, 0, nt_ - 1));
for (auto &c: corr)
{
ptVec.push_back(&c);
}
data_.reset(new XYSampleData(corr[0].size()));
data_->addXDim(nt_, "t/a", true);
for (unsigned int i = 0; i < corr.size(); ++i)
{
data_->addYDim("C_" + strFrom(i) + "(t)");
}
data_->setUnidimData(tVec, ptVec);
model_.resize(corr.size());
range_.resize(corr.size(), make_pair(0, nt_ - 1));
thinning_.resize(corr.size(), 1);
}
void CorrelatorFitter::setModel(const DoubleModel &model, const Index i)
{
model_[i] = model;
}
const DoubleModel & CorrelatorFitter::getModel(const Index i) const
{
return model_.at(i);
}
void CorrelatorFitter::setFitRange(const Index tMin, const Index tMax,
const Index i)
{
range_[i] = make_pair(tMin, tMax);
refreshRanges();
}
void CorrelatorFitter::setCorrelation(const bool isCorrelated, const Index i,
const Index j)
{
data_->assumeYYCorrelated(isCorrelated, i, j);
}
DMat CorrelatorFitter::getVarianceMatrix(void) const
{
return data_->getFitVarMat();
}
void CorrelatorFitter::setThinning(const Index thinning, const Index i)
{
thinning_[i] = thinning;
refreshRanges();
}
// fit functions ///////////////////////////////////////////////////////////////
SampleFitResult CorrelatorFitter::fit(Minimizer &minimizer, const DVec &init)
{
vector<Minimizer *> vecPt = {&minimizer};
return fit(vecPt, init);
}
SampleFitResult CorrelatorFitter::fit(vector<Minimizer *> &minimizer,
const DVec &init)
{
vector<const DoubleModel *> vecPt(model_.size());
for (unsigned int i = 0; i < model_.size(); ++i)
{
vecPt[i] = &(model_[i]);
}
return data_->fit(minimizer, init, vecPt);
}
// internal function to refresh fit ranges /////////////////////////////////////
void CorrelatorFitter::refreshRanges(void)
{
for (unsigned int i = 0; i < range_.size(); ++i)
for (Index t = 0; t < nt_; ++t)
{
data_->fitPoint((t >= range_[i].first) and (t <= range_[i].second)
and ((t - range_[i].first) % thinning_[i] == 0), t);
}
}

View File

@ -0,0 +1,104 @@
/*
* CorrelatorFitter.hpp, part of LatAnalyze 3
*
* Copyright (C) 2013 - 2020 Antonin Portelli
*
* LatAnalyze 3 is free software: you can redistribute it and/or modify
* it under the terms of the GNU General Public License as published by
* the Free Software Foundation, either version 3 of the License, or
* (at your option) any later version.
*
* LatAnalyze 3 is distributed in the hope that it will be useful,
* but WITHOUT ANY WARRANTY; without even the implied warranty of
* MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
* GNU General Public License for more details.
*
* You should have received a copy of the GNU General Public License
* along with LatAnalyze 3. If not, see <http://www.gnu.org/licenses/>.
*/
#ifndef Latan_CorrelatorFitter_hpp_
#define Latan_CorrelatorFitter_hpp_
#include <LatAnalyze/Global.hpp>
#include <LatAnalyze/Functional/Model.hpp>
#include <LatAnalyze/Numerical/FFT.hpp>
#include <LatAnalyze/Statistics/XYSampleData.hpp>
BEGIN_LATAN_NAMESPACE
/******************************************************************************
* Correlator types & models *
******************************************************************************/
enum class CorrelatorType {undefined, exp, cosh, sinh, linear, cst};
namespace CorrelatorModels
{
struct ModelPar
{
CorrelatorType type;
Index nState;
};
DoubleModel makeExpModel(const Index nState);
DoubleModel makeCoshModel(const Index nState, const Index nt);
DoubleModel makeSinhModel(const Index nState, const Index nt);
DoubleModel makeConstModel(void);
DoubleModel makeLinearModel(void);
ModelPar parseModel(const std::string s);
DoubleModel makeModel(const ModelPar par, const Index nt);
DVec parameterGuess(const DMatSample &corr, const ModelPar par);
};
/******************************************************************************
* Correlator utilities *
******************************************************************************/
namespace CorrelatorUtils
{
DMatSample shift(const DMatSample &c, const Index ts);
DMatSample fold(const DMatSample &c);
DMatSample fourierTransform(const DMatSample &c, FFT &fft,
const unsigned int dir = FFT::Forward);
};
/******************************************************************************
* Correlator fit utility class *
******************************************************************************/
class CorrelatorFitter
{
public:
// constructors
CorrelatorFitter(const DMatSample &corr);
CorrelatorFitter(const std::vector<DMatSample> &corr);
// destructor
virtual ~CorrelatorFitter(void) = default;
// access
XYSampleData & data(void);
void setCorrelator(const DMatSample &corr);
void setCorrelators(const std::vector<DMatSample> &corr);
const DMatSample & getCorrelator(const Index i = 0) const;
const std::vector<DMatSample> & getCorrelators(void) const;
void setModel(const DoubleModel &model, const Index i = 0);
const DoubleModel & getModel(const Index i = 0) const;
void setFitRange(const Index tMin, const Index tMax, const Index i = 0);
void setCorrelation(const bool isCorrelated, const Index i = 0,
const Index j = 0);
DMat getVarianceMatrix(void) const;
void setThinning(const Index thinning, const Index i = 0);
// fit functions
SampleFitResult fit(Minimizer &minimizer, const DVec &init);
SampleFitResult fit(std::vector<Minimizer *> &minimizer, const DVec &init);
private:
// internal function to refresh fit ranges
void refreshRanges(void);
private:
Index nt_{0};
std::unique_ptr<XYSampleData> data_;
std::vector<DoubleModel> model_;
std::vector<std::pair<Index, Index>> range_;
std::vector<Index> thinning_;
};
END_LATAN_NAMESPACE
#endif // Latan_CorrelatorFitter_hpp_

View File

@ -0,0 +1,83 @@
/*
* DataFilter.cpp, part of LatAnalyze 3
*
* Copyright (C) 2013 - 2020 Antonin Portelli
*
* LatAnalyze 3 is free software: you can redistribute it and/or modify
* it under the terms of the GNU General Public License as published by
* the Free Software Foundation, either version 3 of the License, or
* (at your option) any later version.
*
* LatAnalyze 3 is distributed in the hope that it will be useful,
* but WITHOUT ANY WARRANTY; without even the implied warranty of
* MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
* GNU General Public License for more details.
*
* You should have received a copy of the GNU General Public License
* along with LatAnalyze 3. If not, see <http://www.gnu.org/licenses/>.
*/
#include <LatAnalyze/Physics/DataFilter.hpp>
#include <LatAnalyze/includes.hpp>
#include <LatAnalyze/Numerical/DWT.hpp>
using namespace std;
using namespace Latan;
/******************************************************************************
* DataFilter implementation *
******************************************************************************/
// constructor ////////////////////////////////////////////////////////////////
DataFilter::DataFilter(const vector<double> &filter, const bool downsample)
: filter_(filter), downsample_(downsample)
{}
// filtering //////////////////////////////////////////////////////////////////
template <typename MatType>
void filter(MatType &out, const MatType &in, const vector<double> &filter,
const bool downsample, MatType &buf)
{
if (!downsample)
{
out.resizeLike(in);
DWT::filterConvolution(out, in, filter, filter.size()/2);
}
else
{
out.resize(in.rows()/2, in.cols());
buf.resizeLike(in);
DWT::filterConvolution(buf, in, filter, filter.size()/2);
DWT::downsample(out, buf);
}
}
void DataFilter::operator()(DVec &out, const DVec &in)
{
filter(out, in, filter_, downsample_, vBuf_);
}
void DataFilter::operator()(DMat &out, const DMat &in)
{
filter(out, in, filter_, downsample_, mBuf_);
}
/******************************************************************************
* LaplaceDataFilter implementation *
******************************************************************************/
// constructor ////////////////////////////////////////////////////////////////
LaplaceDataFilter::LaplaceDataFilter(const bool downsample)
: DataFilter({1., -2. , 1.}, downsample)
{}
// filtering //////////////////////////////////////////////////////////////////
void LaplaceDataFilter::operator()(DVec &out, const DVec &in, const double lambda)
{
filter_[1] = -2. - lambda;
DataFilter::operator()(out, in);
}
void LaplaceDataFilter::operator()(DMat &out, const DMat &in, const double lambda)
{
filter_[1] = -2. - lambda;
DataFilter::operator()(out, in);
}

View File

@ -0,0 +1,139 @@
/*
* DataFilter.hpp, part of LatAnalyze 3
*
* Copyright (C) 2013 - 2020 Antonin Portelli
*
* LatAnalyze 3 is free software: you can redistribute it and/or modify
* it under the terms of the GNU General Public License as published by
* the Free Software Foundation, either version 3 of the License, or
* (at your option) any later version.
*
* LatAnalyze 3 is distributed in the hope that it will be useful,
* but WITHOUT ANY WARRANTY; without even the implied warranty of
* MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
* GNU General Public License for more details.
*
* You should have received a copy of the GNU General Public License
* along with LatAnalyze 3. If not, see <http://www.gnu.org/licenses/>.
*/
#ifndef Latan_DataFilter_hpp_
#define Latan_DataFilter_hpp_
#include <LatAnalyze/Global.hpp>
#include <LatAnalyze/Core/Math.hpp>
#include <LatAnalyze/Statistics/StatArray.hpp>
#include <LatAnalyze/Statistics/MatSample.hpp>
#include <LatAnalyze/Numerical/Minimizer.hpp>
BEGIN_LATAN_NAMESPACE
/******************************************************************************
* Generic convolution filter class *
******************************************************************************/
class DataFilter
{
public:
// constructor
DataFilter(const std::vector<double> &filter, const bool downsample = false);
// filtering
void operator()(DVec &out, const DVec &in);
void operator()(DMat &out, const DMat &in);
template <typename MatType, Index o>
void operator()(StatArray<MatType, o> &out, const StatArray<MatType, o> &in);
protected:
std::vector<double> filter_;
private:
bool downsample_;
DVec vBuf_;
DMat mBuf_;
};
/******************************************************************************
* Laplacian filter class *
******************************************************************************/
class LaplaceDataFilter: public DataFilter
{
public:
// constructor
LaplaceDataFilter(const bool downsample = false);
// filtering
void operator()(DVec &out, const DVec &in, const double lambda = 0.);
void operator()(DMat &out, const DMat &in, const double lambda = 0.);
template <typename MatType, Index o>
void operator()(StatArray<MatType, o> &out, const StatArray<MatType, o> &in,
const double lambda = 0.);
// correlation optimisation
template <typename MatType, Index o>
double optimiseCdr(const StatArray<MatType, o> &data, Minimizer &min,
const unsigned int nPass = 3);
};
/******************************************************************************
* DataFilter class template implementation *
******************************************************************************/
// filtering //////////////////////////////////////////////////////////////////
template <typename MatType, Index o>
void DataFilter::operator()(StatArray<MatType, o> &out, const StatArray<MatType, o> &in)
{
FOR_STAT_ARRAY(in, s)
{
(*this)(out[s], in[s]);
}
}
/******************************************************************************
* LaplaceDataFilter class template implementation *
******************************************************************************/
// filtering //////////////////////////////////////////////////////////////////
template <typename MatType, Index o>
void LaplaceDataFilter::operator()(StatArray<MatType, o> &out,
const StatArray<MatType, o> &in, const double lambda)
{
FOR_STAT_ARRAY(in, s)
{
(*this)(out[s], in[s], lambda);
}
}
// correlation optimisation ///////////////////////////////////////////////////
template <typename MatType, Index o>
double LaplaceDataFilter::optimiseCdr(const StatArray<MatType, o> &data,
Minimizer &min, const unsigned int nPass)
{
StatArray<MatType, o> fdata(data.size());
DVec init(1);
double reg, prec;
DoubleFunction cdr([&data, &fdata, this](const double *x)
{
double res;
(*this)(fdata, data, x[0]);
res = Math::cdr(fdata.correlationMatrix());
return res;
}, 1);
min.setLowLimit(0., -0.1);
min.setHighLimit(0., 100000.);
init(0) = 0.1;
min.setInit(init);
prec = 0.1;
min.setPrecision(prec);
reg = min(cdr)(0);
for (unsigned int pass = 0; pass < nPass; pass++)
{
min.setLowLimit(0., (1.-10.*prec)*reg);
min.setHighLimit(0., (1.+10.*prec)*reg);
init(0) = reg;
min.setInit(init);
prec *= 0.1;
min.setPrecision(prec);
reg = min(cdr)(0);
}
return reg;
}
END_LATAN_NAMESPACE
#endif // Latan_DataFilter_hpp_

View File

@ -0,0 +1,132 @@
/*
* EffectiveMass.cpp, part of LatAnalyze 3
*
* Copyright (C) 2013 - 2020 Antonin Portelli
*
* LatAnalyze 3 is free software: you can redistribute it and/or modify
* it under the terms of the GNU General Public License as published by
* the Free Software Foundation, either version 3 of the License, or
* (at your option) any later version.
*
* LatAnalyze 3 is distributed in the hope that it will be useful,
* but WITHOUT ANY WARRANTY; without even the implied warranty of
* MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
* GNU General Public License for more details.
*
* You should have received a copy of the GNU General Public License
* along with LatAnalyze 3. If not, see <http://www.gnu.org/licenses/>.
*/
#include <LatAnalyze/Physics/EffectiveMass.hpp>
#include <LatAnalyze/includes.hpp>
using namespace std;
using namespace Latan;
/******************************************************************************
* EffectiveMass implementation *
******************************************************************************/
// constructors ////////////////////////////////////////////////////////////////
EffectiveMass::EffectiveMass(const CorrelatorType type)
{
setType(type);
}
// access //////////////////////////////////////////////////////////////////////
CorrelatorType EffectiveMass::getType(void) const
{
return type_;
}
void EffectiveMass::setType(const CorrelatorType type)
{
type_ = type;
}
DVec EffectiveMass::getTime(const Index nt) const
{
DVec tvec;
switch (type_)
{
case CorrelatorType::undefined:
LATAN_ERROR(Argument, "correlator type is undefined");
break;
case CorrelatorType::exp:
case CorrelatorType::linear:
tvec = DVec::LinSpaced(nt - 1, 0, nt - 2);
break;
case CorrelatorType::cosh:
case CorrelatorType::sinh:
tvec = DVec::LinSpaced(nt - 2, 1, nt - 2);
break;
case CorrelatorType::cst:
tvec = DVec::LinSpaced(nt, 0, nt - 1);
break;
}
return tvec;
}
// compute effective mass //////////////////////////////////////////////////////
DVec EffectiveMass::operator()(const DVec &corr) const
{
Index nt = corr.size();
DVec em;
if (nt < 2)
{
LATAN_ERROR(Size, "input vector has less than 2 elements");
}
switch (type_)
{
case CorrelatorType::undefined:
LATAN_ERROR(Argument, "correlator type is undefined");
break;
case CorrelatorType::exp:
em.resize(nt - 1);
for (Index t = 1; t < nt; ++t)
{
em(t - 1) = log(corr(t - 1)/corr(t));
}
break;
case CorrelatorType::cosh:
em.resize(nt - 2);
for (Index t = 1; t < nt - 1; ++t)
{
em(t - 1) = acosh((corr(t - 1) + corr(t + 1))/(2.*corr(t)));
}
break;
case CorrelatorType::sinh:
em.resize(nt - 2);
for (Index t = 1; t < nt - 1; ++t)
{
em(t - 1) = acosh((corr(t - 1) + corr(t + 1))/(2.*corr(t)));
}
break;
case CorrelatorType::linear:
em.resize(nt - 1);
for (Index t = 1; t < nt; ++t)
{
em(t - 1) = corr(t) - corr(t - 1);
}
break;
case CorrelatorType::cst:
em = corr;
break;
}
return em;
}
DMatSample EffectiveMass::operator()(const DMatSample &corr) const
{
DMatSample em(corr.size());
FOR_STAT_ARRAY(em, s)
{
em[s] = (*this)(corr[s]);
}
return em;
}

View File

@ -0,0 +1,50 @@
/*
* EffectiveMass.hpp, part of LatAnalyze 3
*
* Copyright (C) 2013 - 2020 Antonin Portelli
*
* LatAnalyze 3 is free software: you can redistribute it and/or modify
* it under the terms of the GNU General Public License as published by
* the Free Software Foundation, either version 3 of the License, or
* (at your option) any later version.
*
* LatAnalyze 3 is distributed in the hope that it will be useful,
* but WITHOUT ANY WARRANTY; without even the implied warranty of
* MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
* GNU General Public License for more details.
*
* You should have received a copy of the GNU General Public License
* along with LatAnalyze 3. If not, see <http://www.gnu.org/licenses/>.
*/
#ifndef Latan_EffectiveMass_hpp_
#define Latan_EffectiveMass_hpp_
#include <LatAnalyze/Global.hpp>
#include <LatAnalyze/Statistics/MatSample.hpp>
#include <LatAnalyze/Physics/CorrelatorFitter.hpp>
BEGIN_LATAN_NAMESPACE
/******************************************************************************
* Effective mass class *
******************************************************************************/
class EffectiveMass
{
public:
// constructors
EffectiveMass(const CorrelatorType type = CorrelatorType::exp);
// access
CorrelatorType getType(void) const;
void setType(const CorrelatorType type);
DVec getTime(const Index nt) const;
// compute effective mass
DVec operator()(const DVec &corr) const;
DMatSample operator()(const DMatSample &corr) const;
private:
CorrelatorType type_;
};
END_LATAN_NAMESPACE
#endif // Latan_EffectiveMass_hpp_

View File

@ -0,0 +1,152 @@
/*
* Dataset.hpp, part of LatAnalyze 3
*
* Copyright (C) 2013 - 2020 Antonin Portelli
*
* LatAnalyze 3 is free software: you can redistribute it and/or modify
* it under the terms of the GNU General Public License as published by
* the Free Software Foundation, either version 3 of the License, or
* (at your option) any later version.
*
* LatAnalyze 3 is distributed in the hope that it will be useful,
* but WITHOUT ANY WARRANTY; without even the implied warranty of
* MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
* GNU General Public License for more details.
*
* You should have received a copy of the GNU General Public License
* along with LatAnalyze 3. If not, see <http://www.gnu.org/licenses/>.
*/
#ifndef Latan_Dataset_hpp_
#define Latan_Dataset_hpp_
#include <LatAnalyze/Global.hpp>
#include <LatAnalyze/Io/File.hpp>
#include <LatAnalyze/Statistics/StatArray.hpp>
BEGIN_LATAN_NAMESPACE
/******************************************************************************
* Dataset class *
******************************************************************************/
template <typename T>
class Dataset: public StatArray<T>
{
public:
typedef std::random_device::result_type SeedType;
public:
// constructors
Dataset(void) = default;
Dataset(const Index size);
EIGEN_EXPR_CTOR(Dataset, Dataset<T>, StatArray<T>, ArrayExpr)
// destructor
virtual ~Dataset(void) = default;
// IO
template <typename FileType>
void load(const std::string &listFileName, const std::string &dataName);
// resampling
Sample<T> bootstrapMean(const Index nSample, const SeedType seed);
Sample<T> bootstrapMean(const Index nSample);
void dumpBootstrapSeq(std::ostream &out, const Index nSample,
const SeedType seed);
private:
// mean from pointer vector for resampling
void ptVectorMean(T &m, const std::vector<const T *> &v);
};
/******************************************************************************
* Dataset template implementation *
******************************************************************************/
// constructors ////////////////////////////////////////////////////////////////
template <typename T>
Dataset<T>::Dataset(const Index size)
: StatArray<T>(size)
{}
// IO //////////////////////////////////////////////////////////////////////////
template <typename T>
template <typename FileType>
void Dataset<T>::load(const std::string &listFileName,
const std::string &dataName)
{
FileType file;
std::vector<std::string> dataFileName;
dataFileName = readManifest(listFileName);
this->resize(dataFileName.size());
for (Index i = 0; i < static_cast<Index>(dataFileName.size()); ++i)
{
file.open(dataFileName[i], File::Mode::read);
(*this)[i] = file.template read<T>(dataName);
file.close();
}
}
// resampling //////////////////////////////////////////////////////////////////
template <typename T>
Sample<T> Dataset<T>::bootstrapMean(const Index nSample, const SeedType seed)
{
std::vector<const T *> data(this->size());
Sample<T> s(nSample);
std::mt19937 gen(seed);
std::uniform_int_distribution<Index> dis(0, this->size() - 1);
for (unsigned int j = 0; j < this->size(); ++j)
{
data[j] = &((*this)[static_cast<Index>(j)]);
}
ptVectorMean(s[central], data);
for (Index i = 0; i < nSample; ++i)
{
for (unsigned int j = 0; j < this->size(); ++j)
{
data[j] = &((*this)[dis(gen)]);
}
ptVectorMean(s[i], data);
}
return s;
}
template <typename T>
Sample<T> Dataset<T>::bootstrapMean(const Index nSample)
{
std::random_device rd;
return bootstrapMean(nSample, rd());
}
template <typename T>
void Dataset<T>::dumpBootstrapSeq(std::ostream &out, const Index nSample,
const SeedType seed)
{
std::mt19937 gen(seed);
std::uniform_int_distribution<Index> dis(0, this->size() - 1);
for (Index i = 0; i < nSample; ++i)
{
for (unsigned int j = 0; j < this->size(); ++j)
{
out << dis(gen) << " " << std::endl;
}
out << std::endl;
}
}
template <typename T>
void Dataset<T>::ptVectorMean(T &m, const std::vector<const T *> &v)
{
if (v.size())
{
m = *(v[0]);
for (unsigned int i = 1; i < v.size(); ++i)
{
m += *(v[i]);
}
m /= static_cast<double>(v.size());
}
}
END_LATAN_NAMESPACE
#endif // Latan_Dataset_hpp_

View File

@ -0,0 +1,794 @@
/*
* FitInterface.cpp, part of LatAnalyze 3
*
* Copyright (C) 2013 - 2020 Antonin Portelli
*
* LatAnalyze 3 is free software: you can redistribute it and/or modify
* it under the terms of the GNU General Public License as published by
* the Free Software Foundation, either version 3 of the License, or
* (at your option) any later version.
*
* LatAnalyze 3 is distributed in the hope that it will be useful,
* but WITHOUT ANY WARRANTY; without even the implied warranty of
* MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
* GNU General Public License for more details.
*
* You should have received a copy of the GNU General Public License
* along with LatAnalyze 3. If not, see <http://www.gnu.org/licenses/>.
*/
#include <LatAnalyze/Statistics/FitInterface.hpp>
#include <LatAnalyze/includes.hpp>
using namespace std;
using namespace Latan;
/******************************************************************************
* FitInterface implementation *
******************************************************************************/
// constructor /////////////////////////////////////////////////////////////////
FitInterface::FitInterface(void)
: xName_("x")
, yName_("y")
{}
// copy object (not as a constructor to be accessed from derived class) ////////
void FitInterface::copyInterface(const FitInterface &d)
{
*this = d;
scheduleFitVarMatInit();
}
// add dimensions //////////////////////////////////////////////////////////////
void FitInterface::addXDim(const Index nData, const string name,
const bool isExact)
{
if (getYSize() != 0)
{
LATAN_ERROR(Logic, "cannot add an X dimension if fit data is "
"not empty");
}
else
{
xSize_.push_back(nData);
xIsExact_.push_back(isExact);
maxDataIndex_ *= nData;
createXData(name, nData);
scheduleLayoutInit();
scheduleDataCoordInit();
if (!name.empty())
{
xName().setName(getNXDim() - 1, name);
}
}
}
void FitInterface::addYDim(const string name)
{
yDataIndex_.push_back(map<Index, bool>());
createYData(name);
scheduleLayoutInit();
if (!name.empty())
{
yName().setName(getNYDim() - 1, name);
}
}
// access //////////////////////////////////////////////////////////////////////
Index FitInterface::getNXDim(void) const
{
return xSize_.size();
}
Index FitInterface::getNYDim(void) const
{
return yDataIndex_.size();
}
Index FitInterface::getXSize(void) const
{
Index size = 0;
for (Index i = 0; i < getNXDim(); ++i)
{
size += getXSize(i);
}
return size;
}
Index FitInterface::getXSize(const Index i) const
{
checkXDim(i);
return xSize_[i];
}
Index FitInterface::getYSize(void) const
{
Index size = 0;
for (Index j = 0; j < getNYDim(); ++j)
{
size += getYSize(j);
}
return size;
}
Index FitInterface::getYSize(const Index j) const
{
checkYDim(j);
return static_cast<Index>(yDataIndex_[j].size());
}
Index FitInterface::getXFitSize(void) const
{
Index size = 0;
for (Index i = 0; i < getNXDim(); ++i)
{
size += getXFitSize(i);
}
return size;
}
Index FitInterface::getXFitSize(const Index i) const
{
set<Index> fitCoord;
vector<Index> v;
checkXDim(i);
for (Index j = 0; j < getNYDim(); ++j)
{
for (auto &p: yDataIndex_[j])
{
if (p.second)
{
v = dataCoord(p.first);
fitCoord.insert(v[i]);
}
}
}
return fitCoord.size();
}
Index FitInterface::getYFitSize(void) const
{
Index size = 0;
for (Index j = 0; j < getNYDim(); ++j)
{
size += getYFitSize(j);
}
return size;
}
Index FitInterface::getYFitSize(const Index j) const
{
Index size;
auto pred = [](const pair<Index, bool> &p)
{
return p.second;
};
checkYDim(j);
size = count_if(yDataIndex_[j].begin(), yDataIndex_[j].end(), pred);
return size;
}
Index FitInterface::getMaxDataIndex(void) const
{
return maxDataIndex_;
}
const set<Index> & FitInterface::getDataIndexSet(void) const
{
return dataIndexSet_;
}
double FitInterface::getSvdTolerance(void) const
{
return svdTol_;
}
void FitInterface::setSvdTolerance(const double &tol)
{
svdTol_ = tol;
scheduleLayoutInit();
}
VarName & FitInterface::xName(void)
{
return xName_;
}
const VarName & FitInterface::xName(void) const
{
return xName_;
}
VarName & FitInterface::yName(void)
{
return yName_;
}
const VarName & FitInterface::yName(void) const
{
return yName_;
}
// Y dimension index helper ////////////////////////////////////////////////////
Index FitInterface::dataIndex(const vector<Index> &v) const
{
Index k, n = v.size();
checkDataCoord(v);
k = xSize_[1]*v[0];
for (unsigned int d = 1; d < n-1; ++d)
{
k = xSize_[d+1]*(v[d] + k);
}
k += v[n-1];
return k;
}
const vector<Index> & FitInterface::dataCoord(const Index k) const
{
checkDataIndex(k);
updateDataCoord();
return dataCoord_.at(k);
}
// enable fit points ///////////////////////////////////////////////////////////
void FitInterface::fitPoint(const bool isFitPoint, const Index k, const Index j)
{
checkPoint(k, j);
yDataIndex_[j][k] = isFitPoint;
scheduleLayoutInit();
}
// variance interface //////////////////////////////////////////////////////////
void FitInterface::assumeXExact(const bool isExact, const Index i)
{
checkXDim(i);
xIsExact_[i] = isExact;
scheduleLayoutInit();
}
void FitInterface::addCorr(set<array<Index, 4>> &s, const bool isCorr,
const array<Index, 4> &c)
{
if (isCorr)
{
s.insert(c);
}
else
{
auto it = s.find(c);
if (it != s.end())
{
s.erase(it);
}
}
}
void FitInterface::assumeXXCorrelated(const bool isCorr, const Index r1,
const Index i1, const Index r2,
const Index i2)
{
array<Index, 4> c{{r1, i1, r2, i2}};
checkXIndex(r1, i1);
checkXIndex(r2, i2);
if ((i1 != i2) or (r1 != r2))
{
addCorr(xxCorr_, isCorr, c);
}
scheduleFitVarMatInit();
}
void FitInterface::assumeXXCorrelated(const bool isCorr, const Index i1,
const Index i2)
{
for (Index r1 = 0; r1 < getXSize(i1); ++r1)
for (Index r2 = 0; r2 < getXSize(i2); ++r2)
{
assumeXXCorrelated(isCorr, r1, i1, r2, i2);
}
}
void FitInterface::assumeYYCorrelated(const bool isCorr, const Index k1,
const Index j1, const Index k2,
const Index j2)
{
array<Index, 4> c{{k1, j1, k2, j2}};
checkPoint(k1, j1);
checkPoint(k2, j2);
if ((j1 != j2) or (k1 != k2))
{
addCorr(yyCorr_, isCorr, c);
}
scheduleFitVarMatInit();
}
void FitInterface::assumeYYCorrelated(const bool isCorr, const Index j1,
const Index j2)
{
checkYDim(j1);
checkYDim(j2);
for (auto &p1: yDataIndex_[j1])
for (auto &p2: yDataIndex_[j2])
{
assumeYYCorrelated(isCorr, p1.first, j1, p2.first, j2);
}
}
void FitInterface::assumeXYCorrelated(const bool isCorr, const Index r,
const Index i, const Index k,
const Index j)
{
array<Index, 4> c{{r, i, k, j}};
checkXIndex(r, i);
checkPoint(k, j);
addCorr(xyCorr_, isCorr, c);
scheduleFitVarMatInit();
}
void FitInterface::assumeXYCorrelated(const bool isCorr, const Index i,
const Index j)
{
checkYDim(j);
for (Index r = 0; r < getXSize(i); ++r)
for (auto &p: yDataIndex_[j])
{
assumeXYCorrelated(isCorr, r, i, p.first, j);
}
}
// tests ///////////////////////////////////////////////////////////////////////
bool FitInterface::pointExists(const Index k) const
{
bool isUsed = false;
for (Index j = 0; j < getNYDim(); ++j)
{
isUsed = isUsed or pointExists(k, j);
}
return isUsed;
}
bool FitInterface::pointExists(const Index k, const Index j) const
{
checkDataIndex(k);
checkYDim(j);
return !(yDataIndex_[j].find(k) == yDataIndex_[j].end());
}
bool FitInterface::isXExact(const Index i) const
{
checkXDim(i);
return xIsExact_[i];
}
bool FitInterface::isXUsed(const Index r, const Index i, const bool inFit) const
{
vector<Index> v;
checkXDim(i);
for (Index j = 0; j < getNYDim(); ++j)
{
for (auto &p: yDataIndex_[j])
{
if (p.second or !inFit)
{
v = dataCoord(p.first);
if (v[i] == r)
{
return true;
}
}
}
}
return false;
}
bool FitInterface::isFitPoint(const Index k, const Index j) const
{
checkPoint(k, j);
return yDataIndex_[j].at(k);
}
bool FitInterface::isXXCorrelated(const Index r1, const Index i1,
const Index r2, const Index i2) const
{
array<Index, 4> c{{r1, i1, r2, i2}};
auto it = xxCorr_.find(c);
return (it != xxCorr_.end());
}
bool FitInterface::isYYCorrelated(const Index k1, const Index j1,
const Index k2, const Index j2) const
{
array<Index, 4> c{{k1, j1, k2, j2}};
auto it = yyCorr_.find(c);
return (it != yyCorr_.end());
}
bool FitInterface::isXYCorrelated(const Index r, const Index i,
const Index k, const Index j) const
{
array<Index, 4> c{{r, i, k, j}};
auto it = xyCorr_.find(c);
return (it != xyCorr_.end());
}
bool FitInterface::hasCorrelations(void) const
{
return ((xxCorr_.size() != 0) or (yyCorr_.size() != 0)
or (xyCorr_.size() != 0));
}
// make correlation filter for fit variance matrix /////////////////////////////
DMat FitInterface::makeCorrFilter(void)
{
updateLayout();
DMat f = DMat::Identity(layout.totalSize, layout.totalSize);
Index row, col;
for (auto &c: xxCorr_)
{
row = indX(c[0], c[1]);
col = indX(c[2], c[3]);
if ((row != -1) and (col != -1))
{
f(row, col) = 1.;
f(col, row) = 1.;
}
}
for (auto &c: yyCorr_)
{
row = indY(c[0], c[1]);
col = indY(c[2], c[3]);
if ((row != -1) and (col != -1))
{
f(row, col) = 1.;
f(col, row) = 1.;
}
}
for (auto &c: xyCorr_)
{
row = indX(c[0], c[1]);
col = indY(c[2], c[3]);
if ((row != -1) and (col != -1))
{
f(row, col) = 1.;
f(col, row) = 1.;
}
}
return f;
}
// schedule variance matrix initialization /////////////////////////////////////
void FitInterface::scheduleFitVarMatInit(const bool init)
{
initVarMat_ = init;
}
// register a data point ///////////////////////////////////////////////////////
void FitInterface::registerDataPoint(const Index k, const Index j)
{
checkYDim(j);
yDataIndex_[j][k] = true;
dataIndexSet_.insert(k);
scheduleLayoutInit();
}
// coordinate buffering ////////////////////////////////////////////////////////
void FitInterface::scheduleDataCoordInit(void)
{
initDataCoord_ = true;
scheduleFitVarMatInit();
}
void FitInterface::updateDataCoord(void) const
{
FitInterface * modThis = const_cast<FitInterface *>(this);
if (initDataCoord_)
{
modThis->dataCoord_.clear();
for (auto k: getDataIndexSet())
{
modThis->dataCoord_[k] = rowMajToCoord(k);
}
modThis->initDataCoord_ = false;
}
}
// global layout management ////////////////////////////////////////////////////
void FitInterface::scheduleLayoutInit(void)
{
initLayout_ = true;
scheduleFitVarMatInit();
}
bool FitInterface::initVarMat(void) const
{
return initVarMat_;
}
void FitInterface::updateLayout(void) const
{
if (initLayout_)
{
FitInterface * modThis = const_cast<FitInterface *>(this);
Layout & l = modThis->layout;
Index size, ifit;
vector<Index> v;
l.nXFitDim = 0;
l.nYFitDim = 0;
l.totalSize = 0;
l.totalXSize = 0;
l.totalYSize = 0;
l.xSize.clear();
l.ySize.clear();
l.dataIndexSet.clear();
l.xDim.clear();
l.yDim.clear();
l.xFitDim.clear();
l.yFitDim.clear();
l.x.clear();
l.y.clear();
l.xFit.clear();
l.yFit.clear();
ifit = 0;
for (Index i = 0; i < getNXDim(); ++i)
{
if (!xIsExact_[i])
{
l.nXFitDim++;
size = getXFitSize(i);
l.xSize.push_back(size);
l.totalXSize += size;
l.xDim.push_back(i);
l.xFitDim.push_back(layout.xDim.size() - 1);
l.x.push_back(vector<Index>());
l.xFit.push_back(vector<Index>());
for (Index r = 0; r < getXSize(i); ++r)
{
if (isXUsed(r, i))
{
l.x[ifit].push_back(r);
l.xFit[i].push_back(layout.x[ifit].size() - 1);
}
else
{
l.xFit[i].push_back(-1);
}
}
ifit++;
}
else
{
l.xFitDim.push_back(-1);
l.xFit.push_back(vector<Index>());
for (Index r = 0; r < getXSize(i); ++r)
{
l.xFit[i].push_back(-1);
}
}
}
for (Index j = 0; j < getNYDim(); ++j)
{
Index s = 0;
l.nYFitDim++;
size = getYFitSize(j);
l.ySize.push_back(size);
l.totalYSize += size;
l.yDim.push_back(j);
l.yFitDim.push_back(layout.yDim.size() - 1);
l.y.push_back(vector<Index>());
l.yFit.push_back(vector<Index>());
l.data.push_back(vector<Index>());
l.yFitFromData.push_back(map<Index, Index>());
for (auto &p: yDataIndex_[j])
{
if (p.second)
{
l.dataIndexSet.insert(p.first);
l.y[j].push_back(s);
l.yFit[j].push_back(layout.y[j].size() - 1);
l.data[j].push_back(p.first);
l.yFitFromData[j][p.first] = layout.y[j].size() - 1;
}
else
{
l.yFit[j].push_back(-1);
l.yFitFromData[j][p.first] = -1;
}
s++;
}
}
l.totalSize = layout.totalXSize + layout.totalYSize;
l.nXFitDim = static_cast<Index>(layout.xSize.size());
l.nYFitDim = static_cast<Index>(layout.ySize.size());
l.xIndFromData.resize(getMaxDataIndex());
for (Index k: layout.dataIndexSet)
{
v = dataCoord(k);
for (Index i = 0; i < getNXDim(); ++i)
{
l.xIndFromData[k].push_back(indX(v[i], i));
}
}
modThis->initLayout_ = false;
}
}
Index FitInterface::indX(const Index r, const Index i) const
{
Index ind = -1;
if (layout.xFit[i][r] != -1)
{
Index ifit = layout.xFitDim[i], rfit = layout.xFit[i][r];
ind = layout.totalYSize;
for (Index a = 0; a < ifit; ++a)
{
ind += layout.xSize[a];
}
ind += rfit;
}
return ind;
}
Index FitInterface::indY(const Index k, const Index j) const
{
Index ind = -1;
if (layout.yFitFromData[j].at(k) != -1)
{
Index jfit = layout.yFitDim[j], sfit = layout.yFitFromData[j].at(k);
ind = 0;
for (Index b = 0; b < jfit; ++b)
{
ind += layout.ySize[b];
}
ind += sfit;
}
return ind;
}
// function to convert an row-major index into coordinates /////////////////////
vector<Index> FitInterface::rowMajToCoord(const Index k) const
{
vector<Index> v(getNXDim());
Index buf, dimProd;
checkDataIndex(k);
buf = k;
dimProd = 1;
for (Index d = getNXDim() - 1; d >= 0; --d)
{
v[d] = (buf/dimProd)%xSize_[d];
buf -= dimProd*v[d];
dimProd *= xSize_[d];
}
return v;
}
// IO //////////////////////////////////////////////////////////////////////////
ostream & Latan::operator<<(ostream &out, FitInterface &f)
{
out << "X dimensions: " << f.getNXDim() << endl;
for (Index i = 0; i < f.getNXDim(); ++i)
{
out << " * " << i << " \"" << f.xName().getName(i) << "\": ";
out << f.getXSize(i) << " value(s)";
if (f.isXExact(i))
{
out << " (assumed exact)";
}
out << endl;
}
out << "Y dimensions: " << f.getNYDim() << endl;
for (Index j = 0; j < f.getNYDim(); ++j)
{
out << " * " << j << " \"" << f.yName().getName(j) << "\": ";
out << f.getYSize(j) << " value(s)" << endl;
for (auto &p: f.yDataIndex_[j])
{
out << " " << setw(3) << p.first << " (";
for (auto vi: f.dataCoord(p.first))
{
out << vi << ",";
}
out << "\b) fit: " << (p.second ? "true" : "false") << endl;
}
}
out << "X/X correlations (r1 i1 r2 i2): ";
if (f.xxCorr_.empty())
{
out << "no" << endl;
}
else
{
out << endl;
for (auto &c: f.xxCorr_)
{
out << " * ";
for (auto i: c)
{
out << i << " ";
}
out << endl;
}
}
out << "Y/Y correlations (k1 j1 k2 j2): ";
if (f.yyCorr_.empty())
{
out << "no" << endl;
}
else
{
out << endl;
for (auto &c: f.yyCorr_)
{
out << " * ";
for (auto i: c)
{
out << i << " ";
}
out << endl;
}
}
out << "X/Y correlations (r i k j): ";
if (f.xyCorr_.empty())
{
out << "no";
}
else
{
out << endl;
for (auto &c: f.xyCorr_)
{
out << " * ";
for (auto i: c)
{
out << i << " ";
}
out << endl;
}
}
return out;
}

View File

@ -0,0 +1,228 @@
/*
* FitInterface.hpp, part of LatAnalyze 3
*
* Copyright (C) 2013 - 2020 Antonin Portelli
*
* LatAnalyze 3 is free software: you can redistribute it and/or modify
* it under the terms of the GNU General Public License as published by
* the Free Software Foundation, either version 3 of the License, or
* (at your option) any later version.
*
* LatAnalyze 3 is distributed in the hope that it will be useful,
* but WITHOUT ANY WARRANTY; without even the implied warranty of
* MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
* GNU General Public License for more details.
*
* You should have received a copy of the GNU General Public License
* along with LatAnalyze 3. If not, see <http://www.gnu.org/licenses/>.
*/
#ifndef Latan_FitInterface_hpp_
#define Latan_FitInterface_hpp_
#include <LatAnalyze/Global.hpp>
#include <LatAnalyze/Core/Mat.hpp>
BEGIN_LATAN_NAMESPACE
/******************************************************************************
* FitInterface *
******************************************************************************/
class FitInterface
{
private:
typedef struct
{
Index nXFitDim, nYFitDim;
// X/Y block sizes
Index totalSize, totalXSize, totalYSize;
// size of each X/Y dimension
std::vector<Index> xSize, ySize;
// set of active data indices
std::set<Index> dataIndexSet;
// lookup tables
// xDim : x fit dim ifit -> x dim i
// x : x fit point ifit,rfit -> x point r
// xFitDim : x dim i -> x fit dim ifit (-1 if empty)
// xFit : x point i,r -> x fit point rfit (-1 if empty)
// data : y fit point jfit,sfit -> y point index k
// yFitFromData: y point index k,j -> y fit point sfit (-1 if empty)
// xIndFromData: data index k -> index of coordinates of associated x
std::vector<Index> xDim, yDim, xFitDim, yFitDim;
std::vector<std::vector<Index>> x, y, data, xFit, yFit;
std::vector<std::map<Index, Index>> yFitFromData;
// no map here for fit performance
std::vector<std::vector<Index>> xIndFromData;
} Layout;
public:
// constructor
FitInterface(void);
// destructor
virtual ~FitInterface(void) = default;
// copy object (not as a constructor to be accessed from derived class)
void copyInterface(const FitInterface &d);
// add dimensions
void addXDim(const Index nData, const std::string name = "",
const bool isExact = false);
void addYDim(const std::string name = "");
// access
Index getNXDim(void) const;
Index getNYDim(void) const;
Index getXSize(void) const;
Index getXSize(const Index i) const;
Index getYSize(void) const;
Index getYSize(const Index j) const;
Index getXFitSize(void) const;
Index getXFitSize(const Index i) const;
Index getYFitSize(void) const;
Index getYFitSize(const Index j) const;
Index getMaxDataIndex(void) const;
const std::set<Index> & getDataIndexSet(void) const;
double getSvdTolerance(void) const;
void setSvdTolerance(const double &tol);
VarName & xName(void);
const VarName & xName(void) const;
VarName & yName(void);
const VarName & yName(void) const;
// Y dimension index helper
template <typename... Ts>
Index dataIndex(const Ts... is) const;
Index dataIndex(const std::vector<Index> &v) const;
const std::vector<Index> & dataCoord(const Index k) const;
// enable fit points
void fitPoint(const bool isFitPoint, const Index k, const Index j = 0);
// variance interface
void assumeXExact(const bool isExact, const Index i);
void assumeXXCorrelated(const bool isCorr, const Index r1, const Index i1,
const Index r2, const Index i2);
void assumeXXCorrelated(const bool isCorr, const Index i1, const Index i2);
void assumeYYCorrelated(const bool isCorr, const Index k1, const Index j1,
const Index k2, const Index j2);
void assumeYYCorrelated(const bool isCorr, const Index j1, const Index j2);
void assumeXYCorrelated(const bool isCorr, const Index r, const Index i,
const Index k, const Index j);
void assumeXYCorrelated(const bool isCorr, const Index i, const Index j);
// tests
bool pointExists(const Index k) const;
bool pointExists(const Index k, const Index j) const;
bool isXExact(const Index i) const;
bool isXUsed(const Index r, const Index i, const bool inFit = true) const;
bool isFitPoint(const Index k, const Index j) const;
bool isXXCorrelated(const Index r1, const Index i1, const Index r2,
const Index i2) const;
bool isYYCorrelated(const Index k1, const Index j1, const Index k2,
const Index j2) const;
bool isXYCorrelated(const Index r, const Index i, const Index k,
const Index j) const;
bool hasCorrelations(void) const;
// make correlation filter for fit variance matrix
DMat makeCorrFilter(void);
// schedule variance matrix initialization
void scheduleFitVarMatInit(const bool init = true);
// IO
friend std::ostream & operator<<(std::ostream &out, FitInterface &f);
protected:
// register a data point
void registerDataPoint(const Index k, const Index j = 0);
// add correlation to a set
static void addCorr(std::set<std::array<Index, 4>> &s, const bool isCorr,
const std::array<Index, 4> &c);
// abstract methods to create data containers
virtual void createXData(const std::string name, const Index nData) = 0;
virtual void createYData(const std::string name) = 0;
// coordinate buffering
void scheduleDataCoordInit(void);
void updateDataCoord(void) const;
// global layout management
void scheduleLayoutInit(void);
bool initVarMat(void) const;
void updateLayout(void) const;
Index indX(const Index r, const Index i) const;
Index indY(const Index k, const Index j) const;
private:
// function to convert an row-major index into coordinates
std::vector<Index> rowMajToCoord(const Index k) const;
protected:
Layout layout;
private:
VarName xName_, yName_;
std::vector<Index> xSize_;
std::vector<bool> xIsExact_;
std::map<Index, std::vector<Index>> dataCoord_;
std::set<Index> dataIndexSet_;
std::vector<std::map<Index, bool>> yDataIndex_;
std::set<std::array<Index, 4>> xxCorr_, yyCorr_, xyCorr_;
Index maxDataIndex_{1};
bool initLayout_{true};
bool initVarMat_{true};
bool initDataCoord_{true};
double svdTol_{1.e-10};
};
std::ostream & operator<<(std::ostream &out, FitInterface &f);
/******************************************************************************
* FitInterface template implementation *
******************************************************************************/
// Y dimension index helper ////////////////////////////////////////////////////
template <typename... Ts>
Index FitInterface::dataIndex(const Ts... coords) const
{
static_assert(static_or<std::is_convertible<Index, Ts>::value...>::value,
"fitPoint arguments are not compatible with Index");
const std::vector<Index> coord = {coords...};
return dataIndex(coord);
}
/******************************************************************************
* error check macros *
******************************************************************************/
#define checkXDim(i)\
if ((i) >= getNXDim())\
{\
LATAN_ERROR(Range, "X dimension " + strFrom(i) + " out of range");\
}
#define checkXIndex(vi, i)\
if ((vi) >= getXSize(i))\
{\
LATAN_ERROR(Range, "index " + strFrom(vi) + " in X dimension "\
+ strFrom(i) + " out of range");\
}
#define checkYDim(j)\
if ((j) >= getNYDim())\
{\
LATAN_ERROR(Range, "Y dimension " + strFrom(j) + " out of range");\
}
#define checkDataIndex(k)\
if ((k) >= getMaxDataIndex())\
{\
LATAN_ERROR(Range, "data point index " + strFrom(k) + " invalid");\
}
#define checkDataCoord(v)\
if (static_cast<Index>((v).size()) != getNXDim())\
{\
LATAN_ERROR(Size, "number of coordinates and number of X dimensions "\
"mismatch");\
}\
for (unsigned int i_ = 0; i_ < (v).size(); ++i_)\
{\
checkXIndex((v)[i_], i_);\
}
#define checkPoint(k, j)\
if (!pointExists(k, j))\
{\
LATAN_ERROR(Range, "no data point in Y dimension " + strFrom(j)\
+ " with index " + strFrom(k));\
}
END_LATAN_NAMESPACE
#endif // Latan_FitInterface_hpp_

View File

@ -0,0 +1,228 @@
/*
* Histogram.cpp, part of LatAnalyze 3
*
* Copyright (C) 2013 - 2020 Antonin Portelli
*
* LatAnalyze 3 is free software: you can redistribute it and/or modify
* it under the terms of the GNU General Public License as published by
* the Free Software Foundation, either version 3 of the License, or
* (at your option) any later version.
*
* LatAnalyze 3 is distributed in the hope that it will be useful,
* but WITHOUT ANY WARRANTY; without even the implied warranty of
* MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
* GNU General Public License for more details.
*
* You should have received a copy of the GNU General Public License
* along with LatAnalyze 3. If not, see <http://www.gnu.org/licenses/>.
*/
#include <LatAnalyze/Statistics/Histogram.hpp>
#include <LatAnalyze/includes.hpp>
#include <gsl/gsl_histogram.h>
#include <gsl/gsl_sf.h>
#include <gsl/gsl_sort.h>
using namespace std;
using namespace Latan;
#define DECL_GSL_HIST(h) \
gsl_histogram h{static_cast<size_t>(bin_.size()), x_.data(), bin_.data()}
#define DECL_CONST_GSL_HIST(h) \
const gsl_histogram h{static_cast<size_t>(bin_.size()),\
const_cast<double *>(x_.data()),\
const_cast<double *>(bin_.data())}
/******************************************************************************
* Histogram implementation *
******************************************************************************/
// constructor /////////////////////////////////////////////////////////////////
Histogram::Histogram(const DVec &data, const double xMin, const double xMax,
const Index nBin)
: Histogram()
{
setFromData(data, xMin, xMax, nBin);
}
Histogram::Histogram(const DVec &data, const DVec &w, const double xMin,
const double xMax, const Index nBin)
: Histogram()
{
setFromData(data, w, xMin, xMax, nBin);
}
// resize //////////////////////////////////////////////////////////////////////
void Histogram::resize(const Index nBin)
{
x_.resize(nBin + 1);
bin_.resize(nBin);
}
// generate from data //////////////////////////////////////////////////////////
void Histogram::setFromData(const DVec &data, const DVec &w, const double xMin,
const double xMax, const Index nBin)
{
if (data.size() != w.size())
{
LATAN_ERROR(Size, "data vector and weight vector size mismatch");
}
resize(nBin);
data_ = data.array();
w_ = w.array();
xMax_ = xMax;
xMin_ = xMin;
makeHistogram();
}
void Histogram::setFromData(const DVec &data, const double xMin,
const double xMax, const Index nBin)
{
resize(nBin);
data_ = data.array();
xMax_ = xMax;
xMin_ = xMin;
w_.resize(data.size());
w_.fill(1.);
makeHistogram();
}
// histogram calculation ///////////////////////////////////////////////////////
void Histogram::makeHistogram(void)
{
DECL_GSL_HIST(h);
gsl_histogram_set_ranges_uniform(&h, xMin_, xMax_);
FOR_STAT_ARRAY(data_, i)
{
gsl_histogram_accumulate(&h, data_[i], w_[i]);
}
total_ = w_.sum();
sortIndices();
computeNorm();
}
// generate sorted indices /////////////////////////////////////////////////////
void Histogram::sortIndices(void)
{
sInd_.resize(data_.size());
gsl_sort_index(sInd_.data(), data_.data(), 1, data_.size());
}
// compute normalization factor ////////////////////////////////////////////////
void Histogram::computeNorm(void)
{
norm_ = static_cast<double>(bin_.size())/(total_*(xMax_ - xMin_));
}
// normalize as a probablility /////////////////////////////////////////////////
void Histogram::normalize(const bool n)
{
normalize_ = n;
}
bool Histogram::isNormalized(void) const
{
return normalize_;
}
// access //////////////////////////////////////////////////////////////////////
Index Histogram::size(void) const
{
return bin_.size();
}
const StatArray<double> & Histogram::getData(void) const
{
return data_;
}
const StatArray<double> & Histogram::getWeight(void) const
{
return w_;
}
double Histogram::getX(const Index i) const
{
return x_(i);
}
double Histogram::operator[](const Index i) const
{
return bin_(i)*(isNormalized() ? norm_ : 1.);
}
double Histogram::operator()(const double x) const
{
size_t i;
DECL_CONST_GSL_HIST(h);
gsl_histogram_find(&h, x, &i);
return (*this)[static_cast<Index>(i)];
}
// percentiles & confidence interval ///////////////////////////////////////////
double Histogram::percentile(const double p) const
{
if ((p < 0.0) or (p > 100.0))
{
LATAN_ERROR(Range, "percentile (" + strFrom(p) + ")"
" is outside the [0, 100] range");
}
// cf. http://en.wikipedia.org/wiki/Percentile
double wPSum, p_i, p_im1, w_i, res = 0.;
bool haveResult;
wPSum = w_[sInd_[0]];
p_i = (100./total_)*wPSum*0.5;
if (p < p_i)
{
res = data_[sInd_[0]];
}
else
{
haveResult = false;
p_im1 = p_i;
for (Index i = 1; i < data_.size(); ++i)
{
w_i = w_[sInd_[i]];
wPSum += w_i;
p_i = (100./total_)*(wPSum-0.5*w_i);
if ((p >= p_im1) and (p < p_i))
{
double d_i = data_[sInd_[i]], d_im1 = data_[sInd_[i-1]];
res = d_im1 + (p-p_im1)/(p_i-p_im1)*(d_i-d_im1);
haveResult = true;
break;
}
}
if (!haveResult)
{
res = data_[sInd_[data_.size()-1]];
}
}
return res;
}
double Histogram::median(void) const
{
return percentile(50.);
}
pair<double, double> Histogram::confidenceInterval(const double nSigma) const
{
pair<double, double> interval, p;
double cl;
cl = gsl_sf_erf(nSigma/sqrt(2.));
p.first = 50.*(1. - cl);
p.second = 50.*(1. + cl);
interval.first = percentile(p.first);
interval.second = percentile(p.second);
return interval;
}

View File

@ -0,0 +1,80 @@
/*
* Histogram.hpp, part of LatAnalyze 3
*
* Copyright (C) 2013 - 2020 Antonin Portelli
*
* LatAnalyze 3 is free software: you can redistribute it and/or modify
* it under the terms of the GNU General Public License as published by
* the Free Software Foundation, either version 3 of the License, or
* (at your option) any later version.
*
* LatAnalyze 3 is distributed in the hope that it will be useful,
* but WITHOUT ANY WARRANTY; without even the implied warranty of
* MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
* GNU General Public License for more details.
*
* You should have received a copy of the GNU General Public License
* along with LatAnalyze 3. If not, see <http://www.gnu.org/licenses/>.
*/
#ifndef Latan_Histogram_hpp_
#define Latan_Histogram_hpp_
#include <LatAnalyze/Global.hpp>
#include <LatAnalyze/Statistics/StatArray.hpp>
BEGIN_LATAN_NAMESPACE
/******************************************************************************
* Histogram class *
******************************************************************************/
class Histogram
{
public:
// constructor
Histogram(void) = default;
Histogram(const DVec &data, const double xMin, const double xMax,
const Index nBin);
Histogram(const DVec &data, const DVec &w, const double xMin,
const double xMax, const Index nBin);
// destructor
virtual ~Histogram(void) = default;
// generate from data
void setFromData(const DVec &data, const double xMin, const double xMax,
const Index nBin);
void setFromData(const DVec &data, const DVec &w, const double xMin,
const double xMax, const Index nBin);
// normalize as a probablility
void normalize(const bool n = true);
bool isNormalized(void) const;
// access
Index size(void) const;
const StatArray<double> & getData(void) const;
const StatArray<double> & getWeight(void) const;
double getX(const Index i) const;
double operator[](const Index i) const;
double operator()(const double x) const;
// percentiles & confidence interval
double percentile(const double p) const;
double median(void) const;
std::pair<double, double> confidenceInterval(const double nSigma) const;
private:
// resize
void resize(const Index nBin);
// histogram calculation
void makeHistogram(void);
// generate sorted indices
void sortIndices(void);
// compute normalization factor
void computeNorm(void);
private:
StatArray<double> data_, w_;
DVec x_, bin_;
Vec<size_t> sInd_;
double total_, norm_, xMax_, xMin_;
bool normalize_{false};
};
END_LATAN_NAMESPACE
#endif // Latan_Histogram_hpp_

View File

@ -0,0 +1,384 @@
/*
* MatSample.hpp, part of LatAnalyze 3
*
* Copyright (C) 2013 - 2020 Antonin Portelli
*
* LatAnalyze 3 is free software: you can redistribute it and/or modify
* it under the terms of the GNU General Public License as published by
* the Free Software Foundation, either version 3 of the License, or
* (at your option) any later version.
*
* LatAnalyze 3 is distributed in the hope that it will be useful,
* but WITHOUT ANY WARRANTY; without even the implied warranty of
* MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
* GNU General Public License for more details.
*
* You should have received a copy of the GNU General Public License
* along with LatAnalyze 3. If not, see <http://www.gnu.org/licenses/>.
*/
#ifndef Latan_MatSample_hpp_
#define Latan_MatSample_hpp_
#include <LatAnalyze/Global.hpp>
#include <LatAnalyze/Core/Mat.hpp>
#include <LatAnalyze/Statistics/StatArray.hpp>
BEGIN_LATAN_NAMESPACE
/******************************************************************************
* matrix sample class *
******************************************************************************/
#define SCAL_OP_RETURN(op, s, x) s.unaryExpr(\
std::bind(MatSample<T>::scalar##op,\
std::placeholders::_1, x))
template <typename T>
class MatSample: public Sample<Mat<T>>
{
public:
// block type template
template <class S>
class BlockTemplate
{
private:
typedef typename std::remove_const<S>::type NonConstType;
public:
// constructors
BlockTemplate(S &sample, const Index i, const Index j, const Index nRow,
const Index nCol);
BlockTemplate(BlockTemplate<NonConstType> &b);
BlockTemplate(BlockTemplate<NonConstType> &&b);
// destructor
~BlockTemplate(void) = default;
// access
S & getSample(void);
const S & getSample(void) const;
Index getStartRow(void) const;
Index getStartCol(void) const;
Index getNRow(void) const;
Index getNCol(void) const;
// assignement operators
BlockTemplate<S> & operator=(const S &sample);
BlockTemplate<S> & operator=(const S &&sample);
private:
S &sample_;
const Index i_, j_, nRow_, nCol_;
};
// block types
typedef BlockTemplate<Sample<Mat<T>>> Block;
typedef const BlockTemplate<const Sample<Mat<T>>> ConstBlock;
public:
// constructors
MatSample(void) = default;
MatSample(const Index nSample);
MatSample(const Index nSample, const Index nRow, const Index nCol);
MatSample(ConstBlock &sampleBlock);
MatSample(ConstBlock &&sampleBlock);
EIGEN_EXPR_CTOR(MatSample, MatSample<T>, Sample<Mat<T>>, ArrayExpr)
// destructor
virtual ~MatSample(void) = default;
// assignement operator
MatSample<T> & operator=(Block &sampleBlock);
MatSample<T> & operator=(Block &&sampleBlock);
MatSample<T> & operator=(ConstBlock &sampleBlock);
MatSample<T> & operator=(ConstBlock &&sampleBlock);
// product/division by scalar operators (not provided by Eigen)
static inline Mat<T> scalarMul(const Mat<T> &m, const T &x)
{
return m*x;
}
static inline Mat<T> scalarDiv(const Mat<T> &m, const T &x)
{
return m/x;
}
MatSample<T> & operator*=(const T &x);
MatSample<T> & operator*=(const T &&x);
MatSample<T> & operator/=(const T &x);
MatSample<T> & operator/=(const T &&x);
// block access
ConstBlock block(const Index i, const Index j, const Index nRow,
const Index nCol) const;
Block block(const Index i, const Index j, const Index nRow,
const Index nCol);
// resize all matrices
void resizeMat(const Index nRow, const Index nCol);
};
// non-member operators
template <typename T>
inline auto operator*(MatSample<T> s, const T &x)
->decltype(SCAL_OP_RETURN(Mul, s, x))
{
return SCAL_OP_RETURN(Mul, s, x);
}
template <typename T>
inline auto operator*(MatSample<T> s, const T &&x)
->decltype(SCAL_OP_RETURN(Mul, s, x))
{
return SCAL_OP_RETURN(Mul, s, x);
}
template <typename T>
inline auto operator*(const T &x, MatSample<T> s)->decltype(s*x)
{
return s*x;
}
template <typename T>
inline auto operator*(const T &&x, MatSample<T> s)->decltype(s*x)
{
return s*x;
}
template <typename T>
inline auto operator/(MatSample<T> s, const T &x)
->decltype(SCAL_OP_RETURN(Div, s, x))
{
return SCAL_OP_RETURN(Div, s, x);
}
template <typename T>
inline auto operator/(MatSample<T> s, const T &&x)
->decltype(SCAL_OP_RETURN(Div, s, x))
{
return SCAL_OP_RETURN(Div, s, x);
}
// type aliases
typedef MatSample<double> DMatSample;
typedef MatSample<std::complex<double>> CMatSample;
/******************************************************************************
* Block template implementation *
******************************************************************************/
// constructors ////////////////////////////////////////////////////////////////
template <typename T>
template <class S>
MatSample<T>::BlockTemplate<S>::BlockTemplate(S &sample, const Index i,
const Index j, const Index nRow,
const Index nCol)
: sample_(sample)
, i_(i)
, j_(j)
, nRow_(nRow)
, nCol_(nCol)
{}
template <typename T>
template <class S>
MatSample<T>::BlockTemplate<S>::BlockTemplate(BlockTemplate<NonConstType> &b)
: sample_(b.getSample())
, i_(b.getStartRow())
, j_(b.getStartCol())
, nRow_(b.getNRow())
, nCol_(b.getNCol())
{}
template <typename T>
template <class S>
MatSample<T>::BlockTemplate<S>::BlockTemplate(BlockTemplate<NonConstType> &&b)
: BlockTemplate(b)
{}
// access //////////////////////////////////////////////////////////////////////
template <typename T>
template <class S>
S & MatSample<T>::BlockTemplate<S>::getSample(void)
{
return sample_;
}
template <typename T>
template <class S>
const S & MatSample<T>::BlockTemplate<S>::getSample(void) const
{
return sample_;
}
template <typename T>
template <class S>
Index MatSample<T>::BlockTemplate<S>::getStartRow(void) const
{
return i_;
}
template <typename T>
template <class S>
Index MatSample<T>::BlockTemplate<S>::getStartCol(void) const
{
return j_;
}
template <typename T>
template <class S>
Index MatSample<T>::BlockTemplate<S>::getNRow(void) const
{
return nRow_;
}
template <typename T>
template <class S>
Index MatSample<T>::BlockTemplate<S>::getNCol(void) const
{
return nCol_;
}
// assignement operators ///////////////////////////////////////////////////////
template <typename T>
template <class S>
typename MatSample<T>::template BlockTemplate<S> &
MatSample<T>::BlockTemplate<S>::operator=(const S &sample)
{
FOR_STAT_ARRAY(sample_, s)
{
sample_[s].block(i_, j_, nRow_, nCol_) = sample[s];
}
return *this;
}
template <typename T>
template <class S>
typename MatSample<T>::template BlockTemplate<S> &
MatSample<T>::BlockTemplate<S>::operator=(const S &&sample)
{
*this = sample;
return *this;
}
/******************************************************************************
* DMatSample implementation *
******************************************************************************/
// constructors ////////////////////////////////////////////////////////////////
template <typename T>
MatSample<T>::MatSample(const Index nSample)
: Sample<Mat<T>>(nSample)
{}
template <typename T>
MatSample<T>::MatSample(const Index nSample, const Index nRow,
const Index nCol)
: MatSample(nSample)
{
resizeMat(nRow, nCol);
}
template <typename T>
MatSample<T>::MatSample(ConstBlock &sampleBlock)
: MatSample(sampleBlock.getSample().size(), sampleBlock.getNRow(),
sampleBlock.getNCol())
{
const MatSample<T> &sample = sampleBlock.getSample();
this->resize(sample.size());
FOR_STAT_ARRAY(*this, s)
{
(*this)[s] = sample[s].block(sampleBlock.getStartRow(),
sampleBlock.getStartCol(),
sampleBlock.getNRow(),
sampleBlock.getNCol());
}
}
template <typename T>
MatSample<T>::MatSample(ConstBlock &&sampleBlock)
: MatSample(sampleBlock)
{}
// assignement operator ////////////////////////////////////////////////////////
template <typename T>
MatSample<T> & MatSample<T>::operator=(Block &sampleBlock)
{
MatSample<T> tmp(sampleBlock);
this->swap(tmp);
return *this;
}
template <typename T>
MatSample<T> & MatSample<T>::operator=(Block &&sampleBlock)
{
*this = sampleBlock;
return *this;
}
template <typename T>
MatSample<T> & MatSample<T>::operator=(ConstBlock &sampleBlock)
{
MatSample<T> tmp(sampleBlock);
this->swap(tmp);
return *this;
}
template <typename T>
MatSample<T> & MatSample<T>::operator=(ConstBlock &&sampleBlock)
{
*this = sampleBlock;
return *this;
}
// product/division by scalar operators (not provided by Eigen) ////////////////
template <typename T>
MatSample<T> & MatSample<T>::operator*=(const T &x)
{
return *this = (*this)*x;
}
template <typename T>
MatSample<T> & MatSample<T>::operator*=(const T &&x)
{
return *this = (*this)*x;
}
template <typename T>
MatSample<T> & MatSample<T>::operator/=(const T &x)
{
return *this = (*this)/x;
}
template <typename T>
MatSample<T> & MatSample<T>::operator/=(const T &&x)
{
return *this = (*this)/x;
}
// block access ////////////////////////////////////////////////////////////////
template <typename T>
typename MatSample<T>::ConstBlock MatSample<T>::block(const Index i,
const Index j,
const Index nRow,
const Index nCol) const
{
return ConstBlock(*this, i, j, nRow, nCol);
}
template <typename T>
typename MatSample<T>::Block MatSample<T>::block(const Index i,
const Index j,
const Index nRow,
const Index nCol)
{
return Block(*this, i, j, nRow, nCol);
}
// resize all matrices /////////////////////////////////////////////////////////
template <typename T>
void MatSample<T>::resizeMat(const Index nRow, const Index nCol)
{
FOR_STAT_ARRAY(*this, s)
{
(*this)[s].resize(nRow, nCol);
}
}
END_LATAN_NAMESPACE
#endif // Latan_MatSample_hpp_

View File

@ -0,0 +1,56 @@
/*
* Random.cpp, part of LatAnalyze 3
*
* Copyright (C) 2013 - 2020 Antonin Portelli
*
* LatAnalyze 3 is free software: you can redistribute it and/or modify
* it under the terms of the GNU General Public License as published by
* the Free Software Foundation, either version 3 of the License, or
* (at your option) any later version.
*
* LatAnalyze 3 is distributed in the hope that it will be useful,
* but WITHOUT ANY WARRANTY; without even the implied warranty of
* MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
* GNU General Public License for more details.
*
* You should have received a copy of the GNU General Public License
* along with LatAnalyze 3. If not, see <http://www.gnu.org/licenses/>.
*/
#include <LatAnalyze/Core/Plot.hpp>
#include <LatAnalyze/includes.hpp>
#include <LatAnalyze/Statistics/Random.hpp>
using namespace std;
using namespace Latan;
RandomNormal::RandomNormal(const DVec &mean, const DMat &var, const SeedType seed)
: mean_(mean), buf_(mean.size()), var_(var), gen_(seed)
{
if (var_.rows() != var_.cols())
{
LATAN_ERROR(Size, "variance matrix not square");
}
if (mean_.size() != var_.rows())
{
LATAN_ERROR(Size, "variance matrix and mean vector size mismatch");
}
Eigen::SelfAdjointEigenSolver<Eigen::MatrixXd> esolver(var);
Eigen::VectorXd ev = esolver.eigenvalues();
ev = ev.unaryExpr([](const double x){return (x > 0.) ? x : 0.;});
transform_ = esolver.eigenvectors()*ev.cwiseSqrt().asDiagonal();
}
DVec RandomNormal::operator()(void)
{
std::normal_distribution<> dist;
FOR_VEC(buf_, i)
{
buf_(i) = dist(gen_);
}
return mean_ + transform_*buf_;
}

View File

@ -0,0 +1,49 @@
/*
* Random.hpp, part of LatAnalyze 3
*
* Copyright (C) 2013 - 2020 Antonin Portelli
*
* LatAnalyze 3 is free software: you can redistribute it and/or modify
* it under the terms of the GNU General Public License as published by
* the Free Software Foundation, either version 3 of the License, or
* (at your option) any later version.
*
* LatAnalyze 3 is distributed in the hope that it will be useful,
* but WITHOUT ANY WARRANTY; without even the implied warranty of
* MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
* GNU General Public License for more details.
*
* You should have received a copy of the GNU General Public License
* along with LatAnalyze 3. If not, see <http://www.gnu.org/licenses/>.
*/
#ifndef Latan_Random_hpp_
#define Latan_Random_hpp_
#include <LatAnalyze/Global.hpp>
#include <LatAnalyze/Core/Mat.hpp>
BEGIN_LATAN_NAMESPACE
/******************************************************************************
* Multivariate Gaussian RNG *
******************************************************************************/
class RandomNormal
{
public:
// constructors
RandomNormal(const DVec &mean, const DMat &var, const SeedType seed);
// destructor
virtual ~RandomNormal(void) = default;
// draw a random vector
DVec operator()(void);
private:
DVec mean_, buf_;
DMat var_;
Eigen::MatrixXd transform_;
std::mt19937 gen_;
};
END_LATAN_NAMESPACE
#endif // Latan_Random_hpp_

View File

@ -0,0 +1,32 @@
/*
* StatArray.cpp, part of LatAnalyze 3
*
* Copyright (C) 2013 - 2020 Antonin Portelli
*
* LatAnalyze 3 is free software: you can redistribute it and/or modify
* it under the terms of the GNU General Public License as published by
* the Free Software Foundation, either version 3 of the License, or
* (at your option) any later version.
*
* LatAnalyze 3 is distributed in the hope that it will be useful,
* but WITHOUT ANY WARRANTY; without even the implied warranty of
* MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
* GNU General Public License for more details.
*
* You should have received a copy of the GNU General Public License
* along with LatAnalyze 3. If not, see <http://www.gnu.org/licenses/>.
*/
#include <LatAnalyze/Statistics/StatArray.hpp>
#include <LatAnalyze/includes.hpp>
using namespace std;
namespace Latan
{
template <>
IoObject::IoType StatArray<Mat<double>, -1>::getType(void) const
{
return IoType::dMatSample;
}
}

View File

@ -0,0 +1,296 @@
/*
* StatArray.hpp, part of LatAnalyze 3
*
* Copyright (C) 2013 - 2020 Antonin Portelli
*
* LatAnalyze 3 is free software: you can redistribute it and/or modify
* it under the terms of the GNU General Public License as published by
* the Free Software Foundation, either version 3 of the License, or
* (at your option) any later version.
*
* LatAnalyze 3 is distributed in the hope that it will be useful,
* but WITHOUT ANY WARRANTY; without even the implied warranty of
* MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
* GNU General Public License for more details.
*
* You should have received a copy of the GNU General Public License
* along with LatAnalyze 3. If not, see <http://www.gnu.org/licenses/>.
*/
#ifndef Latan_StatArray_hpp_
#define Latan_StatArray_hpp_
#include <LatAnalyze/Global.hpp>
#include <LatAnalyze/Core/Mat.hpp>
#define FOR_STAT_ARRAY(ar, i) \
for (Latan::Index i = -(ar).offset; i < (ar).size(); ++i)
BEGIN_LATAN_NAMESPACE
/******************************************************************************
* Array class with statistics *
******************************************************************************/
template <typename T, Index os = 0>
class StatArray: public Array<T, dynamic, 1>, public IoObject
{
protected:
typedef Array<T, dynamic, 1> Base;
public:
// constructors
StatArray(void);
explicit StatArray(const Index size);
EIGEN_EXPR_CTOR(StatArray, unique_arg(StatArray<T, os>), Base, ArrayExpr)
// destructor
virtual ~StatArray(void) = default;
// access
Index size(void) const;
void resize(const Index size);
// operators
T & operator[](const Index s);
const T & operator[](const Index s) const;
// statistics
void bin(Index binSize);
T sum(const Index pos = 0, const Index n = -1) const;
T mean(const Index pos = 0, const Index n = -1) const;
T covariance(const StatArray<T, os> &array) const;
T variance(void) const;
T covarianceMatrix(const StatArray<T, os> &data) const;
T varianceMatrix(void) const;
T correlationMatrix(void) const;
// IO type
virtual IoType getType(void) const;
public:
static constexpr Index offset = os;
};
// reduction operations
namespace StatOp
{
// general templates
template <typename T>
inline T prod(const T &a, const T &b);
template <typename T>
inline T tensProd(const T &v1, const T &v2);
template <typename T>
inline T sum(const T &a, const T &b);
}
// Sample types
const int central = -1;
template <typename T>
using Sample = StatArray<T, 1>;
typedef Sample<double> DSample;
typedef Sample<std::complex<double>> CSample;
/******************************************************************************
* StatArray class template implementation *
******************************************************************************/
// constructors ////////////////////////////////////////////////////////////////
template <typename T, Index os>
StatArray<T, os>::StatArray(void)
: Base(static_cast<typename Base::Index>(os))
{}
template <typename T, Index os>
StatArray<T, os>::StatArray(const Index size)
: Base(static_cast<typename Base::Index>(size + os))
{}
// access //////////////////////////////////////////////////////////////////////
template <typename T, Index os>
Index StatArray<T, os>::size(void) const
{
return Base::size() - os;
}
template <typename T, Index os>
void StatArray<T, os>::resize(const Index size)
{
Base::resize(size + os);
}
// operators ///////////////////////////////////////////////////////////////////
template <typename T, Index os>
T & StatArray<T, os>::operator[](const Index s)
{
return Base::operator[](s + os);
}
template <typename T, Index os>
const T & StatArray<T, os>::operator[](const Index s) const
{
return Base::operator[](s + os);
}
// statistics //////////////////////////////////////////////////////////////////
template <typename T, Index os>
void StatArray<T, os>::bin(Index binSize)
{
Index q = size()/binSize, r = size()%binSize;
for (Index i = 0; i < q; ++i)
{
(*this)[i] = mean(i*binSize, binSize);
}
if (r != 0)
{
(*this)[q] = mean(q*binSize, r);
this->conservativeResize(os + q + 1);
}
else
{
this->conservativeResize(os + q);
}
}
template <typename T, Index os>
T StatArray<T, os>::sum(const Index pos, const Index n) const
{
T result;
const Index m = (n >= 0) ? n : size();
result = (*this)[pos];
for (Index i = pos + 1; i < pos + m; ++i)
{
result += (*this)[i];
}
return result;
}
template <typename T, Index os>
T StatArray<T, os>::mean(const Index pos, const Index n) const
{
const Index m = (n >= 0) ? n : size();
return sum(pos, n)/static_cast<double>(m);
}
template <typename T, Index os>
T StatArray<T, os>::covariance(const StatArray<T, os> &array) const
{
T s1, s2, res;
s1 = array.sum();
s2 = this->sum();
res = StatOp::prod<T>(array[0], (*this)[0]);
for (Index i = 1; i < size(); ++i)
{
res += StatOp::prod<T>(array[i], (*this)[i]);
}
res -= StatOp::prod<T>(s1, s2)/static_cast<double>(size());
res /= static_cast<double>(size() - 1);
return res;
}
template <typename T, Index os>
T StatArray<T, os>::variance(void) const
{
return covariance(*this);
}
template <typename MatType, Index os>
MatType StatArray<MatType, os>::covarianceMatrix(
const StatArray<MatType, os> &data) const
{
if (((*this)[central].cols() != 1) or (data[central].cols() != 1))
{
LATAN_ERROR(Size, "samples have more than one column");
}
Index n1 = (*this)[central].rows(), n2 = data[central].rows();
Index nSample = this->size();
MatType tmp1(n1, nSample), tmp2(n2, nSample), res(n1, n2);
MatType s1(n1, 1), s2(n2, 1), one(nSample, 1);
one.fill(1.);
s1.fill(0.);
s2.fill(0.);
for (unsigned int s = 0; s < nSample; ++s)
{
s1 += (*this)[s];
tmp1.col(s) = (*this)[s];
}
tmp1 -= s1*one.transpose()/static_cast<double>(nSample);
for (unsigned int s = 0; s < nSample; ++s)
{
s2 += data[s];
tmp2.col(s) = data[s];
}
tmp2 -= s2*one.transpose()/static_cast<double>(nSample);
res = tmp1*tmp2.transpose()/static_cast<double>(nSample - 1);
return res;
}
template <typename MatType, Index os>
MatType StatArray<MatType, os>::varianceMatrix(void) const
{
if ((*this)[0].cols() != 1)
{
LATAN_ERROR(Size, "samples have more than one column");
}
Index n1 = (*this)[0].rows();
Index nSample = this->size();
MatType tmp1(n1, nSample), res(n1, n1);
MatType s1(n1, 1), one(nSample, 1);
one.fill(1.);
s1.fill(0.);
for (unsigned int s = 0; s < nSample; ++s)
{
s1 += (*this)[s];
tmp1.col(s) = (*this)[s];
}
tmp1 -= s1*one.transpose()/static_cast<double>(nSample);
res = tmp1*tmp1.transpose()/static_cast<double>(nSample - 1);
return res;
}
template <typename MatType, Index os>
MatType StatArray<MatType, os>::correlationMatrix(void) const
{
MatType res = varianceMatrix();
MatType invDiag(res.rows(), 1);
invDiag = res.diagonal();
invDiag = invDiag.cwiseInverse().cwiseSqrt();
res = (invDiag*invDiag.transpose()).cwiseProduct(res);
return res;
}
// reduction operations ////////////////////////////////////////////////////////
namespace StatOp
{
template <typename T>
inline T prod(const T &a, const T &b)
{
return a*b;
}
template <>
inline Mat<double> prod(const Mat<double> &a, const Mat<double> &b)
{
return a.cwiseProduct(b);
}
}
// IO type /////////////////////////////////////////////////////////////////////
template <typename T, Index os>
IoObject::IoType StatArray<T, os>::getType(void) const
{
return IoType::noType;
}
END_LATAN_NAMESPACE
#endif // Latan_StatArray_hpp_

View File

@ -0,0 +1,565 @@
/*
* XYSampleData.cpp, part of LatAnalyze 3
*
* Copyright (C) 2013 - 2020 Antonin Portelli
*
* LatAnalyze 3 is free software: you can redistribute it and/or modify
* it under the terms of the GNU General Public License as published by
* the Free Software Foundation, either version 3 of the License, or
* (at your option) any later version.
*
* LatAnalyze 3 is distributed in the hope that it will be useful,
* but WITHOUT ANY WARRANTY; without even the implied warranty of
* MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
* GNU General Public License for more details.
*
* You should have received a copy of the GNU General Public License
* along with LatAnalyze 3. If not, see <http://www.gnu.org/licenses/>.
*/
#include <LatAnalyze/Statistics/XYSampleData.hpp>
#include <LatAnalyze/includes.hpp>
#include <LatAnalyze/Core/Math.hpp>
using namespace std;
using namespace Latan;
/******************************************************************************
* SampleFitResult implementation *
******************************************************************************/
double SampleFitResult::getChi2(const Index s) const
{
return chi2_[s];
}
const DSample & SampleFitResult::getChi2(const PlaceHolder ph __dumb) const
{
return chi2_;
}
double SampleFitResult::getChi2PerDof(const Index s) const
{
return chi2_[s]/getNDof();
}
DSample SampleFitResult::getChi2PerDof(const PlaceHolder ph __dumb) const
{
return chi2_/getNDof();
}
double SampleFitResult::getNDof(void) const
{
return static_cast<double>(nDof_);
}
Index SampleFitResult::getNPar(void) const
{
return nPar_;
}
double SampleFitResult::getPValue(const Index s) const
{
return Math::chi2PValue(getChi2(s), getNDof());
}
double SampleFitResult::getCorrRangeDb(void) const
{
return corrRangeDb_;
}
double SampleFitResult::getCcdf(const Index s) const
{
return Math::chi2Ccdf(getChi2(s), getNDof());
}
const DoubleFunction & SampleFitResult::getModel(const Index s,
const Index j) const
{
return model_[j][s];
}
const DoubleFunctionSample & SampleFitResult::getModel(
const PlaceHolder ph __dumb,
const Index j) const
{
return model_[j];
}
FitResult SampleFitResult::getFitResult(const Index s) const
{
FitResult fit;
fit = (*this)[s];
fit.chi2_ = getChi2();
fit.nDof_ = static_cast<Index>(getNDof());
fit.model_.resize(model_.size());
for (unsigned int k = 0; k < model_.size(); ++k)
{
fit.model_[k] = model_[k][s];
}
return fit;
}
// IO //////////////////////////////////////////////////////////////////////////
void SampleFitResult::print(const bool printXsi, ostream &out) const
{
char buf[256];
Index pMax = printXsi ? size() : nPar_;
DMat err = this->variance().cwiseSqrt();
sprintf(buf, "chi^2/dof= %.1e/%d= %.2e -- chi^2 CCDF= %.2e -- p-value= %.2e",
getChi2(), static_cast<int>(getNDof()), getChi2PerDof(), getCcdf(),
getPValue());
out << buf << endl;
sprintf(buf, "correlation dynamic range= %.1f dB", getCorrRangeDb());
out << buf << endl;
for (Index p = 0; p < pMax; ++p)
{
sprintf(buf, "%12s= % e +/- %e", parName_[p].c_str(),
(*this)[central](p), err(p));
out << buf << endl;
}
}
/******************************************************************************
* XYSampleData implementation *
******************************************************************************/
// constructor /////////////////////////////////////////////////////////////////
XYSampleData::XYSampleData(const Index nSample)
: nSample_(nSample)
{}
// data access /////////////////////////////////////////////////////////////////
DSample & XYSampleData::x(const Index r, const Index i)
{
checkXIndex(r, i);
scheduleDataInit();
scheduleComputeVarMat();
if (xData_[i][r].size() == 0)
{
xData_[i][r].resize(nSample_);
}
return xData_[i][r];
}
const DSample & XYSampleData::x(const Index r, const Index i) const
{
checkXIndex(r, i);
return xData_[i][r];
}
const DMatSample & XYSampleData::x(const Index k)
{
checkDataIndex(k);
updateXMap();
return xMap_.at(k);
}
DSample & XYSampleData::y(const Index k, const Index j)
{
checkYDim(j);
if (!pointExists(k, j))
{
registerDataPoint(k, j);
}
scheduleDataInit();
scheduleComputeVarMat();
if (yData_[j][k].size() == 0)
{
yData_[j][k].resize(nSample_);
}
return yData_[j][k];
}
const DSample & XYSampleData::y(const Index k, const Index j) const
{
checkPoint(k, j);
return yData_[j].at(k);
}
void XYSampleData::setUnidimData(const DMatSample &xData,
const vector<const DMatSample *> &v)
{
FOR_STAT_ARRAY(xData, s)
FOR_VEC(xData[central], r)
{
x(r, 0)[s] = xData[s](r);
for (unsigned int j = 0; j < v.size(); ++j)
{
y(r, j)[s] = (*(v[j]))[s](r);
}
}
}
const DMat & XYSampleData::getXXVar(const Index i1, const Index i2)
{
checkXDim(i1);
checkXDim(i2);
computeVarMat();
return data_.getXXVar(i1, i2);
}
const DMat & XYSampleData::getYYVar(const Index j1, const Index j2)
{
checkYDim(j1);
checkYDim(j2);
computeVarMat();
return data_.getYYVar(j1, j2);
}
const DMat & XYSampleData::getXYVar(const Index i, const Index j)
{
checkXDim(i);
checkYDim(j);
computeVarMat();
return data_.getXYVar(i, j);
}
DVec XYSampleData::getXError(const Index i)
{
checkXDim(i);
computeVarMat();
return data_.getXError(i);
}
DVec XYSampleData::getYError(const Index j)
{
checkYDim(j);
computeVarMat();
return data_.getYError(j);
}
// get total fit variance matrix and its pseudo-inverse ////////////////////////
const DMat & XYSampleData::getFitVarMat(void)
{
computeVarMat();
return data_.getFitVarMat();
}
const DMat & XYSampleData::getFitVarMatPInv(void)
{
computeVarMat();
return data_.getFitVarMatPInv();
}
const DMat & XYSampleData::getFitCorrMat(void)
{
computeVarMat();
return data_.getFitCorrMat();
}
const DMat & XYSampleData::getFitCorrMatPInv(void)
{
computeVarMat();
return data_.getFitCorrMatPInv();
}
// set data to a particular sample /////////////////////////////////////////////
void XYSampleData::setDataToSample(const Index s)
{
if (initData_ or (s != dataSample_))
{
for (Index i = 0; i < getNXDim(); ++i)
for (Index r = 0; r < getXSize(i); ++r)
{
data_.x(r, i) = xData_[i][r][s];
}
for (Index j = 0; j < getNYDim(); ++j)
for (auto &p: yData_[j])
{
data_.y(p.first, j) = p.second[s];
}
dataSample_ = s;
initData_ = false;
}
}
// get internal XYStatData /////////////////////////////////////////////////////
const XYStatData & XYSampleData::getData(void)
{
setDataToSample(central);
computeVarMat();
return data_;
}
// fit /////////////////////////////////////////////////////////////////////////
SampleFitResult XYSampleData::fit(std::vector<Minimizer *> &minimizer,
const DVec &init,
const std::vector<const DoubleModel *> &v)
{
computeVarMat();
SampleFitResult result;
FitResult sampleResult;
DVec initCopy = init;
Minimizer::Verbosity verbCopy = minimizer.back()->getVerbosity();
result.resize(nSample_);
result.chi2_.resize(nSample_);
result.model_.resize(v.size());
FOR_STAT_ARRAY(result, s)
{
setDataToSample(s);
if (s == central)
{
sampleResult = data_.fit(minimizer, initCopy, v);
initCopy = sampleResult.segment(0, initCopy.size());
if (verbCopy != Minimizer::Verbosity::Debug)
{
minimizer.back()->setVerbosity(Minimizer::Verbosity::Silent);
}
}
else
{
sampleResult = data_.fit(*(minimizer.back()), initCopy, v);
}
result[s] = sampleResult;
result.chi2_[s] = sampleResult.getChi2();
for (unsigned int j = 0; j < v.size(); ++j)
{
result.model_[j].resize(nSample_);
result.model_[j][s] = sampleResult.getModel(j);
}
}
minimizer.back()->setVerbosity(verbCopy);
result.nPar_ = sampleResult.getNPar();
result.nDof_ = sampleResult.nDof_;
result.parName_ = sampleResult.parName_;
result.corrRangeDb_ = Math::cdr(getFitCorrMat());
return result;
}
SampleFitResult XYSampleData::fit(Minimizer &minimizer,
const DVec &init,
const std::vector<const DoubleModel *> &v)
{
vector<Minimizer *> mv{&minimizer};
return fit(mv, init, v);
}
// residuals ///////////////////////////////////////////////////////////////////
XYSampleData XYSampleData::getResiduals(const SampleFitResult &fit)
{
XYSampleData res(*this);
for (Index j = 0; j < getNYDim(); ++j)
{
const DoubleFunctionSample &f = fit.getModel(_, j);
for (auto &p: yData_[j])
{
res.y(p.first, j) -= f(x(p.first));
}
}
return res;
}
XYSampleData XYSampleData::getNormalisedResiduals(const SampleFitResult &fit)
{
XYSampleData res(*this);
for (Index j = 0; j < getNYDim(); ++j)
{
const DoubleFunctionSample &f = fit.getModel(_, j);
for (auto &p: yData_[j])
{
res.y(p.first, j) -= f(x(p.first));
}
const DMat &var = res.getYYVar(j, j);
for (auto &p: yData_[j])
{
res.y(p.first, j) /= sqrt(var(p.first, p.first));
}
}
return res;
}
XYSampleData XYSampleData::getPartialResiduals(const SampleFitResult &fit,
const DVec &ref, const Index i)
{
XYSampleData res(*this);
DMatSample buf(nSample_);
buf.fill(ref);
for (Index j = 0; j < getNYDim(); ++j)
{
const DoubleFunctionSample &f = fit.getModel(_, j);
for (auto &p: yData_[j])
{
FOR_STAT_ARRAY(buf, s)
{
buf[s](i) = x(p.first)[s](i);
}
res.y(p.first, j) -= f(x(p.first)) - f(buf);
}
}
return res;
}
// buffer list of x vectors ////////////////////////////////////////////////////
void XYSampleData::scheduleXMapInit(void)
{
initXMap_ = true;
}
void XYSampleData::updateXMap(void)
{
if (initXMap_)
{
for (Index s = central; s < nSample_; ++s)
{
setDataToSample(s);
for (auto k: getDataIndexSet())
{
if (s == central)
{
xMap_[k].resize(nSample_);
}
xMap_[k][s] = data_.x(k);
}
}
initXMap_ = false;
}
}
// schedule data initilization from samples ////////////////////////////////////
void XYSampleData::scheduleDataInit(void)
{
initData_ = true;
}
// variance matrix computation /////////////////////////////////////////////////
void XYSampleData::scheduleComputeVarMat(void)
{
computeVarMat_ = true;
}
void XYSampleData::computeVarMat(void)
{
if (computeVarMat_)
{
// initialize data if necessary
setDataToSample(central);
// compute relevant sizes
Index size = 0, ySize = 0;
for (Index j = 0; j < getNYDim(); ++j)
{
size += getYSize(j);
}
ySize = size;
for (Index i = 0; i < getNXDim(); ++i)
{
size += getXSize(i);
}
// compute total matrix
DMatSample z(nSample_, size, 1);
DMat var;
Index a;
FOR_STAT_ARRAY(z, s)
{
a = 0;
for (Index j = 0; j < getNYDim(); ++j)
for (auto &p: yData_[j])
{
z[s](a, 0) = p.second[s];
a++;
}
for (Index i = 0; i < getNXDim(); ++i)
for (Index r = 0; r < getXSize(i); ++r)
{
z[s](a, 0) = xData_[i][r][s];
a++;
}
}
var = z.varianceMatrix();
// assign blocks to data
Index a1, a2;
a1 = ySize;
for (Index i1 = 0; i1 < getNXDim(); ++i1)
{
a2 = ySize;
for (Index i2 = 0; i2 < getNXDim(); ++i2)
{
data_.setXXVar(i1, i2,
var.block(a1, a2, getXSize(i1), getXSize(i2)));
a2 += getXSize(i2);
}
a1 += getXSize(i1);
}
a1 = 0;
for (Index j1 = 0; j1 < getNYDim(); ++j1)
{
a2 = 0;
for (Index j2 = 0; j2 < getNYDim(); ++j2)
{
data_.setYYVar(j1, j2,
var.block(a1, a2, getYSize(j1), getYSize(j2)));
a2 += getYSize(j2);
}
a1 += getYSize(j1);
}
a1 = ySize;
for (Index i = 0; i < getNXDim(); ++i)
{
a2 = 0;
for (Index j = 0; j < getNYDim(); ++j)
{
data_.setXYVar(i, j,
var.block(a1, a2, getXSize(i), getYSize(j)));
a2 += getYSize(j);
}
a1 += getXSize(i);
}
computeVarMat_ = false;
}
if (initVarMat())
{
data_.copyInterface(*this);
scheduleFitVarMatInit(false);
}
}
// create data /////////////////////////////////////////////////////////////////
void XYSampleData::createXData(const string name, const Index nData)
{
data_.addXDim(nData, name);
xData_.push_back(vector<DSample>(nData));
}
void XYSampleData::createYData(const string name)
{
data_.addYDim(name);
yData_.push_back(map<Index, DSample>());
}

View File

@ -0,0 +1,184 @@
/*
* XYSampleData.hpp, part of LatAnalyze 3
*
* Copyright (C) 2013 - 2020 Antonin Portelli
*
* LatAnalyze 3 is free software: you can redistribute it and/or modify
* it under the terms of the GNU General Public License as published by
* the Free Software Foundation, either version 3 of the License, or
* (at your option) any later version.
*
* LatAnalyze 3 is distributed in the hope that it will be useful,
* but WITHOUT ANY WARRANTY; without even the implied warranty of
* MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
* GNU General Public License for more details.
*
* You should have received a copy of the GNU General Public License
* along with LatAnalyze 3. If not, see <http://www.gnu.org/licenses/>.
*/
#ifndef Latan_XYSampleData_hpp_
#define Latan_XYSampleData_hpp_
#include <LatAnalyze/Global.hpp>
#include <LatAnalyze/Statistics/FitInterface.hpp>
#include <LatAnalyze/Numerical/Minimizer.hpp>
#include <LatAnalyze/Statistics/MatSample.hpp>
#include <LatAnalyze/Functional/Model.hpp>
#include <LatAnalyze/Statistics/XYStatData.hpp>
BEGIN_LATAN_NAMESPACE
/******************************************************************************
* object for fit result *
******************************************************************************/
class SampleFitResult: public DMatSample
{
friend class XYSampleData;
public:
// constructors
SampleFitResult(void) = default;
EIGEN_EXPR_CTOR(SampleFitResult, SampleFitResult, DMatSample, ArrayExpr)
// destructor
virtual ~SampleFitResult(void) = default;
// access
double getChi2(const Index s = central) const;
const DSample & getChi2(const PlaceHolder ph) const;
double getChi2PerDof(const Index s = central) const;
DSample getChi2PerDof(const PlaceHolder ph) const;
double getNDof(void) const;
Index getNPar(void) const;
double getPValue(const Index s = central) const;
double getCorrRangeDb(void) const;
double getCcdf(const Index s = central) const;
const DoubleFunction & getModel(const Index s = central,
const Index j = 0) const;
const DoubleFunctionSample & getModel(const PlaceHolder ph,
const Index j = 0) const;
FitResult getFitResult(const Index s = central) const;
// IO
void print(const bool printXsi = false,
std::ostream &out = std::cout) const;
private:
DSample chi2_;
double corrRangeDb_{0.};
Index nDof_{0}, nPar_{0};
std::vector<DoubleFunctionSample> model_;
std::vector<std::string> parName_;
};
/******************************************************************************
* XYSampleData *
******************************************************************************/
class XYSampleData: public FitInterface
{
public:
// constructor
explicit XYSampleData(const Index nSample);
// destructor
virtual ~XYSampleData(void) = default;
// data access
DSample & x(const Index r, const Index i);
const DSample & x(const Index r, const Index i) const;
const DMatSample & x(const Index k);
DSample & y(const Index k, const Index j);
const DSample & y(const Index k, const Index j) const;
void setUnidimData(const DMatSample &xData,
const std::vector<const DMatSample *> &v);
template <typename... Ts>
void setUnidimData(const DMatSample &xData,
const Ts & ...yDatas);
const DMat & getXXVar(const Index i1, const Index i2);
const DMat & getYYVar(const Index j1, const Index j2);
const DMat & getXYVar(const Index i, const Index j);
DVec getXError(const Index i);
DVec getYError(const Index j);
// get total fit variance & correlation matrices and their pseudo-inverse
const DMat & getFitVarMat(void);
const DMat & getFitVarMatPInv(void);
const DMat & getFitCorrMat(void);
const DMat & getFitCorrMatPInv(void);
// set data to a particular sample
void setDataToSample(const Index s);
// get internal XYStatData
const XYStatData & getData(void);
// fit
SampleFitResult fit(std::vector<Minimizer *> &minimizer, const DVec &init,
const std::vector<const DoubleModel *> &v);
SampleFitResult fit(Minimizer &minimizer, const DVec &init,
const std::vector<const DoubleModel *> &v);
template <typename... Ts>
SampleFitResult fit(std::vector<Minimizer *> &minimizer, const DVec &init,
const DoubleModel &model, const Ts... models);
template <typename... Ts>
SampleFitResult fit(Minimizer &minimizer, const DVec &init,
const DoubleModel &model, const Ts... models);
// residuals
XYSampleData getResiduals(const SampleFitResult &fit);
XYSampleData getNormalisedResiduals(const SampleFitResult &fit);
XYSampleData getPartialResiduals(const SampleFitResult &fit, const DVec &x,
const Index i);
private:
// buffer list of x vectors
void scheduleXMapInit(void);
void updateXMap(void);
// schedule data initilization from samples
void scheduleDataInit(void);
// variance matrix computation
void scheduleComputeVarMat(void);
void computeVarMat(void);
// create data
virtual void createXData(const std::string name, const Index nData);
virtual void createYData(const std::string name);
private:
std::vector<std::map<Index, DSample>> yData_;
std::vector<std::vector<DSample>> xData_;
std::map<Index, DMatSample> xMap_;
XYStatData data_;
Index nSample_, dataSample_{central};
bool initData_{true}, computeVarMat_{true};
bool initXMap_{true};
};
/******************************************************************************
* XYSampleData template implementation *
******************************************************************************/
template <typename... Ts>
void XYSampleData::setUnidimData(const DMatSample &xData, const Ts & ...yDatas)
{
static_assert(static_or<std::is_assignable<DMatSample, Ts>::value...>::value,
"y data arguments are not compatible with DMatSample");
std::vector<const DMatSample *> v{&yDatas...};
setUnidimData(xData, v);
}
template <typename... Ts>
SampleFitResult XYSampleData::fit(std::vector<Minimizer *> &minimizer,
const DVec &init,
const DoubleModel &model, const Ts... models)
{
static_assert(static_or<std::is_assignable<DoubleModel &, Ts>::value...>::value,
"model arguments are not compatible with DoubleModel");
std::vector<const DoubleModel *> modelVector{&model, &models...};
return fit(minimizer, init, modelVector);
}
template <typename... Ts>
SampleFitResult XYSampleData::fit(Minimizer &minimizer, const DVec &init,
const DoubleModel &model, const Ts... models)
{
static_assert(static_or<std::is_assignable<DoubleModel &, Ts>::value...>::value,
"model arguments are not compatible with DoubleModel");
std::vector<Minimizer *> mv{&minimizer};
return fit(mv, init, model, models...);
}
END_LATAN_NAMESPACE
#endif // Latan_XYSampleData_hpp_

View File

@ -0,0 +1,662 @@
/*
* XYStatData.cpp, part of LatAnalyze 3
*
* Copyright (C) 2013 - 2020 Antonin Portelli
*
* LatAnalyze 3 is free software: you can redistribute it and/or modify
* it under the terms of the GNU General Public License as published by
* the Free Software Foundation, either version 3 of the License, or
* (at your option) any later version.
*
* LatAnalyze 3 is distributed in the hope that it will be useful,
* but WITHOUT ANY WARRANTY; without even the implied warranty of
* MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
* GNU General Public License for more details.
*
* You should have received a copy of the GNU General Public License
* along with LatAnalyze 3. If not, see <http://www.gnu.org/licenses/>.
*/
#include <LatAnalyze/Statistics/XYStatData.hpp>
#include <LatAnalyze/includes.hpp>
#include <LatAnalyze/Core/Math.hpp>
using namespace std;
using namespace Latan;
static constexpr double maxXsiDev = 10.;
/******************************************************************************
* FitResult implementation *
******************************************************************************/
// access //////////////////////////////////////////////////////////////////////
double FitResult::getChi2(void) const
{
return chi2_;
}
double FitResult::getChi2PerDof(void) const
{
return chi2_/getNDof();
}
double FitResult::getNDof(void) const
{
return static_cast<double>(nDof_);
}
Index FitResult::getNPar(void) const
{
return nPar_;
}
double FitResult::getPValue(void) const
{
return Math::chi2PValue(getChi2(), getNDof());;
}
double FitResult::getCcdf(void) const
{
return Math::chi2Ccdf(getChi2(), getNDof());;
}
double FitResult::getCorrRangeDb(void) const
{
return corrRangeDb_;
}
const DoubleFunction & FitResult::getModel(const Index j) const
{
return model_[j];
}
// IO //////////////////////////////////////////////////////////////////////////
void FitResult::print(const bool printXsi, ostream &out) const
{
char buf[256];
Index pMax = printXsi ? size() : nPar_;
sprintf(buf, "chi^2/dof= %.1e/%d= %.2e -- chi^2 CCDF= %.2e -- p-value= %.2e",
getChi2(), static_cast<int>(getNDof()), getChi2PerDof(), getCcdf(),
getPValue());
out << buf << endl;
sprintf(buf, "correlation dynamic range= %.1f dB", getCorrRangeDb());
out << buf << endl;
for (Index p = 0; p < pMax; ++p)
{
sprintf(buf, "%12s= %e", parName_[p].c_str(), (*this)(p));
out << buf << endl;
}
}
/******************************************************************************
* XYStatData implementation *
******************************************************************************/
// data access /////////////////////////////////////////////////////////////////
double & XYStatData::x(const Index r, const Index i)
{
checkXIndex(r, i);
scheduleXMapInit();
scheduleChi2DataVecInit();
return xData_[i](r);
}
const double & XYStatData::x(const Index r, const Index i) const
{
checkXIndex(r, i);
return xData_[i](r);
}
const DVec & XYStatData::x(const Index k) const
{
checkDataIndex(k);
updateXMap();
return xMap_.at(k);
}
double & XYStatData::y(const Index k, const Index j)
{
checkYDim(j);
if (!pointExists(k, j))
{
registerDataPoint(k, j);
resizeVarMat();
}
scheduleXMapInit();
scheduleChi2DataVecInit();
return yData_[j][k];
}
const double & XYStatData::y(const Index k, const Index j) const
{
checkPoint(k, j);
return yData_[j].at(k);
}
void XYStatData::setXXVar(const Index i1, const Index i2, const DMat &m)
{
checkXDim(i1);
checkXDim(i2);
checkVarMat(m, xxVar_(i1, i2));
xxVar_(i1, i2) = m;
if (i1 != i2)
{
xxVar_(i2, i1) = m.transpose();
}
scheduleFitVarMatInit();
}
void XYStatData::setYYVar(const Index j1, const Index j2, const DMat &m)
{
checkYDim(j1);
checkYDim(j2);
checkVarMat(m, yyVar_(j1, j2));
yyVar_(j1, j2) = m;
if (j1 != j2)
{
yyVar_(j2, j1) = m.transpose();
}
scheduleFitVarMatInit();
}
void XYStatData::setXYVar(const Index i, const Index j, const DMat &m)
{
checkXDim(i);
checkYDim(j);
checkVarMat(m, xyVar_(i, j));
xyVar_(i, j) = m;
scheduleFitVarMatInit();
}
void XYStatData::setXError(const Index i, const DVec &err)
{
checkXDim(i);
checkErrVec(err, xxVar_(i, i));
xxVar_(i, i).diagonal() = err.cwiseProduct(err);
scheduleFitVarMatInit();
}
void XYStatData::setYError(const Index j, const DVec &err)
{
checkXDim(j);
checkErrVec(err, yyVar_(j, j));
yyVar_(j, j).diagonal() = err.cwiseProduct(err);
scheduleFitVarMatInit();
}
const DMat & XYStatData::getXXVar(const Index i1, const Index i2) const
{
checkXDim(i1);
checkXDim(i2);
return xxVar_(i1, i2);
}
const DMat & XYStatData::getYYVar(const Index j1, const Index j2) const
{
checkYDim(j1);
checkYDim(j2);
return yyVar_(j1, j2);
}
const DMat & XYStatData::getXYVar(const Index i, const Index j) const
{
checkXDim(i);
checkYDim(j);
return xyVar_(i, j);
}
DVec XYStatData::getXError(const Index i) const
{
checkXDim(i);
return xxVar_(i, i).diagonal().cwiseSqrt();
}
DVec XYStatData::getYError(const Index j) const
{
checkYDim(j);
return yyVar_(j, j).diagonal().cwiseSqrt();
}
DMat XYStatData::getTable(const Index i, const Index j) const
{
checkXDim(i);
checkYDim(j);
DMat table(getYSize(j), 4);
Index row = 0;
for (auto &p: yData_[j])
{
Index k = p.first;
Index r = dataCoord(k)[i];
table(row, 0) = x(k)(i);
table(row, 2) = p.second;
table(row, 1) = xxVar_(i, i).diagonal().cwiseSqrt()(r);
table(row, 3) = yyVar_(j, j).diagonal().cwiseSqrt()(row);
row++;
}
return table;
}
// get total fit variance matrix ///////////////////////////////////////////////
const DMat & XYStatData::getFitVarMat(void)
{
updateFitVarMat();
return fitVar_;
}
const DMat & XYStatData::getFitVarMatPInv(void)
{
updateFitVarMat();
return fitVarInv_;
}
const DMat & XYStatData::getFitCorrMat(void)
{
updateFitVarMat();
return fitCorr_;
}
const DMat & XYStatData::getFitCorrMatPInv(void)
{
updateFitVarMat();
return fitCorrInv_;
}
// fit /////////////////////////////////////////////////////////////////////////
FitResult XYStatData::fit(vector<Minimizer *> &minimizer, const DVec &init,
const vector<const DoubleModel *> &v)
{
// check model consistency
checkModelVec(v);
// buffering
updateLayout();
updateFitVarMat();
updateChi2DataVec();
// get number of parameters
Index nPar = v[0]->getNPar();
Index nXDim = getNXDim();
Index totalNPar = nPar + layout.totalXSize;
// chi^2 functions
auto corrChi2Func = [this, nPar, nXDim, totalNPar, &v](const double *x)->double
{
ConstMap<DVec> p(x, totalNPar);
updateChi2ModVec(p, v, nPar, nXDim);
chi2Vec_ = (chi2ModVec_ - chi2DataVec_);
return chi2Vec_.dot(fitVarInv_*chi2Vec_);
};
DoubleFunction corrChi2(corrChi2Func, totalNPar);
auto uncorrChi2Func = [this, nPar, nXDim, totalNPar, &v](const double *x)->double
{
ConstMap<DVec> p(x, totalNPar);
updateChi2ModVec(p, v, nPar, nXDim);
chi2Vec_ = (chi2ModVec_ - chi2DataVec_);
return chi2Vec_.dot(chi2Vec_.cwiseQuotient(fitVar_.diagonal()));
};
DoubleFunction uncorrChi2(uncorrChi2Func, totalNPar);
DoubleFunction &chi2 = hasCorrelations() ? corrChi2 : uncorrChi2;
for (Index p = 0; p < nPar; ++p)
{
chi2.varName().setName(p, v[0]->parName().getName(p));
}
for (Index p = 0; p < totalNPar - nPar; ++p)
{
chi2.varName().setName(p + nPar, "xsi_" + strFrom(p));
}
// minimization
FitResult result;
DVec totalInit(totalNPar);
//// set total init vector
totalInit.segment(0, nPar) = init;
totalInit.segment(nPar, layout.totalXSize) =
chi2DataVec_.segment(layout.totalYSize, layout.totalXSize);
for (auto &m: minimizer)
{
m->setInit(totalInit);
if (m->supportLimits())
{
//// do not allow more than maxXsiDev std. deviations on the x-axis
for (Index p = nPar; p < totalNPar; ++p)
{
double err;
err = sqrt(fitVar_.diagonal()(layout.totalYSize + p - nPar));
m->useLowLimit(p);
m->useHighLimit(p);
m->setLowLimit(p, totalInit(p) - maxXsiDev*err);
m->setHighLimit(p, totalInit(p) + maxXsiDev*err);
}
}
//// minimize and store results
result = (*m)(chi2);
totalInit = result;
}
result.corrRangeDb_ = Math::cdr(getFitCorrMat());
result.chi2_ = chi2(result);
result.nPar_ = nPar;
result.nDof_ = layout.totalYSize - nPar;
result.model_.resize(v.size());
for (unsigned int j = 0; j < v.size(); ++j)
{
result.model_[j] = v[j]->fixPar(result);
}
for (Index p = 0; p < totalNPar; ++p)
{
result.parName_.push_back(chi2.varName().getName(p));
}
return result;
}
FitResult XYStatData::fit(Minimizer &minimizer, const DVec &init,
const vector<const DoubleModel *> &v)
{
vector<Minimizer *> mv{&minimizer};
return fit(mv, init, v);
}
// residuals ///////////////////////////////////////////////////////////////////
XYStatData XYStatData::getResiduals(const FitResult &fit)
{
XYStatData res(*this);
for (Index j = 0; j < getNYDim(); ++j)
{
const DoubleFunction &f = fit.getModel(j);
for (auto &p: yData_[j])
{
res.y(p.first, j) -= f(x(p.first));
}
}
return res;
}
XYStatData XYStatData::getNormalisedResiduals(const FitResult &fit)
{
XYStatData res(*this);
for (Index j = 0; j < getNYDim(); ++j)
{
const DoubleFunction &f = fit.getModel(j);
const DVec err = getYError(j);
Index row = 0;
for (auto &p: yData_[j])
{
res.y(p.first, j) -= f(x(p.first));
res.y(p.first, j) /= err(row);
row++;
}
}
return res;
}
XYStatData XYStatData::getPartialResiduals(const FitResult &fit,
const DVec &ref, const Index i)
{
XYStatData res(*this);
DVec buf(ref);
for (Index j = 0; j < res.getNYDim(); ++j)
{
const DoubleFunction &f = fit.getModel(j);
for (auto &p: yData_[j])
{
buf(i) = x(p.first)(i);
res.y(p.first, j) -= f(x(p.first)) - f(buf);
}
}
return res;
}
// create data /////////////////////////////////////////////////////////////////
void XYStatData::createXData(const std::string name __dumb, const Index nData)
{
xData_.push_back(DVec::Zero(nData));
xBuf_.resize(xData_.size());
resizeVarMat();
}
void XYStatData::createYData(const std::string name __dumb)
{
yData_.push_back(map<Index, double>());
resizeVarMat();
}
void XYStatData::resizeVarMat(void)
{
xxVar_.conservativeResize(getNXDim(), getNXDim());
for (Index i1 = 0; i1 < getNXDim(); ++i1)
for (Index i2 = 0; i2 < getNXDim(); ++i2)
{
xxVar_(i1, i2).conservativeResize(getXSize(i1), getXSize(i2));
}
yyVar_.conservativeResize(getNYDim(), getNYDim());
for (Index j1 = 0; j1 < getNYDim(); ++j1)
for (Index j2 = 0; j2 < getNYDim(); ++j2)
{
yyVar_(j1, j2).conservativeResize(getYSize(j1), getYSize(j2));
}
xyVar_.conservativeResize(getNXDim(), getNYDim());
for (Index i = 0; i < getNXDim(); ++i)
for (Index j = 0; j < getNYDim(); ++j)
{
xyVar_(i, j).conservativeResize(getXSize(i), getYSize(j));
}
scheduleFitVarMatInit();
}
// schedule buffer computation /////////////////////////////////////////////////
void XYStatData::scheduleXMapInit(void)
{
initXMap_ = true;
}
void XYStatData::scheduleChi2DataVecInit(void)
{
initChi2DataVec_ = true;
}
// buffer total fit variance matrix ////////////////////////////////////////////
void XYStatData::updateFitVarMat(void)
{
if (initVarMat())
{
updateLayout();
DMat &v = fitVar_;
Index roffs, coffs;
v.resize(layout.totalSize, layout.totalSize);
roffs = layout.totalYSize;
for (Index ifit1 = 0; ifit1 < layout.nXFitDim; ++ifit1)
{
coffs = layout.totalYSize;
for (Index ifit2 = 0; ifit2 < layout.nXFitDim; ++ifit2)
{
for (Index rfit1 = 0; rfit1 < layout.xSize[ifit1]; ++rfit1)
for (Index rfit2 = 0; rfit2 < layout.xSize[ifit2]; ++rfit2)
{
Index i1, i2, r1, r2;
i1 = layout.xDim[ifit1];
i2 = layout.xDim[ifit2];
r1 = layout.x[ifit1][rfit1];
r2 = layout.x[ifit2][rfit2];
v(roffs+rfit1, coffs+rfit2) = xxVar_(i1, i2)(r1, r2);
v(coffs+rfit2, roffs+rfit1) = v(roffs+rfit1, coffs+rfit2);
}
coffs += layout.xSize[ifit2];
}
roffs += layout.xSize[ifit1];
}
roffs = 0;
for (Index jfit1 = 0; jfit1 < layout.nYFitDim; ++jfit1)
{
coffs = 0;
for (Index jfit2 = 0; jfit2 < layout.nYFitDim; ++jfit2)
{
for (Index sfit1 = 0; sfit1 < layout.ySize[jfit1]; ++sfit1)
for (Index sfit2 = 0; sfit2 < layout.ySize[jfit2]; ++sfit2)
{
Index j1, j2, s1, s2;
j1 = layout.yDim[jfit1];
j2 = layout.yDim[jfit2];
s1 = layout.y[jfit1][sfit1];
s2 = layout.y[jfit2][sfit2];
v(roffs+sfit1, coffs+sfit2) = yyVar_(j1, j2)(s1, s2);
v(coffs+sfit2, roffs+sfit1) = v(roffs+sfit1, coffs+sfit2);
}
coffs += layout.ySize[jfit2];
}
roffs += layout.ySize[jfit1];
}
roffs = layout.totalYSize;
for (Index ifit = 0; ifit < layout.nXFitDim; ++ifit)
{
coffs = 0;
for (Index jfit = 0; jfit < layout.nYFitDim; ++jfit)
{
for (Index rfit = 0; rfit < layout.xSize[ifit]; ++rfit)
for (Index sfit = 0; sfit < layout.ySize[jfit]; ++sfit)
{
Index i, j, r, s;
i = layout.xDim[ifit];
j = layout.yDim[jfit];
r = layout.x[ifit][rfit];
s = layout.y[jfit][sfit];
v(roffs+rfit, coffs+sfit) = xyVar_(i, j)(r, s);
v(coffs+sfit, roffs+rfit) = v(roffs+rfit, coffs+sfit);
}
coffs += layout.ySize[jfit];
}
roffs += layout.xSize[ifit];
}
chi2DataVec_.resize(layout.totalSize);
chi2ModVec_.resize(layout.totalSize);
chi2Vec_.resize(layout.totalSize);
fitVar_ = fitVar_.cwiseProduct(makeCorrFilter());
fitCorr_ = Math::varToCorr(fitVar_);
fitCorrInv_ = fitCorr_.pInverse(getSvdTolerance());
fitVarInv_ = Math::corrToVar(fitCorrInv_, fitVar_.diagonal().cwiseInverse());
scheduleFitVarMatInit(false);
}
}
// buffer list of x vectors ////////////////////////////////////////////////////
void XYStatData::updateXMap(void) const
{
if (initXMap_)
{
XYStatData * modThis = const_cast<XYStatData *>(this);
modThis->xMap_.clear();
modThis->xMap_.resize(getMaxDataIndex());
for (auto k: getDataIndexSet())
{
modThis->xMap_[k] = DVec(getNXDim());
for (Index i = 0; i < getNXDim(); ++i)
{
modThis->xMap_[k](i) = xData_[i](dataCoord(k)[i]);
}
}
modThis->initXMap_ = false;
}
}
// buffer chi^2 vectors ////////////////////////////////////////////////////////
void XYStatData::updateChi2DataVec(void)
{
if (initChi2DataVec_)
{
Index a = 0, j, k, i, r;
updateLayout();
for (Index jfit = 0; jfit < layout.nYFitDim; ++jfit)
for (Index sfit = 0; sfit < layout.ySize[jfit]; ++sfit)
{
j = layout.yDim[jfit];
k = layout.data[jfit][sfit];
chi2DataVec_(a) = yData_[j][k];
a++;
}
for (Index ifit = 0; ifit < layout.nXFitDim; ++ifit)
for (Index rfit = 0; rfit < layout.xSize[ifit]; ++rfit)
{
i = layout.xDim[ifit];
r = layout.x[ifit][rfit];
chi2DataVec_(a) = xData_[i](r);
a++;
}
initChi2DataVec_ = false;
}
}
// WARNING: updateChi2ModVec is heavily called by fit
void XYStatData::updateChi2ModVec(const DVec p,
const vector<const DoubleModel *> &v,
const Index nPar, const Index nXDim)
{
updateLayout();
updateXMap();
Index a = 0, j, k, ind;
auto &par = p.segment(0, nPar), &xsi = p.segment(nPar, layout.totalXSize);
for (Index jfit = 0; jfit < layout.nYFitDim; ++jfit)
{
j = layout.yDim[jfit];
for (Index sfit = 0; sfit < layout.ySize[jfit]; ++sfit)
{
k = layout.data[jfit][sfit];
for (Index i = 0; i < nXDim; ++i)
{
ind = layout.xIndFromData[k][i] - layout.totalYSize;
xBuf_(i) = (ind >= 0) ? xsi(ind) : xMap_[k](i);
}
chi2ModVec_(a) = (*v[j])(xBuf_.data(), par.data());
a++;
}
}
chi2ModVec_.segment(a, layout.totalXSize) = xsi;
}

View File

@ -0,0 +1,241 @@
/*
* XYStatData.hpp, part of LatAnalyze 3
*
* Copyright (C) 2013 - 2020 Antonin Portelli
*
* LatAnalyze 3 is free software: you can redistribute it and/or modify
* it under the terms of the GNU General Public License as published by
* the Free Software Foundation, either version 3 of the License, or
* (at your option) any later version.
*
* LatAnalyze 3 is distributed in the hope that it will be useful,
* but WITHOUT ANY WARRANTY; without even the implied warranty of
* MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
* GNU General Public License for more details.
*
* You should have received a copy of the GNU General Public License
* along with LatAnalyze 3. If not, see <http://www.gnu.org/licenses/>.
*/
#ifndef Latan_XYStatData_hpp_
#define Latan_XYStatData_hpp_
#include <LatAnalyze/Global.hpp>
#include <LatAnalyze/Statistics/FitInterface.hpp>
#include <LatAnalyze/Numerical/Minimizer.hpp>
#include <LatAnalyze/Functional/Model.hpp>
BEGIN_LATAN_NAMESPACE
/******************************************************************************
* object for fit result *
******************************************************************************/
class FitResult: public DVec
{
friend class XYStatData;
friend class XYSampleData;
friend class SampleFitResult;
public:
// constructors
FitResult(void) = default;
EIGEN_EXPR_CTOR(FitResult, FitResult, Base, MatExpr)
// destructor
virtual ~FitResult(void) = default;
// access
double getChi2(void) const;
double getChi2PerDof(void) const;
double getNDof(void) const;
Index getNPar(void) const;
double getPValue(void) const;
double getCcdf(void) const;
double getCorrRangeDb(void) const;
const DoubleFunction & getModel(const Index j = 0) const;
// IO
void print(const bool printXsi = false,
std::ostream &out = std::cout) const;
private:
double chi2_{0.}, corrRangeDb_{0.};
Index nDof_{0}, nPar_{0};
std::vector<DoubleFunction> model_;
std::vector<std::string> parName_;
};
/******************************************************************************
* class for X vs. Y statistical data *
******************************************************************************/
class XYStatData: public FitInterface
{
public:
// constructor
XYStatData(void) = default;
// destructor
virtual ~XYStatData(void) = default;
// data access
double & x(const Index r, const Index i);
const double & x(const Index r, const Index i) const;
const DVec & x(const Index k) const;
double & y(const Index k, const Index j);
const double & y(const Index k, const Index j) const;
void setXXVar(const Index i1, const Index i2, const DMat &m);
void setYYVar(const Index j1, const Index j2, const DMat &m);
void setXYVar(const Index i, const Index j, const DMat &m);
void setXError(const Index i, const DVec &err);
void setYError(const Index j, const DVec &err);
template <typename... Ts>
void setUnidimData(const DMat &xData, const Ts & ...yDatas);
const DMat & getXXVar(const Index i1, const Index i2) const;
const DMat & getYYVar(const Index j1, const Index j2) const;
const DMat & getXYVar(const Index i, const Index j) const;
DVec getXError(const Index i) const;
DVec getYError(const Index j) const;
DMat getTable(const Index i, const Index j) const;
// get total fit variance & correlation matrices and their pseudo-inverse
const DMat & getFitVarMat(void);
const DMat & getFitVarMatPInv(void);
const DMat & getFitCorrMat(void);
const DMat & getFitCorrMatPInv(void);
// fit
FitResult fit(std::vector<Minimizer *> &minimizer, const DVec &init,
const std::vector<const DoubleModel *> &v);
FitResult fit(Minimizer &minimizer, const DVec &init,
const std::vector<const DoubleModel *> &v);
template <typename... Ts>
FitResult fit(std::vector<Minimizer *> &minimizer, const DVec &init,
const DoubleModel &model, const Ts... models);
template <typename... Ts>
FitResult fit(Minimizer &minimizer, const DVec &init,
const DoubleModel &model, const Ts... models);
// residuals
XYStatData getResiduals(const FitResult &fit);
XYStatData getNormalisedResiduals(const FitResult &fit);
XYStatData getPartialResiduals(const FitResult &fit, const DVec &ref,
const Index i);
protected:
// create data
virtual void createXData(const std::string name, const Index nData);
virtual void createYData(const std::string name);
void resizeVarMat(void);
private:
// schedule buffer computation
void scheduleXMapInit(void);
void scheduleChi2DataVecInit(void);
// buffer total fit variance matrix
void updateFitVarMat(void);
// buffer list of x vectors
void updateXMap(void) const;
// buffer chi^2 vectors
void updateChi2DataVec(void);
void updateChi2ModVec(const DVec p,
const std::vector<const DoubleModel *> &v,
const Index nPar, const Index nXDim);
private:
std::vector<std::map<Index, double>> yData_;
// no map here for fit performance
std::vector<DVec> xData_;
std::vector<DVec> xMap_;
Mat<DMat> xxVar_, yyVar_, xyVar_;
DMat fitVar_, fitVarInv_, fitCorr_, fitCorrInv_;
DVec chi2DataVec_, chi2ModVec_, chi2Vec_;
DVec xBuf_;
bool initXMap_{true};
bool initChi2DataVec_{true};
};
/******************************************************************************
* XYStatData template implementation *
******************************************************************************/
template <typename... Ts>
void XYStatData::setUnidimData(const DMat &xData, const Ts & ...yDatas)
{
static_assert(static_or<std::is_assignable<DMat, Ts>::value...>::value,
"y data arguments are not compatible with DMat");
std::vector<const DMat *> yData{&yDatas...};
FOR_VEC(xData, r)
{
x(r, 0) = xData(r);
for (unsigned int j = 0; j < yData.size(); ++j)
{
y(r, j) = (*(yData[j]))(r);
}
}
}
template <typename... Ts>
FitResult XYStatData::fit(std::vector<Minimizer *> &minimizer, const DVec &init,
const DoubleModel &model, const Ts... models)
{
static_assert(static_or<std::is_assignable<DoubleModel &, Ts>::value...>::value,
"model arguments are not compatible with DoubleModel");
std::vector<const DoubleModel *> modelVector{&model, &models...};
return fit(minimizer, init, modelVector);
}
template <typename... Ts>
FitResult XYStatData::fit(Minimizer &minimizer, const DVec &init,
const DoubleModel &model, const Ts... models)
{
static_assert(static_or<std::is_assignable<DoubleModel &, Ts>::value...>::value,
"model arguments are not compatible with DoubleModel");
std::vector<Minimizer *> mv{&minimizer};
return fit(mv, init, model, models...);
}
/******************************************************************************
* error check macros *
******************************************************************************/
#define checkVarMat(m, var)\
if (((m).rows() != (var).rows()) or ((m).cols() != (var).cols()))\
{\
LATAN_ERROR(Size, "provided variance matrix has a wrong size"\
" (expected " + strFrom((var).rows()) + "x"\
+ strFrom((var).cols()) + ", got " + strFrom((m).rows())\
+ "x" + strFrom((m).cols()) + ")");\
}
#define checkErrVec(err, var)\
if ((err).size() != (var).rows())\
{\
LATAN_ERROR(Size, "provided error vector has a wrong size"\
" (expected " + strFrom((var).rows()) + ", got "\
+ strFrom((err).size()) + ")");\
}
#define checkModelVec(v)\
if (static_cast<Index>((v).size()) != getNYDim())\
{\
LATAN_ERROR(Size, "provided model vector has a wrong size"\
" (expected " + strFrom(getNYDim()) + ", got "\
+ strFrom((v).size()) + ")");\
}\
for (unsigned int _i = 1; _i < (v).size(); ++_i)\
{\
if ((v)[_i]->getNArg() != getNXDim())\
{\
LATAN_ERROR(Size, "model " + strFrom(_i) + " has a wrong"\
+ " number of argument (expected " + strFrom(getNXDim())\
+ ", got " + strFrom((v)[_i]->getNArg()));\
}\
}\
{\
Index _nPar = (v)[0]->getNPar();\
for (unsigned int _i = 1; _i < (v).size(); ++_i)\
{\
if ((v)[_i]->getNPar() != _nPar)\
{\
LATAN_ERROR(Size, "model " + strFrom(_i) + " has a wrong"\
+ " number of parameter (expected " + strFrom(_nPar)\
+ ", got " + strFrom((v)[_i]->getNPar()));\
}\
}\
}
END_LATAN_NAMESPACE
#endif // Latan_XYStatData_hpp_

View File

@ -0,0 +1,25 @@
/*
* includes.hpp, part of LatAnalyze 3
*
* Copyright (C) 2013 - 2020 Antonin Portelli
*
* LatAnalyze 3 is free software: you can redistribute it and/or modify
* it under the terms of the GNU General Public License as published by
* the Free Software Foundation, either version 3 of the License, or
* (at your option) any later version.
*
* LatAnalyze 3 is distributed in the hope that it will be useful,
* but WITHOUT ANY WARRANTY; without even the implied warranty of
* MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
* GNU General Public License for more details.
*
* You should have received a copy of the GNU General Public License
* along with LatAnalyze 3. If not, see <http://www.gnu.org/licenses/>.
*/
#ifndef Latan_includes_hpp_
#define Latan_includes_hpp_
#include <config.h>
#endif // Latan_includes_hpp_