From 69bbd9892ec6b50874077b1ce246aeae0284e858 Mon Sep 17 00:00:00 2001 From: alix <154560637+a1ix2@users.noreply.github.com> Date: Fri, 6 Dec 2024 10:09:52 -0500 Subject: [PATCH] keep parameter order (#466) * keep parameter order * Test order of parameters for get(c; section) --------- Co-authored-by: Markus Hauru --- src/chains.jl | 4 ++-- src/utils.jl | 2 +- test/sections_tests.jl | 7 +++++-- 3 files changed, 8 insertions(+), 5 deletions(-) diff --git a/src/chains.jl b/src/chains.jl index 8f00438d..877316a0 100644 --- a/src/chains.jl +++ b/src/chains.jl @@ -240,7 +240,7 @@ julia> get(chn, :param_1; flatten=true) ``` """ function Base.get(c::Chains, vs::Vector{Symbol}; flatten=false) - pairs = Dict() + pairs = OrderedCollections.OrderedDict() for v in vs syms = namesingroup(c, v) len = length(syms) @@ -289,7 +289,7 @@ function Base.get( section::Union{Symbol,AbstractVector{Symbol}}, flatten = false ) - names = Set(Symbol[]) + names = OrderedCollections.OrderedSet(Symbol[]) regex = r"[^\[]*" _section = section isa Symbol ? (section,) : section for v in _section diff --git a/src/utils.jl b/src/utils.jl index 1eeebcdc..a384d253 100644 --- a/src/utils.jl +++ b/src/utils.jl @@ -56,7 +56,7 @@ function cummean(x::AbstractVector) return y end -function _dict2namedtuple(d::Dict) +function _dict2namedtuple(d::AbstractDict) t_keys = ntuple(x -> Symbol(collect(keys(d))[x]), length(keys(d))) t_vals = ntuple(x -> collect(values(d))[x], length(values(d))) return NamedTuple{t_keys}(t_vals) diff --git a/test/sections_tests.jl b/test/sections_tests.jl index 1f463073..9ea64b34 100644 --- a/test/sections_tests.jl +++ b/test/sections_tests.jl @@ -1,9 +1,12 @@ using MCMCChains, Test # https://github.com/TuringLang/AdvancedMH.jl/pull/63 +# https://github.com/TuringLang/MCMCChains.jl/issues/443 @testset "order of parameters" begin - chains = Chains(rand(2, 10, 4), vcat(["μ[$i]" for i in 1:9], :lp), (internals=[:lp],)) - @test names(chains, :parameters) == [Symbol("μ[$i]") for i in 1:9] + params = vcat(["μ[$i]" for i in 1:9], :p1, :p2, :p3) + chains = Chains(rand(2, length(params)+1, 4), vcat(params, :lp), (internals=[:lp],)) + @test names(chains, :parameters) == map(Symbol, params) + @test collect(keys(get(chains; section=:parameters))) == [:μ, :p1, :p2, :p3] end @testset "describe sections" begin