From 500210a2eb5249562373f56164b76e245cac70a7 Mon Sep 17 00:00:00 2001 From: Antonin Portelli Date: Fri, 18 Feb 2022 14:06:52 +0000 Subject: [PATCH] DWT working and tested --- examples/Makefile.am | 5 +++ examples/exDWT.cpp | 28 +++++++++++++ lib/Numerical/DWT.cpp | 92 ++++++++++++++++++++++++++++++++++++++++--- lib/Numerical/DWT.hpp | 7 +++- 4 files changed, 124 insertions(+), 8 deletions(-) create mode 100644 examples/exDWT.cpp diff --git a/examples/Makefile.am b/examples/Makefile.am index 123bafb..bf23b0d 100644 --- a/examples/Makefile.am +++ b/examples/Makefile.am @@ -9,6 +9,7 @@ endif noinst_PROGRAMS = \ exCompiledDoubleFunction\ exDerivative \ + exDWT \ exFit \ exFitSample \ exIntegrator \ @@ -30,6 +31,10 @@ exDerivative_SOURCES = exDerivative.cpp exDerivative_CXXFLAGS = $(COM_CXXFLAGS) exDerivative_LDFLAGS = -L../lib/.libs -lLatAnalyze +exDWT_SOURCES = exDWT.cpp +exDWT_CXXFLAGS = $(COM_CXXFLAGS) +exDWT_LDFLAGS = -L../lib/.libs -lLatAnalyze + exFit_SOURCES = exFit.cpp exFit_CXXFLAGS = $(COM_CXXFLAGS) exFit_LDFLAGS = -L../lib/.libs -lLatAnalyze diff --git a/examples/exDWT.cpp b/examples/exDWT.cpp new file mode 100644 index 0000000..082d8b5 --- /dev/null +++ b/examples/exDWT.cpp @@ -0,0 +1,28 @@ +#include + +using namespace std; +using namespace Latan; + +int main(void) +{ + DVec data, dataRec; + vector dataDWT; + DWT dwt(DWTFilters::db3); + + cout << "-- random data" << endl; + data.setRandom(16); + cout << data.transpose() << endl; + cout << "-- compute Daubechies 3 DWT" << endl; + dataDWT = dwt.forward(data, 4); + for (unsigned int l = 0; l < dataDWT.size(); ++l) + { + cout << "* level " << l << endl; + cout << "L= " << dataDWT[l].first.transpose() << endl; + cout << "H= " << dataDWT[l].second.transpose() << endl; + } + cout << "-- check inverse DWT" << endl; + dataRec = dwt.backward(dataDWT); + cout << "rel diff = " << 2.*(data - dataRec).norm()/(data + dataRec).norm() << endl; + + return EXIT_SUCCESS; +} diff --git a/lib/Numerical/DWT.cpp b/lib/Numerical/DWT.cpp index b5c49a7..bc0e03a 100644 --- a/lib/Numerical/DWT.cpp +++ b/lib/Numerical/DWT.cpp @@ -32,26 +32,106 @@ DWT::DWT(const DWTFilter &filter) {} // convolution primitive /////////////////////////////////////////////////////// -DVec DWT::filterConvolution(const DVec &data, const DWTFilter &filter, - const Index offset) +void DWT::filterConvolution(DVec &out, const DVec &data, + const std::vector &filter, const Index offset) { - DVec res(data.size()); + Index n = data.size(), nf = n*filter.size(); - return res; + out.resize(n); + out.fill(0.); + for (unsigned int i = 0; i < filter.size(); ++i) + { + FOR_VEC(out, j) + { + out(j) += filter[i]*data((j + i + nf - offset) % n); + } + } +} + +// downsampling/upsampling primitives ////////////////////////////////////////// +void DWT::downsample(DVec &out, const DVec &in) +{ + if (out.size() < in.size()/2) + { + LATAN_ERROR(Size, "output vector smaller than half the input vector size"); + } + for (Index i = 0; i < in.size(); i += 2) + { + out(i/2) = in(i); + } +} + +void DWT::upsample(DVec &out, const DVec &in) +{ + if (out.size() < 2*in.size()) + { + LATAN_ERROR(Size, "output vector smaller than twice the input vector size"); + } + out.segment(0, 2*in.size()).fill(0.); + for (Index i = 0; i < in.size(); i ++) + { + out(2*i) = in(i); + } } // DWT ///////////////////////////////////////////////////////////////////////// std::vector DWT::forward(const DVec &data, const unsigned int level) const { - std::vector dwt(level); + std::vector dwt(level); + DVec *finePt = const_cast(&data); + DVec tmp; + Index n = data.size(), o = filter_.fwdL.size()/2, minSize; + + minSize = 1; + for (unsigned int l = 0; l < level; ++l) minSize *= 2; + if (n < minSize) + { + LATAN_ERROR(Size, "data vector too small for a " + strFrom(level) + + "-level DWT (data size is " + strFrom(n) + ")"); + } + for (unsigned int l = 0; l < level; ++l) + { + n /= 2; + dwt[l].first.resize(n); + dwt[l].second.resize(n); + filterConvolution(tmp, *finePt, filter_.fwdL, o); + downsample(dwt[l].first, tmp); + filterConvolution(tmp, *finePt, filter_.fwdH, o); + downsample(dwt[l].second, tmp); + finePt = &dwt[l].first; + } return dwt; } DVec DWT::backward(const std::vector& dwt) const { - DVec res; + unsigned int level = dwt.size(); + Index n = dwt.back().second.size(), o = filter_.bwdL.size()/2 - 1; + DVec res, tmp, conv; + + res = dwt.back().first; + for (int l = level - 2; l >= 0; --l) + { + n *= 2; + if (dwt[l].second.size() != n) + { + LATAN_ERROR(Size, "DWT result size mismatch"); + } + } + n = dwt.back().second.size(); + for (int l = level - 1; l >= 0; --l) + { + n *= 2; + tmp.resize(n); + upsample(tmp, res); + filterConvolution(conv, tmp, filter_.bwdL, o); + res = conv; + upsample(tmp, dwt[l].second); + filterConvolution(conv, tmp, filter_.bwdH, o); + res += conv; + } return res; } diff --git a/lib/Numerical/DWT.hpp b/lib/Numerical/DWT.hpp index f1ee646..5745245 100644 --- a/lib/Numerical/DWT.hpp +++ b/lib/Numerical/DWT.hpp @@ -38,8 +38,11 @@ public: // destructor virtual ~DWT(void) = default; // convolution primitive - static DVec filterConvolution(const DVec &data, const DWTFilter &filter, - const Index offset); + static void filterConvolution(DVec &out, const DVec &data, + const std::vector &filter, const Index offset); + // downsampling/upsampling primitives + static void downsample(DVec &out, const DVec &in); + static void upsample(DVec &out, const DVec &in); // DWT std::vector forward(const DVec &data, const unsigned int level) const; DVec backward(const std::vector& dwt) const;