Skip to content

Commit

Permalink
initial commit
Browse files Browse the repository at this point in the history
  • Loading branch information
jihodori committed Nov 14, 2024
1 parent cd9d264 commit ec733ce
Show file tree
Hide file tree
Showing 2 changed files with 78 additions and 1 deletion.
4 changes: 3 additions & 1 deletion Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@ POMDPs = "a93abf59-7444-517b-a68a-c42f96afdd7d"
Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c"
Statistics = "10745b16-79ce-11e8-11f9-7d13ad32a3b2"
StatsBase = "2913bbd2-ae8a-5f71-8c99-4fb6c76f3a91"
GaussianFilters = "08d575fb-911d-541e-8431-6cfa767e7851"

[compat]
POMDPLinter = "0.1"
Expand All @@ -19,6 +20,7 @@ POMDPs = "0.9, 1"
Statistics = "1"
StatsBase = "0.32, 0.33, 0.34"
julia = "1.1"
GaussianFilters = "0.1.2"

[extras]
DelimitedFiles = "8bb1440f-4735-579b-a4ab-409b98df4dab"
Expand All @@ -38,4 +40,4 @@ Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40"
VegaLite = "112f6efa-9a02-5b7d-90c0-432ed331239a"

[targets]
test = ["DelimitedFiles", "Distributions", "InteractiveUtils", "LinearAlgebra", "Markdown", "Plots", "PlutoUI", "POMDPModels", "POMDPs", "POMDPTools", "Random", "Reel", "StaticArrays", "Test", "VegaLite"]
test = ["DelimitedFiles", "Distributions", "InteractiveUtils", "LinearAlgebra", "Markdown", "Plots", "PlutoUI", "POMDPModels", "POMDPs", "POMDPTools", "Random", "Reel", "StaticArrays", "Test", "VegaLite", "GaussianFilters"]
75 changes: 75 additions & 0 deletions src/raoblackwellized.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,75 @@
abstract type RaoBlackwellizedParticleBelief end

struct RaoBlackwellizedParticleFilter <: Updater
particle_filter::BasicParticleFilter
analytical_filter::AbstractFilter # https://github.com/sisl/GaussianFilters.jl/blob/master/src/kf_classes.jl#L7
end

mutable struct RaoBlackwellizedParticleCollection{P, T<:AbstractVector{<:Number}, S<:Symmetric{<:Number}, SP, AP, R} <: RaoBlackwellizedParticleBelief
particles::AbstractParticleBelief{P}
analytical_beliefs::Vector{GaussianBelief{T, S}}
sampled_part::SP
analytical_part::AP
reconstruct::R
end

function RaoBlackwellizedParticleCollection(
particles::AbstractParticleBelief{P},
analytical_beliefs::Vector{GaussianBelief{T, S}},
sampled_part::SP,
analytical_part::AP,
reconstruct::R) where {P, T<:AbstractVector{<:Number}, S<:Symmetric{<:Number}, SP, AP, R}
@assert n_particles(particles) == length(analytical_beliefs) "The number of particles must match the number of analytical beliefs."
return RaoBlackwellizedParticleCollection{P, T, S, SP, AP, R}(
particles,
analytical_beliefs,
sampled_part,
analytical_part,
reconstruct)
end

ParticleFilters.n_particles(b::RaoBlackwellizedParticleCollection) = n_particles(b.particles)
ParticleFilters.particles(b::RaoBlackwellizedParticleCollection) = ParticleFilters.particles(b.particles)
ParticleFilters.particle(b::RaoBlackwellizedParticleCollection, i::Int) = ParticleFilters.particle(b.particles, i)

function Statistics.mean(b::RaoBlackwellizedParticleCollection{T}) where T
particles_mean = mean(b.particles)
normalized_weights = weights ./ sum(weights)
analytical_means = [belief.μ for belief in b.analytical_beliefs]
weighted_analytical_mean = sum(normalized_weights .* analytical_means)
return (particles_mean, weighted_analytical_mean)
end

function Statistics.cov(b::RaoBlackwellizedParticleCollection{T}) where T
particles_cov = cov(b.particles)
if b.particles isa ParticleCollection
weights = fill(1.0 / n_particles(b.particles), n_particles(b.particles))
else
weights = b.particles.weights
end
normalized_weights = weights ./ sum(weights)
analytical_variances = [belief.Σ for belief in b.analytical_beliefs]
weighted_analytical_variances = sum(normalized_weights .* analytical_variances)
return (particles_cov, weighted_analytical_variances)
end

function Random.rand(rng::AbstractRNG, sampler::Random.SamplerTrivial{<:RaoBlackwellizedParticleCollection})
b = sampler[]
t = rand(rng) * weight_sum(b.particles)
i = 1
cw = b.particles.weights[1]
while cw < t && i < length(b.particles.weights)
i += 1
@inbounds cw += b.particles.weights[i]
end
particle_states = b.particles.particles[i]
linear_belief = b.analytical_beliefs[i]

return RaoBlackwellizedParticleCollection(
ParticleCollection([particle_states]),
[linear_belief],
b.sampled_part,
b.analytical_part,
b.reconstruct)
return new_rbpc
end

0 comments on commit ec733ce

Please sign in to comment.