From 47f5b1e2b51af1793ecc825489d13930d50c49bb Mon Sep 17 00:00:00 2001 From: Michael Marshall <43034299+mmphys@users.noreply.github.com> Date: Mon, 25 Mar 2019 18:19:55 +0000 Subject: [PATCH] Iterator added. Will wait for review comments before finalising. --- Grid/util/EigenUtil.h | 89 +++++++++++++++++++++++++++++++++++++++---- 1 file changed, 81 insertions(+), 8 deletions(-) diff --git a/Grid/util/EigenUtil.h b/Grid/util/EigenUtil.h index ae8a1e50..63a0a860 100644 --- a/Grid/util/EigenUtil.h +++ b/Grid/util/EigenUtil.h @@ -30,6 +30,83 @@ #include #include +namespace Grid { + // Custom iterator for Eigen tensors + namespace EigenUtil { + template // Is the tensor constant + class TensorIterator_raw{ + public: + using Index = typename ETensor::Index; + using Scalar = typename ETensor::Scalar; + using FullIndex = std::array; + const ETensor * pET; + const Index end; // same as pET->size() + Index position; // position (memory order) + Index Seq; // sequence (what our position would be if we were column major) + FullIndex indexPos; + FullIndex indexSize; + + inline TensorIterator_raw( ETensor & eT, Index pos = 0 ) : pET{&eT}, position{pos}, Seq{pos}, end{pET->size()} { + for( int i = 0 ; i < ETensor::NumIndices ; i++ ) { + indexPos[i] = 0; + indexSize[i] = pET->dimension(i); + } + } + inline TensorIterator_raw & operator++() { + auto sz = pET->size(); + if( position < sz ) { + position++; + if( ETensor::Options & Eigen::RowMajor ) { + for( int i = ETensor::NumIndices - 1; i != -1 && ++indexPos[i] == indexSize[i]; i-- ) + indexPos[i] = 0; + Seq++; + } else { + for( int i = 0; i < ETensor::NumIndices && ++indexPos[i] == indexSize[i]; i++ ) + indexPos[i] = 0; + Seq = 0; + for( int i = 0; i < ETensor::NumIndices; i++ ) { + Seq *= indexSize[i]; + Seq += indexPos[i]; + } + } + } + return * this; + } + inline typename std::conditional::type operator*() const { + assert( position >= 0 && position < pET->size() && "Attempt to access Eigen tensor iterator out of range" ); + return ( ( typename std::conditional::type ) pET->data() )[position]; + } + inline bool operator!=(const TensorIterator_raw &r) + { return pET == nullptr || pET != r.pET || position != r.position; } + // These functions aren't required for iterators, but they make using them easier + inline bool AtEnd() { return position == end; } + inline void DumpIndex(void) { + for( auto dim : indexPos ) + std::cout << "[" << dim << "]"; + } + }; + } +} + +// The only way I could get these iterators to work is to put the begin() and end() functions in the Eigen namespace +// So if Eigen ever defines these, we'll have a conflict and have to change this +namespace Eigen { + template using TensorIterator = Grid::EigenUtil::TensorIterator_raw< ETensor, false>; + template using TensorIteratorConst = Grid::EigenUtil::TensorIterator_raw; + template + inline typename std::enable_if::value, TensorIterator>::type + begin( ETensor & ET ) { return TensorIterator(ET); } + template + inline typename std::enable_if::value, TensorIterator>::type + end( ETensor & ET ) { return TensorIterator(ET, ET.size()); } + template + inline typename std::enable_if::value, TensorIteratorConst>::type + begin( const ETensor & ET ) { return TensorIteratorConst(ET); } + template + inline typename std::enable_if::value, TensorIteratorConst>::type + end( const ETensor & ET ) { return TensorIteratorConst(ET, ET.size()); } +} + namespace Grid { // for_all helper function to call the lambda for scalar template @@ -174,15 +251,11 @@ namespace Grid { for( int i = 0 ; i < rank; i++ ) std::cout << "[" << dims[i] << "]"; for( int i = 0 ; i < Traits::Rank; i++ ) std::cout << "(" << Traits::Dimension(i) << ")"; std::cout << " in memory order:" << std::endl; - for_all( t, [&](scalar_type &c, Index n, const std::array &TensorIndex, - const std::array &ScalarIndex ){ + for( auto it = begin(t); !it.AtEnd(); ++it ) { std::cout << " "; - for( auto dim : TensorIndex ) - std::cout << "[" << dim << "]"; - for( auto dim : ScalarIndex ) - std::cout << "(" << dim << ")"; - std::cout << " = " << c << std::endl; - } ); + it.DumpIndex(); + std::cout << " = " << (const typename T::Scalar)(*it) << std::endl; + } std::cout << "========================================" << std::endl; }