1
0
mirror of https://github.com/paboyle/Grid.git synced 2025-06-15 14:27:06 +01:00

update rng-state, change output format

This commit is contained in:
Luchang Jin
2016-09-18 13:47:48 -04:00
parent 4fb37ececd
commit 1abbe2fd0c
2 changed files with 81 additions and 38 deletions

View File

@ -31,6 +31,7 @@
#include <string> #include <string>
#include <ostream> #include <ostream>
#include <istream> #include <istream>
#include <vector>
#ifdef CURRENT_DEFAULT_NAMESPACE_NAME #ifdef CURRENT_DEFAULT_NAMESPACE_NAME
namespace CURRENT_DEFAULT_NAMESPACE_NAME { namespace CURRENT_DEFAULT_NAMESPACE_NAME {
@ -69,9 +70,9 @@ struct RngState
unsigned long index; unsigned long index;
// //
uint64_t cache[3]; uint64_t cache[3];
double gaussion; double gaussian;
int cacheAvail; int cacheAvail;
bool gaussionAvail; bool gaussianAvail;
// //
inline void init() inline void init()
{ {
@ -111,37 +112,74 @@ struct RngState
} }
}; };
const size_t RNG_STATE_SIZE_OF_INT32 = 2 + 8 + 2 + 3 * 2 + 2 + 1 + 1;
inline uint64_t patchTwoUint32(const uint32_t a, const uint32_t b)
{
return (uint64_t)a << 32 | (uint64_t)b;
}
inline void splitTwoUint32(uint32_t& a, uint32_t& b, const uint64_t x)
{
b = (uint32_t)x;
a = (uint32_t)(x >> 32);
assert(x == patchTwoUint32(a, b));
}
inline void exportRngState(std::vector<uint32_t>& v, const RngState& rs)
{
assert(22 == RNG_STATE_SIZE_OF_INT32);
v.resize(RNG_STATE_SIZE_OF_INT32);
splitTwoUint32(v[0], v[1], rs.numBytes);
for (int i = 0; i < 8; ++i) {
v[2 + i] = rs.hash[i];
}
splitTwoUint32(v[10], v[11], rs.index);
for (int i = 0; i < 3; ++i) {
splitTwoUint32(v[12 + i * 2], v[12 + i * 2 + 1], rs.cache[i]);
}
const uint64_t* p = (const uint64_t*)&rs.gaussian;
splitTwoUint32(v[18], v[19], *p);
v[20] = rs.cacheAvail;
v[21] = rs.gaussianAvail;
}
inline void importRngState(RngState& rs, const std::vector<uint32_t>& v)
{
assert(RNG_STATE_SIZE_OF_INT32 == v.size());
assert(22 == RNG_STATE_SIZE_OF_INT32);
rs.numBytes = patchTwoUint32(v[0], v[1]);
for (int i = 0; i < 8; ++i) {
rs.hash[i] = v[2 + i];
}
rs.index = patchTwoUint32(v[10], v[11]);
for (int i = 0; i < 3; ++i) {
rs.cache[i] = patchTwoUint32(v[12 + i * 2], v[12 + i * 2 + 1]);
}
uint64_t* p = (uint64_t*)&rs.gaussian;
*p = patchTwoUint32(v[18], v[19]);
rs.cacheAvail = v[20];
rs.gaussianAvail = v[21];
}
inline std::ostream& operator<<(std::ostream& os, const RngState& rs) inline std::ostream& operator<<(std::ostream& os, const RngState& rs)
{ {
os << rs.numBytes << " "; std::vector<uint32_t> v(RNG_STATE_SIZE_OF_INT32);
for (int i = 0; i < 8; ++i) { exportRngState(v, rs);
os << rs.hash[i] << " "; for (size_t i = 0; i < v.size() - 1; ++i) {
os << v[i] << " ";
} }
os << rs.index << " "; os << v.back();
for (int i = 0; i < 3; ++i) {
os << rs.cache[i] << " ";
}
const uint64_t* p = (const uint64_t*)&rs.gaussion;
os << *p << " ";
os << rs.cacheAvail << " ";
os << rs.gaussionAvail;
return os; return os;
} }
inline std::istream& operator>>(std::istream& is, RngState& rs) inline std::istream& operator>>(std::istream& is, RngState& rs)
{ {
is >> rs.numBytes; std::vector<uint32_t> v(RNG_STATE_SIZE_OF_INT32);
for (int i = 0; i < 8; ++i) { for (size_t i = 0; i < v.size(); ++i) {
is >> rs.hash[i]; is >> v[i];
} }
is >> rs.index; importRngState(rs, v);
for (int i = 0; i < 3; ++i) {
is >> rs.cache[i];
}
uint64_t* p = (uint64_t*)&rs.gaussion;
is >> *p;
is >> rs.cacheAvail;
is >> rs.gaussionAvail;
return is; return is;
} }
@ -430,9 +468,9 @@ inline void reset(RngState& rs)
rs.cache[0] = 0; rs.cache[0] = 0;
rs.cache[1] = 0; rs.cache[1] = 0;
rs.cache[2] = 0; rs.cache[2] = 0;
rs.gaussion = 0.0; rs.gaussian = 0.0;
rs.cacheAvail = 0; rs.cacheAvail = 0;
rs.gaussionAvail = false; rs.gaussianAvail = false;
} }
inline void reset(RngState& rs, const std::string& seed) inline void reset(RngState& rs, const std::string& seed)
@ -458,14 +496,9 @@ inline void splitRngState(RngState& rs, const RngState& rs0, const std::string&
rs.cache[0] = 0; rs.cache[0] = 0;
rs.cache[1] = 0; rs.cache[1] = 0;
rs.cache[2] = 0; rs.cache[2] = 0;
rs.gaussion = 0.0; rs.gaussian = 0.0;
rs.cacheAvail = 0; rs.cacheAvail = 0;
rs.gaussionAvail = false; rs.gaussianAvail = false;
}
inline uint64_t patchTwoUint32(const uint32_t a, const uint32_t b)
{
return (uint64_t)a << 32 | (uint64_t)b;
} }
inline void computeHashWithInput(uint32_t hash[8], const RngState& rs, const std::string& input) inline void computeHashWithInput(uint32_t hash[8], const RngState& rs, const std::string& input)
@ -503,9 +536,9 @@ inline double uRandGen(RngState& rs, const double upper, const double lower)
inline double gRandGen(RngState& rs, const double sigma, const double center) inline double gRandGen(RngState& rs, const double sigma, const double center)
{ {
rs.index += 1; rs.index += 1;
if (rs.gaussionAvail) { if (rs.gaussianAvail) {
rs.gaussionAvail = false; rs.gaussianAvail = false;
return rs.gaussion * sigma + center; return rs.gaussian * sigma + center;
} else { } else {
// pick 2 uniform numbers in the square extending from // pick 2 uniform numbers in the square extending from
// -1 to 1 in each direction, see if they are in the // -1 to 1 in each direction, see if they are in the
@ -526,8 +559,8 @@ inline double gRandGen(RngState& rs, const double sigma, const double center)
return 1e+10; return 1e+10;
} }
double fac = std::sqrt(-2.0 * std::log(rsq)/rsq); double fac = std::sqrt(-2.0 * std::log(rsq)/rsq);
rs.gaussion = v1 * fac; rs.gaussian = v1 * fac;
rs.gaussionAvail = true; rs.gaussianAvail = true;
return v2 * fac * sigma + center; return v2 * fac * sigma + center;
} }
} }

View File

@ -105,6 +105,16 @@ T& reads(T& x, const std::string& str)
return x; return x;
} }
void fdisplay(FILE* fp, const std::string& str)
{
fprintf(fp, "%s", str.c_str());
}
void fdisplayln(FILE* fp, const std::string& str)
{
fprintf(fp, "%s\n", str.c_str());
}
#ifdef CURRENT_DEFAULT_NAMESPACE_NAME #ifdef CURRENT_DEFAULT_NAMESPACE_NAME
} }
#endif #endif