From 070cb22f5687e191784455611bbc81c494e84d99 Mon Sep 17 00:00:00 2001 From: Tim Hargreaves Date: Sat, 27 Sep 2025 13:26:24 +0100 Subject: [PATCH 01/20] Unify RB BF/GF --- .../src/GFTest/models/linear_gaussian.jl | 16 +- GeneralisedFilters/src/GeneralisedFilters.jl | 7 +- GeneralisedFilters/src/algorithms/kalman.jl | 14 +- .../src/algorithms/particles.jl | 264 +++++++++++------- GeneralisedFilters/src/algorithms/rbpf.jl | 112 +++----- GeneralisedFilters/src/containers.jl | 136 +++++---- GeneralisedFilters/src/models/hierarchical.jl | 37 ++- GeneralisedFilters/src/resamplers.jl | 46 ++- .../test/combination_test_script.jl | 102 +++++++ SSMProblems/src/SSMProblems.jl | 9 +- 10 files changed, 476 insertions(+), 267 deletions(-) create mode 100644 GeneralisedFilters/test/combination_test_script.jl diff --git a/GeneralisedFilters/src/GFTest/models/linear_gaussian.jl b/GeneralisedFilters/src/GFTest/models/linear_gaussian.jl index 67ea04fc..58887d9e 100644 --- a/GeneralisedFilters/src/GFTest/models/linear_gaussian.jl +++ b/GeneralisedFilters/src/GFTest/models/linear_gaussian.jl @@ -1,10 +1,13 @@ +using StaticArrays + function create_linear_gaussian_model( rng::AbstractRNG, Dx::Integer, Dy::Integer, T::Type{<:Real}=Float64, process_noise_scale=T(0.1), - obs_noise_scale=T(1.0), + obs_noise_scale=T(1.0); + static_arrays::Bool=false, ) μ0 = rand(rng, T, Dx) Σ0 = rand_cov(rng, T, Dx) @@ -15,6 +18,17 @@ function create_linear_gaussian_model( c = rand(rng, T, Dy) R = rand_cov(rng, T, Dy; scale=obs_noise_scale) + if static_arrays + μ0 = SVector{Dx,T}(μ0) + Σ0 = SMatrix{Dx,Dx,T}(Σ0) + A = SMatrix{Dx,Dx,T}(A) + b = SVector{Dx,T}(b) + Q = SMatrix{Dx,Dx,T}(Q) + H = SMatrix{Dy,Dx,T}(H) + c = SVector{Dy,T}(c) + R = SMatrix{Dy,Dy,T}(R) + end + return create_homogeneous_linear_gaussian_model(μ0, Σ0, A, b, Q, H, c, R) end diff --git a/GeneralisedFilters/src/GeneralisedFilters.jl b/GeneralisedFilters/src/GeneralisedFilters.jl index 7e712c87..1eeee2ab 100644 --- a/GeneralisedFilters/src/GeneralisedFilters.jl +++ b/GeneralisedFilters/src/GeneralisedFilters.jl @@ -3,6 +3,7 @@ module GeneralisedFilters using AbstractMCMC: AbstractMCMC, AbstractSampler import Distributions: MvNormal import Random: AbstractRNG, default_rng, rand +import SSMProblems: prior, dyn, obs using OffsetArrays using SSMProblems using StatsBase @@ -66,7 +67,7 @@ function filter( kwargs..., ) # draw from the prior - init_state = initialise(rng, model, algo; kwargs...) + init_state = initialise(rng, prior(model), algo; kwargs...) callback(model, algo, init_state, observations, PostInit; kwargs...) # iterations starts here for type stability @@ -117,10 +118,10 @@ function move( callback::CallbackType=nothing, kwargs..., ) - state = predict(rng, model, algo, iter, state, observation; kwargs...) + state = predict(rng, dyn(model), algo, iter, state, observation; kwargs...) callback(model, algo, iter, state, observation, PostPredict; kwargs...) - state, ll_increment = update(model, algo, iter, state, observation; kwargs...) + state, ll_increment = update(obs(model), algo, iter, state, observation; kwargs...) callback(model, algo, iter, state, observation, PostUpdate; kwargs...) return state, ll_increment diff --git a/GeneralisedFilters/src/algorithms/kalman.jl b/GeneralisedFilters/src/algorithms/kalman.jl index 80be96f7..3edfb56e 100644 --- a/GeneralisedFilters/src/algorithms/kalman.jl +++ b/GeneralisedFilters/src/algorithms/kalman.jl @@ -8,23 +8,21 @@ struct KalmanFilter <: AbstractFilter end KF() = KalmanFilter() -function initialise( - rng::AbstractRNG, model::LinearGaussianStateSpaceModel, filter::KalmanFilter; kwargs... -) - μ0, Σ0 = calc_initial(model.prior; kwargs...) +function initialise(rng::AbstractRNG, prior::GaussianPrior, filter::KalmanFilter; kwargs...) + μ0, Σ0 = calc_initial(prior; kwargs...) return GaussianDistribution(μ0, Σ0) end function predict( rng::AbstractRNG, - model::LinearGaussianStateSpaceModel, + dyn::LinearGaussianLatentDynamics, algo::KalmanFilter, iter::Integer, state::GaussianDistribution, observation=nothing; kwargs..., ) - params = calc_params(model.dyn, iter; kwargs...) + params = calc_params(dyn, iter; kwargs...) state = kalman_predict(state, params) return state end @@ -39,14 +37,14 @@ function kalman_predict(state, params) end function update( - model::LinearGaussianStateSpaceModel, + obs::LinearGaussianObservationProcess, algo::KalmanFilter, iter::Integer, state::GaussianDistribution, observation::AbstractVector; kwargs..., ) - params = calc_params(model.obs, iter; kwargs...) + params = calc_params(obs, iter; kwargs...) state, ll = kalman_update(state, params, observation) return state, ll end diff --git a/GeneralisedFilters/src/algorithms/particles.jl b/GeneralisedFilters/src/algorithms/particles.jl index 3888fc40..e4a54722 100644 --- a/GeneralisedFilters/src/algorithms/particles.jl +++ b/GeneralisedFilters/src/algorithms/particles.jl @@ -6,48 +6,79 @@ import SSMProblems: distribution, simulate, logdensity abstract type AbstractProposal end function SSMProblems.distribution( - model::AbstractStateSpaceModel, - prop::AbstractProposal, - iter::Integer, - state, - observation; - kwargs..., + prop::AbstractProposal, iter::Integer, state, observation; kwargs... ) - return throw( - MethodError(distribution, (model, prop, iter, state, observation, kwargs...)) - ) + return throw(MethodError(distribution, (prop, iter, state, observation, kwargs...))) end function SSMProblems.simulate( + rng::AbstractRNG, prop::AbstractProposal, iter::Integer, state, observation; kwargs... +) + return rand(rng, SSMProblems.distribution(prop, iter, state, observation; kwargs...)) +end + +function SSMProblems.logdensity( + prop::AbstractProposal, iter::Integer, prev_state, new_state, observation; kwargs... +) + return logpdf( + SSMProblems.distribution(prop, iter, prev_state, observation; kwargs...), new_state + ) +end + +abstract type AbstractParticleFilter <: AbstractFilter end +function num_particles end +function resampler end + +function initialise( rng::AbstractRNG, - model::AbstractStateSpaceModel, - prop::AbstractProposal, + prior::StatePrior, + algo::AbstractParticleFilter; + ref_state::Union{Nothing,AbstractVector}=nothing, + kwargs..., +) + N = num_particles(algo) + particles = map(1:N) do i + initialise_particle(rng, prior, algo; ref_state, kwargs...) + end + + # TODO: need to check this is correct in the GF case + prev_logsumexp = logsumexp(map(p -> p.log_w, particles)) + return ParticleDistribution(particles, prev_logsumexp) +end + +function predict( + rng::AbstractRNG, + dyn::LatentDynamics, + algo::AbstractParticleFilter, iter::Integer, state, observation; + ref_state::Union{Nothing,AbstractVector}=nothing, kwargs..., ) - return rand( - rng, SSMProblems.distribution(model, prop, iter, state, observation; kwargs...) - ) + particles = map(state.particles) do particle + predict_particle(rng, dyn, algo, iter, particle, observation; ref_state, kwargs...) + end + state.particles = particles + return state end -function SSMProblems.logdensity( - model::AbstractStateSpaceModel, - prop::AbstractProposal, +function update( + obs::ObservationProcess, + algo::AbstractParticleFilter, iter::Integer, - prev_state, - new_state, + state::ParticleDistribution, observation; kwargs..., ) - return logpdf( - SSMProblems.distribution(model, prop, iter, prev_state, observation; kwargs...), - new_state, - ) -end + particles = map(state.particles) do particle + update_particle(obs, algo, iter, particle, observation; kwargs...) + end + state.particles = particles + ll_increment = marginalise!(state) -abstract type AbstractParticleFilter <: AbstractFilter end + return state, ll_increment +end struct ParticleFilter{RS,PT} <: AbstractParticleFilter N::Int @@ -64,6 +95,51 @@ function ParticleFilter( return ParticleFilter{ESSResampler,PT}(N, conditional_resampler, proposal) end +num_particles(algo::ParticleFilter) = algo.N +resampler(algo::ParticleFilter) = algo.resampler + +function initialise_particle( + rng::AbstractRNG, + prior::StatePrior, + algo::ParticleFilter; + ref_state::Union{Nothing,AbstractVector}=nothing, + kwargs..., +) + x = sample_prior(rng, prior, algo; ref_state, kwargs...) + # TODO (RB): determine the correct type for the log_w field or use a NoWeight type + return Particle(x, 0.0, 0) +end + +function predict_particle( + rng::AbstractRNG, + dyn::LatentDynamics, + algo::ParticleFilter, + iter::Integer, + particle::Particle, + observation; + ref_state, + kwargs..., +) + new_x, logw_inc = propogate( + rng, dyn, algo, iter, particle.state, observation; ref_state, kwargs... + ) + return Particle(new_x, particle.log_w + logw_inc, particle.ancestor) +end + +function update_particle( + obs::ObservationProcess, + ::ParticleFilter, + iter::Integer, + particle::Particle, + observation; + kwargs..., +) + log_increment = SSMProblems.logdensity( + obs, iter, particle.state, observation; kwargs... + ) + return Particle(particle.state, particle.log_w + log_increment, particle.ancestor) +end + function step( rng::AbstractRNG, model::AbstractStateSpaceModel, @@ -75,80 +151,55 @@ function step( callback::CallbackType=nothing, kwargs..., ) - state = resample(rng, algo.resampler, state; ref_state) + rs = resampler(algo) + state = resample(rng, rs, state; ref_state) callback(model, algo, iter, state, observation, PostResample; kwargs...) return move(rng, model, algo, iter, state, observation; ref_state, callback, kwargs...) end -function initialise( +function sample_prior( rng::AbstractRNG, - model::StateSpaceModel, + prior::StatePrior, algo::ParticleFilter; ref_state::Union{Nothing,AbstractVector}=nothing, kwargs..., ) - particles = map(1:(algo.N)) do i - if !isnothing(ref_state) && i == 1 - ref_state[0] - else - SSMProblems.simulate(rng, model.prior; kwargs...) - end + x = if isnothing(ref_state) + SSMProblems.simulate(rng, prior; kwargs...) + else + ref_state[1] end - - return Particles(particles) + return x end -function predict( +function propogate( rng::AbstractRNG, - model::StateSpaceModel, + dyn, algo::ParticleFilter, iter::Integer, - state, + x, observation; - ref_state::Union{Nothing,AbstractVector}=nothing, + ref_state, kwargs..., ) - proposed_particles = map(enumerate(state)) do (i, particle) - if !isnothing(ref_state) && i == 1 - ref_state[iter] - else - simulate(rng, model, algo.proposal, iter, particle, observation; kwargs...) - end - end - - log_increments = map(zip(proposed_particles, state)) do (new_state, prev_state) - log_f = SSMProblems.logdensity(model.dyn, iter, prev_state, new_state; kwargs...) - - log_q = SSMProblems.logdensity( - model, algo.proposal, iter, prev_state, new_state, observation; kwargs... - ) - - (log_f - log_q) + # TODO: use a trait to compute the sample and logpdf in one go if distribution is defined + new_x = if isnothing(ref_state) + SSMProblems.simulate(rng, algo.proposal, iter, x, observation; kwargs...) + else + ref_state[iter] end - - state.particles = proposed_particles - state = update_weights(state, log_increments) - return state + log_p = SSMProblems.logdensity(dyn, iter, x, new_x; kwargs...) + log_q = SSMProblems.logdensity(algo.proposal, iter, x, new_x, observation; kwargs...) + logw_inc = log_p - log_q + return new_x, logw_inc end -function update( - model::StateSpaceModel, - algo::ParticleFilter, - iter::Integer, - state, - observation; - kwargs..., -) - log_increments = map( - x -> SSMProblems.logdensity(model.obs, iter, x, observation; kwargs...), - state.particles, - ) - - state = update_weights(state, log_increments) - ll_increment = marginalise!(state) - - return state, ll_increment -end +# function update( +# obs, algo::ParticleFilter, iter::Integer, p::Particle, observation; kwargs... +# ) +# log_increment = SSMProblems.logdensity(obs, iter, p.state, observation; kwargs...) +# return Particle(p.state, p.log_w + log_increment, p.ancestor) +# end struct LatentProposal <: AbstractProposal end @@ -180,41 +231,42 @@ function SSMProblems.logdensity( return SSMProblems.logdensity(model.dyn, iter, prev_state, new_state; kwargs...) end -# overwrite predict for the bootstrap filter to remove redundant computation -function predict( +# overwrite propogate for the bootstrap filter to remove redundant computation +function propogate( rng::AbstractRNG, - model::StateSpaceModel, + dyn, algo::BootstrapFilter, iter::Integer, - state, - observation=nothing; - ref_state::Union{Nothing,AbstractVector}=nothing, + x, + observation; + ref_state, kwargs..., ) - state.particles = map(enumerate(state)) do (i, particle) - if !isnothing(ref_state) && i == 1 - ref_state[iter] - else - SSMProblems.simulate(rng, model.dyn, iter, particle; kwargs...) - end + new_x = if isnothing(ref_state) + SSMProblems.simulate(rng, dyn, iter, x; kwargs...) + else + ref_state[iter] end - return state + # TODO: make this type consistent + # Will have to do a lazy zero or change propogate to accept a particle (in which case + # we'll need to construct a particle in the RBPF predict method) + return new_x, 0.0 end # Application of particle filter to hierarchical models -function filter( - rng::AbstractRNG, - model::HierarchicalSSM, - algo::ParticleFilter, - observations::AbstractVector; - ref_state::Union{Nothing,AbstractVector}=nothing, - kwargs..., -) - ssm = StateSpaceModel( - HierarchicalPrior(model.outer_prior, model.inner_model.prior), - HierarchicalDynamics(model.outer_dyn, model.inner_model.dyn), - HierarchicalObservations(model.inner_model.obs), - ) - return filter(rng, ssm, algo, observations; ref_state=ref_state, kwargs...) -end +# function filter( +# rng::AbstractRNG, +# model::HierarchicalSSM, +# algo::ParticleFilter, +# observations::AbstractVector; +# ref_state::Union{Nothing,AbstractVector}=nothing, +# kwargs..., +# ) +# ssm = StateSpaceModel( +# HierarchicalPrior(model.outer_prior, model.inner_model.prior), +# HierarchicalDynamics(model.outer_dyn, model.inner_model.dyn), +# HierarchicalObservations(model.inner_model.obs), +# ) +# return filter(rng, ssm, algo, observations; ref_state=ref_state, kwargs...) +# end diff --git a/GeneralisedFilters/src/algorithms/rbpf.jl b/GeneralisedFilters/src/algorithms/rbpf.jl index fc1f082b..2c5db62b 100644 --- a/GeneralisedFilters/src/algorithms/rbpf.jl +++ b/GeneralisedFilters/src/algorithms/rbpf.jl @@ -5,93 +5,65 @@ import StatsBase: Weights export RBPF -struct RBPF{F<:AbstractFilter,RS<:AbstractResampler} <: AbstractParticleFilter - inner_algo::F - N::Int - resampler::RS +struct RBPF{PFT<:AbstractParticleFilter,AFT<:AbstractFilter} <: AbstractParticleFilter + pf::PFT + af::AFT end -function RBPF( - inner_algo::AbstractFilter, - N::Integer; - threshold::Real=1.0, - resampler::AbstractResampler=Systematic(), -) - return RBPF(inner_algo, N, ESSResampler(threshold, resampler)) -end +num_particles(algo::RBPF) = num_particles(algo.pf) +resampler(algo::RBPF) = resampler(algo.pf) -function initialise( +function initialise_particle( rng::AbstractRNG, - model::HierarchicalSSM{T}, + prior::HierarchicalPrior, algo::RBPF; ref_state::Union{Nothing,AbstractVector}=nothing, kwargs..., -) where {T} - particles = map(1:(algo.N)) do i - x = if !isnothing(ref_state) && i == 1 - ref_state[0] - else - SSMProblems.simulate(rng, model.outer_prior; kwargs...) - end - z = initialise(rng, model.inner_model, algo.inner_algo; new_outer=x, kwargs...) - RaoBlackwellisedParticle(x, z) - end - - return Particles(particles) +) + x = sample_prior(rng, prior.outer_prior, algo.pf; ref_state, kwargs...) + z = initialise(rng, prior.inner_prior, algo.af; new_outer=x, kwargs...) + # TODO (RB): determine the correct type for the log_w field or use a NoWeight type + return RBParticle(x, z, 0.0, 0) end -function predict( +function predict_particle( rng::AbstractRNG, - model::HierarchicalSSM, + dyn::HierarchicalDynamics, algo::RBPF, iter::Integer, - state, + particle::RBParticle, observation; - ref_state::Union{Nothing,AbstractVector}=nothing, + ref_state, kwargs..., ) - state.particles = map(enumerate(state.particles)) do (i, particle) - new_x = if !isnothing(ref_state) && i == 1 - ref_state[iter] - else - SSMProblems.simulate(rng, model.outer_dyn, iter, particle.x; kwargs...) - end - new_z = predict( - rng, - model.inner_model, - algo.inner_algo, - iter, - particle.z, - observation; - prev_outer=particle.x, - new_outer=new_x, - kwargs..., - ) - - RaoBlackwellisedParticle(new_x, new_z) - end + new_x, logw_inc = propogate( + rng, dyn.outer_dyn, algo.pf, iter, particle.x, observation; ref_state, kwargs... + ) + new_z = predict( + rng, + dyn.inner_dyn, + algo.af, + iter, + particle.z, + observation; + prev_outer=particle.x, + new_outer=new_x, + kwargs..., + ) - return state + return RBParticle(new_x, new_z, particle.log_w + logw_inc, particle.ancestor) end -function update( - model::HierarchicalSSM, algo::RBPF, iter::Integer, state, observation; kwargs... +function update_particle( + obs::ObservationProcess, + algo::RBPF, + iter::Integer, + particle::RBParticle, + observation; + kwargs..., ) - log_increments = map(enumerate(state.particles)) do (i, particle) - state.particles[i].z, log_increment = update( - model.inner_model, - algo.inner_algo, - iter, - particle.z, - observation; - new_outer=particle.x, - kwargs..., - ) - log_increment - end - - state = update_weights(state, log_increments) - ll_increment = marginalise!(state) - - return state, ll_increment + new_z, log_increment = update( + obs, algo.af, iter, particle.z, observation; new_outer=particle.x, kwargs... + ) + return RBParticle(particle.x, new_z, particle.log_w + log_increment, particle.ancestor) end diff --git a/GeneralisedFilters/src/containers.jl b/GeneralisedFilters/src/containers.jl index 1ca6266a..6fe29716 100644 --- a/GeneralisedFilters/src/containers.jl +++ b/GeneralisedFilters/src/containers.jl @@ -2,70 +2,100 @@ ## PARTICLES ############################################################################### -mutable struct ParticleWeights{WT<:Real} - log_weights::Vector{WT} - prev_logsumexp::WT -end - -""" - ParticleDistribution - -A container for particle filters which composes the weighted sample into a distibution-like -object, with the states (or particles) distributed accoring to their log-weights. -""" -abstract type ParticleDistribution{PT} end - -Base.collect(state::ParticleDistribution) = state.particles -Base.length(state::ParticleDistribution) = length(state.particles) -Base.keys(state::ParticleDistribution) = LinearIndices(state.particles) - -Base.iterate(state::ParticleDistribution, i) = iterate(state.particles, i) -Base.iterate(state::ParticleDistribution) = iterate(state.particles) - -# not sure if this is kosher, since it doesn't follow the convention of Base.getindex -Base.@propagate_inbounds Base.getindex(state::ParticleDistribution, i) = state.particles[i] - -mutable struct Particles{PT} <: ParticleDistribution{PT} - particles::Vector{PT} - ancestors::Vector{Int} -end +abstract type AbstractParticle{WT} end -mutable struct WeightedParticles{PT,WT<:Real} <: ParticleDistribution{PT} - particles::Vector{PT} - ancestors::Vector{Int} - weights::ParticleWeights{WT} +# New types +# TODO (RB): could the RB particle be a regular particle with a RB state? +struct Particle{ST,WT} <: AbstractParticle{WT} + state::ST + log_w::WT + ancestor::Int end -function Particles(particles::AbstractVector) - N = length(particles) - return Particles(particles, Vector{Int}(1:N)) +struct RBParticle{XT,ZT,WT} <: AbstractParticle{WT} + x::XT + z::ZT + log_w::WT + ancestor::Int end -function WeightedParticles(particles::AbstractVector, log_weights::AbstractVector) - N = length(particles) - weights = ParticleWeights(log_weights, logsumexp(log_weights)) - return WeightedParticles(particles, Vector{Int}(1:N), weights) +mutable struct ParticleDistribution{WT,PT<:AbstractParticle{WT},VT<:AbstractVector{PT}} + particles::VT + prev_logsumexp::WT end -StatsBase.weights(state::Particles) = Weights(fill(1 / length(state), length(state))) -StatsBase.weights(state::WeightedParticles) = Weights(softmax(state.weights.log_weights)) - -function update_weights(state::Particles, log_weights::Vector{WT}) where {WT} - weights = ParticleWeights(log_weights, WT(log(length(state)))) - return WeightedParticles(state.particles, state.ancestors, weights) +function marginalise!(state::ParticleDistribution) + log_marginalisation = logsumexp(map(p -> p.log_w, state.particles)) + ll_increment = (log_marginalisation - state.prev_logsumexp) + state.prev_logsumexp = log_marginalisation + return ll_increment end -function update_weights(state::WeightedParticles, log_weights) - state.weights.log_weights += log_weights - return state +# Old code +mutable struct ParticleWeights{WT<:Real} + log_weights::Vector{WT} + prev_logsumexp::WT end -function marginalise!(state::WeightedParticles) - log_marginalisation = logsumexp(state.weights.log_weights) - ll_increment = (log_marginalisation - state.weights.prev_logsumexp) - state.weights.prev_logsumexp = log_marginalisation - return ll_increment -end +# """ +# ParticleDistribution + +# A container for particle filters which composes the weighted sample into a distibution-like +# object, with the states (or particles) distributed accoring to their log-weights. +# """ +# abstract type ParticleDistribution{PT} end + +# Base.collect(state::ParticleDistribution) = state.particles +# Base.length(state::ParticleDistribution) = length(state.particles) +# Base.keys(state::ParticleDistribution) = LinearIndices(state.particles) + +# Base.iterate(state::ParticleDistribution, i) = iterate(state.particles, i) +# Base.iterate(state::ParticleDistribution) = iterate(state.particles) + +# # not sure if this is kosher, since it doesn't follow the convention of Base.getindex +# Base.@propagate_inbounds Base.getindex(state::ParticleDistribution, i) = state.particles[i] + +# mutable struct Particles{PT} <: ParticleDistribution{PT} +# particles::Vector{PT} +# ancestors::Vector{Int} +# end + +# mutable struct WeightedParticles{PT,WT<:Real} <: ParticleDistribution{PT} +# particles::Vector{PT} +# ancestors::Vector{Int} +# weights::ParticleWeights{WT} +# end + +# function Particles(particles::AbstractVector) +# N = length(particles) +# return Particles(particles, Vector{Int}(1:N)) +# end + +# function WeightedParticles(particles::AbstractVector, log_weights::AbstractVector) +# N = length(particles) +# weights = ParticleWeights(log_weights, logsumexp(log_weights)) +# return WeightedParticles(particles, Vector{Int}(1:N), weights) +# end + +# StatsBase.weights(state::Particles) = Weights(fill(1 / length(state), length(state))) +# StatsBase.weights(state::WeightedParticles) = Weights(softmax(state.weights.log_weights)) + +# function update_weights(state::Particles, log_weights::Vector{WT}) where {WT} +# weights = ParticleWeights(log_weights, WT(log(length(state)))) +# return WeightedParticles(state.particles, state.ancestors, weights) +# end + +# function update_weights(state::WeightedParticles, log_weights) +# state.weights.log_weights += log_weights +# return state +# end + +# function marginalise!(state::WeightedParticles) +# log_marginalisation = logsumexp(state.weights.log_weights) +# ll_increment = (log_marginalisation - state.weights.prev_logsumexp) +# state.weights.prev_logsumexp = log_marginalisation +# return ll_increment +# end ## GAUSSIAN STATES ######################################################################### diff --git a/GeneralisedFilters/src/models/hierarchical.jl b/GeneralisedFilters/src/models/hierarchical.jl index 23d75551..1d336815 100644 --- a/GeneralisedFilters/src/models/hierarchical.jl +++ b/GeneralisedFilters/src/models/hierarchical.jl @@ -7,6 +7,31 @@ struct HierarchicalSSM{PT<:StatePrior,LD<:LatentDynamics,MT<:StateSpaceModel} <: outer_dyn::LD inner_model::MT end +outer_prior(model::HierarchicalSSM) = model.outer_prior +inner_prior(model::HierarchicalSSM) = model.inner_model.prior +outer_dyn(model::HierarchicalSSM) = model.outer_dyn +inner_dyn(model::HierarchicalSSM) = model.inner_model.dyn +SSMProblems.obs(model::HierarchicalSSM) = model.inner_model.obs + +struct HierarchicalPrior{P1<:StatePrior,P2<:StatePrior} <: StatePrior + outer_prior::P1 + inner_prior::P2 +end +function SSMProblems.prior(model::HierarchicalSSM) + return HierarchicalPrior(model.outer_prior, model.inner_model.prior) +end +outer_prior(prior::HierarchicalPrior) = prior.outer_prior +inner_prior(prior::HierarchicalPrior) = prior.inner_prior + +struct HierarchicalDynamics{D1<:LatentDynamics,D2<:LatentDynamics} <: LatentDynamics + outer_dyn::D1 + inner_dyn::D2 +end +function SSMProblems.dyn(model::HierarchicalSSM) + return HierarchicalDynamics(model.outer_dyn, model.inner_model.dyn) +end +outer_dyn(dyn::HierarchicalDynamics) = dyn.outer_dyn +inner_dyn(dyn::HierarchicalDynamics) = dyn.inner_dyn function HierarchicalSSM( outer_prior::StatePrior, @@ -50,21 +75,11 @@ function AbstractMCMC.sample( return x0, z0, xs, zs, ys end -## Methods to make HierarchicalSSM compatible with the bootstrap filter -struct HierarchicalDynamics{D1<:LatentDynamics,D2<:LatentDynamics} <: LatentDynamics - outer_dyn::D1 - inner_dyn::D2 -end - -struct HierarchicalPrior{P1<:StatePrior,P2<:StatePrior} <: StatePrior - outer_prior::P1 - inner_prior::P2 -end - function SSMProblems.simulate(rng::AbstractRNG, prior::HierarchicalPrior; kwargs...) outer_prior, inner_prior = prior.outer_prior, prior.inner_prior x0 = simulate(rng, outer_prior; kwargs...) z0 = simulate(rng, inner_prior; new_outer=x0, kwargs...) + # TODO (RB): this isn't really RB at all, just hierarchical state return RaoBlackwellisedParticle(x0, z0) end diff --git a/GeneralisedFilters/src/resamplers.jl b/GeneralisedFilters/src/resamplers.jl index b486f489..e256c3b9 100644 --- a/GeneralisedFilters/src/resamplers.jl +++ b/GeneralisedFilters/src/resamplers.jl @@ -10,24 +10,33 @@ abstract type AbstractResampler end function resample( rng::AbstractRNG, resampler::AbstractResampler, - states; + state; ref_state::Union{Nothing,AbstractVector}=nothing, ) - idxs = sample_ancestors(rng, resampler, weights(states)) + weights = softmax(map(p -> p.log_w, state.particles)) + idxs = sample_ancestors(rng, resampler, weights) # Set reference trajectory indices if !isnothing(ref_state) CUDA.@allowscalar idxs[1] = 1 end - return construct_new_state(states, idxs) + return construct_new_state(state, idxs) end -function construct_new_state(states::Particles{PT}, idxs) where {PT} - return Particles{PT}(states.particles[idxs], idxs) +function construct_new_state(state, idxs) + new_particles = similar(state.particles) + for i in 1:length(state.particles) + new_particles[i] = resample_ancestor(state.particles[idxs[i]], idxs[i]) + end + return ParticleDistribution(new_particles, log(length(state.particles))) end -function construct_new_state(states::WeightedParticles{PT,WT}, idxs) where {PT,WT} - weights = ParticleWeights(zeros(WT, length(states)), WT(log(length(states)))) - return WeightedParticles{PT,WT}(states.particles[idxs], idxs, weights) +# TODO (RB): this can probably be cleaned up if we allow mutation (I'm just playing it safe +# whilst developing) +function resample_ancestor(particle::Particle, ancestor::Int) + return Particle(particle.state, 0.0, ancestor) +end +function resample_ancestor(particle::RBParticle, ancestor::Int) + return RBParticle(particle.x, particle.z, 0.0, ancestor) end ## CONDITIONAL RESAMPLING ################################################################## @@ -48,19 +57,30 @@ function resample( state; ref_state::Union{Nothing,AbstractVector}=nothing, ) - n = length(state) - # TODO: computing weights twice. Should create a wrapper to avoid this - weights = StatsBase.weights(state) + n = length(state.particles) + weights = softmax(map(p -> p.log_w, state.particles)) ess = inv(sum(abs2, weights)) if cond_resampler.threshold * n ≥ ess return resample(rng, cond_resampler.resampler, state; ref_state) else - state.ancestors = collect(1:n) - return state + new_particles = similar(state.particles) + for i in 1:n + new_particles[i] = set_ancestor(state.particles[i], i) + end + return ParticleDistribution(new_particles, state.prev_logsumexp) end end +# TODO (RB): this can probably be cleaned up if we allow mutation (I'm just playing it safe +# whilst developing) +function set_ancestor(particle::Particle, ancestor::Int) + return Particle(particle.state, particle.log_w, ancestor) +end +function set_ancestor(particle::RBParticle, ancestor::Int) + return RBParticle(particle.x, particle.z, particle.log_w, ancestor) +end + ## CATEGORICAL RESAMPLE #################################################################### # this is adapted from AdvancedPS diff --git a/GeneralisedFilters/test/combination_test_script.jl b/GeneralisedFilters/test/combination_test_script.jl new file mode 100644 index 00000000..2443c453 --- /dev/null +++ b/GeneralisedFilters/test/combination_test_script.jl @@ -0,0 +1,102 @@ +using Distributions +using GeneralisedFilters +using LinearAlgebra +using LogExpFunctions +using SSMProblems +using StableRNGs +using StatsBase +using Test + +println() +println("########################") +println("#### STARTING TESTS ####") +println("########################") +println() + +rng = StableRNG(1234) + +model = GeneralisedFilters.GFTest.create_linear_gaussian_model(rng, 1, 1) +_, _, ys = sample(rng, model, 3) + +bf = BF(10^6; threshold=0.8) +bf_state, llbf = GeneralisedFilters.filter(rng, model, bf, ys) +kf_state, llkf = GeneralisedFilters.filter(rng, model, KF(), ys) + +xs = getfield.(bf_state.particles, :state) +log_ws = getfield.(bf_state.particles, :log_w) +ws = softmax(log_ws) + +# Compare log-likelihood and states +println("BF State: ", @test first(kf_state.μ) ≈ sum(first.(xs) .* ws) rtol = 1e-3) +println("BF LL: ", @test llkf ≈ llbf atol = 1e-3) + +struct OptimalProposal <: AbstractProposal + dyn::LinearGaussianLatentDynamics + obs::LinearGaussianObservationProcess + dummy::Bool # if using dummy hierarchical model +end +function SSMProblems.distribution(prop::OptimalProposal, step::Integer, x, y; kwargs...) + A, b, Q = GeneralisedFilters.calc_params(prop.dyn, step; kwargs...) + H, c, R = GeneralisedFilters.calc_params(prop.obs, step; kwargs...) + Σ = inv(inv(Q) + H' * inv(R) * H) + μ = Σ * (inv(Q) * (A * x + b) + H' * inv(R) * (y - c)) + if prop.dummy + μ = μ[[1]] + Σ = Σ[[1], [1]] + end + return MvNormal(μ, Σ) +end +# Propose from observation distribution +# proposal = PeturbationProposal(only(model.obs.R)) +proposal = OptimalProposal(model.dyn, model.obs, false) +gf = ParticleFilter(10^6, proposal; threshold=1.0) + +gf_state, llgf = GeneralisedFilters.filter(rng, model, gf, ys) +xs = getfield.(gf_state.particles, :state) +log_ws = getfield.(gf_state.particles, :log_w) +ws = softmax(log_ws) + +# Fairly sure this is correct but would be good to confirm (needs to be faster — SArrays) +println("GF State: ", @test first(kf_state.μ) ≈ sum(first.(xs) .* ws) rtol = 1e-3) +println("GF LL: ", @test llkf ≈ llgf atol = 1e-3) + +############################## +#### RAO-BLACKWELLISATION #### +############################## + +full_model, hier_model = GeneralisedFilters.GFTest.create_dummy_linear_gaussian_model( + rng, 1, 1, 1; static_arrays=true +) +_, _, ys = sample(rng, hier_model, 3) + +rbbf = RBPF(bf, KalmanFilter()) + +rbbf_state, llrbbf = GeneralisedFilters.filter(rng, hier_model, rbbf, ys) +xs = getfield.(rbbf_state.particles, :x) +zs = getfield.(rbbf_state.particles, :z) +log_ws = getfield.(rbbf_state.particles, :log_w) +ws = softmax(log_ws) + +kf_state, llkf = GeneralisedFilters.filter(rng, full_model, KF(), ys) + +println("RBBF Outer: ", @test first(kf_state.μ) ≈ sum(only.(xs) .* ws) rtol = 1e-3) +println( + "RBBF Inner: ", @test last(kf_state.μ) ≈ sum(only.(getfield.(zs, :μ)) .* ws) rtol = 1e-3 +) +println("RBBF LL: ", @test llkf ≈ llrbbf atol = 1e-3) + +proposal = OptimalProposal(model.dyn, model.obs, true) +gf = ParticleFilter(10^6, proposal; threshold=1.0) +rbgf = RBPF(gf, KalmanFilter()) +rbgf_state, llrbgf = GeneralisedFilters.filter(rng, hier_model, rbgf, ys) +xs = getfield.(rbgf_state.particles, :x) +zs = getfield.(rbgf_state.particles, :z) +log_ws = getfield.(rbgf_state.particles, :log_w) +ws = softmax(log_ws) + +# Reduce tolerance since this is a bit harder to filter to high precision +println("RBGF Outer: ", @test first(kf_state.μ) ≈ sum(only.(xs) .* ws) rtol = 1e-2) +println( + "RBGF Inner: ", @test last(kf_state.μ) ≈ sum(only.(getfield.(zs, :μ)) .* ws) rtol = 1e-2 +) +println("RBGF LL: ", @test llkf ≈ llrbgf atol = 1e-2) diff --git a/SSMProblems/src/SSMProblems.jl b/SSMProblems/src/SSMProblems.jl index dfa940af..bfffe41d 100644 --- a/SSMProblems/src/SSMProblems.jl +++ b/SSMProblems/src/SSMProblems.jl @@ -8,8 +8,9 @@ import Base: eltype import Random: AbstractRNG, default_rng import Distributions: logpdf -export StatePrior, - LatentDynamics, ObservationProcess, AbstractStateSpaceModel, StateSpaceModel +export StatePrior, LatentDynamics, ObservationProcess +export AbstractStateSpaceModel, StateSpaceModel +export prior, dyn, obs """ Initial state prior of a state space model. @@ -236,6 +237,10 @@ struct StateSpaceModel{PT,LD,OP} <: AbstractStateSpaceModel obs::OP end +prior(model::StateSpaceModel) = model.prior +dyn(model::StateSpaceModel) = model.dyn +obs(model::StateSpaceModel) = model.obs + include("batch_methods.jl") include("utils/forward_simulation.jl") From 2d751e7ec11071b61d141c1e92cdab46d37d73ff Mon Sep 17 00:00:00 2001 From: Tim Hargreaves Date: Sat, 27 Sep 2025 22:12:33 +0100 Subject: [PATCH 02/20] Add CSMC execution tests --- GeneralisedFilters/test/combination_test_script.jl | 13 +++++++++++++ 1 file changed, 13 insertions(+) diff --git a/GeneralisedFilters/test/combination_test_script.jl b/GeneralisedFilters/test/combination_test_script.jl index 2443c453..11d7566e 100644 --- a/GeneralisedFilters/test/combination_test_script.jl +++ b/GeneralisedFilters/test/combination_test_script.jl @@ -100,3 +100,16 @@ println( "RBGF Inner: ", @test last(kf_state.μ) ≈ sum(only.(getfield.(zs, :μ)) .* ws) rtol = 1e-2 ) println("RBGF LL: ", @test llkf ≈ llrbgf atol = 1e-2) + +################################ +#### REFERENCE TRAJECTORIES #### +################################ + +# Hard to verify these are correct until the code is faster and we can run a full loop +# For now we just check they run without error + +ref_traj = [randn(rng, 1) for _ in 1:3] +GeneralisedFilters.filter(rng, model, bf, ys; ref_state=ref_traj) +GeneralisedFilters.filter(rng, model, gf, ys; ref_state=ref_traj) +GeneralisedFilters.filter(rng, hier_model, rbbf, ys; ref_state=ref_traj) +GeneralisedFilters.filter(rng, hier_model, rbgf, ys; ref_state=ref_traj) From 2c0a19a1544337510c5a199de2919d1a02db4f68 Mon Sep 17 00:00:00 2001 From: Tim Hargreaves Date: Mon, 29 Sep 2025 15:05:50 +0100 Subject: [PATCH 03/20] Implement APF --- .../src/algorithms/particles.jl | 157 +++++++++++++++--- GeneralisedFilters/src/algorithms/rbpf.jl | 52 +++++- GeneralisedFilters/src/containers.jl | 4 +- .../test/combination_test_script.jl | 53 +++++- 4 files changed, 222 insertions(+), 44 deletions(-) diff --git a/GeneralisedFilters/src/algorithms/particles.jl b/GeneralisedFilters/src/algorithms/particles.jl index e4a54722..823cef72 100644 --- a/GeneralisedFilters/src/algorithms/particles.jl +++ b/GeneralisedFilters/src/algorithms/particles.jl @@ -1,5 +1,7 @@ export BootstrapFilter, BF export ParticleFilter, PF, AbstractProposal +export AuxiliaryParticleFilter, PredictivePosterior +export MeanPredictive, ModePredictive, DrawPredictive import SSMProblems: distribution, simulate, logdensity @@ -38,7 +40,8 @@ function initialise( ) N = num_particles(algo) particles = map(1:N) do i - initialise_particle(rng, prior, algo; ref_state, kwargs...) + ref = !isnothing(ref_state) && i == 1 ? ref_state[0] : nothing + initialise_particle(rng, prior, algo, ref; kwargs...) end # TODO: need to check this is correct in the GF case @@ -56,8 +59,10 @@ function predict( ref_state::Union{Nothing,AbstractVector}=nothing, kwargs..., ) - particles = map(state.particles) do particle - predict_particle(rng, dyn, algo, iter, particle, observation; ref_state, kwargs...) + particles = map(1:length(state.particles)) do i + particle = state.particles[i] + ref = !isnothing(ref_state) && i == 1 ? ref_state[iter] : nothing + predict_particle(rng, dyn, algo, iter, particle, observation, ref; kwargs...) end state.particles = particles return state @@ -99,13 +104,9 @@ num_particles(algo::ParticleFilter) = algo.N resampler(algo::ParticleFilter) = algo.resampler function initialise_particle( - rng::AbstractRNG, - prior::StatePrior, - algo::ParticleFilter; - ref_state::Union{Nothing,AbstractVector}=nothing, - kwargs..., + rng::AbstractRNG, prior::StatePrior, algo::ParticleFilter, ref_state; kwargs... ) - x = sample_prior(rng, prior, algo; ref_state, kwargs...) + x = sample_prior(rng, prior, algo, ref_state; kwargs...) # TODO (RB): determine the correct type for the log_w field or use a NoWeight type return Particle(x, 0.0, 0) end @@ -116,12 +117,12 @@ function predict_particle( algo::ParticleFilter, iter::Integer, particle::Particle, - observation; - ref_state, + observation, + ref_state; kwargs..., ) new_x, logw_inc = propogate( - rng, dyn, algo, iter, particle.state, observation; ref_state, kwargs... + rng, dyn, algo, iter, particle.state, observation, ref_state; kwargs... ) return Particle(new_x, particle.log_w + logw_inc, particle.ancestor) end @@ -158,16 +159,12 @@ function step( end function sample_prior( - rng::AbstractRNG, - prior::StatePrior, - algo::ParticleFilter; - ref_state::Union{Nothing,AbstractVector}=nothing, - kwargs..., + rng::AbstractRNG, prior::StatePrior, algo::ParticleFilter, ref_state; kwargs... ) x = if isnothing(ref_state) SSMProblems.simulate(rng, prior; kwargs...) else - ref_state[1] + ref_state end return x end @@ -178,15 +175,15 @@ function propogate( algo::ParticleFilter, iter::Integer, x, - observation; - ref_state, + observation, + ref_state; kwargs..., ) # TODO: use a trait to compute the sample and logpdf in one go if distribution is defined new_x = if isnothing(ref_state) SSMProblems.simulate(rng, algo.proposal, iter, x, observation; kwargs...) else - ref_state[iter] + ref_state end log_p = SSMProblems.logdensity(dyn, iter, x, new_x; kwargs...) log_q = SSMProblems.logdensity(algo.proposal, iter, x, new_x, observation; kwargs...) @@ -238,14 +235,14 @@ function propogate( algo::BootstrapFilter, iter::Integer, x, - observation; - ref_state, + observation, + ref_state; kwargs..., ) new_x = if isnothing(ref_state) SSMProblems.simulate(rng, dyn, iter, x; kwargs...) else - ref_state[iter] + ref_state end # TODO: make this type consistent @@ -270,3 +267,115 @@ end # ) # return filter(rng, ssm, algo, observations; ref_state=ref_state, kwargs...) # end + +abstract type PredictivePosterior end + +struct AuxiliaryParticleFilter{PFT<:AbstractParticleFilter,PPT<:PredictivePosterior} <: + AbstractFilter + pf::PFT + pp::PPT +end + +resampler(algo::AuxiliaryParticleFilter) = resampler(algo.pf) +num_particles(algo::AuxiliaryParticleFilter) = num_particles(algo.pf) + +function initialise( + rng::AbstractRNG, + prior::StatePrior, + algo::AuxiliaryParticleFilter; + ref_state::Union{Nothing,AbstractVector}=nothing, + kwargs..., +) + return initialise(rng, prior, algo.pf; ref_state, kwargs...) +end + +function step( + rng::AbstractRNG, + model::AbstractStateSpaceModel, + algo::AuxiliaryParticleFilter, + iter::Integer, + state, + observation; + ref_state::Union{Nothing,AbstractVector}=nothing, + callback::CallbackType=nothing, + kwargs..., +) + # Incorporate lookahead weights to form first-stage weights + log_ξs = map(state.particles) do particle + p_star = predictive_state(rng, dyn(model), algo, iter, particle; kwargs...) + predictive_loglik(obs(model), algo.pf, iter, p_star, observation; kwargs...) + end + # Normalise + log_ξs .-= logsumexp(log_ξs) + for (i, particle) in enumerate(state.particles) + particle.log_w += log_ξs[i] + end + + # Resample as usual + state = resample(rng, resampler(algo), state; ref_state) + + # Compensate for lookahead weights in the final weights + for particle in state.particles + particle.log_w += -log_ξs[particle.ancestor] + end + + # Bit ugly looking. Maybe we should store the prev logsumexp with -log(N) + state.prev_logsumexp = 2 * log(num_particles(algo)) + + # Dispatch to wrapped filter for predict and update + callback(model, algo, iter, state, observation, PostResample; kwargs...) + return move( + rng, model, algo.pf, iter, state, observation; ref_state, callback, kwargs... + ) +end + +struct MeanPredictive <: PredictivePosterior end + +function predictive_statistic( + ::AbstractRNG, ::MeanPredictive, dyn, iter::Integer, state; kwargs... +) + transition_dist = SSMProblems.distribution(dyn, iter, state; kwargs...) + return mean(transition_dist) +end + +struct ModePredictive <: PredictivePosterior end + +function predictive_statistic( + ::AbstractRNG, ::ModePredictive, dyn, iter::Integer, state; kwargs... +) + transition_dist = SSMProblems.distribution(dyn, iter, state; kwargs...) + return mode(transition_dist) +end + +struct DrawPredictive <: PredictivePosterior end + +function predictive_statistic( + rng::AbstractRNG, ::DrawPredictive, dyn, iter::Integer, state; kwargs... +) + return SSMProblems.simulate(rng, dyn, iter, state; kwargs...) +end + +# TODO (RB): Really these should be returning a state rather than a particle but we would +# need to define a RB state first +function predictive_state( + rng::AbstractRNG, + dyn::LatentDynamics, + apf::AuxiliaryParticleFilter{<:AbstractParticleFilter}, + iter::Integer, + particle::Particle; + kwargs..., +) + x_star = predictive_statistic(rng, apf.pp, dyn, iter, particle.state; kwargs...) + return Particle(x_star, particle.log_w, particle.ancestor) +end + +function predictive_loglik( + obs::ObservationProcess, + algo::ParticleFilter, + iter::Integer, + p_star::Particle, + observation; + kwargs..., +) + return SSMProblems.logdensity(obs, iter, p_star.state, observation; kwargs...) +end diff --git a/GeneralisedFilters/src/algorithms/rbpf.jl b/GeneralisedFilters/src/algorithms/rbpf.jl index 2c5db62b..ee89eb66 100644 --- a/GeneralisedFilters/src/algorithms/rbpf.jl +++ b/GeneralisedFilters/src/algorithms/rbpf.jl @@ -14,13 +14,9 @@ num_particles(algo::RBPF) = num_particles(algo.pf) resampler(algo::RBPF) = resampler(algo.pf) function initialise_particle( - rng::AbstractRNG, - prior::HierarchicalPrior, - algo::RBPF; - ref_state::Union{Nothing,AbstractVector}=nothing, - kwargs..., + rng::AbstractRNG, prior::HierarchicalPrior, algo::RBPF, ref_state; kwargs... ) - x = sample_prior(rng, prior.outer_prior, algo.pf; ref_state, kwargs...) + x = sample_prior(rng, prior.outer_prior, algo.pf, ref_state; kwargs...) z = initialise(rng, prior.inner_prior, algo.af; new_outer=x, kwargs...) # TODO (RB): determine the correct type for the log_w field or use a NoWeight type return RBParticle(x, z, 0.0, 0) @@ -32,12 +28,12 @@ function predict_particle( algo::RBPF, iter::Integer, particle::RBParticle, - observation; - ref_state, + observation, + ref_state; kwargs..., ) new_x, logw_inc = propogate( - rng, dyn.outer_dyn, algo.pf, iter, particle.x, observation; ref_state, kwargs... + rng, dyn.outer_dyn, algo.pf, iter, particle.x, observation, ref_state; kwargs... ) new_z = predict( rng, @@ -67,3 +63,41 @@ function update_particle( ) return RBParticle(particle.x, new_z, particle.log_w + log_increment, particle.ancestor) end + +function predictive_state( + rng::AbstractRNG, + dyn::HierarchicalDynamics, + apf::AuxiliaryParticleFilter{<:RBPF}, + iter::Integer, + particle::RBParticle; + kwargs..., +) + rbpf = apf.pf + x_star = predictive_statistic(rng, apf.pp, dyn.outer_dyn, iter, particle.x; kwargs...) + z_star = predict( + rng, + dyn.inner_dyn, + rbpf.af, + iter, + particle.z, + nothing; # no observation available — maybe we should pass this in + prev_outer=particle.x, + new_outer=x_star, + kwargs..., + ) + return RBParticle(x_star, z_star, particle.log_w, particle.ancestor) +end + +function predictive_loglik( + obs::ObservationProcess, + algo::RBPF, + iter::Integer, + p_star::RBParticle, + observation; + kwargs..., +) + _, log_increment = update( + obs, algo.af, iter, p_star.z, observation; new_outer=p_star.x, kwargs... + ) + return log_increment +end diff --git a/GeneralisedFilters/src/containers.jl b/GeneralisedFilters/src/containers.jl index 6fe29716..0119c65b 100644 --- a/GeneralisedFilters/src/containers.jl +++ b/GeneralisedFilters/src/containers.jl @@ -6,13 +6,13 @@ abstract type AbstractParticle{WT} end # New types # TODO (RB): could the RB particle be a regular particle with a RB state? -struct Particle{ST,WT} <: AbstractParticle{WT} +mutable struct Particle{ST,WT} <: AbstractParticle{WT} state::ST log_w::WT ancestor::Int end -struct RBParticle{XT,ZT,WT} <: AbstractParticle{WT} +mutable struct RBParticle{XT,ZT,WT} <: AbstractParticle{WT} x::XT z::ZT log_w::WT diff --git a/GeneralisedFilters/test/combination_test_script.jl b/GeneralisedFilters/test/combination_test_script.jl index 11d7566e..a5061b29 100644 --- a/GeneralisedFilters/test/combination_test_script.jl +++ b/GeneralisedFilters/test/combination_test_script.jl @@ -15,7 +15,9 @@ println() rng = StableRNG(1234) -model = GeneralisedFilters.GFTest.create_linear_gaussian_model(rng, 1, 1) +model = GeneralisedFilters.GFTest.create_linear_gaussian_model( + rng, 1, 1; static_arrays=true +) _, _, ys = sample(rng, model, 3) bf = BF(10^6; threshold=0.8) @@ -30,9 +32,11 @@ ws = softmax(log_ws) println("BF State: ", @test first(kf_state.μ) ≈ sum(first.(xs) .* ws) rtol = 1e-3) println("BF LL: ", @test llkf ≈ llbf atol = 1e-3) -struct OptimalProposal <: AbstractProposal - dyn::LinearGaussianLatentDynamics - obs::LinearGaussianObservationProcess +struct OptimalProposal{ + LD<:LinearGaussianLatentDynamics,OP<:LinearGaussianObservationProcess +} <: AbstractProposal + dyn::LD + obs::OP dummy::Bool # if using dummy hierarchical model end function SSMProblems.distribution(prop::OptimalProposal, step::Integer, x, y; kwargs...) @@ -108,8 +112,39 @@ println("RBGF LL: ", @test llkf ≈ llrbgf atol = 1e-2) # Hard to verify these are correct until the code is faster and we can run a full loop # For now we just check they run without error -ref_traj = [randn(rng, 1) for _ in 1:3] -GeneralisedFilters.filter(rng, model, bf, ys; ref_state=ref_traj) -GeneralisedFilters.filter(rng, model, gf, ys; ref_state=ref_traj) -GeneralisedFilters.filter(rng, hier_model, rbbf, ys; ref_state=ref_traj) -GeneralisedFilters.filter(rng, hier_model, rbgf, ys; ref_state=ref_traj) +using OffsetArrays +ref_traj = [randn(rng, 1) for _ in 0:3]; +ref_traj = OffsetArray(ref_traj, 0:3); +GeneralisedFilters.filter(rng, model, bf, ys; ref_state=ref_traj); +GeneralisedFilters.filter(rng, model, gf, ys; ref_state=ref_traj); +GeneralisedFilters.filter(rng, hier_model, rbbf, ys; ref_state=ref_traj); +GeneralisedFilters.filter(rng, hier_model, rbgf, ys; ref_state=ref_traj); + +#################################### +#### AUXILIARY PARTICLE FILTERS #### +#################################### + +kf_state, llkf = GeneralisedFilters.filter(rng, model, KF(), ys); + +bf = BF(10^6; threshold=1.0) +abf = AuxiliaryParticleFilter(bf, MeanPredictive()) +abf_state, llabf = GeneralisedFilters.filter(rng, model, abf, ys); +xs = getfield.(abf_state.particles, :state) +log_ws = getfield.(abf_state.particles, :log_w) +ws = softmax(log_ws) +println("ABF State: ", @test first(kf_state.μ) ≈ sum(first.(xs) .* ws) rtol = 1e-3) +println("ABF LL: ", @test llkf ≈ llabf atol = 1e-3) + +kf_state, llkf = GeneralisedFilters.filter(rng, full_model, KF(), ys); +rbbf = RBPF(bf, KalmanFilter()) +arbf = AuxiliaryParticleFilter(rbbf, MeanPredictive()) +arbf_state, llarbf = GeneralisedFilters.filter(rng, hier_model, arbf, ys); +xs = getfield.(arbf_state.particles, :x) +zs = getfield.(arbf_state.particles, :z) +log_ws = getfield.(arbf_state.particles, :log_w) +ws = softmax(log_ws) +println("ARBF Outer: ", @test first(kf_state.μ) ≈ sum(only.(xs) .* ws) rtol = 1e-3) +println( + "ARBF Inner: ", @test last(kf_state.μ) ≈ sum(only.(getfield.(zs, :μ)) .* ws) rtol = 1e-3 +) +println("ARBF LL: ", @test llkf ≈ llarbf atol = 1e-3) From e48da08823e1412c8b5e611a6b68e84d360b35ce Mon Sep 17 00:00:00 2001 From: Tim Hargreaves Date: Tue, 30 Sep 2025 10:16:08 +0100 Subject: [PATCH 04/20] Remove old RB particle type --- GeneralisedFilters/src/containers.jl | 13 ------------- 1 file changed, 13 deletions(-) diff --git a/GeneralisedFilters/src/containers.jl b/GeneralisedFilters/src/containers.jl index 0119c65b..50fd3c95 100644 --- a/GeneralisedFilters/src/containers.jl +++ b/GeneralisedFilters/src/containers.jl @@ -108,19 +108,6 @@ function mean_cov(state::GaussianDistribution) return state.μ, state.Σ end -## RAO-BLACKWELLISED PARTICLE ############################################################## - -""" - RaoBlackwellisedParticle - -A container for Rao-Blackwellised states, composed of a marginalised state `z` (e.g. a -Gaussian or Categorical distribution) and a singular state `x`. -""" -mutable struct RaoBlackwellisedParticle{XT,ZT} - x::XT - z::ZT -end - ## RAO-BLACKWELLISED PARTICLE DISTRIBUTIONS ################################################ mutable struct BatchRaoBlackwellisedParticles{XT,ZT} From ae2d4f132f843833e72d6cf4050861c83fdb9ac1 Mon Sep 17 00:00:00 2001 From: Tim Hargreaves Date: Tue, 30 Sep 2025 10:16:44 +0100 Subject: [PATCH 05/20] Remove legacy batching interface --- GeneralisedFilters/src/containers.jl | 52 ---------------------------- 1 file changed, 52 deletions(-) diff --git a/GeneralisedFilters/src/containers.jl b/GeneralisedFilters/src/containers.jl index 50fd3c95..bd98c281 100644 --- a/GeneralisedFilters/src/containers.jl +++ b/GeneralisedFilters/src/containers.jl @@ -107,55 +107,3 @@ end function mean_cov(state::GaussianDistribution) return state.μ, state.Σ end - -## RAO-BLACKWELLISED PARTICLE DISTRIBUTIONS ################################################ - -mutable struct BatchRaoBlackwellisedParticles{XT,ZT} - xs::XT - zs::ZT -end - -# Allow particle to be get and set via tree_states[:, 1:N] = states -function Base.getindex(state::BatchRaoBlackwellisedParticles, i) - return BatchRaoBlackwellisedParticles(state.xs[:, [i]], state.zs[i]) -end -function Base.getindex(state::BatchRaoBlackwellisedParticles, i::AbstractVector) - return BatchRaoBlackwellisedParticles(state.xs[:, i], state.zs[i]) -end -function Base.setindex!( - state::BatchRaoBlackwellisedParticles, value::BatchRaoBlackwellisedParticles, i -) - state.xs[:, i] = value.xs - state.zs[i] = value.zs - return state -end -Base.length(state::BatchRaoBlackwellisedParticles) = size(state.xs, 2) - -## BATCH GAUSSIAN DISTRIBUTION ############################################################# - -mutable struct BatchGaussianDistribution{T} - μs::CuArray{T,2,CUDA.DeviceMemory} - Σs::CuArray{T,3,CUDA.DeviceMemory} -end - -function Base.getindex(d::BatchGaussianDistribution, i) - return BatchGaussianDistribution(d.μs[:, [i]], d.Σs[:, :, [i]]) -end - -function Base.getindex(d::BatchGaussianDistribution, i::AbstractVector) - return BatchGaussianDistribution(d.μs[:, i], d.Σs[:, :, i]) -end - -function Base.setindex!(d::BatchGaussianDistribution, value::BatchGaussianDistribution, i) - d.μs[:, i] = value.μs - d.Σs[:, :, i] = value.Σs - return d -end - -function expand(d::BatchGaussianDistribution{T}, M::Integer) where {T} - new_μs = CuArray{T}(undef, size(d.μs, 1), M) - new_Σs = CuArray{T}(undef, size(d.Σs, 1), size(d.Σs, 2), M) - new_μs[:, 1:size(d.μs, 2)] = d.μs - new_Σs[:, :, 1:size(d.Σs, 3)] = d.Σs - return BatchGaussianDistribution(new_μs, new_Σs) -end From 0cc76962d96827c30a67ce3f3555e43da046a6f8 Mon Sep 17 00:00:00 2001 From: Tim Hargreaves Date: Tue, 30 Sep 2025 10:30:23 +0100 Subject: [PATCH 06/20] Replace RaoBlackwellisedParticle with HierarchicalState in HSMM forward simulation --- .../examples/trend-inflation/utilities.jl | 2 +- GeneralisedFilters/src/models/hierarchical.jl | 18 ++++++++++++++---- 2 files changed, 15 insertions(+), 5 deletions(-) diff --git a/GeneralisedFilters/examples/trend-inflation/utilities.jl b/GeneralisedFilters/examples/trend-inflation/utilities.jl index 579e8f2a..b8b37514 100644 --- a/GeneralisedFilters/examples/trend-inflation/utilities.jl +++ b/GeneralisedFilters/examples/trend-inflation/utilities.jl @@ -16,7 +16,7 @@ mean_path(paths, states) = _mean_path(identity, paths, states) # for rao blackwellised particles function mean_path( paths::Vector{Vector{T}}, states -) where {T<:GeneralisedFilters.RaoBlackwellisedParticle} +) where {T<:GeneralisedFilters.RBParticle} zs = _mean_path(z -> getproperty.(getproperty.(z, :z), :μ), paths, states) xs = _mean_path(x -> getproperty.(x, :x), paths, states) return zs, xs diff --git a/GeneralisedFilters/src/models/hierarchical.jl b/GeneralisedFilters/src/models/hierarchical.jl index 1d336815..ab57b213 100644 --- a/GeneralisedFilters/src/models/hierarchical.jl +++ b/GeneralisedFilters/src/models/hierarchical.jl @@ -44,6 +44,16 @@ function HierarchicalSSM( return HierarchicalSSM(outer_prior, outer_dyn, inner_model) end +""" +A container for a sampled state from a hierarchical SSM, with separation between the outer +and inner dimensions. Note this differs from a RBState in the the inner state is a sample +rather than a conditional distribution. +""" +struct HierarchicalState{XT,ZT} + x::XT + z::ZT +end + function AbstractMCMC.sample( rng::AbstractRNG, model::HierarchicalSSM, T::Integer; kwargs... ) @@ -80,14 +90,14 @@ function SSMProblems.simulate(rng::AbstractRNG, prior::HierarchicalPrior; kwargs x0 = simulate(rng, outer_prior; kwargs...) z0 = simulate(rng, inner_prior; new_outer=x0, kwargs...) # TODO (RB): this isn't really RB at all, just hierarchical state - return RaoBlackwellisedParticle(x0, z0) + return HierarchicalState(x0, z0) end function SSMProblems.simulate( rng::AbstractRNG, proc::HierarchicalDynamics, step::Integer, - prev_state::RaoBlackwellisedParticle; + prev_state::HierarchicalState; kwargs..., ) outer_dyn, inner_dyn = proc.outer_dyn, proc.inner_dyn @@ -95,7 +105,7 @@ function SSMProblems.simulate( z = simulate( rng, inner_dyn, step, prev_state.z; prev_outer=prev_state.x, new_outer=x, kwargs... ) - return RaoBlackwellisedParticle(x, z) + return HierarchicalState(x, z) end struct HierarchicalObservations{OP<:ObservationProcess} <: ObservationProcess @@ -103,7 +113,7 @@ struct HierarchicalObservations{OP<:ObservationProcess} <: ObservationProcess end function SSMProblems.distribution( - obs::HierarchicalObservations, step::Integer, state::RaoBlackwellisedParticle; kwargs... + obs::HierarchicalObservations, step::Integer, state::HierarchicalState; kwargs... ) return distribution(obs.obs, step, state.z; new_outer=state.x, kwargs...) end From 190484a3f992c39a73db4f0468a7e22d6e655afc Mon Sep 17 00:00:00 2001 From: Tim Hargreaves Date: Tue, 30 Sep 2025 10:35:49 +0100 Subject: [PATCH 07/20] Remove legacy GPU tests --- GeneralisedFilters/src/algorithms/kalman.jl | 75 --------- GeneralisedFilters/test/batch_kalman_test.jl | 150 ------------------ GeneralisedFilters/test/runtests.jl | 152 ------------------- 3 files changed, 377 deletions(-) delete mode 100644 GeneralisedFilters/test/batch_kalman_test.jl diff --git a/GeneralisedFilters/src/algorithms/kalman.jl b/GeneralisedFilters/src/algorithms/kalman.jl index 3edfb56e..a47202b6 100644 --- a/GeneralisedFilters/src/algorithms/kalman.jl +++ b/GeneralisedFilters/src/algorithms/kalman.jl @@ -69,81 +69,6 @@ function kalman_update(state, params, observation) return state, ll end -struct BatchKalmanFilter <: AbstractBatchFilter - batch_size::Int -end - -function initialise( - rng::AbstractRNG, - model::LinearGaussianStateSpaceModel, - algo::BatchKalmanFilter; - kwargs..., -) - μ0s, Σ0s = batch_calc_initial(model.prior, algo.batch_size; kwargs...) - return BatchGaussianDistribution(μ0s, Σ0s) -end - -function predict( - rng::AbstractRNG, - model::LinearGaussianStateSpaceModel, - algo::BatchKalmanFilter, - iter::Integer, - state::BatchGaussianDistribution, - observation; - kwargs..., -) - μs, Σs = state.μs, state.Σs - As, bs, Qs = batch_calc_params(model.dyn, iter, algo.batch_size; kwargs...) - μ̂s = NNlib.batched_vec(As, μs) .+ bs - Σ̂s = NNlib.batched_mul(NNlib.batched_mul(As, Σs), NNlib.batched_transpose(As)) .+ Qs - return BatchGaussianDistribution(μ̂s, Σ̂s) -end - -function update( - model::LinearGaussianStateSpaceModel, - algo::BatchKalmanFilter, - iter::Integer, - state::BatchGaussianDistribution, - observation; - kwargs..., -) - # T = Float32 # temporary fix!!! - μs, Σs = state.μs, state.Σs - Hs, cs, Rs = batch_calc_params(model.obs, iter, algo.batch_size; kwargs...) - D = size(observation, 1) - - m = NNlib.batched_vec(Hs, μs) .+ cs - y_res = cu(observation) .- m - S = NNlib.batched_mul(Hs, NNlib.batched_mul(Σs, NNlib.batched_transpose(Hs))) .+ Rs - - ΣH_T = NNlib.batched_mul(Σs, NNlib.batched_transpose(Hs)) - - S_inv = CUDA.similar(S) - d_ipiv, _, d_S = CUDA.CUBLAS.getrf_strided_batched(S, true) - CUDA.CUBLAS.getri_strided_batched!(d_S, S_inv, d_ipiv) - - diags = CuArray{eltype(S)}(undef, size(S, 1), size(S, 3)) - for i in 1:size(S, 1) - diags[i, :] .= d_S[i, i, :] - end - - log_dets = sum(log ∘ abs, diags; dims=1) - - K = NNlib.batched_mul(ΣH_T, S_inv) - - μ_filt = μs .+ NNlib.batched_vec(K, y_res) - Σ_filt = Σs .- NNlib.batched_mul(K, NNlib.batched_mul(Hs, Σs)) - - inv_term = NNlib.batched_vec(S_inv, y_res) - log_likes = -NNlib.batched_vec(reshape(y_res, 1, D, size(S, 3)), inv_term) - log_likes = (log_likes .- (log_dets .+ D * convert(eltype(log_likes), log(2π)))) ./ 2 - - # HACK: only errors seems to be from numerical stability so will just overwrite - log_likes[isnan.(log_likes)] .= -Inf - - return BatchGaussianDistribution(μ_filt, Σ_filt), dropdims(log_likes; dims=1) -end - ## KALMAN SMOOTHER ######################################################################### struct KalmanSmoother <: AbstractSmoother end diff --git a/GeneralisedFilters/test/batch_kalman_test.jl b/GeneralisedFilters/test/batch_kalman_test.jl deleted file mode 100644 index 6d353186..00000000 --- a/GeneralisedFilters/test/batch_kalman_test.jl +++ /dev/null @@ -1,150 +0,0 @@ -@testitem "Batch Kalman test" tags = [:gpu] begin - using GeneralisedFilters - using Distributions - using LinearAlgebra - using StableRNGs - - using Random - using SSMProblems - - using CUDA - - rng = StableRNG(1234) - K = 10 - Dx = 2 - Dy = 2 - μ0s = [rand(rng, Dx) for _ in 1:K] - Σ0s = [rand(rng, Dx, Dx) for _ in 1:K] - Σ0s .= Σ0s .* transpose.(Σ0s) - As = [rand(rng, Dx, Dx) for _ in 1:K] - bs = [rand(rng, Dx) for _ in 1:K] - Qs = [rand(rng, Dx, Dx) for _ in 1:K] - Qs .= Qs .* transpose.(Qs) - Hs = [rand(rng, Dy, Dx) for _ in 1:K] - cs = [rand(rng, Dy) for _ in 1:K] - Rs = [rand(rng, Dy, Dy) for _ in 1:K] - Rs .= Rs .* transpose.(Rs) - - models = [ - create_homogeneous_linear_gaussian_model( - μ0s[k], Σ0s[k], As[k], bs[k], Qs[k], Hs[k], cs[k], Rs[k] - ) for k in 1:K - ] - - T = 5 - Ys = [[rand(rng, Dy) for _ in 1:T] for _ in 1:K] - - outputs = [ - GeneralisedFilters.filter(rng, models[k], KalmanFilter(), Ys[k]) for k in 1:K - ] - - states = first.(outputs) - log_likelihoods = last.(outputs) - - struct BatchGaussianPrior{T,MT} <: GaussianPrior - μ0s::CuArray{T,2,MT} - Σ0s::CuArray{T,3,MT} - end - - function BatchGaussianPrior(μ0s::Vector{Vector{T}}, Σ0s::Vector{Matrix{T}}) where {T} - μ0s = CuArray(stack(μ0s)) - Σ0s = CuArray(stack(Σ0s)) - return BatchGaussianPrior(μ0s, Σ0s) - end - - function GeneralisedFilters.batch_calc_μ0s( - dyn::BatchGaussianPrior, ::Integer; kwargs... - ) - return dyn.μ0s - end - function GeneralisedFilters.batch_calc_Σ0s( - dyn::BatchGaussianPrior, ::Integer; kwargs... - ) - return dyn.Σ0s - end - - struct BatchLinearGaussianDynamics{T,MT} <: LinearGaussianLatentDynamics - As::CuArray{T,3,MT} - bs::CuArray{T,2,MT} - Qs::CuArray{T,3,MT} - end - - function BatchLinearGaussianDynamics( - As::Vector{Matrix{T}}, bs::Vector{Vector{T}}, Qs::Vector{Matrix{T}} - ) where {T} - As = CuArray(stack(As)) - bs = CuArray(stack(bs)) - Qs = CuArray(stack(Qs)) - return BatchLinearGaussianDynamics(As, bs, Qs) - end - - function GeneralisedFilters.batch_calc_As( - dyn::BatchLinearGaussianDynamics, ::Integer, ::Integer; kwargs... - ) - return dyn.As - end - function GeneralisedFilters.batch_calc_bs( - dyn::BatchLinearGaussianDynamics, ::Integer, ::Integer; kwargs... - ) - return dyn.bs - end - function GeneralisedFilters.batch_calc_Qs( - dyn::BatchLinearGaussianDynamics, ::Integer, ::Integer; kwargs... - ) - return dyn.Qs - end - - struct BatchLinearGaussianObservations{T,MT} <: LinearGaussianObservationProcess - Hs::CuArray{T,3,MT} - cs::CuArray{T,2,MT} - Rs::CuArray{T,3,MT} - end - - function BatchLinearGaussianObservations( - Hs::Vector{Matrix{T}}, cs::Vector{Vector{T}}, Rs::Vector{Matrix{T}} - ) where {T} - Hs = CuArray(stack(Hs)) - cs = CuArray(stack(cs)) - Rs = CuArray(stack(Rs)) - return BatchLinearGaussianObservations(Hs, cs, Rs) - end - - function GeneralisedFilters.batch_calc_Hs( - obs::BatchLinearGaussianObservations, ::Integer, ::Integer; kwargs... - ) - return obs.Hs - end - function GeneralisedFilters.batch_calc_cs( - obs::BatchLinearGaussianObservations, ::Integer, ::Integer; kwargs... - ) - return obs.cs - end - function GeneralisedFilters.batch_calc_Rs( - obs::BatchLinearGaussianObservations, ::Integer, ::Integer; kwargs... - ) - return obs.Rs - end - - batch_model = GeneralisedFilters.StateSpaceModel( - BatchGaussianPrior(μ0s, Σ0s), - BatchLinearGaussianDynamics(As, bs, Qs), - BatchLinearGaussianObservations(Hs, cs, Rs), - ) - - Ys_batch = Vector{Matrix{Float64}}(undef, T) - for t in 1:T - Ys_batch[t] = stack(Ys[k][t] for k in 1:K) - end - batch_output = GeneralisedFilters.filter( - rng, batch_model, BatchKalmanFilter(K), Ys_batch - ) - - # println("Batch log-likelihood: ", batch_output[2]) - # println("Individual log-likelihoods: ", log_likelihoods) - - # println("Batch states: ", batch_output[1].μs') - # println("Individual states: ", getproperty.(states, :μ)) - - @test Array(batch_output[2])[end] .≈ log_likelihoods[end] rtol = 1e-5 - @test Array(batch_output[1].μs) ≈ stack(getproperty.(states, :μ)) rtol = 1e-5 -end diff --git a/GeneralisedFilters/test/runtests.jl b/GeneralisedFilters/test/runtests.jl index 107fe461..c8fa7d69 100644 --- a/GeneralisedFilters/test/runtests.jl +++ b/GeneralisedFilters/test/runtests.jl @@ -5,7 +5,6 @@ using TestItemRunner @run_package_tests filter = ti -> !(:gpu in ti.tags) include("Aqua.jl") -include("batch_kalman_test.jl") include("resamplers.jl") @testitem "Kalman filter test" begin @@ -552,154 +551,3 @@ end @test state.μ[1] ≈ only(mean(x_trajectories)) rtol = 1e-2 @test state.μ[2] ≈ mean(z_smoothed_means) rtol = 1e-3 end - -@testitem "GPU Conditional Kalman-RBPF execution test" tags = [:gpu] begin - using CUDA - using OffsetArrays - using SSMProblems - using StableRNGs - - SEED = 1234 - D_outer = 2 - D_inner = 3 - D_obs = 2 - K = 5 - T = Float32 - N_particles = 1000 - - rng = StableRNG(1234) - - full_model, hier_model = GeneralisedFilters.GFTest.create_dummy_linear_gaussian_model( - rng, D_outer, D_inner, D_obs, T - ) - _, _, ys = sample(rng, full_model, K) - - # Generate random reference trajectory - ref_trajectory = [CuArray(rand(rng, T, D_outer, 1)) for _ in 0:K] - ref_trajectory = OffsetVector(ref_trajectory, -1) - - rbpf = BatchRBPF(BatchKalmanFilter(N_particles), N_particles) - states, ll = GeneralisedFilters.filter(hier_model, rbpf, ys; ref_state=ref_trajectory) - - # Check returned type - @test typeof(ll) == T -end - -@testitem "GPU-RBPF ancestory test" tags = [:gpu] begin - using GeneralisedFilters - using CUDA - using LinearAlgebra - using SSMProblems - using StableRNGs - - SEED = 1234 - D_outer = 2 - D_inner = 3 - D_obs = 2 - K = 5 - T = Float32 - N_particles = 10^5 - - rng = StableRNG(1234) - - full_model, hier_model = GeneralisedFilters.GFTest.create_dummy_linear_gaussian_model( - rng, D_outer, D_inner, D_obs, T - ) - _, _, ys = sample(rng, full_model, K) - - # Manually create tree to force expansion on second step - M = N_particles * 2 - 1 - tree = GeneralisedFilters.ParallelParticleTree( - GeneralisedFilters.BatchRaoBlackwellisedParticles( - CuArray{T}(undef, D_outer, N_particles), - GeneralisedFilters.BatchGaussianDistribution( - CuArray{T}(undef, D_inner, N_particles), - CuArray{T}(undef, D_inner, D_inner, N_particles), - ), - ), - M, - ) - - rbpf = BatchRBPF(BatchKalmanFilter(N_particles), N_particles) - cb = GeneralisedFilters.ParallelAncestorCallback(tree) - states, ll = GeneralisedFilters.filter(hier_model, rbpf, ys; callback=cb) - - # TODO: add proper test comparing to dense storage - ancestry = GeneralisedFilters.get_ancestry(tree, K) -end - -@testitem "GPU Conditional Kalman-RBPF validity test" tags = [:gpu, :long] begin - using GeneralisedFilters - using CUDA - using NNlib - using OffsetArrays - using StableRNGs - using StatsBase - - SEED = 1234 - D_outer = 1 - D_inner = 1 - D_obs = 1 - K = 3 - t_smooth = 2 - T = Float32 - N_particles = 10000 - N_burnin = 100 - N_sample = 2000 - - rng = StableRNG(1234) - - full_model, hier_model = GeneralisedFilters.GFTest.create_dummy_linear_gaussian_model( - rng, D_outer, D_inner, D_obs, T - ) - _, _, ys = sample(rng, full_model, K) - - # Kalman smoother - state, _ = GeneralisedFilters.smooth( - rng, full_model, KalmanSmoother(), ys; t_smooth=t_smooth - ) - - particle_template = GeneralisedFilters.BatchRaoBlackwellisedParticles( - CuArray{T}(undef, D_outer, N_particles), - GeneralisedFilters.BatchGaussianDistribution( - CuArray{T}(undef, D_inner, N_particles), - CuArray{T}(undef, D_inner, D_inner, N_particles), - ), - ) - particle_type = typeof(particle_template) - - N_steps = N_burnin + N_sample - M = floor(Int64, N_particles * log(N_particles)) - rbpf = BatchRBPF(BatchKalmanFilter(N_particles), N_particles; threshold=1.0) - ref_traj = nothing - trajectory_samples = Vector{OffsetArray{particle_type,1,Vector{particle_type}}}( - undef, N_sample - ) - - for i in 1:N_steps - tree = GeneralisedFilters.ParallelParticleTree(deepcopy(particle_template), M) - cb = GeneralisedFilters.ParallelAncestorCallback(tree) - rbpf_state, _ = GeneralisedFilters.filter( - hier_model, rbpf, ys; ref_state=ref_traj, callback=cb - ) - weights = softmax(rbpf_state.log_weights) - ancestors = GeneralisedFilters.sample_ancestors(rng, Multinomial(), weights) - sampled_idx = CUDA.@allowscalar ancestors[1] - global ref_traj = GeneralisedFilters.get_ancestry(tree, sampled_idx, K) - if i > N_burnin - trajectory_samples[i - N_burnin] = ref_traj - end - # Reference trajectory should only be nonlinear state for RBPF - ref_traj = getproperty.(ref_traj, :xs) - end - - # Extract inner and outer trajectories - x_trajectories = getproperty.(getindex.(trajectory_samples, t_smooth), :xs) - z_trajectories = getproperty.(getindex.(trajectory_samples, t_smooth), :zs) - - # Compare to ground truth - CUDA.@allowscalar begin - @test state.μ[1] ≈ only(mean(x_trajectories)) rtol = 1e-1 - @test state.μ[2] ≈ only(mean(getproperty.(z_trajectories, :μs))) rtol = 1e-1 - end -end From 3759cba49aa0373f5b84bffc53c919b25dc1b266 Mon Sep 17 00:00:00 2001 From: Tim Hargreaves Date: Tue, 30 Sep 2025 11:15:26 +0100 Subject: [PATCH 08/20] Update unit tests to match new syntax and add unit tests for new features --- GeneralisedFilters/src/algorithms/rbpf.jl | 1 + GeneralisedFilters/test/runtests.jl | 250 ++++++++++++++-------- 2 files changed, 158 insertions(+), 93 deletions(-) diff --git a/GeneralisedFilters/src/algorithms/rbpf.jl b/GeneralisedFilters/src/algorithms/rbpf.jl index ee89eb66..c1ab8ede 100644 --- a/GeneralisedFilters/src/algorithms/rbpf.jl +++ b/GeneralisedFilters/src/algorithms/rbpf.jl @@ -32,6 +32,7 @@ function predict_particle( ref_state; kwargs..., ) + # TODO: really we should be conditioning on the current RB state to allow for optimal proposals new_x, logw_inc = propogate( rng, dyn.outer_dyn, algo.pf, iter, particle.x, observation, ref_state; kwargs... ) diff --git a/GeneralisedFilters/test/runtests.jl b/GeneralisedFilters/test/runtests.jl index c8fa7d69..1cb79482 100644 --- a/GeneralisedFilters/test/runtests.jl +++ b/GeneralisedFilters/test/runtests.jl @@ -116,60 +116,65 @@ end @testitem "Bootstrap filter test" begin using SSMProblems using StableRNGs - using StatsBase + using LogExpFunctions rng = StableRNG(1234) - model = GeneralisedFilters.GFTest.create_linear_gaussian_model(rng, 1, 1) - _, _, ys = sample(rng, model, 10) + model = GeneralisedFilters.GFTest.create_linear_gaussian_model( + rng, 1, 1; static_arrays=true + ) + _, _, ys = sample(rng, model, 3) - bf = BF(2^12; threshold=0.8) + bf = BF(10^6; threshold=0.8) bf_state, llbf = GeneralisedFilters.filter(rng, model, bf, ys) kf_state, llkf = GeneralisedFilters.filter(rng, model, KF(), ys) - xs = bf_state.particles - ws = weights(bf_state) + xs = getfield.(bf_state.particles, :state) + log_ws = getfield.(bf_state.particles, :log_w) + ws = softmax(log_ws) # Compare log-likelihood and states - @test first(kf_state.μ) ≈ sum(first.(xs) .* ws) rtol = 1e-2 - @test llkf ≈ llbf atol = 1e-2 + @test first(kf_state.μ) ≈ sum(first.(xs) .* ws) rtol = 1e-3 + @test llkf ≈ llbf atol = 1e-3 end @testitem "Guided filter test" begin using SSMProblems - using StatsBase + using LogExpFunctions using StableRNGs using Distributions using LinearAlgebra - struct LinearGaussianProposal <: GeneralisedFilters.AbstractProposal end - - function SSMProblems.distribution( - model::AbstractStateSpaceModel, - kernel::LinearGaussianProposal, - iter::Integer, - state, - observation; - kwargs..., - ) - A, b, Q = GeneralisedFilters.calc_params(model.dyn, iter; kwargs...) - pred = GeneralisedFilters.GaussianDistribution(A * state + b, Q) - prop, _ = GeneralisedFilters.update(model, KF(), iter, pred, observation; kwargs...) - return MvNormal(prop.μ, hermitianpart(prop.Σ)) + struct OptimalProposal{ + LD<:LinearGaussianLatentDynamics,OP<:LinearGaussianObservationProcess + } <: AbstractProposal + dyn::LD + obs::OP + end + function SSMProblems.distribution(prop::OptimalProposal, step::Integer, x, y; kwargs...) + A, b, Q = GeneralisedFilters.calc_params(prop.dyn, step; kwargs...) + H, c, R = GeneralisedFilters.calc_params(prop.obs, step; kwargs...) + Σ = inv(inv(Q) + H' * inv(R) * H) + μ = Σ * (inv(Q) * (A * x + b) + H' * inv(R) * (y - c)) + return MvNormal(μ, Σ) end rng = StableRNG(1234) - model = GeneralisedFilters.GFTest.create_linear_gaussian_model(rng, 1, 1) - _, _, ys = sample(rng, model, 10) + model = GeneralisedFilters.GFTest.create_linear_gaussian_model( + rng, 1, 1; static_arrays=true + ) + _, _, ys = sample(rng, model, 3) - algo = PF(2^10, LinearGaussianProposal(); threshold=0.6) - kf_states, kf_ll = GeneralisedFilters.filter(rng, model, KalmanFilter(), ys) - pf_states, pf_ll = GeneralisedFilters.filter(rng, model, algo, ys) - xs = pf_states.particles - ws = weights(pf_states) + proposal = OptimalProposal(model.dyn, model.obs) + gf = ParticleFilter(10^6, proposal; threshold=1.0) + gf_state, llgf = GeneralisedFilters.filter(rng, model, gf, ys) + kf_state, llkf = GeneralisedFilters.filter(rng, model, KF(), ys) - # Compare log-likelihood and states - @test first(kf_states.μ) ≈ sum(first.(xs) .* ws) rtol = 1e-2 - @test kf_ll ≈ pf_ll rtol = 1e-2 + xs = getfield.(gf_state.particles, :state) + log_ws = getfield.(gf_state.particles, :log_w) + ws = softmax(log_ws) + + @test first(kf_state.μ) ≈ sum(first.(xs) .* ws) rtol = 1e-3 + @test llkf ≈ llgf atol = 1e-3 end @testitem "Forward algorithm test" begin @@ -227,83 +232,137 @@ end @test ll ≈ log(marginal) end -@testitem "Kalman-RBPF test" begin - using LogExpFunctions: softmax +@testitem "Rao-Blackwellised BF test" begin + using Distributions + using GeneralisedFilters + using LinearAlgebra + using LogExpFunctions + using SSMProblems using StableRNGs - using StatsBase - - SEED = 1234 - D_outer = 2 - D_inner = 3 - D_obs = 2 - T = 5 - N_particles = 10^4 - rng = StableRNG(SEED) + rng = StableRNG(1234) full_model, hier_model = GeneralisedFilters.GFTest.create_dummy_linear_gaussian_model( - rng, D_outer, D_inner, D_obs + rng, 1, 1, 1; static_arrays=true ) - _, _, ys = sample(rng, full_model, T) + _, _, ys = sample(rng, hier_model, 3) - # Ground truth Kalman filtering - kf_states, kf_ll = GeneralisedFilters.filter(rng, full_model, KalmanFilter(), ys) + bf = BF(10^6; threshold=0.8) + rbbf = RBPF(bf, KalmanFilter()) - # Rao-Blackwellised particle filtering - rbpf = RBPF(KalmanFilter(), N_particles) - states, ll = GeneralisedFilters.filter(rng, hier_model, rbpf, ys) + rbbf_state, llrbbf = GeneralisedFilters.filter(rng, hier_model, rbbf, ys) + xs = getfield.(rbbf_state.particles, :x) + zs = getfield.(rbbf_state.particles, :z) + log_ws = getfield.(rbbf_state.particles, :log_w) + ws = softmax(log_ws) - # Extract final filtered states - xs = map(p -> getproperty(p, :x), states.particles) - zs = map(p -> getproperty(p, :z), states.particles) - ws = weights(states) + kf_state, llkf = GeneralisedFilters.filter(rng, full_model, KF(), ys) - @test kf_ll ≈ ll rtol = 1e-2 - - # Higher tolerance for outer state since variance is higher - @test first(kf_states.μ) ≈ sum(first.(xs) .* ws) rtol = 1e-1 - @test last(kf_states.μ) ≈ sum(last.(getproperty.(zs, :μ)) .* ws) rtol = 1e-2 + @test first(kf_state.μ) ≈ sum(only.(xs) .* ws) rtol = 1e-3 + @test last(kf_state.μ) ≈ sum(only.(getfield.(zs, :μ)) .* ws) rtol = 1e-3 + @test llkf ≈ llrbbf atol = 1e-3 end -@testitem "GPU Kalman-RBPF test" tags = [:gpu] begin - using CUDA +@testitem "Rao-Blackwellised GF test" begin + using Distributions + using GeneralisedFilters using LinearAlgebra - using NNlib + using LogExpFunctions using SSMProblems using StableRNGs - using StatsBase - SEED = 1234 - D_outer = 2 - D_inner = 3 - D_obs = 2 - T = 5 - N_particles = 10^4 - ET = Float32 + struct OverdispersedProposal{LD<:LinearGaussianLatentDynamics} <: AbstractProposal + dyn::LD + k::Float64 + end + function SSMProblems.distribution( + prop::OverdispersedProposal, step::Integer, x, y; kwargs... + ) + A, b, Q = GeneralisedFilters.calc_params(prop.dyn, step; kwargs...) + Q = prop.k * Q # overdisperse + μ = A * x + b + return MvNormal(μ, Q) + end - rng = StableRNG(SEED) + rng = StableRNG(1234) full_model, hier_model = GeneralisedFilters.GFTest.create_dummy_linear_gaussian_model( - rng, D_outer, D_inner, D_obs, ET + rng, 1, 1, 1; static_arrays=true ) - _, _, ys = sample(rng, full_model, T) + _, _, ys = sample(rng, hier_model, 3) + + proposal = OverdispersedProposal(dyn(hier_model).outer_dyn, 1.5) + gf = ParticleFilter(10^6, proposal; threshold=1.0) + rbgf = RBPF(gf, KalmanFilter()) + rbgf_state, llrbgf = GeneralisedFilters.filter(rng, hier_model, rbgf, ys) + xs = getfield.(rbgf_state.particles, :x) + zs = getfield.(rbgf_state.particles, :z) + log_ws = getfield.(rbgf_state.particles, :log_w) + ws = softmax(log_ws) + + kf_state, llkf = GeneralisedFilters.filter(rng, full_model, KF(), ys) + + @test first(kf_state.μ) ≈ sum(only.(xs) .* ws) rtol = 1e-3 + @test last(kf_state.μ) ≈ sum(only.(getfield.(zs, :μ)) .* ws) rtol = 1e-3 + @test llkf ≈ llrbgf atol = 1e-3 +end - # Ground truth Kalman filtering - kf_state, kf_ll = GeneralisedFilters.filter(full_model, KalmanFilter(), ys) +@testitem "ABF test" begin + using Distributions + using GeneralisedFilters + using LinearAlgebra + using LogExpFunctions + using SSMProblems + using StableRNGs - # Rao-Blackwellised particle filtering - rbpf = BatchRBPF(BatchKalmanFilter(N_particles), N_particles) - states, ll = GeneralisedFilters.filter(hier_model, rbpf, ys) + rng = StableRNG(1234) + model = GeneralisedFilters.GFTest.create_linear_gaussian_model( + rng, 1, 1; static_arrays=true + ) + _, _, ys = sample(rng, model, 3) - # Extract final filtered states - xs = states.particles.xs - zs = states.particles.zs - ws = weights(states) + bf = BF(10^6; threshold=1.0) + abf = AuxiliaryParticleFilter(bf, MeanPredictive()) + abf_state, llabf = GeneralisedFilters.filter(rng, model, abf, ys) + kf_state, llkf = GeneralisedFilters.filter(rng, model, KF(), ys) - @test kf_ll ≈ ll rtol = 1e-2 - @test first(kf_state.μ) ≈ sum(xs[1, :] .* ws) rtol = 1e-1 - @test last(kf_state.μ) ≈ sum(zs.μs[end, :] .* ws) rtol = 1e-2 - @test eltype(xs) == ET + xs = getfield.(abf_state.particles, :state) + log_ws = getfield.(abf_state.particles, :log_w) + ws = softmax(log_ws) + + @test first(kf_state.μ) ≈ sum(first.(xs) .* ws) rtol = 1e-3 + @test llkf ≈ llabf atol = 1e-3 +end + +@testitem "ARBF test" begin + using Distributions + using GeneralisedFilters + using LinearAlgebra + using LogExpFunctions + using SSMProblems + using StableRNGs + + rng = StableRNG(1234) + + full_model, hier_model = GeneralisedFilters.GFTest.create_dummy_linear_gaussian_model( + rng, 1, 1, 1; static_arrays=true + ) + _, _, ys = sample(rng, hier_model, 3) + + bf = BF(10^6; threshold=1.0) + rbbf = RBPF(bf, KalmanFilter()) + arbf = AuxiliaryParticleFilter(rbbf, MeanPredictive()) + arbf_state, llarbf = GeneralisedFilters.filter(rng, hier_model, arbf, ys) + xs = getfield.(arbf_state.particles, :x) + zs = getfield.(arbf_state.particles, :z) + log_ws = getfield.(arbf_state.particles, :log_w) + ws = softmax(log_ws) + + kf_state, llkf = GeneralisedFilters.filter(rng, full_model, KF(), ys) + + @test first(kf_state.μ) ≈ sum(only.(xs) .* ws) rtol = 1e-3 + @test last(kf_state.μ) ≈ sum(only.(getfield.(zs, :μ)) .* ws) rtol = 1e-3 + @test llkf ≈ llarbf atol = 1e-3 end @testitem "RBPF ancestory test" begin @@ -330,8 +389,8 @@ end end @testitem "BF on hierarchical model test" begin + using LogExpFunctions using StableRNGs - using StatsBase SEED = 1234 D_outer = 1 @@ -357,7 +416,8 @@ end # Extract final filtered states xs = map(p -> getproperty(p, :x), states.particles) zs = map(p -> getproperty(p, :z), states.particles) - ws = weights(states) + log_ws = getfield.(states.particles, :log_w) + ws = softmax(log_ws) @test kf_ll ≈ ll rtol = 1e-2 @@ -415,7 +475,7 @@ end using LinearAlgebra using LogExpFunctions: softmax, logsumexp using Random: randexp - using StatsBase + using StatsBase: sample, Weights using OffsetArrays @@ -449,7 +509,9 @@ end bf_state, ll = GeneralisedFilters.filter( rng, model, bf, ys; ref_state=ref_traj, callback=cb ) - sampled_idx = sample(rng, 1:length(bf_state), weights(bf_state)) + log_ws = getfield.(bf_state.particles, :log_w) + ws = softmax(log_ws) + sampled_idx = sample(rng, 1:length(bf_state), Weights(ws)) global ref_traj = GeneralisedFilters.get_ancestry(cb.container, sampled_idx) if i > N_burnin push!(trajectory_samples, ref_traj) @@ -473,7 +535,7 @@ end using LinearAlgebra using LogExpFunctions: softmax using Random: randexp - using StatsBase + using StatsBase: sample, Weights using StaticArrays using OffsetArrays @@ -510,7 +572,9 @@ end bf_state, _ = GeneralisedFilters.filter( rng, hier_model, rbpf, ys; ref_state=ref_traj, callback=cb ) - sampled_idx = sample(rng, 1:length(bf_state), StatsBase.weights(bf_state)) + log_ws = getfield.(bf_state.particles, :log_w) + ws = softmax(log_ws) + sampled_idx = sample(rng, 1:length(bf_state), Weights(ws)) global ref_traj = GeneralisedFilters.get_ancestry(cb.container, sampled_idx) if i > N_burnin From 272c2d5ebc2196838811948e9cc183514a73e666 Mon Sep 17 00:00:00 2001 From: Tim Hargreaves Date: Tue, 30 Sep 2025 11:36:03 +0100 Subject: [PATCH 09/20] Move unit tests proposals to GFTest --- GeneralisedFilters/src/GFTest/GFTest.jl | 1 + GeneralisedFilters/src/GFTest/proposals.jl | 43 ++++++++++++++++++++++ GeneralisedFilters/test/runtests.jl | 40 ++------------------ 3 files changed, 48 insertions(+), 36 deletions(-) create mode 100644 GeneralisedFilters/src/GFTest/proposals.jl diff --git a/GeneralisedFilters/src/GFTest/GFTest.jl b/GeneralisedFilters/src/GFTest/GFTest.jl index 6e3eaeb2..04e54e1a 100644 --- a/GeneralisedFilters/src/GFTest/GFTest.jl +++ b/GeneralisedFilters/src/GFTest/GFTest.jl @@ -12,5 +12,6 @@ using SSMProblems include("utils.jl") include("models/linear_gaussian.jl") include("models/dummy_linear_gaussian.jl") +include("proposals.jl") end diff --git a/GeneralisedFilters/src/GFTest/proposals.jl b/GeneralisedFilters/src/GFTest/proposals.jl new file mode 100644 index 00000000..9b4ca3d3 --- /dev/null +++ b/GeneralisedFilters/src/GFTest/proposals.jl @@ -0,0 +1,43 @@ +using Distributions + +""" + OptimalProposal + +Optimal importance proposal for linear Gaussian state space models. + +A proposal coming from the closed-form distribution `p(x_t | x_{t-1}, y_t)`. This proposal +minimizes the variance of the importance weights. +""" +struct OptimalProposal{ + LD<:LinearGaussianLatentDynamics,OP<:LinearGaussianObservationProcess +} <: AbstractProposal + dyn::LD + obs::OP +end + +function SSMProblems.distribution(prop::OptimalProposal, step::Integer, x, y; kwargs...) + A, b, Q = GeneralisedFilters.calc_params(prop.dyn, step; kwargs...) + H, c, R = GeneralisedFilters.calc_params(prop.obs, step; kwargs...) + Σ = inv(inv(Q) + H' * inv(R) * H) + μ = Σ * (inv(Q) * (A * x + b) + H' * inv(R) * (y - c)) + return MvNormal(μ, Σ) +end + +""" + OverdispersedProposal + +A proposal that overdisperses the latent dynamics by inflating the covariance. +""" +struct OverdispersedProposal{LD<:LinearGaussianLatentDynamics} <: AbstractProposal + dyn::LD + k::Float64 +end + +function SSMProblems.distribution( + prop::OverdispersedProposal, step::Integer, x, y; kwargs... +) + A, b, Q = GeneralisedFilters.calc_params(prop.dyn, step; kwargs...) + Q = prop.k * Q # overdisperse + μ = A * x + b + return MvNormal(μ, Q) +end diff --git a/GeneralisedFilters/test/runtests.jl b/GeneralisedFilters/test/runtests.jl index 1cb79482..39a896e4 100644 --- a/GeneralisedFilters/test/runtests.jl +++ b/GeneralisedFilters/test/runtests.jl @@ -141,22 +141,6 @@ end using SSMProblems using LogExpFunctions using StableRNGs - using Distributions - using LinearAlgebra - - struct OptimalProposal{ - LD<:LinearGaussianLatentDynamics,OP<:LinearGaussianObservationProcess - } <: AbstractProposal - dyn::LD - obs::OP - end - function SSMProblems.distribution(prop::OptimalProposal, step::Integer, x, y; kwargs...) - A, b, Q = GeneralisedFilters.calc_params(prop.dyn, step; kwargs...) - H, c, R = GeneralisedFilters.calc_params(prop.obs, step; kwargs...) - Σ = inv(inv(Q) + H' * inv(R) * H) - μ = Σ * (inv(Q) * (A * x + b) + H' * inv(R) * (y - c)) - return MvNormal(μ, Σ) - end rng = StableRNG(1234) model = GeneralisedFilters.GFTest.create_linear_gaussian_model( @@ -164,8 +148,8 @@ end ) _, _, ys = sample(rng, model, 3) - proposal = OptimalProposal(model.dyn, model.obs) - gf = ParticleFilter(10^6, proposal; threshold=1.0) + prop = GeneralisedFilters.GFTest.OptimalProposal(model.dyn, model.obs) + gf = ParticleFilter(10^6, prop; threshold=1.0) gf_state, llgf = GeneralisedFilters.filter(rng, model, gf, ys) kf_state, llkf = GeneralisedFilters.filter(rng, model, KF(), ys) @@ -264,26 +248,10 @@ end end @testitem "Rao-Blackwellised GF test" begin - using Distributions - using GeneralisedFilters - using LinearAlgebra using LogExpFunctions using SSMProblems using StableRNGs - struct OverdispersedProposal{LD<:LinearGaussianLatentDynamics} <: AbstractProposal - dyn::LD - k::Float64 - end - function SSMProblems.distribution( - prop::OverdispersedProposal, step::Integer, x, y; kwargs... - ) - A, b, Q = GeneralisedFilters.calc_params(prop.dyn, step; kwargs...) - Q = prop.k * Q # overdisperse - μ = A * x + b - return MvNormal(μ, Q) - end - rng = StableRNG(1234) full_model, hier_model = GeneralisedFilters.GFTest.create_dummy_linear_gaussian_model( @@ -291,8 +259,8 @@ end ) _, _, ys = sample(rng, hier_model, 3) - proposal = OverdispersedProposal(dyn(hier_model).outer_dyn, 1.5) - gf = ParticleFilter(10^6, proposal; threshold=1.0) + prop = GeneralisedFilters.GFTest.OverdispersedProposal(dyn(hier_model).outer_dyn, 1.5) + gf = ParticleFilter(10^6, prop; threshold=1.0) rbgf = RBPF(gf, KalmanFilter()) rbgf_state, llrbgf = GeneralisedFilters.filter(rng, hier_model, rbgf, ys) xs = getfield.(rbgf_state.particles, :x) From 415ab5590766e44c444720a187b780d49ece94b2 Mon Sep 17 00:00:00 2001 From: Tim Hargreaves Date: Tue, 30 Sep 2025 11:55:29 +0100 Subject: [PATCH 10/20] Use alternating resampler for unit tests --- GeneralisedFilters/src/GFTest/GFTest.jl | 1 + GeneralisedFilters/src/GFTest/resamplers.jl | 45 +++++++++++++++++++++ GeneralisedFilters/test/runtests.jl | 31 ++++++++------ 3 files changed, 64 insertions(+), 13 deletions(-) create mode 100644 GeneralisedFilters/src/GFTest/resamplers.jl diff --git a/GeneralisedFilters/src/GFTest/GFTest.jl b/GeneralisedFilters/src/GFTest/GFTest.jl index 04e54e1a..3d673c44 100644 --- a/GeneralisedFilters/src/GFTest/GFTest.jl +++ b/GeneralisedFilters/src/GFTest/GFTest.jl @@ -13,5 +13,6 @@ include("utils.jl") include("models/linear_gaussian.jl") include("models/dummy_linear_gaussian.jl") include("proposals.jl") +include("resamplers.jl") end diff --git a/GeneralisedFilters/src/GFTest/resamplers.jl b/GeneralisedFilters/src/GFTest/resamplers.jl new file mode 100644 index 00000000..4a437290 --- /dev/null +++ b/GeneralisedFilters/src/GFTest/resamplers.jl @@ -0,0 +1,45 @@ +""" + AlternatingResampler + +A resampler wrapper that alternates between resampling and not resampling on each step. +This is useful for testing the validity of filters in both cases. + +The resampler maintains internal state to track whether it should resample on the next call. +By default, it resamples on the first call, then skips the second, then resamples on the +third, and so on. +""" +mutable struct AlternatingResampler <: GeneralisedFilters.AbstractConditionalResampler + resampler::GeneralisedFilters.AbstractResampler + resample_next::Bool + function AlternatingResampler( + resampler::GeneralisedFilters.AbstractResampler=Systematic() + ) + return new(resampler, true) + end +end + +function GeneralisedFilters.resample( + rng::AbstractRNG, + alt_resampler::AlternatingResampler, + state; + ref_state::Union{Nothing,AbstractVector}=nothing, +) + n = length(state.particles) + + if alt_resampler.resample_next + # Resample using wrapped resampler + alt_resampler.resample_next = false + return GeneralisedFilters.resample( + rng, alt_resampler.resampler, state; ref_state + ) + else + # Skip resampling - keep particles with their current weights and set ancestors to + # themselves + alt_resampler.resample_next = true + new_particles = similar(state.particles) + for i in 1:n + new_particles[i] = GeneralisedFilters.set_ancestor(state.particles[i], i) + end + return GeneralisedFilters.ParticleDistribution(new_particles, state.prev_logsumexp) + end +end diff --git a/GeneralisedFilters/test/runtests.jl b/GeneralisedFilters/test/runtests.jl index 39a896e4..c23062b8 100644 --- a/GeneralisedFilters/test/runtests.jl +++ b/GeneralisedFilters/test/runtests.jl @@ -122,9 +122,10 @@ end model = GeneralisedFilters.GFTest.create_linear_gaussian_model( rng, 1, 1; static_arrays=true ) - _, _, ys = sample(rng, model, 3) + _, _, ys = sample(rng, model, 4) - bf = BF(10^6; threshold=0.8) + resampler = GeneralisedFilters.GFTest.AlternatingResampler() + bf = BF(10^6; resampler=resampler) bf_state, llbf = GeneralisedFilters.filter(rng, model, bf, ys) kf_state, llkf = GeneralisedFilters.filter(rng, model, KF(), ys) @@ -146,10 +147,11 @@ end model = GeneralisedFilters.GFTest.create_linear_gaussian_model( rng, 1, 1; static_arrays=true ) - _, _, ys = sample(rng, model, 3) + _, _, ys = sample(rng, model, 4) prop = GeneralisedFilters.GFTest.OptimalProposal(model.dyn, model.obs) - gf = ParticleFilter(10^6, prop; threshold=1.0) + resampler = GeneralisedFilters.GFTest.AlternatingResampler() + gf = ParticleFilter(10^6, prop; resampler=resampler) gf_state, llgf = GeneralisedFilters.filter(rng, model, gf, ys) kf_state, llkf = GeneralisedFilters.filter(rng, model, KF(), ys) @@ -229,9 +231,10 @@ end full_model, hier_model = GeneralisedFilters.GFTest.create_dummy_linear_gaussian_model( rng, 1, 1, 1; static_arrays=true ) - _, _, ys = sample(rng, hier_model, 3) + _, _, ys = sample(rng, hier_model, 4) - bf = BF(10^6; threshold=0.8) + resampler = GeneralisedFilters.GFTest.AlternatingResampler() + bf = BF(10^6; resampler=resampler) rbbf = RBPF(bf, KalmanFilter()) rbbf_state, llrbbf = GeneralisedFilters.filter(rng, hier_model, rbbf, ys) @@ -257,10 +260,11 @@ end full_model, hier_model = GeneralisedFilters.GFTest.create_dummy_linear_gaussian_model( rng, 1, 1, 1; static_arrays=true ) - _, _, ys = sample(rng, hier_model, 3) + _, _, ys = sample(rng, hier_model, 4) prop = GeneralisedFilters.GFTest.OverdispersedProposal(dyn(hier_model).outer_dyn, 1.5) - gf = ParticleFilter(10^6, prop; threshold=1.0) + resampler = GeneralisedFilters.GFTest.AlternatingResampler() + gf = ParticleFilter(10^6, prop; resampler=resampler) rbgf = RBPF(gf, KalmanFilter()) rbgf_state, llrbgf = GeneralisedFilters.filter(rng, hier_model, rbgf, ys) xs = getfield.(rbgf_state.particles, :x) @@ -287,9 +291,9 @@ end model = GeneralisedFilters.GFTest.create_linear_gaussian_model( rng, 1, 1; static_arrays=true ) - _, _, ys = sample(rng, model, 3) + _, _, ys = sample(rng, model, 4) - bf = BF(10^6; threshold=1.0) + bf = BF(10^6; threshold=1.0) # APF needs resampling every step abf = AuxiliaryParticleFilter(bf, MeanPredictive()) abf_state, llabf = GeneralisedFilters.filter(rng, model, abf, ys) kf_state, llkf = GeneralisedFilters.filter(rng, model, KF(), ys) @@ -315,9 +319,9 @@ end full_model, hier_model = GeneralisedFilters.GFTest.create_dummy_linear_gaussian_model( rng, 1, 1, 1; static_arrays=true ) - _, _, ys = sample(rng, hier_model, 3) + _, _, ys = sample(rng, hier_model, 4) - bf = BF(10^6; threshold=1.0) + bf = BF(10^6; threshold=1.0) # APF needs resampling every step rbbf = RBPF(bf, KalmanFilter()) arbf = AuxiliaryParticleFilter(rbbf, MeanPredictive()) arbf_state, llarbf = GeneralisedFilters.filter(rng, hier_model, arbf, ys) @@ -358,6 +362,7 @@ end @testitem "BF on hierarchical model test" begin using LogExpFunctions + using SSMProblems using StableRNGs SEED = 1234 @@ -378,7 +383,7 @@ end kf_states, kf_ll = GeneralisedFilters.filter(rng, full_model, KalmanFilter(), ys) # Rao-Blackwellised particle filtering - bf = BF(N_particles) + bf = BF(N_particles; threshold=0.8) states, ll = GeneralisedFilters.filter(rng, hier_model, bf, ys) # Extract final filtered states From 6ec276e424a96db77990ef63f1c5f32848abfd8c Mon Sep 17 00:00:00 2001 From: Tim Hargreaves Date: Tue, 30 Sep 2025 11:55:54 +0100 Subject: [PATCH 11/20] Remove legacy code --- GeneralisedFilters/src/algorithms/particles.jl | 7 ------- 1 file changed, 7 deletions(-) diff --git a/GeneralisedFilters/src/algorithms/particles.jl b/GeneralisedFilters/src/algorithms/particles.jl index 823cef72..10fba052 100644 --- a/GeneralisedFilters/src/algorithms/particles.jl +++ b/GeneralisedFilters/src/algorithms/particles.jl @@ -191,13 +191,6 @@ function propogate( return new_x, logw_inc end -# function update( -# obs, algo::ParticleFilter, iter::Integer, p::Particle, observation; kwargs... -# ) -# log_increment = SSMProblems.logdensity(obs, iter, p.state, observation; kwargs...) -# return Particle(p.state, p.log_w + log_increment, p.ancestor) -# end - struct LatentProposal <: AbstractProposal end const BootstrapFilter{RS} = ParticleFilter{RS,LatentProposal} From 970cdfc3d400a1ab7648d0d2ceb9e82bc340fdfe Mon Sep 17 00:00:00 2001 From: Tim Hargreaves Date: Tue, 30 Sep 2025 12:00:02 +0100 Subject: [PATCH 12/20] Fix unit test for BF on H-SSM --- .../src/algorithms/particles.jl | 31 ++++++++++--------- GeneralisedFilters/test/runtests.jl | 4 +-- 2 files changed, 18 insertions(+), 17 deletions(-) diff --git a/GeneralisedFilters/src/algorithms/particles.jl b/GeneralisedFilters/src/algorithms/particles.jl index 10fba052..d1378ca7 100644 --- a/GeneralisedFilters/src/algorithms/particles.jl +++ b/GeneralisedFilters/src/algorithms/particles.jl @@ -244,22 +244,23 @@ function propogate( return new_x, 0.0 end +# TODO: I feel like we shouldn't need to do this conversion. It should be handled by dispatch # Application of particle filter to hierarchical models -# function filter( -# rng::AbstractRNG, -# model::HierarchicalSSM, -# algo::ParticleFilter, -# observations::AbstractVector; -# ref_state::Union{Nothing,AbstractVector}=nothing, -# kwargs..., -# ) -# ssm = StateSpaceModel( -# HierarchicalPrior(model.outer_prior, model.inner_model.prior), -# HierarchicalDynamics(model.outer_dyn, model.inner_model.dyn), -# HierarchicalObservations(model.inner_model.obs), -# ) -# return filter(rng, ssm, algo, observations; ref_state=ref_state, kwargs...) -# end +function filter( + rng::AbstractRNG, + model::HierarchicalSSM, + algo::ParticleFilter, + observations::AbstractVector; + ref_state::Union{Nothing,AbstractVector}=nothing, + kwargs..., +) + ssm = StateSpaceModel( + HierarchicalPrior(model.outer_prior, model.inner_model.prior), + HierarchicalDynamics(model.outer_dyn, model.inner_model.dyn), + HierarchicalObservations(model.inner_model.obs), + ) + return filter(rng, ssm, algo, observations; ref_state=ref_state, kwargs...) +end abstract type PredictivePosterior end diff --git a/GeneralisedFilters/test/runtests.jl b/GeneralisedFilters/test/runtests.jl index c23062b8..fe5b96c7 100644 --- a/GeneralisedFilters/test/runtests.jl +++ b/GeneralisedFilters/test/runtests.jl @@ -387,8 +387,8 @@ end states, ll = GeneralisedFilters.filter(rng, hier_model, bf, ys) # Extract final filtered states - xs = map(p -> getproperty(p, :x), states.particles) - zs = map(p -> getproperty(p, :z), states.particles) + xs = map(p -> getproperty(p.state, :x), states.particles) + zs = map(p -> getproperty(p.state, :z), states.particles) log_ws = getfield.(states.particles, :log_w) ws = softmax(log_ws) From 8b9ecbe87b776744054382ae5661f134934cdd39 Mon Sep 17 00:00:00 2001 From: Tim Hargreaves Date: Tue, 30 Sep 2025 12:42:48 +0100 Subject: [PATCH 13/20] Formatting and performance tweaks --- GeneralisedFilters/src/GFTest/resamplers.jl | 4 +--- GeneralisedFilters/test/runtests.jl | 4 ++-- SSMProblems/src/utils/forward_simulation.jl | 2 +- 3 files changed, 4 insertions(+), 6 deletions(-) diff --git a/GeneralisedFilters/src/GFTest/resamplers.jl b/GeneralisedFilters/src/GFTest/resamplers.jl index 4a437290..a581f26f 100644 --- a/GeneralisedFilters/src/GFTest/resamplers.jl +++ b/GeneralisedFilters/src/GFTest/resamplers.jl @@ -29,9 +29,7 @@ function GeneralisedFilters.resample( if alt_resampler.resample_next # Resample using wrapped resampler alt_resampler.resample_next = false - return GeneralisedFilters.resample( - rng, alt_resampler.resampler, state; ref_state - ) + return GeneralisedFilters.resample(rng, alt_resampler.resampler, state; ref_state) else # Skip resampling - keep particles with their current weights and set ancestors to # themselves diff --git a/GeneralisedFilters/test/runtests.jl b/GeneralisedFilters/test/runtests.jl index fe5b96c7..fbb00feb 100644 --- a/GeneralisedFilters/test/runtests.jl +++ b/GeneralisedFilters/test/runtests.jl @@ -231,7 +231,7 @@ end full_model, hier_model = GeneralisedFilters.GFTest.create_dummy_linear_gaussian_model( rng, 1, 1, 1; static_arrays=true ) - _, _, ys = sample(rng, hier_model, 4) + _, _, ys = sample(rng, full_model, 4) resampler = GeneralisedFilters.GFTest.AlternatingResampler() bf = BF(10^6; resampler=resampler) @@ -260,7 +260,7 @@ end full_model, hier_model = GeneralisedFilters.GFTest.create_dummy_linear_gaussian_model( rng, 1, 1, 1; static_arrays=true ) - _, _, ys = sample(rng, hier_model, 4) + _, _, ys = sample(rng, full_model, 4) prop = GeneralisedFilters.GFTest.OverdispersedProposal(dyn(hier_model).outer_dyn, 1.5) resampler = GeneralisedFilters.GFTest.AlternatingResampler() diff --git a/SSMProblems/src/utils/forward_simulation.jl b/SSMProblems/src/utils/forward_simulation.jl index e822cccd..55c299bc 100644 --- a/SSMProblems/src/utils/forward_simulation.jl +++ b/SSMProblems/src/utils/forward_simulation.jl @@ -27,4 +27,4 @@ Simulate a trajectory using the default random number generator. """ function sample(model::AbstractStateSpaceModel, T::Integer; kwargs...) return sample(default_rng(), model, T; kwargs...) -end \ No newline at end of file +end From d808f1169879a38bc7a8edd033e44c2660dc5978 Mon Sep 17 00:00:00 2001 From: Tim Hargreaves Date: Tue, 30 Sep 2025 12:51:00 +0100 Subject: [PATCH 14/20] Update forward algorithm to new interface --- GeneralisedFilters/src/algorithms/forward.jl | 16 +++++++--------- 1 file changed, 7 insertions(+), 9 deletions(-) diff --git a/GeneralisedFilters/src/algorithms/forward.jl b/GeneralisedFilters/src/algorithms/forward.jl index 5da6bc14..e2c8d144 100644 --- a/GeneralisedFilters/src/algorithms/forward.jl +++ b/GeneralisedFilters/src/algorithms/forward.jl @@ -3,38 +3,36 @@ export ForwardAlgorithm, FW struct ForwardAlgorithm <: AbstractFilter end const FW = ForwardAlgorithm -function initialise( - rng::AbstractRNG, model::DiscreteStateSpaceModel, ::ForwardAlgorithm; kwargs... -) - return calc_α0(model.prior; kwargs...) +function initialise(rng::AbstractRNG, prior::DiscretePrior, ::ForwardAlgorithm; kwargs...) + return calc_α0(prior; kwargs...) end function predict( rng::AbstractRNG, - model::DiscreteStateSpaceModel, + dyn::DiscreteLatentDynamics, filter::ForwardAlgorithm, step::Integer, states::AbstractVector, observation; kwargs..., ) - P = calc_P(model.dyn, step; kwargs...) + P = calc_P(dyn, step; kwargs...) return (states' * P)' end function update( - model::DiscreteStateSpaceModel{T}, + obs::ObservationProcess, filter::ForwardAlgorithm, step::Integer, states::AbstractVector, observation; kwargs..., -) where {T} +) # Compute emission probability vector # TODO: should we define density as part of the interface or run the whole algorithm in # log space? b = map( - x -> exp(SSMProblems.logdensity(model.obs, step, x, observation; kwargs...)), + x -> exp(SSMProblems.logdensity(obs, step, x, observation; kwargs...)), eachindex(states), ) filtered_states = b .* states From 13013e374aaf4c2806daae4d261a51dc6348b0d4 Mon Sep 17 00:00:00 2001 From: Tim Hargreaves Date: Tue, 30 Sep 2025 13:39:09 +0100 Subject: [PATCH 15/20] Remove temporary test script --- .../test/combination_test_script.jl | 150 ------------------ 1 file changed, 150 deletions(-) delete mode 100644 GeneralisedFilters/test/combination_test_script.jl diff --git a/GeneralisedFilters/test/combination_test_script.jl b/GeneralisedFilters/test/combination_test_script.jl deleted file mode 100644 index a5061b29..00000000 --- a/GeneralisedFilters/test/combination_test_script.jl +++ /dev/null @@ -1,150 +0,0 @@ -using Distributions -using GeneralisedFilters -using LinearAlgebra -using LogExpFunctions -using SSMProblems -using StableRNGs -using StatsBase -using Test - -println() -println("########################") -println("#### STARTING TESTS ####") -println("########################") -println() - -rng = StableRNG(1234) - -model = GeneralisedFilters.GFTest.create_linear_gaussian_model( - rng, 1, 1; static_arrays=true -) -_, _, ys = sample(rng, model, 3) - -bf = BF(10^6; threshold=0.8) -bf_state, llbf = GeneralisedFilters.filter(rng, model, bf, ys) -kf_state, llkf = GeneralisedFilters.filter(rng, model, KF(), ys) - -xs = getfield.(bf_state.particles, :state) -log_ws = getfield.(bf_state.particles, :log_w) -ws = softmax(log_ws) - -# Compare log-likelihood and states -println("BF State: ", @test first(kf_state.μ) ≈ sum(first.(xs) .* ws) rtol = 1e-3) -println("BF LL: ", @test llkf ≈ llbf atol = 1e-3) - -struct OptimalProposal{ - LD<:LinearGaussianLatentDynamics,OP<:LinearGaussianObservationProcess -} <: AbstractProposal - dyn::LD - obs::OP - dummy::Bool # if using dummy hierarchical model -end -function SSMProblems.distribution(prop::OptimalProposal, step::Integer, x, y; kwargs...) - A, b, Q = GeneralisedFilters.calc_params(prop.dyn, step; kwargs...) - H, c, R = GeneralisedFilters.calc_params(prop.obs, step; kwargs...) - Σ = inv(inv(Q) + H' * inv(R) * H) - μ = Σ * (inv(Q) * (A * x + b) + H' * inv(R) * (y - c)) - if prop.dummy - μ = μ[[1]] - Σ = Σ[[1], [1]] - end - return MvNormal(μ, Σ) -end -# Propose from observation distribution -# proposal = PeturbationProposal(only(model.obs.R)) -proposal = OptimalProposal(model.dyn, model.obs, false) -gf = ParticleFilter(10^6, proposal; threshold=1.0) - -gf_state, llgf = GeneralisedFilters.filter(rng, model, gf, ys) -xs = getfield.(gf_state.particles, :state) -log_ws = getfield.(gf_state.particles, :log_w) -ws = softmax(log_ws) - -# Fairly sure this is correct but would be good to confirm (needs to be faster — SArrays) -println("GF State: ", @test first(kf_state.μ) ≈ sum(first.(xs) .* ws) rtol = 1e-3) -println("GF LL: ", @test llkf ≈ llgf atol = 1e-3) - -############################## -#### RAO-BLACKWELLISATION #### -############################## - -full_model, hier_model = GeneralisedFilters.GFTest.create_dummy_linear_gaussian_model( - rng, 1, 1, 1; static_arrays=true -) -_, _, ys = sample(rng, hier_model, 3) - -rbbf = RBPF(bf, KalmanFilter()) - -rbbf_state, llrbbf = GeneralisedFilters.filter(rng, hier_model, rbbf, ys) -xs = getfield.(rbbf_state.particles, :x) -zs = getfield.(rbbf_state.particles, :z) -log_ws = getfield.(rbbf_state.particles, :log_w) -ws = softmax(log_ws) - -kf_state, llkf = GeneralisedFilters.filter(rng, full_model, KF(), ys) - -println("RBBF Outer: ", @test first(kf_state.μ) ≈ sum(only.(xs) .* ws) rtol = 1e-3) -println( - "RBBF Inner: ", @test last(kf_state.μ) ≈ sum(only.(getfield.(zs, :μ)) .* ws) rtol = 1e-3 -) -println("RBBF LL: ", @test llkf ≈ llrbbf atol = 1e-3) - -proposal = OptimalProposal(model.dyn, model.obs, true) -gf = ParticleFilter(10^6, proposal; threshold=1.0) -rbgf = RBPF(gf, KalmanFilter()) -rbgf_state, llrbgf = GeneralisedFilters.filter(rng, hier_model, rbgf, ys) -xs = getfield.(rbgf_state.particles, :x) -zs = getfield.(rbgf_state.particles, :z) -log_ws = getfield.(rbgf_state.particles, :log_w) -ws = softmax(log_ws) - -# Reduce tolerance since this is a bit harder to filter to high precision -println("RBGF Outer: ", @test first(kf_state.μ) ≈ sum(only.(xs) .* ws) rtol = 1e-2) -println( - "RBGF Inner: ", @test last(kf_state.μ) ≈ sum(only.(getfield.(zs, :μ)) .* ws) rtol = 1e-2 -) -println("RBGF LL: ", @test llkf ≈ llrbgf atol = 1e-2) - -################################ -#### REFERENCE TRAJECTORIES #### -################################ - -# Hard to verify these are correct until the code is faster and we can run a full loop -# For now we just check they run without error - -using OffsetArrays -ref_traj = [randn(rng, 1) for _ in 0:3]; -ref_traj = OffsetArray(ref_traj, 0:3); -GeneralisedFilters.filter(rng, model, bf, ys; ref_state=ref_traj); -GeneralisedFilters.filter(rng, model, gf, ys; ref_state=ref_traj); -GeneralisedFilters.filter(rng, hier_model, rbbf, ys; ref_state=ref_traj); -GeneralisedFilters.filter(rng, hier_model, rbgf, ys; ref_state=ref_traj); - -#################################### -#### AUXILIARY PARTICLE FILTERS #### -#################################### - -kf_state, llkf = GeneralisedFilters.filter(rng, model, KF(), ys); - -bf = BF(10^6; threshold=1.0) -abf = AuxiliaryParticleFilter(bf, MeanPredictive()) -abf_state, llabf = GeneralisedFilters.filter(rng, model, abf, ys); -xs = getfield.(abf_state.particles, :state) -log_ws = getfield.(abf_state.particles, :log_w) -ws = softmax(log_ws) -println("ABF State: ", @test first(kf_state.μ) ≈ sum(first.(xs) .* ws) rtol = 1e-3) -println("ABF LL: ", @test llkf ≈ llabf atol = 1e-3) - -kf_state, llkf = GeneralisedFilters.filter(rng, full_model, KF(), ys); -rbbf = RBPF(bf, KalmanFilter()) -arbf = AuxiliaryParticleFilter(rbbf, MeanPredictive()) -arbf_state, llarbf = GeneralisedFilters.filter(rng, hier_model, arbf, ys); -xs = getfield.(arbf_state.particles, :x) -zs = getfield.(arbf_state.particles, :z) -log_ws = getfield.(arbf_state.particles, :log_w) -ws = softmax(log_ws) -println("ARBF Outer: ", @test first(kf_state.μ) ≈ sum(only.(xs) .* ws) rtol = 1e-3) -println( - "ARBF Inner: ", @test last(kf_state.μ) ≈ sum(only.(getfield.(zs, :μ)) .* ws) rtol = 1e-3 -) -println("ARBF LL: ", @test llkf ≈ llarbf atol = 1e-3) From ba637e124c51bc849a0c13d5feed43ffd813d806 Mon Sep 17 00:00:00 2001 From: Tim Hargreaves Date: Tue, 30 Sep 2025 13:45:02 +0100 Subject: [PATCH 16/20] Introduce RBState to implement upated particle tree --- .../examples/trend-inflation/utilities.jl | 6 ++- GeneralisedFilters/src/algorithms/rbpf.jl | 41 +++++++++++++------ GeneralisedFilters/src/callbacks.jl | 20 +++++---- GeneralisedFilters/src/containers.jl | 7 ++-- GeneralisedFilters/src/resamplers.jl | 6 --- GeneralisedFilters/test/runtests.jl | 30 +++++++------- research/variational_filter/script.jl | 3 +- 7 files changed, 67 insertions(+), 46 deletions(-) diff --git a/GeneralisedFilters/examples/trend-inflation/utilities.jl b/GeneralisedFilters/examples/trend-inflation/utilities.jl index b8b37514..95252935 100644 --- a/GeneralisedFilters/examples/trend-inflation/utilities.jl +++ b/GeneralisedFilters/examples/trend-inflation/utilities.jl @@ -17,8 +17,10 @@ mean_path(paths, states) = _mean_path(identity, paths, states) function mean_path( paths::Vector{Vector{T}}, states ) where {T<:GeneralisedFilters.RBParticle} - zs = _mean_path(z -> getproperty.(getproperty.(z, :z), :μ), paths, states) - xs = _mean_path(x -> getproperty.(x, :x), paths, states) + zs = _mean_path( + z -> getproperty.(getproperty.(getproperty.(z, :state), :z), :μ), paths, states + ) + xs = _mean_path(x -> getproperty.(getproperty.(x, :state), :x), paths, states) return zs, xs end diff --git a/GeneralisedFilters/src/algorithms/rbpf.jl b/GeneralisedFilters/src/algorithms/rbpf.jl index c1ab8ede..e9cfad0d 100644 --- a/GeneralisedFilters/src/algorithms/rbpf.jl +++ b/GeneralisedFilters/src/algorithms/rbpf.jl @@ -19,7 +19,7 @@ function initialise_particle( x = sample_prior(rng, prior.outer_prior, algo.pf, ref_state; kwargs...) z = initialise(rng, prior.inner_prior, algo.af; new_outer=x, kwargs...) # TODO (RB): determine the correct type for the log_w field or use a NoWeight type - return RBParticle(x, z, 0.0, 0) + return Particle(RBState(x, z), 0.0, 0) end function predict_particle( @@ -34,21 +34,28 @@ function predict_particle( ) # TODO: really we should be conditioning on the current RB state to allow for optimal proposals new_x, logw_inc = propogate( - rng, dyn.outer_dyn, algo.pf, iter, particle.x, observation, ref_state; kwargs... + rng, + dyn.outer_dyn, + algo.pf, + iter, + particle.state.x, + observation, + ref_state; + kwargs..., ) new_z = predict( rng, dyn.inner_dyn, algo.af, iter, - particle.z, + particle.state.z, observation; - prev_outer=particle.x, + prev_outer=particle.state.x, new_outer=new_x, kwargs..., ) - return RBParticle(new_x, new_z, particle.log_w + logw_inc, particle.ancestor) + return Particle(RBState(new_x, new_z), particle.log_w + logw_inc, particle.ancestor) end function update_particle( @@ -60,9 +67,17 @@ function update_particle( kwargs..., ) new_z, log_increment = update( - obs, algo.af, iter, particle.z, observation; new_outer=particle.x, kwargs... + obs, + algo.af, + iter, + particle.state.z, + observation; + new_outer=particle.state.x, + kwargs..., + ) + return Particle( + RBState(particle.state.x, new_z), particle.log_w + log_increment, particle.ancestor ) - return RBParticle(particle.x, new_z, particle.log_w + log_increment, particle.ancestor) end function predictive_state( @@ -74,19 +89,21 @@ function predictive_state( kwargs..., ) rbpf = apf.pf - x_star = predictive_statistic(rng, apf.pp, dyn.outer_dyn, iter, particle.x; kwargs...) + x_star = predictive_statistic( + rng, apf.pp, dyn.outer_dyn, iter, particle.state.x; kwargs... + ) z_star = predict( rng, dyn.inner_dyn, rbpf.af, iter, - particle.z, + particle.state.z, nothing; # no observation available — maybe we should pass this in - prev_outer=particle.x, + prev_outer=particle.state.x, new_outer=x_star, kwargs..., ) - return RBParticle(x_star, z_star, particle.log_w, particle.ancestor) + return Particle(RBState(x_star, z_star), particle.log_w, particle.ancestor) end function predictive_loglik( @@ -98,7 +115,7 @@ function predictive_loglik( kwargs..., ) _, log_increment = update( - obs, algo.af, iter, p_star.z, observation; new_outer=p_star.x, kwargs... + obs, algo.af, iter, p_star.state.z, observation; new_outer=p_star.state.x, kwargs... ) return log_increment end diff --git a/GeneralisedFilters/src/callbacks.jl b/GeneralisedFilters/src/callbacks.jl index f017ba49..145b7cc4 100644 --- a/GeneralisedFilters/src/callbacks.jl +++ b/GeneralisedFilters/src/callbacks.jl @@ -344,7 +344,8 @@ end function (c::ParallelAncestorCallback)( model, filter, step, state, data, ::PostInitCallback; kwargs... ) - @inbounds c.tree.states[1:(filter.N)] = deepcopy(state.particles) + N = num_particles(filter) + @inbounds c.tree.states[1:N] = deepcopy(state.particles) return nothing end @@ -352,7 +353,8 @@ function (c::ParallelAncestorCallback)( model, filter, step, state, data, ::PostUpdateCallback; kwargs... ) # insert! implicitly deepcopies - insert!(c.tree, state.particles, state.ancestors) + particles = state.particles + insert!(c.tree, getfield.(particles, :state), getfield.(particles, :ancestor)) return nothing end @@ -369,15 +371,17 @@ mutable struct AncestorCallback <: AbstractCallback end function (c::AncestorCallback)(model, filter, state, data, ::PostInitCallback; kwargs...) - c.tree = ParticleTree(state.particles, floor(Int64, filter.N * log(filter.N))) + N = num_particles(filter) + c.tree = ParticleTree(getfield.(state.particles, :state), floor(Int64, N * log(N))) return nothing end function (c::AncestorCallback)( model, filter, step, state, data, ::PostPredictCallback; kwargs... ) - prune!(c.tree, get_offspring(state.ancestors)) - insert!(c.tree, state.particles, state.ancestors) + particles = state.particles + prune!(c.tree, getfield.(particles, :ancestor)) + insert!(c.tree, getfield.(particles, :state), getfield.(particles, :ancestor)) return nothing end @@ -392,14 +396,16 @@ mutable struct ResamplerCallback <: AbstractCallback end function (c::ResamplerCallback)(model, filter, state, data, ::PostInitCallback; kwargs...) - c.tree = ParticleTree(collect(1:N), floor(Int64, filter.N * log(filter.N))) + N = num_particles(filter) + c.tree = ParticleTree(collect(1:N), floor(Int64, N * log(N))) return nothing end function (c::ResamplerCallback)( model, filter, step, state, data, ::PostResampleCallback; kwargs... ) + N = num_particles(filter) prune!(c.tree, get_offspring(state.ancestors)) - insert!(c.tree, collect(1:(filter.N)), state.ancestors) + insert!(c.tree, collect(1:N), state.ancestors) return nothing end diff --git a/GeneralisedFilters/src/containers.jl b/GeneralisedFilters/src/containers.jl index bd98c281..577206dd 100644 --- a/GeneralisedFilters/src/containers.jl +++ b/GeneralisedFilters/src/containers.jl @@ -5,20 +5,19 @@ abstract type AbstractParticle{WT} end # New types -# TODO (RB): could the RB particle be a regular particle with a RB state? mutable struct Particle{ST,WT} <: AbstractParticle{WT} state::ST log_w::WT ancestor::Int end -mutable struct RBParticle{XT,ZT,WT} <: AbstractParticle{WT} +mutable struct RBState{XT,ZT} x::XT z::ZT - log_w::WT - ancestor::Int end +const RBParticle{XT,ZT,WT} = Particle{RBState{XT,ZT},WT} + mutable struct ParticleDistribution{WT,PT<:AbstractParticle{WT},VT<:AbstractVector{PT}} particles::VT prev_logsumexp::WT diff --git a/GeneralisedFilters/src/resamplers.jl b/GeneralisedFilters/src/resamplers.jl index e256c3b9..6128b9ab 100644 --- a/GeneralisedFilters/src/resamplers.jl +++ b/GeneralisedFilters/src/resamplers.jl @@ -35,9 +35,6 @@ end function resample_ancestor(particle::Particle, ancestor::Int) return Particle(particle.state, 0.0, ancestor) end -function resample_ancestor(particle::RBParticle, ancestor::Int) - return RBParticle(particle.x, particle.z, 0.0, ancestor) -end ## CONDITIONAL RESAMPLING ################################################################## @@ -77,9 +74,6 @@ end function set_ancestor(particle::Particle, ancestor::Int) return Particle(particle.state, particle.log_w, ancestor) end -function set_ancestor(particle::RBParticle, ancestor::Int) - return RBParticle(particle.x, particle.z, particle.log_w, ancestor) -end ## CATEGORICAL RESAMPLE #################################################################### diff --git a/GeneralisedFilters/test/runtests.jl b/GeneralisedFilters/test/runtests.jl index fbb00feb..9af59de8 100644 --- a/GeneralisedFilters/test/runtests.jl +++ b/GeneralisedFilters/test/runtests.jl @@ -238,8 +238,8 @@ end rbbf = RBPF(bf, KalmanFilter()) rbbf_state, llrbbf = GeneralisedFilters.filter(rng, hier_model, rbbf, ys) - xs = getfield.(rbbf_state.particles, :x) - zs = getfield.(rbbf_state.particles, :z) + xs = getfield.(getfield.(rbbf_state.particles, :state), :x) + zs = getfield.(getfield.(rbbf_state.particles, :state), :z) log_ws = getfield.(rbbf_state.particles, :log_w) ws = softmax(log_ws) @@ -267,8 +267,8 @@ end gf = ParticleFilter(10^6, prop; resampler=resampler) rbgf = RBPF(gf, KalmanFilter()) rbgf_state, llrbgf = GeneralisedFilters.filter(rng, hier_model, rbgf, ys) - xs = getfield.(rbgf_state.particles, :x) - zs = getfield.(rbgf_state.particles, :z) + xs = getfield.(getfield.(rbgf_state.particles, :state), :x) + zs = getfield.(getfield.(rbgf_state.particles, :state), :z) log_ws = getfield.(rbgf_state.particles, :log_w) ws = softmax(log_ws) @@ -325,8 +325,8 @@ end rbbf = RBPF(bf, KalmanFilter()) arbf = AuxiliaryParticleFilter(rbbf, MeanPredictive()) arbf_state, llarbf = GeneralisedFilters.filter(rng, hier_model, arbf, ys) - xs = getfield.(arbf_state.particles, :x) - zs = getfield.(arbf_state.particles, :z) + xs = getfield.(getfield.(arbf_state.particles, :state), :x) + zs = getfield.(getfield.(arbf_state.particles, :state), :z) log_ws = getfield.(arbf_state.particles, :log_w) ws = softmax(log_ws) @@ -352,7 +352,7 @@ end _, _, ys = sample(rng, full_model, T) cb = GeneralisedFilters.AncestorCallback(nothing) - rbpf = RBPF(KalmanFilter(), N_particles) + rbpf = RBPF(BF(N_particles; threshold=0.8), KalmanFilter()) GeneralisedFilters.filter(rng, hier_model, rbpf, ys; callback=cb) # TODO: add proper test comparing to dense storage @@ -554,11 +554,13 @@ end push!(trajectory_samples, deepcopy(ref_traj)) end # Reference trajectory should only be nonlinear state for RBPF - ref_traj = getproperty.(ref_traj, :x) + ref_traj = getproperty.(getproperty.(ref_traj, :state), :x) end # Extract inner and outer trajectories - x_trajectories = getproperty.(getindex.(trajectory_samples, t_smooth), :x) + x_trajectories = getproperty.( + getproperty.(getindex.(trajectory_samples, t_smooth), :state), :x + ) # Manually perform smoothing until we have a cleaner interface A = hier_model.inner_model.dyn.A @@ -567,13 +569,13 @@ end Q = hier_model.inner_model.dyn.Q z_smoothed_means = Vector{T}(undef, N_sample) for i in 1:N_sample - μ = trajectory_samples[i][K].z.μ - Σ = trajectory_samples[i][K].z.Σ + μ = trajectory_samples[i][K].state.z.μ + Σ = trajectory_samples[i][K].state.z.Σ for t in (K - 1):-1:t_smooth - μ_filt = trajectory_samples[i][t].z.μ - Σ_filt = trajectory_samples[i][t].z.Σ - μ_pred = A * μ_filt + b + C * trajectory_samples[i][t].x + μ_filt = trajectory_samples[i][t].state.z.μ + Σ_filt = trajectory_samples[i][t].state.z.Σ + μ_pred = A * μ_filt + b + C * trajectory_samples[i][t].state.x Σ_pred = A * Σ_filt * A' + Q G = Σ_filt * A' * inv(Σ_pred) diff --git a/research/variational_filter/script.jl b/research/variational_filter/script.jl index e92a7550..535a8fe4 100644 --- a/research/variational_filter/script.jl +++ b/research/variational_filter/script.jl @@ -4,7 +4,8 @@ This example demonstrates the extensibility of GeneralisedFilters with an adaptation of VSMC with a tunable proposal ([Naesseth et al, 2016](https://arxiv.org/pdf/1705.11140)). =# -using Pkg; Pkg.activate("research/variational_filter") +using Pkg; +Pkg.activate("research/variational_filter") using GeneralisedFilters, SSMProblems using PDMats, LinearAlgebra From fd6a9b3469f8219fd79c6d1c7d2dc979474dd575 Mon Sep 17 00:00:00 2001 From: Tim Hargreaves Date: Tue, 30 Sep 2025 13:51:18 +0100 Subject: [PATCH 17/20] Correct DenseAncestorCallback for new interface --- GeneralisedFilters/src/callbacks.jl | 8 +++++--- 1 file changed, 5 insertions(+), 3 deletions(-) diff --git a/GeneralisedFilters/src/callbacks.jl b/GeneralisedFilters/src/callbacks.jl index 145b7cc4..c53b4bdd 100644 --- a/GeneralisedFilters/src/callbacks.jl +++ b/GeneralisedFilters/src/callbacks.jl @@ -79,8 +79,9 @@ end function (c::DenseAncestorCallback)( model, filter, state, data, ::PostInitCallback; kwargs... ) + particles = state.particles c.container = DenseParticleContainer( - OffsetVector([deepcopy(state.particles)], -1), Vector{Int}[] + OffsetVector([deepcopy(getfield.(particles, :state))], -1), Vector{Int}[] ) return nothing end @@ -88,8 +89,9 @@ end function (c::DenseAncestorCallback)( model, filter, step, state, data, ::PostUpdateCallback; kwargs... ) - push!(c.container.particles, deepcopy(state.particles)) - push!(c.container.ancestors, deepcopy(state.ancestors)) + particles = state.particles + push!(c.container.particles, deepcopy(getfield.(particles, :state))) + push!(c.container.ancestors, deepcopy(getfield.(particles, :ancestor))) return nothing end From 60dbb0b73cd2acbbf95a58b19ead6a37be15f5db Mon Sep 17 00:00:00 2001 From: Tim Hargreaves Date: Tue, 30 Sep 2025 14:11:42 +0100 Subject: [PATCH 18/20] Update CSMC tests to new syntax --- GeneralisedFilters/test/runtests.jl | 29 ++++++++++++++--------------- 1 file changed, 14 insertions(+), 15 deletions(-) diff --git a/GeneralisedFilters/test/runtests.jl b/GeneralisedFilters/test/runtests.jl index 9af59de8..a5408faf 100644 --- a/GeneralisedFilters/test/runtests.jl +++ b/GeneralisedFilters/test/runtests.jl @@ -460,7 +460,7 @@ end T = Float64 N_particles = 10 N_burnin = 1000 - N_sample = 10000 + N_sample = 100000 rng = StableRNG(SEED) model = GeneralisedFilters.GFTest.create_linear_gaussian_model(rng, Dx, Dy) @@ -484,7 +484,7 @@ end ) log_ws = getfield.(bf_state.particles, :log_w) ws = softmax(log_ws) - sampled_idx = sample(rng, 1:length(bf_state), Weights(ws)) + sampled_idx = sample(rng, 1:N_particles, Weights(ws)) global ref_traj = GeneralisedFilters.get_ancestry(cb.container, sampled_idx) if i > N_burnin push!(trajectory_samples, ref_traj) @@ -497,8 +497,8 @@ end log_recip_likelihood_estimate = logsumexp(-lls) - log(length(lls)) csmc_mean = sum(getindex.(trajectory_samples, t_smooth)) / N_sample - @test csmc_mean ≈ state.μ rtol = 1e-2 - @test log_recip_likelihood_estimate ≈ -ks_ll rtol = 1e-2 + @test csmc_mean ≈ state.μ rtol = 1e-3 + @test log_recip_likelihood_estimate ≈ -ks_ll rtol = 1e-3 end @testitem "RBCSMC test" begin @@ -510,6 +510,7 @@ end using Random: randexp using StatsBase: sample, Weights using StaticArrays + using Statistics using OffsetArrays @@ -536,7 +537,7 @@ end ) N_steps = N_burnin + N_sample - rbpf = RBPF(KalmanFilter(), N_particles; threshold=0.6) + rbpf = RBPF(BF(N_particles; threshold=0.8), KalmanFilter()) ref_traj = nothing trajectory_samples = [] @@ -547,20 +548,18 @@ end ) log_ws = getfield.(bf_state.particles, :log_w) ws = softmax(log_ws) - sampled_idx = sample(rng, 1:length(bf_state), Weights(ws)) + sampled_idx = sample(rng, 1:N_particles, Weights(ws)) global ref_traj = GeneralisedFilters.get_ancestry(cb.container, sampled_idx) if i > N_burnin push!(trajectory_samples, deepcopy(ref_traj)) end # Reference trajectory should only be nonlinear state for RBPF - ref_traj = getproperty.(getproperty.(ref_traj, :state), :x) + ref_traj = getproperty.(ref_traj, :x) end # Extract inner and outer trajectories - x_trajectories = getproperty.( - getproperty.(getindex.(trajectory_samples, t_smooth), :state), :x - ) + x_trajectories = getproperty.(getindex.(trajectory_samples, t_smooth), :x) # Manually perform smoothing until we have a cleaner interface A = hier_model.inner_model.dyn.A @@ -569,13 +568,13 @@ end Q = hier_model.inner_model.dyn.Q z_smoothed_means = Vector{T}(undef, N_sample) for i in 1:N_sample - μ = trajectory_samples[i][K].state.z.μ - Σ = trajectory_samples[i][K].state.z.Σ + μ = trajectory_samples[i][K].z.μ + Σ = trajectory_samples[i][K].z.Σ for t in (K - 1):-1:t_smooth - μ_filt = trajectory_samples[i][t].state.z.μ - Σ_filt = trajectory_samples[i][t].state.z.Σ - μ_pred = A * μ_filt + b + C * trajectory_samples[i][t].state.x + μ_filt = trajectory_samples[i][t].z.μ + Σ_filt = trajectory_samples[i][t].z.Σ + μ_pred = A * μ_filt + b + C * trajectory_samples[i][t].x Σ_pred = A * Σ_filt * A' + Q G = Σ_filt * A' * inv(Σ_pred) From 3fed80ef0fc7fd99cce1fdab1d07b253d4ce2672 Mon Sep 17 00:00:00 2001 From: Tim Hargreaves Date: Tue, 30 Sep 2025 14:44:41 +0100 Subject: [PATCH 19/20] Update trend inflation example to new syntax --- .../examples/trend-inflation/Project.toml | 1 + .../examples/trend-inflation/script.jl | 4 ++-- .../examples/trend-inflation/utilities.jl | 16 ++++++++-------- GeneralisedFilters/src/containers.jl | 6 ------ 4 files changed, 11 insertions(+), 16 deletions(-) diff --git a/GeneralisedFilters/examples/trend-inflation/Project.toml b/GeneralisedFilters/examples/trend-inflation/Project.toml index e063ab8f..a5f3513f 100644 --- a/GeneralisedFilters/examples/trend-inflation/Project.toml +++ b/GeneralisedFilters/examples/trend-inflation/Project.toml @@ -6,6 +6,7 @@ Dates = "ade2ca70-3891-5945-98fb-dc099432e06a" Distributions = "31c24e10-a181-5473-b8eb-7969acd0382f" GeneralisedFilters = "3ef92589-7ab8-43f9-b5b9-a3a0c86ecbb7" Literate = "98b081ad-f1c9-55d3-8b20-4c87d4299306" +LogExpFunctions = "2ab3a3ac-af41-5b50-aa03-7779005ae688" Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c" SSMProblems = "26aad666-b158-4e64-9d35-0e672562fa48" StatsBase = "2913bbd2-ae8a-5f71-8c99-4fb6c76f3a91" diff --git a/GeneralisedFilters/examples/trend-inflation/script.jl b/GeneralisedFilters/examples/trend-inflation/script.jl index 68822b64..44f8d733 100644 --- a/GeneralisedFilters/examples/trend-inflation/script.jl +++ b/GeneralisedFilters/examples/trend-inflation/script.jl @@ -135,7 +135,7 @@ sparse_ancestry = GF.AncestorCallback(nothing); states, ll = GF.filter( rng, UCSV(0.2), - RBPF(KalmanFilter(), 2^12; threshold=1.0), + RBPF(BF(2^12), KalmanFilter()), [[pce] for pce in fred_data.value]; callback=sparse_ancestry, ); @@ -238,7 +238,7 @@ sparse_ancestry = GF.AncestorCallback(nothing) states, ll = GF.filter( rng, UCSVO(0.2, 0.05), - RBPF(KalmanFilter(), 2^12; threshold=1.0), + RBPF(BF(2^12), KalmanFilter()), [[pce] for pce in fred_data.value]; callback=sparse_ancestry, ); diff --git a/GeneralisedFilters/examples/trend-inflation/utilities.jl b/GeneralisedFilters/examples/trend-inflation/utilities.jl index 95252935..e3f9e0da 100644 --- a/GeneralisedFilters/examples/trend-inflation/utilities.jl +++ b/GeneralisedFilters/examples/trend-inflation/utilities.jl @@ -1,26 +1,26 @@ using CSV, DataFrames using CairoMakie using Dates +using LogExpFunctions fred_data = CSV.read(joinpath(INFL_PATH, "data.csv"), DataFrame) ## PLOTTING UTILITIES ###################################################################### function _mean_path(f, paths, states) - return mean(map(x -> hcat(f(x)...), paths), StatsBase.weights(states)) + return mean( + map(x -> hcat(f(x)...), paths), + Weights(softmax(getproperty.(states.particles, :log_w))), + ) end # for normal collections mean_path(paths, states) = _mean_path(identity, paths, states) # for rao blackwellised particles -function mean_path( - paths::Vector{Vector{T}}, states -) where {T<:GeneralisedFilters.RBParticle} - zs = _mean_path( - z -> getproperty.(getproperty.(getproperty.(z, :state), :z), :μ), paths, states - ) - xs = _mean_path(x -> getproperty.(getproperty.(x, :state), :x), paths, states) +function mean_path(paths::Vector{Vector{T}}, states) where {T<:GeneralisedFilters.RBState} + zs = _mean_path(s -> getproperty.(getproperty.(s, :z), :μ), paths, states) + xs = _mean_path(s -> getproperty.(s, :x), paths, states) return zs, xs end diff --git a/GeneralisedFilters/src/containers.jl b/GeneralisedFilters/src/containers.jl index 577206dd..d830522b 100644 --- a/GeneralisedFilters/src/containers.jl +++ b/GeneralisedFilters/src/containers.jl @@ -30,12 +30,6 @@ function marginalise!(state::ParticleDistribution) return ll_increment end -# Old code -mutable struct ParticleWeights{WT<:Real} - log_weights::Vector{WT} - prev_logsumexp::WT -end - # """ # ParticleDistribution From 7c87ed420b071904f46222c23bd5a46e292c64ba Mon Sep 17 00:00:00 2001 From: Tim Hargreaves Date: Tue, 30 Sep 2025 14:59:22 +0100 Subject: [PATCH 20/20] Removed undefined export --- GeneralisedFilters/src/algorithms/kalman.jl | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/GeneralisedFilters/src/algorithms/kalman.jl b/GeneralisedFilters/src/algorithms/kalman.jl index a47202b6..7fe07e82 100644 --- a/GeneralisedFilters/src/algorithms/kalman.jl +++ b/GeneralisedFilters/src/algorithms/kalman.jl @@ -1,4 +1,4 @@ -export KalmanFilter, filter, BatchKalmanFilter +export KalmanFilter, filter using CUDA: i32 import PDMats: PDMat