1
0
mirror of https://github.com/aportelli/LatAnalyze.git synced 2024-11-10 00:45:36 +00:00

DWT working and tested

This commit is contained in:
Antonin Portelli 2022-02-18 14:06:52 +00:00 committed by Andrew Zhen Ning Yong
parent b9f61d8c17
commit 500210a2eb
4 changed files with 124 additions and 8 deletions

View File

@ -9,6 +9,7 @@ endif
noinst_PROGRAMS = \
exCompiledDoubleFunction\
exDerivative \
exDWT \
exFit \
exFitSample \
exIntegrator \
@ -30,6 +31,10 @@ exDerivative_SOURCES = exDerivative.cpp
exDerivative_CXXFLAGS = $(COM_CXXFLAGS)
exDerivative_LDFLAGS = -L../lib/.libs -lLatAnalyze
exDWT_SOURCES = exDWT.cpp
exDWT_CXXFLAGS = $(COM_CXXFLAGS)
exDWT_LDFLAGS = -L../lib/.libs -lLatAnalyze
exFit_SOURCES = exFit.cpp
exFit_CXXFLAGS = $(COM_CXXFLAGS)
exFit_LDFLAGS = -L../lib/.libs -lLatAnalyze

28
examples/exDWT.cpp Normal file
View File

@ -0,0 +1,28 @@
#include <LatAnalyze/Numerical/DWT.hpp>
using namespace std;
using namespace Latan;
int main(void)
{
DVec data, dataRec;
vector<DWT::DWTLevel> dataDWT;
DWT dwt(DWTFilters::db3);
cout << "-- random data" << endl;
data.setRandom(16);
cout << data.transpose() << endl;
cout << "-- compute Daubechies 3 DWT" << endl;
dataDWT = dwt.forward(data, 4);
for (unsigned int l = 0; l < dataDWT.size(); ++l)
{
cout << "* level " << l << endl;
cout << "L= " << dataDWT[l].first.transpose() << endl;
cout << "H= " << dataDWT[l].second.transpose() << endl;
}
cout << "-- check inverse DWT" << endl;
dataRec = dwt.backward(dataDWT);
cout << "rel diff = " << 2.*(data - dataRec).norm()/(data + dataRec).norm() << endl;
return EXIT_SUCCESS;
}

View File

@ -32,26 +32,106 @@ DWT::DWT(const DWTFilter &filter)
{}
// convolution primitive ///////////////////////////////////////////////////////
DVec DWT::filterConvolution(const DVec &data, const DWTFilter &filter,
const Index offset)
void DWT::filterConvolution(DVec &out, const DVec &data,
const std::vector<double> &filter, const Index offset)
{
DVec res(data.size());
Index n = data.size(), nf = n*filter.size();
return res;
out.resize(n);
out.fill(0.);
for (unsigned int i = 0; i < filter.size(); ++i)
{
FOR_VEC(out, j)
{
out(j) += filter[i]*data((j + i + nf - offset) % n);
}
}
}
// downsampling/upsampling primitives //////////////////////////////////////////
void DWT::downsample(DVec &out, const DVec &in)
{
if (out.size() < in.size()/2)
{
LATAN_ERROR(Size, "output vector smaller than half the input vector size");
}
for (Index i = 0; i < in.size(); i += 2)
{
out(i/2) = in(i);
}
}
void DWT::upsample(DVec &out, const DVec &in)
{
if (out.size() < 2*in.size())
{
LATAN_ERROR(Size, "output vector smaller than twice the input vector size");
}
out.segment(0, 2*in.size()).fill(0.);
for (Index i = 0; i < in.size(); i ++)
{
out(2*i) = in(i);
}
}
// DWT /////////////////////////////////////////////////////////////////////////
std::vector<DWT::DWTLevel>
DWT::forward(const DVec &data, const unsigned int level) const
{
std::vector<DWT::DWTLevel> dwt(level);
std::vector<DWTLevel> dwt(level);
DVec *finePt = const_cast<DVec *>(&data);
DVec tmp;
Index n = data.size(), o = filter_.fwdL.size()/2, minSize;
minSize = 1;
for (unsigned int l = 0; l < level; ++l) minSize *= 2;
if (n < minSize)
{
LATAN_ERROR(Size, "data vector too small for a " + strFrom(level)
+ "-level DWT (data size is " + strFrom(n) + ")");
}
for (unsigned int l = 0; l < level; ++l)
{
n /= 2;
dwt[l].first.resize(n);
dwt[l].second.resize(n);
filterConvolution(tmp, *finePt, filter_.fwdL, o);
downsample(dwt[l].first, tmp);
filterConvolution(tmp, *finePt, filter_.fwdH, o);
downsample(dwt[l].second, tmp);
finePt = &dwt[l].first;
}
return dwt;
}
DVec DWT::backward(const std::vector<DWTLevel>& dwt) const
{
DVec res;
unsigned int level = dwt.size();
Index n = dwt.back().second.size(), o = filter_.bwdL.size()/2 - 1;
DVec res, tmp, conv;
res = dwt.back().first;
for (int l = level - 2; l >= 0; --l)
{
n *= 2;
if (dwt[l].second.size() != n)
{
LATAN_ERROR(Size, "DWT result size mismatch");
}
}
n = dwt.back().second.size();
for (int l = level - 1; l >= 0; --l)
{
n *= 2;
tmp.resize(n);
upsample(tmp, res);
filterConvolution(conv, tmp, filter_.bwdL, o);
res = conv;
upsample(tmp, dwt[l].second);
filterConvolution(conv, tmp, filter_.bwdH, o);
res += conv;
}
return res;
}

View File

@ -38,8 +38,11 @@ public:
// destructor
virtual ~DWT(void) = default;
// convolution primitive
static DVec filterConvolution(const DVec &data, const DWTFilter &filter,
const Index offset);
static void filterConvolution(DVec &out, const DVec &data,
const std::vector<double> &filter, const Index offset);
// downsampling/upsampling primitives
static void downsample(DVec &out, const DVec &in);
static void upsample(DVec &out, const DVec &in);
// DWT
std::vector<DWTLevel> forward(const DVec &data, const unsigned int level) const;
DVec backward(const std::vector<DWTLevel>& dwt) const;