mirror of
				https://github.com/paboyle/Grid.git
				synced 2025-10-30 11:34:32 +00:00 
			
		
		
		
	Eigen tensor serialisation fixes after Antonin's review
This commit is contained in:
		| @@ -61,15 +61,25 @@ namespace Grid { | ||||
|     template<typename T, typename V = void> struct is_tensor_of_container : public std::false_type {}; | ||||
|     template<typename T> struct is_tensor_of_container<T, typename std::enable_if<is_tensor<T>::value && isGridTensor<typename T::Scalar>::value>::type> : public std::true_type {}; | ||||
|  | ||||
|     // Traits are the default for scalars, or come from GridTypeMapper for GridTensors | ||||
|     // These traits describe the scalars inside Eigen tensors | ||||
|     // I wish I could define these in reference to the scalar type (so there would be fewer traits defined) | ||||
|     // but I'm unable to find a syntax to make this work | ||||
|     template<typename T, typename V = void> struct Traits {}; | ||||
|     template<typename T> struct Traits<T, typename std::enable_if<is_tensor_of_scalar<T>::value>::type> : public GridTypeMapper_Base { | ||||
|       using scalar_type = typename T::Scalar; | ||||
|     // Traits are the default for scalars, or come from GridTypeMapper for GridTensors | ||||
|     template<typename T> struct Traits<T, typename std::enable_if<is_tensor_of_scalar<T>::value>::type> | ||||
|       : public GridTypeMapper_Base { | ||||
|       using scalar_type   = typename T::Scalar; // ultimate base scalar | ||||
|       static constexpr bool is_complex = ::Grid::EigenIO::is_complex<scalar_type>::value; | ||||
|     }; | ||||
|     template<typename T> struct Traits<T, typename std::enable_if<is_tensor_of_container<T>::value>::type> : public GridTypeMapper<typename T::Scalar> { | ||||
|       using scalar_type = typename GridTypeMapper<typename T::Scalar>::scalar_type; | ||||
|       static constexpr bool is_complex = ::Grid::EigenIO::is_complex<scalar_type>::value; | ||||
|     // Traits are the default for scalars, or come from GridTypeMapper for GridTensors | ||||
|     template<typename T> struct Traits<T, typename std::enable_if<is_tensor_of_container<T>::value>::type> { | ||||
|       using BaseTraits  = GridTypeMapper<typename T::Scalar>; | ||||
|       using scalar_type = typename BaseTraits::scalar_type; // ultimate base scalar | ||||
|       static constexpr bool   is_complex = ::Grid::EigenIO::is_complex<scalar_type>::value; | ||||
|       static constexpr int   TensorLevel = BaseTraits::TensorLevel; | ||||
|       static constexpr int          Rank = BaseTraits::Rank; | ||||
|       static constexpr std::size_t count = BaseTraits::count; | ||||
|       static constexpr int Dimension(int dim) { return BaseTraits::Dimension(dim); } | ||||
|     }; | ||||
|  | ||||
|     // Is this a fixed-size Eigen tensor | ||||
| @@ -310,15 +320,15 @@ namespace Grid { | ||||
|     // If the Tensor isn't in Row-Major order, then we'll need to copy it's data | ||||
|     const bool CopyData{NumElements > 1 && ETensor::Layout != Eigen::StorageOptions::RowMajor}; | ||||
|     const Scalar * pWriteBuffer; | ||||
|     Scalar * pCopyBuffer = nullptr; | ||||
|     std::vector<Scalar> CopyBuffer; | ||||
|     const Index TotalNumElements = NumElements * Traits::count; | ||||
|     if( !CopyData ) { | ||||
|       pWriteBuffer = getFirstScalar( output ); | ||||
|     } else { | ||||
|       // Regardless of the Eigen::Tensor storage order, the copy will be Row Major | ||||
|       pCopyBuffer = new Scalar[TotalNumElements]; | ||||
|       pWriteBuffer = pCopyBuffer; | ||||
|       Scalar * pCopy = pCopyBuffer; | ||||
|       CopyBuffer.resize( TotalNumElements ); | ||||
|       Scalar * pCopy = &CopyBuffer[0]; | ||||
|       pWriteBuffer = pCopy; | ||||
|       std::array<Index, TensorRank> MyIndex; | ||||
|       for( auto &idx : MyIndex ) idx = 0; | ||||
|       for( auto n = 0; n < NumElements; n++ ) { | ||||
| @@ -330,7 +340,6 @@ namespace Grid { | ||||
|       } | ||||
|     } | ||||
|     upcast->template writeMultiDim<Scalar>(s, TotalDims, pWriteBuffer, TotalNumElements); | ||||
|     if( pCopyBuffer ) delete [] pCopyBuffer; | ||||
|   } | ||||
|  | ||||
|   template <typename T> | ||||
|   | ||||
| @@ -51,7 +51,7 @@ namespace Grid | ||||
|     std::vector<std::string> path_; | ||||
|     H5NS::H5File             file_; | ||||
|     H5NS::Group              group_; | ||||
|     unsigned int             dataSetThres_{HDF5_DEF_DATASET_THRES}; | ||||
|     const unsigned int       dataSetThres_{HDF5_DEF_DATASET_THRES}; | ||||
|   }; | ||||
|    | ||||
|   class Hdf5Reader: public Reader<Hdf5Reader> | ||||
| @@ -117,15 +117,8 @@ namespace Grid | ||||
|     // write the entire dataset to file | ||||
|     H5NS::DataSpace dataSpace(rank, dim.data()); | ||||
|  | ||||
|     size_t DataSize = NumElements * sizeof(U); | ||||
|     if (DataSize > dataSetThres_) | ||||
|     if (NumElements > dataSetThres_) | ||||
|     { | ||||
|       // First few prime numbers from https://oeis.org/A000040 | ||||
|       static const unsigned short Primes[] = { 2, 3, 5, 7, 11, 13, 17, 19, 23, 29, 31, | ||||
|         37, 41, 43, 47, 53, 59, 61, 67, 71, 73, 79, 83, 89, 97, 101, 103, 107, 109, | ||||
|         113, 127, 131, 137, 139, 149, 151, 157, 163, 167, 173, 179, 181, 191, 193, | ||||
|         197, 199, 211, 223, 227, 229, 233, 239, 241, 251, 257, 263, 269, 271 }; | ||||
|       constexpr int NumPrimes = sizeof( Primes ) / sizeof( Primes[0] ); | ||||
|       // Make sure 1) each dimension; and 2) chunk size is < 4GB | ||||
|       const hsize_t MaxElements = ( sizeof( U ) == 1 ) ? 0xffffffff : 0x100000000 / sizeof( U ); | ||||
|       hsize_t ElementsPerChunk = 1; | ||||
| @@ -136,13 +129,9 @@ namespace Grid | ||||
|           d = 1; // Chunk size is already as big as can be - remaining dimensions = 1 | ||||
|         else { | ||||
|           // If individual dimension too big, reduce by prime factors if possible | ||||
|           for( int PrimeIdx = 0; d > MaxElements && PrimeIdx < NumPrimes; ) { | ||||
|             if( d % Primes[PrimeIdx] ) | ||||
|               PrimeIdx++; | ||||
|             else | ||||
|               d /= Primes[PrimeIdx]; | ||||
|           } | ||||
|           const char ErrorMsg[] = " dimension > 4GB without small prime factors. " | ||||
|           while( d > MaxElements && ( d & 1 ) == 0 ) | ||||
|             d >>= 1; | ||||
|           const char ErrorMsg[] = " dimension > 4GB and not divisible by 2^n. " | ||||
|                                   "Hdf5IO chunk size will be inefficient. NB Serialisation is not intended for large datasets - please consider alternatives."; | ||||
|           if( d > MaxElements ) { | ||||
|             std::cout << GridLogWarning << "Individual" << ErrorMsg << std::endl; | ||||
| @@ -156,17 +145,13 @@ namespace Grid | ||||
|           ElementsPerChunk *= d; | ||||
|           assert( OverflowCheck == ElementsPerChunk / d && "Product of dimensions overflowed hsize_t" ); | ||||
|           // If product of dimensions too big, reduce by prime factors | ||||
|           for( int PrimeIdx = 0; ElementsPerChunk > MaxElements && PrimeIdx < NumPrimes; ) { | ||||
|           while( ElementsPerChunk > MaxElements && ( ElementsPerChunk & 1 ) == 0 ) { | ||||
|             bTooBig = true; | ||||
|             if( d % Primes[PrimeIdx] ) | ||||
|               PrimeIdx++; | ||||
|             else { | ||||
|               d /= Primes[PrimeIdx]; | ||||
|               ElementsPerChunk /= Primes[PrimeIdx]; | ||||
|             } | ||||
|             d >>= 1; | ||||
|             ElementsPerChunk >>= 1; | ||||
|           } | ||||
|           if( ElementsPerChunk > MaxElements ) { | ||||
|             std::cout << GridLogMessage << "Product of" << ErrorMsg << std::endl; | ||||
|             std::cout << GridLogWarning << "Product of" << ErrorMsg << std::endl; | ||||
|             hsize_t quotient = ElementsPerChunk / MaxElements; | ||||
|             if( ElementsPerChunk % MaxElements ) | ||||
|               quotient++; | ||||
| @@ -270,7 +255,7 @@ namespace Grid | ||||
|     // read the flat vector | ||||
|     buf.resize(size); | ||||
|      | ||||
|     if (size * sizeof(Element) > dataSetThres_) | ||||
|     if (size > dataSetThres_) | ||||
|     { | ||||
|       H5NS::DataSet dataSet; | ||||
|        | ||||
|   | ||||
		Reference in New Issue
	
	Block a user