diff --git a/utils/sample-combine.cpp b/utils/sample-combine.cpp index 453434c..4323034 100644 --- a/utils/sample-combine.cpp +++ b/utils/sample-combine.cpp @@ -25,14 +25,21 @@ using namespace std; using namespace Latan; template -static void loadAndCheck(vector &sample, const vector &fileName) +static void loadAndCheck(vector &sample __dumb, + const vector &fileName __dumb) { - const unsigned int n = sample.size(); + abort(); +} + +template <> +void loadAndCheck(vector &sample, const vector &fileName) +{ + const unsigned int n = static_cast(sample.size()); Index nSample = 0; for (unsigned int i = 0; i < n; ++i) { - sample[i] = Io::load(fileName[i]); + sample[i] = Io::load(fileName[i]); if (i == 0) { nSample = sample[i].size(); @@ -46,6 +53,65 @@ static void loadAndCheck(vector &sample, const vector &fileName) } } +template <> +void loadAndCheck(vector &sample, const vector &fileName) +{ + const unsigned int n = static_cast(sample.size()); + Index nSample = 0; + set failed; + bool gotSize = false; + Index nRow = 0, nCol = 0; + + for (unsigned int i = 0; i < n; ++i) + { + try + { + sample[i] = Io::load(fileName[i]); + if (!gotSize) + { + nRow = sample[i][central].rows(); + nCol = sample[i][central].cols(); + gotSize = true; + } + } + catch (Exceptions::Definition) + { + failed.insert(i); + } + if (i == 0) + { + nSample = sample[i].size(); + } + } + for (unsigned int i: failed) + { + DSample buf; + + buf = Io::load(fileName[i]); + sample[i].resize(nSample); + FOR_STAT_ARRAY(sample[i], s) + { + sample[i][s] = DMat::Constant(nRow, nCol, buf[s]); + } + } + for (unsigned int i = 0; i < n; ++i) + { + if (sample[i].size() != nSample) + { + cerr << "error: number of sample mismatch (between '"; + cerr << fileName[0] << "' and '" << fileName[i] << "')" << endl; + abort(); + } + if ((sample[i][central].rows() != nRow) and + (sample[i][central].cols() != nCol)) + { + cerr << "error: matrix size mismatch (between '"; + cerr << fileName[0] << "' and '" << fileName[i] << "')" << endl; + abort(); + } + } +} + template static void combine(const string &outFileName __dumb, const vector &sample __dumb, const string &code __dumb) @@ -57,7 +123,7 @@ template <> void combine(const string &outFileName, const vector &sample, const string &code) { - const unsigned int n = sample.size(); + const unsigned int n = static_cast(sample.size()); DoubleFunction f = compile(code, n); DSample result(sample[0]); DVec buf(n); @@ -87,7 +153,7 @@ template <> void combine(const string &outFileName, const vector &sample, const string &code) { - const unsigned int n = sample.size(); + const unsigned int n = static_cast(sample.size()); DoubleFunction f = compile(code, n); DVec buf(n); DMatSample result(sample[0]);