1
0
mirror of https://github.com/aportelli/LatAnalyze.git synced 2024-09-19 21:25:36 +01:00

multivariate Gaussian RNG

This commit is contained in:
Antonin Portelli 2019-03-25 23:20:09 +00:00
parent 0bf6d8c8ae
commit e37f2ab124
4 changed files with 130 additions and 0 deletions

View File

@ -1,11 +1,15 @@
#include <LatAnalyze/Io/Io.hpp>
#include <LatAnalyze/Functional/CompiledFunction.hpp>
#include <LatAnalyze/Core/Plot.hpp>
#include <LatAnalyze/Statistics/Random.hpp>
#include <LatAnalyze/Statistics/MatSample.hpp>
using namespace std;
using namespace Latan;
constexpr Index size = 8;
constexpr Index nDraw = 20000;
constexpr Index nSample = 2000;
const string stateFileName = "exRand.seed";
int main(void)
@ -36,5 +40,24 @@ int main(void)
p << PlotFunction(compile("return exp(-x_0^2/2)/sqrt(2*pi);", 1), -5., 5.);
p.display();
DMat var(size, size);
DVec mean(size);
DMatSample sample(nSample, size, 1);
cout << "-- generating " << nSample << " Gaussian random vectors..." << endl;
var = DMat::Random(size, size);
var *= var.adjoint();
mean = DVec::Random(size);
RandomNormal mgauss(mean, var, rd());
sample[central] = mgauss();
FOR_STAT_ARRAY(sample, s)
{
sample[s] = mgauss();
}
cout << "* original variance matrix:\n" << var << endl;
cout << "* measured variance matrix:\n" << sample.varianceMatrix() << endl;
cout << "* original mean:\n" << mean << endl;
cout << "* measured mean:\n" << sample.mean() << endl;
return EXIT_SUCCESS;
}

View File

@ -56,6 +56,7 @@ libLatAnalyze_la_SOURCES = \
Numerical/Solver.cpp \
Statistics/FitInterface.cpp \
Statistics/Histogram.cpp \
Statistics/Random.cpp \
Statistics/StatArray.cpp \
Statistics/XYSampleData.cpp \
Statistics/XYStatData.cpp \
@ -100,6 +101,7 @@ HPPFILES = \
Statistics/FitInterface.hpp \
Statistics/Histogram.hpp \
Statistics/MatSample.hpp \
Statistics/Random.hpp \
Statistics/StatArray.hpp \
Statistics/XYSampleData.hpp \
Statistics/XYStatData.hpp

56
lib/Statistics/Random.cpp Normal file
View File

@ -0,0 +1,56 @@
/*
* Random.cpp, part of LatAnalyze 3
*
* Copyright (C) 2013 - 2016 Antonin Portelli
*
* LatAnalyze 3 is free software: you can redistribute it and/or modify
* it under the terms of the GNU General Public License as published by
* the Free Software Foundation, either version 3 of the License, or
* (at your option) any later version.
*
* LatAnalyze 3 is distributed in the hope that it will be useful,
* but WITHOUT ANY WARRANTY; without even the implied warranty of
* MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
* GNU General Public License for more details.
*
* You should have received a copy of the GNU General Public License
* along with LatAnalyze 3. If not, see <http://www.gnu.org/licenses/>.
*/
#include <LatAnalyze/Core/Plot.hpp>
#include <LatAnalyze/includes.hpp>
#include <LatAnalyze/Statistics/Random.hpp>
using namespace std;
using namespace Latan;
RandomNormal::RandomNormal(const DVec &mean, const DMat &var, const SeedType seed)
: mean_(mean), buf_(mean.size()), var_(var), gen_(seed)
{
if (var_.rows() != var_.cols())
{
LATAN_ERROR(Size, "variance matrix not square");
}
if (mean_.size() != var_.rows())
{
LATAN_ERROR(Size, "variance matrix and mean vector size mismatch");
}
Eigen::SelfAdjointEigenSolver<Eigen::MatrixXd> esolver(var);
Eigen::VectorXd ev = esolver.eigenvalues();
ev = ev.unaryExpr([](const double x){return (x > 0.) ? x : 0.;});
transform_ = esolver.eigenvectors()*ev.cwiseSqrt().asDiagonal();
}
DVec RandomNormal::operator()(void)
{
std::normal_distribution<> dist;
FOR_VEC(buf_, i)
{
buf_(i) = dist(gen_);
}
return mean_ + transform_*buf_;
}

49
lib/Statistics/Random.hpp Normal file
View File

@ -0,0 +1,49 @@
/*
* Random.hpp, part of LatAnalyze 3
*
* Copyright (C) 2013 - 2019 Antonin Portelli
*
* LatAnalyze 3 is free software: you can redistribute it and/or modify
* it under the terms of the GNU General Public License as published by
* the Free Software Foundation, either version 3 of the License, or
* (at your option) any later version.
*
* LatAnalyze 3 is distributed in the hope that it will be useful,
* but WITHOUT ANY WARRANTY; without even the implied warranty of
* MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
* GNU General Public License for more details.
*
* You should have received a copy of the GNU General Public License
* along with LatAnalyze 3. If not, see <http://www.gnu.org/licenses/>.
*/
#ifndef Latan_Random_hpp_
#define Latan_Random_hpp_
#include <LatAnalyze/Global.hpp>
#include <LatAnalyze/Core/Mat.hpp>
BEGIN_LATAN_NAMESPACE
/******************************************************************************
* Multivariate Gaussian RNG *
******************************************************************************/
class RandomNormal
{
public:
// constructors
RandomNormal(const DVec &mean, const DMat &var, const SeedType seed);
// destructor
virtual ~RandomNormal(void) = default;
// draw a random vector
DVec operator()(void);
private:
DVec mean_, buf_;
DMat var_;
Eigen::MatrixXd transform_;
std::mt19937 gen_;
};
END_LATAN_NAMESPACE
#endif // Latan_Random_hpp_