Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

MersenneTwister: hash seeds like for Xoshiro #51436

Merged
merged 2 commits into from
Sep 29, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 4 additions & 4 deletions stdlib/Random/docs/src/index.md
Original file line number Diff line number Diff line change
Expand Up @@ -126,8 +126,8 @@ Random.SamplerSimple
Decoupling pre-computation from actually generating the values is part of the API, and is also available to the user. As an example, assume that `rand(rng, 1:20)` has to be called repeatedly in a loop: the way to take advantage of this decoupling is as follows:

```julia
rng = MersenneTwister()
sp = Random.Sampler(rng, 1:20) # or Random.Sampler(MersenneTwister, 1:20)
rng = Xoshiro()
sp = Random.Sampler(rng, 1:20) # or Random.Sampler(Xoshiro, 1:20)
for x in X
n = rand(rng, sp) # similar to n = rand(rng, 1:20)
# use n
Expand Down Expand Up @@ -159,8 +159,8 @@ Scalar and array methods for `Die` now work as expected:
julia> rand(Die)
Die(5)

julia> rand(MersenneTwister(0), Die)
Die(11)
julia> rand(Xoshiro(0), Die)
Die(10)

julia> rand(Die, 3)
3-element Vector{Die}:
Expand Down
3 changes: 2 additions & 1 deletion stdlib/Random/src/DSFMT.jl
Original file line number Diff line number Diff line change
Expand Up @@ -65,7 +65,8 @@ function dsfmt_init_gen_rand(s::DSFMT_state, seed::UInt32)
s.val, seed)
end

function dsfmt_init_by_array(s::DSFMT_state, seed::Vector{UInt32})
function dsfmt_init_by_array(s::DSFMT_state, seed::StridedVector{UInt32})
strides(seed) == (1,) || throw(ArgumentError("seed must have its stride equal to 1"))
ccall((:dsfmt_init_by_array,:libdSFMT),
Cvoid,
(Ptr{Cvoid}, Ptr{UInt32}, Int32),
Expand Down
132 changes: 68 additions & 64 deletions stdlib/Random/src/RNGs.jl
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@ The entropy is obtained from the operating system.
"""
struct RandomDevice <: AbstractRNG; end
RandomDevice(seed::Nothing) = RandomDevice()
seed!(rng::RandomDevice) = rng
seed!(rng::RandomDevice, ::Nothing) = rng

rand(rd::RandomDevice, sp::SamplerBoolBitInteger) = Libc.getrandom!(Ref{sp[]}())[]
rand(rd::RandomDevice, ::SamplerType{Bool}) = rand(rd, UInt8) % Bool
Expand Down Expand Up @@ -44,7 +44,7 @@ const MT_CACHE_I = 501 << 4 # number of bytes in the UInt128 cache
@assert dsfmt_get_min_array_size() <= MT_CACHE_F

mutable struct MersenneTwister <: AbstractRNG
seed::Vector{UInt32}
seed::Any
state::DSFMT_state
vals::Vector{Float64}
ints::Vector{UInt128}
Expand All @@ -70,7 +70,7 @@ mutable struct MersenneTwister <: AbstractRNG
end
end

MersenneTwister(seed::Vector{UInt32}, state::DSFMT_state) =
MersenneTwister(seed, state::DSFMT_state) =
MersenneTwister(seed, state,
Vector{Float64}(undef, MT_CACHE_F),
Vector{UInt128}(undef, MT_CACHE_I >> 4),
Expand All @@ -92,19 +92,17 @@ See the [`seed!`](@ref) function for reseeding an already existing `MersenneTwis

