mirror of
https://github.com/paboyle/Grid.git
synced 2024-11-14 01:35:36 +00:00
Added iterator for Eigen tensors
This commit is contained in:
parent
b3b9e608e1
commit
d1e02f50ff
@ -30,6 +30,83 @@
|
|||||||
#include <Grid/tensors/Tensor_traits.h>
|
#include <Grid/tensors/Tensor_traits.h>
|
||||||
#include <Grid/Eigen/unsupported/CXX11/Tensor>
|
#include <Grid/Eigen/unsupported/CXX11/Tensor>
|
||||||
|
|
||||||
|
namespace Grid {
|
||||||
|
// Custom iterator for Eigen tensors
|
||||||
|
namespace EigenUtil {
|
||||||
|
template <typename ETensor, bool bConst> // Is the tensor constant
|
||||||
|
class TensorIterator_raw{
|
||||||
|
public:
|
||||||
|
using Index = typename ETensor::Index;
|
||||||
|
using Scalar = typename ETensor::Scalar;
|
||||||
|
using FullIndex = std::array<Index, ETensor::NumIndices>;
|
||||||
|
const ETensor * pET;
|
||||||
|
const Index end; // same as pET->size()
|
||||||
|
Index position; // position (memory order)
|
||||||
|
Index Seq; // sequence (what our position would be if we were column major)
|
||||||
|
FullIndex indexPos;
|
||||||
|
FullIndex indexSize;
|
||||||
|
|
||||||
|
inline TensorIterator_raw( ETensor & eT, Index pos = 0 ) : pET{&eT}, position{pos}, Seq{pos}, end{pET->size()} {
|
||||||
|
for( int i = 0 ; i < ETensor::NumIndices ; i++ ) {
|
||||||
|
indexPos[i] = 0;
|
||||||
|
indexSize[i] = pET->dimension(i);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
inline TensorIterator_raw<ETensor, bConst> & operator++() {
|
||||||
|
auto sz = pET->size();
|
||||||
|
if( position < sz ) {
|
||||||
|
position++;
|
||||||
|
if( ETensor::Options & Eigen::RowMajor ) {
|
||||||
|
for( int i = ETensor::NumIndices - 1; i != -1 && ++indexPos[i] == indexSize[i]; i-- )
|
||||||
|
indexPos[i] = 0;
|
||||||
|
Seq++;
|
||||||
|
} else {
|
||||||
|
for( int i = 0; i < ETensor::NumIndices && ++indexPos[i] == indexSize[i]; i++ )
|
||||||
|
indexPos[i] = 0;
|
||||||
|
Seq = 0;
|
||||||
|
for( int i = 0; i < ETensor::NumIndices; i++ ) {
|
||||||
|
Seq *= indexSize[i];
|
||||||
|
Seq += indexPos[i];
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return * this;
|
||||||
|
}
|
||||||
|
inline typename std::conditional<bConst,const Scalar &,Scalar &>::type operator*() const {
|
||||||
|
assert( position >= 0 && position < pET->size() && "Attempt to access Eigen tensor iterator out of range" );
|
||||||
|
return ( ( typename std::conditional<bConst,const Scalar *,Scalar*>::type ) pET->data() )[position];
|
||||||
|
}
|
||||||
|
inline bool operator!=(const TensorIterator_raw<ETensor, bConst> &r)
|
||||||
|
{ return pET == nullptr || pET != r.pET || position != r.position; }
|
||||||
|
// These functions aren't rerquired for iterators, but they make using them easier
|
||||||
|
inline bool AtEnd() { return position == end; }
|
||||||
|
inline void DumpIndex(void) {
|
||||||
|
for( auto dim : indexPos )
|
||||||
|
std::cout << "[" << dim << "]";
|
||||||
|
}
|
||||||
|
};
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// The only way I could get these iterators to work is to put the begin() and end() functions in the Eigen namespace
|
||||||
|
// So if Eigen ever defines these, we'll have a conflict and have to change this
|
||||||
|
namespace Eigen {
|
||||||
|
template <typename ETensor> using TensorIterator = Grid::EigenUtil::TensorIterator_raw< ETensor, false>;
|
||||||
|
template <typename ETensor> using TensorIteratorConst = Grid::EigenUtil::TensorIterator_raw<const ETensor, true>;
|
||||||
|
template <typename ETensor>
|
||||||
|
inline typename std::enable_if<Grid::EigenIO::is_tensor<ETensor>::value, TensorIterator<ETensor>>::type
|
||||||
|
begin( ETensor & ET ) { return TensorIterator<ETensor>(ET); }
|
||||||
|
template <typename ETensor>
|
||||||
|
inline typename std::enable_if<Grid::EigenIO::is_tensor<ETensor>::value, TensorIterator<ETensor>>::type
|
||||||
|
end( ETensor & ET ) { return TensorIterator<ETensor>(ET, ET.size()); }
|
||||||
|
template <typename ETensor>
|
||||||
|
inline typename std::enable_if<Grid::EigenIO::is_tensor<ETensor>::value, TensorIteratorConst<ETensor>>::type
|
||||||
|
begin( const ETensor & ET ) { return TensorIteratorConst<ETensor>(ET); }
|
||||||
|
template <typename ETensor>
|
||||||
|
inline typename std::enable_if<Grid::EigenIO::is_tensor<ETensor>::value, TensorIteratorConst<ETensor>>::type
|
||||||
|
end( const ETensor & ET ) { return TensorIteratorConst<ETensor>(ET, ET.size()); }
|
||||||
|
}
|
||||||
|
|
||||||
namespace Grid {
|
namespace Grid {
|
||||||
// for_all helper function to call the lambda for scalar
|
// for_all helper function to call the lambda for scalar
|
||||||
template <typename ETensor, typename Lambda>
|
template <typename ETensor, typename Lambda>
|
||||||
@ -174,6 +251,7 @@ namespace Grid {
|
|||||||
for( int i = 0 ; i < rank; i++ ) std::cout << "[" << dims[i] << "]";
|
for( int i = 0 ; i < rank; i++ ) std::cout << "[" << dims[i] << "]";
|
||||||
for( int i = 0 ; i < Traits::Rank; i++ ) std::cout << "(" << Traits::Dimension(i) << ")";
|
for( int i = 0 ; i < Traits::Rank; i++ ) std::cout << "(" << Traits::Dimension(i) << ")";
|
||||||
std::cout << " in memory order:" << std::endl;
|
std::cout << " in memory order:" << std::endl;
|
||||||
|
#ifdef OLD_DEFINITION
|
||||||
for_all( t, [&](scalar_type &c, Index n, const std::array<Index, rank> &TensorIndex,
|
for_all( t, [&](scalar_type &c, Index n, const std::array<Index, rank> &TensorIndex,
|
||||||
const std::array<int, Traits::Rank> &ScalarIndex ){
|
const std::array<int, Traits::Rank> &ScalarIndex ){
|
||||||
std::cout << " ";
|
std::cout << " ";
|
||||||
@ -183,6 +261,13 @@ namespace Grid {
|
|||||||
std::cout << "(" << dim << ")";
|
std::cout << "(" << dim << ")";
|
||||||
std::cout << " = " << c << std::endl;
|
std::cout << " = " << c << std::endl;
|
||||||
} );
|
} );
|
||||||
|
#else
|
||||||
|
for( auto it = begin(t); !it.AtEnd(); ++it ) {
|
||||||
|
std::cout << " ";
|
||||||
|
it.DumpIndex();
|
||||||
|
std::cout << " = " << (const typename T::Scalar)(*it) << std::endl;
|
||||||
|
}
|
||||||
|
#endif
|
||||||
std::cout << "========================================" << std::endl;
|
std::cout << "========================================" << std::endl;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
Loading…
Reference in New Issue
Block a user