1
0
mirror of https://github.com/aportelli/LatAnalyze.git synced 2024-11-10 00:45:36 +00:00

latan-sample-combine now accepts mix of DSample & DMatSample

This commit is contained in:
Antonin Portelli 2017-07-09 14:15:37 +01:00
parent 92822348f6
commit e43a7c6d0b

View File

@ -25,14 +25,21 @@ using namespace std;
using namespace Latan; using namespace Latan;
template <typename T> template <typename T>
static void loadAndCheck(vector<T> &sample, const vector<string> &fileName) static void loadAndCheck(vector<T> &sample __dumb,
const vector<string> &fileName __dumb)
{ {
const unsigned int n = sample.size(); abort();
}
template <>
void loadAndCheck(vector<DSample> &sample, const vector<string> &fileName)
{
const unsigned int n = static_cast<unsigned int>(sample.size());
Index nSample = 0; Index nSample = 0;
for (unsigned int i = 0; i < n; ++i) for (unsigned int i = 0; i < n; ++i)
{ {
sample[i] = Io::load<T>(fileName[i]); sample[i] = Io::load<DSample>(fileName[i]);
if (i == 0) if (i == 0)
{ {
nSample = sample[i].size(); nSample = sample[i].size();
@ -46,6 +53,65 @@ static void loadAndCheck(vector<T> &sample, const vector<string> &fileName)
} }
} }
template <>
void loadAndCheck(vector<DMatSample> &sample, const vector<string> &fileName)
{
const unsigned int n = static_cast<unsigned int>(sample.size());
Index nSample = 0;
set<unsigned int> failed;
bool gotSize = false;
Index nRow = 0, nCol = 0;
for (unsigned int i = 0; i < n; ++i)
{
try
{
sample[i] = Io::load<DMatSample>(fileName[i]);
if (!gotSize)
{
nRow = sample[i][central].rows();
nCol = sample[i][central].cols();
gotSize = true;
}
}
catch (Exceptions::Definition)
{
failed.insert(i);
}
if (i == 0)
{
nSample = sample[i].size();
}
}
for (unsigned int i: failed)
{
DSample buf;
buf = Io::load<DSample>(fileName[i]);
sample[i].resize(nSample);
FOR_STAT_ARRAY(sample[i], s)
{
sample[i][s] = DMat::Constant(nRow, nCol, buf[s]);
}
}
for (unsigned int i = 0; i < n; ++i)
{
if (sample[i].size() != nSample)
{
cerr << "error: number of sample mismatch (between '";
cerr << fileName[0] << "' and '" << fileName[i] << "')" << endl;
abort();
}
if ((sample[i][central].rows() != nRow) and
(sample[i][central].cols() != nCol))
{
cerr << "error: matrix size mismatch (between '";
cerr << fileName[0] << "' and '" << fileName[i] << "')" << endl;
abort();
}
}
}
template <typename T> template <typename T>
static void combine(const string &outFileName __dumb, static void combine(const string &outFileName __dumb,
const vector<T> &sample __dumb, const string &code __dumb) const vector<T> &sample __dumb, const string &code __dumb)
@ -57,7 +123,7 @@ template <>
void combine(const string &outFileName, const vector<DSample> &sample, void combine(const string &outFileName, const vector<DSample> &sample,
const string &code) const string &code)
{ {
const unsigned int n = sample.size(); const unsigned int n = static_cast<unsigned int>(sample.size());
DoubleFunction f = compile(code, n); DoubleFunction f = compile(code, n);
DSample result(sample[0]); DSample result(sample[0]);
DVec buf(n); DVec buf(n);
@ -87,7 +153,7 @@ template <>
void combine(const string &outFileName, const vector<DMatSample> &sample, void combine(const string &outFileName, const vector<DMatSample> &sample,
const string &code) const string &code)
{ {
const unsigned int n = sample.size(); const unsigned int n = static_cast<unsigned int>(sample.size());
DoubleFunction f = compile(code, n); DoubleFunction f = compile(code, n);
DVec buf(n); DVec buf(n);
DMatSample result(sample[0]); DMatSample result(sample[0]);