# Examples
```jldoctest
julia> rng = MersenneTwister(1234);
julia> rng = MersenneTwister(123);

julia> x1 = rand(rng, 2)
2-element Vector{Float64}:
0.5908446386657102
0.7667970365022592
0.37453777969575874
0.8735343642013971

julia> rng = MersenneTwister(1234);

julia> x2 = rand(rng, 2)
julia> x2 = rand(MersenneTwister(123), 2)
2-element Vector{Float64}:
0.5908446386657102
0.7667970365022592
0.37453777969575874
0.8735343642013971

julia> x1 == x2
true
Expand All @@ -115,7 +113,7 @@ MersenneTwister(seed=nothing) =


function copy!(dst::MersenneTwister, src::MersenneTwister)
copyto!(resize!(dst.seed, length(src.seed)), src.seed)
dst.seed = src.seed
copy!(dst.state, src.state)
copyto!(dst.vals, src.vals)
copyto!(dst.ints, src.ints)
Expand All @@ -129,7 +127,7 @@ function copy!(dst::MersenneTwister, src::MersenneTwister)
end

copy(src::MersenneTwister) =
MersenneTwister(copy(src.seed), copy(src.state), copy(src.vals), copy(src.ints),
MersenneTwister(src.seed, copy(src.state), copy(src.vals), copy(src.ints),
src.idxF, src.idxI, src.adv, src.adv_jump, src.adv_vals, src.adv_ints)


Expand All @@ -144,12 +142,10 @@ hash(r::MersenneTwister, h::UInt) =

function show(io::IO, rng::MersenneTwister)
# seed
seed = from_seed(rng.seed)
seed_str = seed <= typemax(Int) ? string(seed) : "0x" * string(seed, base=16) # DWIM
if rng.adv_jump == 0 && rng.adv == 0
return print(io, MersenneTwister, "(", seed_str, ")")
return print(io, MersenneTwister, "(", repr(rng.seed), ")")
end
print(io, MersenneTwister, "(", seed_str, ", (")
print(io, MersenneTwister, "(", repr(rng.seed), ", (")
# state
adv = Integer[rng.adv_jump, rng.adv]
if rng.adv_vals != -1 || rng.adv_ints != -1
Expand Down Expand Up @@ -277,76 +273,84 @@ end

### seeding

#### make_seed()
#### random_seed() & hash_seed()

# make_seed produces values of type Vector{UInt32}, suitable for MersenneTwister seeding
function make_seed()
# random_seed tries to produce a random seed of type UInt128 from system entropy
function random_seed()
try
return rand(RandomDevice(), UInt32, 4)
# as MersenneTwister prints its seed when `show`ed, 128 bits is a good compromise for
# almost surely always getting distinct seeds, while having them printed reasonably tersely
return rand(RandomDevice(), UInt128)
catch ex
ex isa IOError || rethrow()
@warn "Entropy pool not available to seed RNG; using ad-hoc entropy sources."
return make_seed(Libc.rand())
return Libc.rand()
end
end

"""
make_seed(n::Integer) -> Vector{UInt32}

Transform `n` into a bit pattern encoded as a `Vector{UInt32}`, suitable for
RNG seeding routines.

`make_seed` is "injective" : if `n != m`, then `make_seed(n) != `make_seed(m)`.
Moreover, if `n == m`, then `make_seed(n) == make_seed(m)`.

