diff --git a/stdlib/Random/Project.toml b/stdlib/Random/Project.toml index 6aa9f65374539..51915c2e6e388 100644 --- a/stdlib/Random/Project.toml +++ b/stdlib/Random/Project.toml @@ -3,6 +3,7 @@ uuid = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c" [deps] Serialization = "9e88b42a-f829-5b0c-bbe9-9e923198166b" +SHA = "ea8e919c-243c-51af-8825-aaa63cd721ce" [extras] Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40" diff --git a/stdlib/Random/src/DSFMT.jl b/stdlib/Random/src/DSFMT.jl index f72a9dd5e9a0a..6aab7c8b8d0e7 100644 --- a/stdlib/Random/src/DSFMT.jl +++ b/stdlib/Random/src/DSFMT.jl @@ -65,7 +65,7 @@ function dsfmt_init_gen_rand(s::DSFMT_state, seed::UInt32) s.val, seed) end -function dsfmt_init_by_array(s::DSFMT_state, seed::Vector{UInt32}) +function dsfmt_init_by_array(s::DSFMT_state, seed::AbstractVector{UInt32}) ccall((:dsfmt_init_by_array,:libdSFMT), Cvoid, (Ptr{Cvoid}, Ptr{UInt32}, Int32), diff --git a/stdlib/Random/src/RNGs.jl b/stdlib/Random/src/RNGs.jl index 281acad533dad..ce52235ff7b6a 100644 --- a/stdlib/Random/src/RNGs.jl +++ b/stdlib/Random/src/RNGs.jl @@ -278,9 +278,17 @@ end #### seed!() -function seed!(r::MersenneTwister, seed::Vector{UInt32}) +function seed!(r::MersenneTwister, seed::Vector{UInt32}; hash=true) copyto!(resize!(r.seed, length(seed)), seed) - dsfmt_init_by_array(r.state, r.seed) + if hash # hash the seed to make it defensively more robust + c = SHA.SHA2_512_CTX() + # hash VERSION to help people not rely on stream stability between minor releases + SHA.update!(c, reinterpret(UInt8, [VERSION.major, VERSION.minor])) + SHA.update!(c, reinterpret(UInt8, seed)) + dsfmt_init_by_array(r.state, reinterpret(UInt32, SHA.digest!(c))) + else + dsfmt_init_by_array(r.state, r.seed) + end mt_setempty!(r) mt_setempty!(r, UInt128) fillcache_zeros!(r) diff --git a/stdlib/Random/src/Random.jl b/stdlib/Random/src/Random.jl index 5197ac1c34e7b..819d8a33ba05c 100644 --- a/stdlib/Random/src/Random.jl +++ b/stdlib/Random/src/Random.jl @@ -14,6 +14,8 @@ using .DSFMT using Base.GMP.MPZ using Base.GMP: Limb +import SHA + using Base: BitInteger, BitInteger_types, BitUnsigned, require_one_based_indexing import Base: copymutable, copy, copy!, ==, hash, convert,