From 3720103f41dd2e2f6a92d36007602c1bfe35ad84 Mon Sep 17 00:00:00 2001 From: Michael Marshall Date: Sat, 9 Feb 2019 17:12:36 +0000 Subject: [PATCH] Adding Eigen::Tensor still WIP --- Grid/serialisation/BaseIO.h | 57 +++++++++++++++++++++++----- Grid/serialisation/VectorUtils.h | 2 +- tests/hadrons/Test_hadrons_distil.cc | 8 +++- 3 files changed, 55 insertions(+), 12 deletions(-) diff --git a/Grid/serialisation/BaseIO.h b/Grid/serialisation/BaseIO.h index e96a2f9a..26b6ab5a 100644 --- a/Grid/serialisation/BaseIO.h +++ b/Grid/serialisation/BaseIO.h @@ -55,14 +55,21 @@ namespace Grid { template typename std::enable_if::value, void>::type write(const std::string& s, const U &output); + template + typename std::enable_if::value || Grid::is_complex::value, void>::type + write(const std::string &s, const Eigen::Tensor &output); + template + void write(const std::string &s, const Eigen::Tensor, NumIndices_, Options_, IndexType_> &output); + template + void write(const std::string &s, const Eigen::Tensor, NumIndices_, Options_, IndexType_> &output); + template + void write(const std::string &s, const Eigen::Tensor, NumIndices_, Options_, IndexType_> &output); template void write(const std::string &s, const iScalar &output); template void write(const std::string &s, const iVector &output); template void write(const std::string &s, const iMatrix &output); - template - void write(const std::string &s, const Eigen::Tensor &output); void scientificFormat(const bool set); bool isScientific(void); @@ -145,6 +152,44 @@ namespace Grid { upcast->writeDefault(s, output); } + // Eigen::Tensors of arithmetic/complex base type + template + template + typename std::enable_if::value || Grid::is_complex::value, void>::type + Writer::write(const std::string &s, const Eigen::Tensor &output) + { + //upcast->writeDefault(s, tensorToVec(output)); + std::cout << "I really should add code to write Eigen::Tensor (arithmetic/complex) ..." << std::endl; + } + + // Eigen::Tensors of iScalar + template + template + void Writer::write(const std::string &s, const Eigen::Tensor, NumIndices_, Options_, IndexType_> &output) + { + //upcast->writeDefault(s, tensorToVec(output)); + std::cout << "I really should add code to write Eigen::Tensor (iScalar) ..." << std::endl; + } + + // Eigen::Tensors of iVector + template + template + void Writer::write(const std::string &s, const Eigen::Tensor, NumIndices_, Options_, IndexType_> &output) + { + //upcast->writeDefault(s, tensorToVec(output)); + std::cout << "I really should add code to write Eigen::Tensor (iVector) ..." << std::endl; + } + + // Eigen::Tensors of iMatrix + template + template + void Writer::write(const std::string &s, const Eigen::Tensor, NumIndices_, Options_, IndexType_> &output) + { + //upcast->writeDefault(s, tensorToVec(output)); + std::cout << "I really should add code to write Eigen::Tensor (iMatrix) ..." << std::endl; + } + + template template void Writer::write(const std::string &s, const iScalar &output) @@ -166,14 +211,6 @@ namespace Grid { upcast->writeDefault(s, tensorToVec(output)); } - template - template - void Writer::write(const std::string &s, const Eigen::Tensor &output) - { - //upcast->writeDefault(s, tensorToVec(output)); - std::cout << "I really should add code to write Eigen::Tensor ..." << std::endl; - } - template void Writer::scientificFormat(const bool set) { diff --git a/Grid/serialisation/VectorUtils.h b/Grid/serialisation/VectorUtils.h index b6b95c10..3372c8ad 100644 --- a/Grid/serialisation/VectorUtils.h +++ b/Grid/serialisation/VectorUtils.h @@ -436,4 +436,4 @@ std::string vecToStr(const std::vector &v) return sstr.str(); } -#endif \ No newline at end of file +#endif diff --git a/tests/hadrons/Test_hadrons_distil.cc b/tests/hadrons/Test_hadrons_distil.cc index 1529eb15..6be63841 100644 --- a/tests/hadrons/Test_hadrons_distil.cc +++ b/tests/hadrons/Test_hadrons_distil.cc @@ -424,12 +424,15 @@ bool DebugEigenTest() } typedef iMatrix OddBall; +typedef Eigen::Tensor TensorInt; +typedef Eigen::Tensor, 3, Eigen::RowMajor> TensorComplex; +typedef Eigen::Tensor TensorOddBall; // From Test_serialisation.cc class myclass: Serializable { public: GRID_SERIALIZABLE_CLASS_MEMBERS(myclass - , OddBall, critter + //, OddBall, critter , SpinColourVector, scv , SpinColourMatrix, scm ); @@ -465,6 +468,9 @@ bool DebugIOTest(void) { ioTest("iotest_vector.h5", scv, "SpinColourVector"); myclass o; ioTest("iotest_object.h5", o, "myclass_object_instance_name"); + TensorInt t(3,6,2); + ioTest("iotest_tensor.h5", t, "eigen_tensor_instance_name"); + return true; } #endif