This is an internal function, subject to change.
"""
function make_seed(n::Integer)
neg = signbit(n)
function hash_seed(seed::Integer)
ctx = SHA.SHA2_256_CTX()
neg = signbit(seed)
if neg
n = ~n
end
@assert n >= 0
seed = UInt32[]
# we directly encode the bit pattern of `n` into the resulting vector `seed`;
# to greatly limit breaking the streams of random numbers, we encode the sign bit
# as the upper bit of `seed[end]` (i.e. for most positive seeds, `make_seed` returns
# the same vector as when we didn't encode the sign bit)
while !iszero(n)
push!(seed, n & 0xffffffff)
n >>>= 32
seed = ~seed
end
if isempty(seed) || !iszero(seed[end] & 0x80000000)
push!(seed, zero(UInt32))
end
if neg
seed[end] |= 0x80000000
@assert seed >= 0
while true
word = (seed % UInt32) & 0xffffffff
seed >>>= 32
SHA.update!(ctx, reinterpret(NTuple{4, UInt8}, word))
iszero(seed) && break
end
seed
# make sure the hash of negative numbers is different from the hash of positive numbers
neg && SHA.update!(ctx, (0x01,))
SHA.digest!(ctx)
end

# inverse of make_seed(::Integer)
function from_seed(a::Vector{UInt32})::BigInt
neg = !iszero(a[end] & 0x80000000)
seed = sum((i == length(a) ? a[i] & 0x7fffffff : a[i]) * big(2)^(32*(i-1))
for i in 1:length(a))
neg ? ~seed : seed
function hash_seed(seed::Union{AbstractArray{UInt32}, AbstractArray{UInt64}})
ctx = SHA.SHA2_256_CTX()
for xx in seed
SHA.update!(ctx, reinterpret(NTuple{8, UInt8}, UInt64(xx)))
end
# discriminate from hash_seed(::Integer)
SHA.update!(ctx, (0x10,))
SHA.digest!(ctx)
end


"""
hash_seed(seed) -> AbstractVector{UInt8}

Return a cryptographic hash of `seed` of size 256 bits (32 bytes).
`seed` can currently be of type `Union{Integer, DenseArray{UInt32}, DenseArray{UInt64}}`,
but modules can extend this function for types they own.

`hash_seed` is "injective" : if `n != m`, then `hash_seed(n) != `hash_seed(m)`.
Moreover, if `n == m`, then `hash_seed(n) == hash_seed(m)`.

This is an internal function subject to change.
"""
hash_seed

#### seed!()

function seed!(r::MersenneTwister, seed::Vector{UInt32})
copyto!(resize!(r.seed, length(seed)), seed)
dsfmt_init_by_array(r.state, r.seed)
function initstate!(r::MersenneTwister, data::StridedVector, seed)
# we deepcopy `seed` because the caller might mutate it, and it's useful
# to keep it constant inside `MersenneTwister`; but multiple instances
# can share the same seed without any problem (e.g. in `copy`)
r.seed = deepcopy(seed)
dsfmt_init_by_array(r.state, reinterpret(UInt32, data))
reset_caches!(r)
r.adv = 0
r.adv_jump = 0
return r
end

seed!(r::MersenneTwister) = seed!(r, make_seed())
seed!(r::MersenneTwister, n::Integer) = seed!(r, make_seed(n))
# when a seed is not provided, we generate one via `RandomDevice()` in `random_seed()` rather
# than calling directly `initstate!` with `rand(RandomDevice(), UInt32, whatever)` because the
# seed is printed in `show(::MersenneTwister)`, so we need one; the cost of `hash_seed` is a
# small overhead compared to `initstate!`, so this simple solution is fine
seed!(r::MersenneTwister, ::Nothing) = seed!(r, random_seed())
seed!(r::MersenneTwister, seed) = initstate!(r, hash_seed(seed), seed)


### Global RNG
Expand Down Expand Up @@ -713,7 +717,7 @@ end
function _randjump(r::MersenneTwister, jumppoly::DSFMT.GF2X)
adv = r.adv
adv_jump = r.adv_jump
s = MersenneTwister(copy(r.seed), DSFMT.dsfmt_jump(r.state, jumppoly))
s = MersenneTwister(r.seed, DSFMT.dsfmt_jump(r.state, jumppoly))
reset_caches!(s)
s.adv = adv
s.adv_jump = adv_jump
Expand Down
25 changes: 14 additions & 11 deletions stdlib/Random/src/Random.jl
Original file line number Diff line number Diff line change
Expand Up @@ -337,8 +337,8 @@ julia> rand(Int, 2)

julia> using Random

julia> rand(MersenneTwister(0), Dict(1=>2, 3=>4))
1=>2
julia> rand(Xoshiro(0), Dict(1=>2, 3=>4))
3 => 4

julia> rand((2, 3))
3
Expand Down Expand Up @@ -370,15 +370,13 @@ but without allocating a new array.

# Examples
```jldoctest
julia> rng = MersenneTwister(1234);

