diff --git a/lib/lattice/rng/rng-state.h b/lib/lattice/rng/rng-state.h index eab070ba..af1ea815 100644 --- a/lib/lattice/rng/rng-state.h +++ b/lib/lattice/rng/rng-state.h @@ -31,6 +31,7 @@ #include #include #include +#include #ifdef CURRENT_DEFAULT_NAMESPACE_NAME namespace CURRENT_DEFAULT_NAMESPACE_NAME { @@ -69,9 +70,9 @@ struct RngState unsigned long index; // uint64_t cache[3]; - double gaussion; + double gaussian; int cacheAvail; - bool gaussionAvail; + bool gaussianAvail; // inline void init() { @@ -111,37 +112,74 @@ struct RngState } }; +const size_t RNG_STATE_SIZE_OF_INT32 = 2 + 8 + 2 + 3 * 2 + 2 + 1 + 1; + +inline uint64_t patchTwoUint32(const uint32_t a, const uint32_t b) +{ + return (uint64_t)a << 32 | (uint64_t)b; +} + +inline void splitTwoUint32(uint32_t& a, uint32_t& b, const uint64_t x) +{ + b = (uint32_t)x; + a = (uint32_t)(x >> 32); + assert(x == patchTwoUint32(a, b)); +} + +inline void exportRngState(std::vector& v, const RngState& rs) +{ + assert(22 == RNG_STATE_SIZE_OF_INT32); + v.resize(RNG_STATE_SIZE_OF_INT32); + splitTwoUint32(v[0], v[1], rs.numBytes); + for (int i = 0; i < 8; ++i) { + v[2 + i] = rs.hash[i]; + } + splitTwoUint32(v[10], v[11], rs.index); + for (int i = 0; i < 3; ++i) { + splitTwoUint32(v[12 + i * 2], v[12 + i * 2 + 1], rs.cache[i]); + } + const uint64_t* p = (const uint64_t*)&rs.gaussian; + splitTwoUint32(v[18], v[19], *p); + v[20] = rs.cacheAvail; + v[21] = rs.gaussianAvail; +} + +inline void importRngState(RngState& rs, const std::vector& v) +{ + assert(RNG_STATE_SIZE_OF_INT32 == v.size()); + assert(22 == RNG_STATE_SIZE_OF_INT32); + rs.numBytes = patchTwoUint32(v[0], v[1]); + for (int i = 0; i < 8; ++i) { + rs.hash[i] = v[2 + i]; + } + rs.index = patchTwoUint32(v[10], v[11]); + for (int i = 0; i < 3; ++i) { + rs.cache[i] = patchTwoUint32(v[12 + i * 2], v[12 + i * 2 + 1]); + } + uint64_t* p = (uint64_t*)&rs.gaussian; + *p = patchTwoUint32(v[18], v[19]); + rs.cacheAvail = v[20]; + rs.gaussianAvail = v[21]; +} + inline std::ostream& operator<<(std::ostream& os, const RngState& rs) { - os << rs.numBytes << " "; - for (int i = 0; i < 8; ++i) { - os << rs.hash[i] << " "; + std::vector v(RNG_STATE_SIZE_OF_INT32); + exportRngState(v, rs); + for (size_t i = 0; i < v.size() - 1; ++i) { + os << v[i] << " "; } - os << rs.index << " "; - for (int i = 0; i < 3; ++i) { - os << rs.cache[i] << " "; - } - const uint64_t* p = (const uint64_t*)&rs.gaussion; - os << *p << " "; - os << rs.cacheAvail << " "; - os << rs.gaussionAvail; + os << v.back(); return os; } inline std::istream& operator>>(std::istream& is, RngState& rs) { - is >> rs.numBytes; - for (int i = 0; i < 8; ++i) { - is >> rs.hash[i]; + std::vector v(RNG_STATE_SIZE_OF_INT32); + for (size_t i = 0; i < v.size(); ++i) { + is >> v[i]; } - is >> rs.index; - for (int i = 0; i < 3; ++i) { - is >> rs.cache[i]; - } - uint64_t* p = (uint64_t*)&rs.gaussion; - is >> *p; - is >> rs.cacheAvail; - is >> rs.gaussionAvail; + importRngState(rs, v); return is; } @@ -430,9 +468,9 @@ inline void reset(RngState& rs) rs.cache[0] = 0; rs.cache[1] = 0; rs.cache[2] = 0; - rs.gaussion = 0.0; + rs.gaussian = 0.0; rs.cacheAvail = 0; - rs.gaussionAvail = false; + rs.gaussianAvail = false; } inline void reset(RngState& rs, const std::string& seed) @@ -458,14 +496,9 @@ inline void splitRngState(RngState& rs, const RngState& rs0, const std::string& rs.cache[0] = 0; rs.cache[1] = 0; rs.cache[2] = 0; - rs.gaussion = 0.0; + rs.gaussian = 0.0; rs.cacheAvail = 0; - rs.gaussionAvail = false; -} - -inline uint64_t patchTwoUint32(const uint32_t a, const uint32_t b) -{ - return (uint64_t)a << 32 | (uint64_t)b; + rs.gaussianAvail = false; } inline void computeHashWithInput(uint32_t hash[8], const RngState& rs, const std::string& input) @@ -503,9 +536,9 @@ inline double uRandGen(RngState& rs, const double upper, const double lower) inline double gRandGen(RngState& rs, const double sigma, const double center) { rs.index += 1; - if (rs.gaussionAvail) { - rs.gaussionAvail = false; - return rs.gaussion * sigma + center; + if (rs.gaussianAvail) { + rs.gaussianAvail = false; + return rs.gaussian * sigma + center; } else { // pick 2 uniform numbers in the square extending from // -1 to 1 in each direction, see if they are in the @@ -526,8 +559,8 @@ inline double gRandGen(RngState& rs, const double sigma, const double center) return 1e+10; } double fac = std::sqrt(-2.0 * std::log(rsq)/rsq); - rs.gaussion = v1 * fac; - rs.gaussionAvail = true; + rs.gaussian = v1 * fac; + rs.gaussianAvail = true; return v2 * fac * sigma + center; } } diff --git a/lib/lattice/rng/show.h b/lib/lattice/rng/show.h index 41a92ce9..60953875 100644 --- a/lib/lattice/rng/show.h +++ b/lib/lattice/rng/show.h @@ -105,6 +105,16 @@ T& reads(T& x, const std::string& str) return x; } +void fdisplay(FILE* fp, const std::string& str) +{ + fprintf(fp, "%s", str.c_str()); +} + +void fdisplayln(FILE* fp, const std::string& str) +{ + fprintf(fp, "%s\n", str.c_str()); +} + #ifdef CURRENT_DEFAULT_NAMESPACE_NAME } #endif