julia> rand!(rng, zeros(5))
julia> rand!(Xoshiro(123), zeros(5))
5-element Vector{Float64}:
0.5908446386657102
0.7667970365022592
0.5662374165061859
0.4600853424625171
0.7940257103317943
0.521213795535383
0.5868067574533484
0.8908786980927811
0.19090669902576285
0.5256623915420473
```
"""
rand!
Expand Down Expand Up @@ -433,6 +431,11 @@ julia> rand(Xoshiro(), Bool) # not reproducible either
true
```
"""
seed!(rng::AbstractRNG, ::Nothing) = seed!(rng)
seed!(rng::AbstractRNG) = seed!(rng, nothing)
#=
We have this generic definition instead of the alternative option
`seed!(rng::AbstractRNG, ::Nothing) = seed!(rng)`
because it would lead too easily to ambiguities, e.g. when we define `seed!(::Xoshiro, seed)`.
=#

end # module
17 changes: 8 additions & 9 deletions stdlib/Random/src/Xoshiro.jl
Original file line number Diff line number Diff line change
Expand Up @@ -230,16 +230,20 @@ rng_native_52(::TaskLocalRNG) = UInt64
## Shared implementation between Xoshiro and TaskLocalRNG

# this variant of setstate! initializes the internal splitmix state, a.k.a. `s4`
@inline initstate!(x::Union{TaskLocalRNG, Xoshiro}, (s0, s1, s2, s3)::NTuple{4, UInt64}) =
@inline function initstate!(x::Union{TaskLocalRNG, Xoshiro}, state)
length(state) == 4 && eltype(state) == UInt64 ||
throw(ArgumentError("initstate! expects a list of 4 `UInt64` values"))
s0, s1, s2, s3 = state
setstate!(x, (s0, s1, s2, s3, 1s0 + 3s1 + 5s2 + 7s3))
end

copy(rng::Union{TaskLocalRNG, Xoshiro}) = Xoshiro(getstate(rng)...)
copy!(dst::Union{TaskLocalRNG, Xoshiro}, src::Union{TaskLocalRNG, Xoshiro}) = setstate!(dst, getstate(src))
==(x::Union{TaskLocalRNG, Xoshiro}, y::Union{TaskLocalRNG, Xoshiro}) = getstate(x) == getstate(y)
# use a magic (random) number to scramble `h` so that `hash(x)` is distinct from `hash(getstate(x))`
hash(x::Union{TaskLocalRNG, Xoshiro}, h::UInt) = hash(getstate(x), h + 0x49a62c2dda6fa9be % UInt)

function seed!(rng::Union{TaskLocalRNG, Xoshiro})
function seed!(rng::Union{TaskLocalRNG, Xoshiro}, ::Nothing)
# as we get good randomness from RandomDevice, we can skip hashing
rd = RandomDevice()
s0 = rand(rd, UInt64)
Expand All @@ -249,14 +253,9 @@ function seed!(rng::Union{TaskLocalRNG, Xoshiro})
initstate!(rng, (s0, s1, s2, s3))
end

function seed!(rng::Union{TaskLocalRNG, Xoshiro}, seed::Union{Vector{UInt32}, Vector{UInt64}})
c = SHA.SHA2_256_CTX()
SHA.update!(c, reinterpret(UInt8, seed))
s0, s1, s2, s3 = reinterpret(UInt64, SHA.digest!(c))
initstate!(rng, (s0, s1, s2, s3))
end
seed!(rng::Union{TaskLocalRNG, Xoshiro}, seed) =
initstate!(rng, reinterpret(UInt64, hash_seed(seed)))

seed!(rng::Union{TaskLocalRNG, Xoshiro}, seed::Integer) = seed!(rng, make_seed(seed))

@inline function rand(x::Union{TaskLocalRNG, Xoshiro}, ::SamplerType{UInt64})
s0, s1, s2, s3 = getstate(x)
Expand Down
Loading