Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions GeneralisedFilters/examples/trend-inflation/Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@ Dates = "ade2ca70-3891-5945-98fb-dc099432e06a"
Distributions = "31c24e10-a181-5473-b8eb-7969acd0382f"
GeneralisedFilters = "3ef92589-7ab8-43f9-b5b9-a3a0c86ecbb7"
Literate = "98b081ad-f1c9-55d3-8b20-4c87d4299306"
LogExpFunctions = "2ab3a3ac-af41-5b50-aa03-7779005ae688"
Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c"
SSMProblems = "26aad666-b158-4e64-9d35-0e672562fa48"
StatsBase = "2913bbd2-ae8a-5f71-8c99-4fb6c76f3a91"
4 changes: 2 additions & 2 deletions GeneralisedFilters/examples/trend-inflation/script.jl
Original file line number Diff line number Diff line change
Expand Up @@ -135,7 +135,7 @@ sparse_ancestry = GF.AncestorCallback(nothing);
states, ll = GF.filter(
rng,
UCSV(0.2),
RBPF(KalmanFilter(), 2^12; threshold=1.0),
RBPF(BF(2^12), KalmanFilter()),
[[pce] for pce in fred_data.value];
callback=sparse_ancestry,
);
Expand Down Expand Up @@ -238,7 +238,7 @@ sparse_ancestry = GF.AncestorCallback(nothing)
states, ll = GF.filter(
rng,
UCSVO(0.2, 0.05),
RBPF(KalmanFilter(), 2^12; threshold=1.0),
RBPF(BF(2^12), KalmanFilter()),
[[pce] for pce in fred_data.value];
callback=sparse_ancestry,
);
Expand Down
14 changes: 8 additions & 6 deletions GeneralisedFilters/examples/trend-inflation/utilities.jl
Original file line number Diff line number Diff line change
@@ -1,24 +1,26 @@
using CSV, DataFrames
using CairoMakie
using Dates
using LogExpFunctions

fred_data = CSV.read(joinpath(INFL_PATH, "data.csv"), DataFrame)

## PLOTTING UTILITIES ######################################################################

function _mean_path(f, paths, states)
return mean(map(x -> hcat(f(x)...), paths), StatsBase.weights(states))
return mean(
map(x -> hcat(f(x)...), paths),
Weights(softmax(getproperty.(states.particles, :log_w))),
)
end

# for normal collections
mean_path(paths, states) = _mean_path(identity, paths, states)

# for rao blackwellised particles
function mean_path(
paths::Vector{Vector{T}}, states
) where {T<:GeneralisedFilters.RaoBlackwellisedParticle}
zs = _mean_path(z -> getproperty.(getproperty.(z, :z), :μ), paths, states)
xs = _mean_path(x -> getproperty.(x, :x), paths, states)
function mean_path(paths::Vector{Vector{T}}, states) where {T<:GeneralisedFilters.RBState}
zs = _mean_path(s -> getproperty.(getproperty.(s, :z), :μ), paths, states)
xs = _mean_path(s -> getproperty.(s, :x), paths, states)
return zs, xs
end

Expand Down
2 changes: 2 additions & 0 deletions GeneralisedFilters/src/GFTest/GFTest.jl
Original file line number Diff line number Diff line change
Expand Up @@ -12,5 +12,7 @@ using SSMProblems
include("utils.jl")
include("models/linear_gaussian.jl")
include("models/dummy_linear_gaussian.jl")
include("proposals.jl")
include("resamplers.jl")

end
16 changes: 15 additions & 1 deletion GeneralisedFilters/src/GFTest/models/linear_gaussian.jl
Original file line number Diff line number Diff line change
@@ -1,10 +1,13 @@
using StaticArrays

function create_linear_gaussian_model(
rng::AbstractRNG,
Dx::Integer,
Dy::Integer,
T::Type{<:Real}=Float64,
process_noise_scale=T(0.1),
obs_noise_scale=T(1.0),
obs_noise_scale=T(1.0);
static_arrays::Bool=false,
)
μ0 = rand(rng, T, Dx)
Σ0 = rand_cov(rng, T, Dx)
Expand All @@ -15,6 +18,17 @@ function create_linear_gaussian_model(
c = rand(rng, T, Dy)
R = rand_cov(rng, T, Dy; scale=obs_noise_scale)

if static_arrays
μ0 = SVector{Dx,T}(μ0)
Σ0 = SMatrix{Dx,Dx,T}(Σ0)
A = SMatrix{Dx,Dx,T}(A)
b = SVector{Dx,T}(b)
Q = SMatrix{Dx,Dx,T}(Q)
H = SMatrix{Dy,Dx,T}(H)
c = SVector{Dy,T}(c)
R = SMatrix{Dy,Dy,T}(R)
end

return create_homogeneous_linear_gaussian_model(μ0, Σ0, A, b, Q, H, c, R)
end

Expand Down
43 changes: 43 additions & 0 deletions GeneralisedFilters/src/GFTest/proposals.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,43 @@
using Distributions

"""
OptimalProposal

Optimal importance proposal for linear Gaussian state space models.

A proposal coming from the closed-form distribution `p(x_t | x_{t-1}, y_t)`. This proposal
minimizes the variance of the importance weights.
"""
struct OptimalProposal{
LD<:LinearGaussianLatentDynamics,OP<:LinearGaussianObservationProcess
} <: AbstractProposal
dyn::LD
obs::OP
end

function SSMProblems.distribution(prop::OptimalProposal, step::Integer, x, y; kwargs...)
A, b, Q = GeneralisedFilters.calc_params(prop.dyn, step; kwargs...)
H, c, R = GeneralisedFilters.calc_params(prop.obs, step; kwargs...)
Σ = inv(inv(Q) + H' * inv(R) * H)
μ = Σ * (inv(Q) * (A * x + b) + H' * inv(R) * (y - c))
return MvNormal(μ, Σ)
end

"""
OverdispersedProposal

A proposal that overdisperses the latent dynamics by inflating the covariance.
"""
struct OverdispersedProposal{LD<:LinearGaussianLatentDynamics} <: AbstractProposal
dyn::LD
k::Float64
end

function SSMProblems.distribution(
prop::OverdispersedProposal, step::Integer, x, y; kwargs...
)
A, b, Q = GeneralisedFilters.calc_params(prop.dyn, step; kwargs...)
Q = prop.k * Q # overdisperse
μ = A * x + b
return MvNormal(μ, Q)
end
43 changes: 43 additions & 0 deletions GeneralisedFilters/src/GFTest/resamplers.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,43 @@
"""
AlternatingResampler

A resampler wrapper that alternates between resampling and not resampling on each step.
This is useful for testing the validity of filters in both cases.

The resampler maintains internal state to track whether it should resample on the next call.
By default, it resamples on the first call, then skips the second, then resamples on the
third, and so on.
"""
mutable struct AlternatingResampler <: GeneralisedFilters.AbstractConditionalResampler
resampler::GeneralisedFilters.AbstractResampler
resample_next::Bool
function AlternatingResampler(
resampler::GeneralisedFilters.AbstractResampler=Systematic()
)
return new(resampler, true)
end
end

function GeneralisedFilters.resample(
rng::AbstractRNG,
alt_resampler::AlternatingResampler,
state;
ref_state::Union{Nothing,AbstractVector}=nothing,
)
n = length(state.particles)

if alt_resampler.resample_next
# Resample using wrapped resampler
alt_resampler.resample_next = false
return GeneralisedFilters.resample(rng, alt_resampler.resampler, state; ref_state)
else
# Skip resampling - keep particles with their current weights and set ancestors to
# themselves
alt_resampler.resample_next = true
new_particles = similar(state.particles)
for i in 1:n
new_particles[i] = GeneralisedFilters.set_ancestor(state.particles[i], i)
end
return GeneralisedFilters.ParticleDistribution(new_particles, state.prev_logsumexp)
end
end
7 changes: 4 additions & 3 deletions GeneralisedFilters/src/GeneralisedFilters.jl
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@ module GeneralisedFilters
using AbstractMCMC: AbstractMCMC, AbstractSampler
import Distributions: MvNormal
import Random: AbstractRNG, default_rng, rand
import SSMProblems: prior, dyn, obs
using OffsetArrays
using SSMProblems
using StatsBase
Expand Down Expand Up @@ -66,7 +67,7 @@ function filter(
kwargs...,
)
# draw from the prior
init_state = initialise(rng, model, algo; kwargs...)
init_state = initialise(rng, prior(model), algo; kwargs...)
callback(model, algo, init_state, observations, PostInit; kwargs...)

# iterations starts here for type stability
Expand Down Expand Up @@ -117,10 +118,10 @@ function move(
callback::CallbackType=nothing,
kwargs...,
)
state = predict(rng, model, algo, iter, state, observation; kwargs...)
state = predict(rng, dyn(model), algo, iter, state, observation; kwargs...)
callback(model, algo, iter, state, observation, PostPredict; kwargs...)

state, ll_increment = update(model, algo, iter, state, observation; kwargs...)
state, ll_increment = update(obs(model), algo, iter, state, observation; kwargs...)
callback(model, algo, iter, state, observation, PostUpdate; kwargs...)

return state, ll_increment
Expand Down
16 changes: 7 additions & 9 deletions GeneralisedFilters/src/algorithms/forward.jl
Original file line number Diff line number Diff line change
Expand Up @@ -3,38 +3,36 @@ export ForwardAlgorithm, FW
struct ForwardAlgorithm <: AbstractFilter end
const FW = ForwardAlgorithm

function initialise(
rng::AbstractRNG, model::DiscreteStateSpaceModel, ::ForwardAlgorithm; kwargs...
)
return calc_α0(model.prior; kwargs...)
function initialise(rng::AbstractRNG, prior::DiscretePrior, ::ForwardAlgorithm; kwargs...)
return calc_α0(prior; kwargs...)
end

function predict(
rng::AbstractRNG,
model::DiscreteStateSpaceModel,
dyn::DiscreteLatentDynamics,
filter::ForwardAlgorithm,
step::Integer,
states::AbstractVector,
observation;
kwargs...,
)
P = calc_P(model.dyn, step; kwargs...)
P = calc_P(dyn, step; kwargs...)
return (states' * P)'
end

function update(
model::DiscreteStateSpaceModel{T},
obs::ObservationProcess,
filter::ForwardAlgorithm,
step::Integer,
states::AbstractVector,
observation;
kwargs...,
) where {T}
)
# Compute emission probability vector
# TODO: should we define density as part of the interface or run the whole algorithm in
# log space?
b = map(
x -> exp(SSMProblems.logdensity(model.obs, step, x, observation; kwargs...)),
x -> exp(SSMProblems.logdensity(obs, step, x, observation; kwargs...)),
eachindex(states),
)
filtered_states = b .* states
Expand Down
91 changes: 7 additions & 84 deletions GeneralisedFilters/src/algorithms/kalman.jl
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
export KalmanFilter, filter, BatchKalmanFilter
export KalmanFilter, filter
using CUDA: i32
import PDMats: PDMat

Expand All @@ -8,23 +8,21 @@ struct KalmanFilter <: AbstractFilter end

KF() = KalmanFilter()

function initialise(
rng::AbstractRNG, model::LinearGaussianStateSpaceModel, filter::KalmanFilter; kwargs...
)
μ0, Σ0 = calc_initial(model.prior; kwargs...)
function initialise(rng::AbstractRNG, prior::GaussianPrior, filter::KalmanFilter; kwargs...)
μ0, Σ0 = calc_initial(prior; kwargs...)
return GaussianDistribution(μ0, Σ0)
end

function predict(
rng::AbstractRNG,
model::LinearGaussianStateSpaceModel,
dyn::LinearGaussianLatentDynamics,
algo::KalmanFilter,
iter::Integer,
state::GaussianDistribution,
observation=nothing;
kwargs...,
)
params = calc_params(model.dyn, iter; kwargs...)
params = calc_params(dyn, iter; kwargs...)
state = kalman_predict(state, params)
return state
end
Expand All @@ -39,14 +37,14 @@ function kalman_predict(state, params)
end

function update(
model::LinearGaussianStateSpaceModel,
obs::LinearGaussianObservationProcess,
algo::KalmanFilter,
iter::Integer,
state::GaussianDistribution,
observation::AbstractVector;
kwargs...,
)
params = calc_params(model.obs, iter; kwargs...)
params = calc_params(obs, iter; kwargs...)
state, ll = kalman_update(state, params, observation)
return state, ll
end
Expand All @@ -71,81 +69,6 @@ function kalman_update(state, params, observation)
return state, ll
end

struct BatchKalmanFilter <: AbstractBatchFilter
batch_size::Int
end

function initialise(
rng::AbstractRNG,
model::LinearGaussianStateSpaceModel,
algo::BatchKalmanFilter;
kwargs...,
)
μ0s, Σ0s = batch_calc_initial(model.prior, algo.batch_size; kwargs...)
return BatchGaussianDistribution(μ0s, Σ0s)
end

function predict(
rng::AbstractRNG,
model::LinearGaussianStateSpaceModel,
algo::BatchKalmanFilter,
iter::Integer,
state::BatchGaussianDistribution,
observation;
kwargs...,
)
μs, Σs = state.μs, state.Σs
As, bs, Qs = batch_calc_params(model.dyn, iter, algo.batch_size; kwargs...)
μ̂s = NNlib.batched_vec(As, μs) .+ bs
Σ̂s = NNlib.batched_mul(NNlib.batched_mul(As, Σs), NNlib.batched_transpose(As)) .+ Qs
return BatchGaussianDistribution(μ̂s, Σ̂s)
end

function update(
model::LinearGaussianStateSpaceModel,
algo::BatchKalmanFilter,
iter::Integer,
state::BatchGaussianDistribution,
observation;
kwargs...,
)
# T = Float32 # temporary fix!!!
μs, Σs = state.μs, state.Σs
Hs, cs, Rs = batch_calc_params(model.obs, iter, algo.batch_size; kwargs...)
D = size(observation, 1)

m = NNlib.batched_vec(Hs, μs) .+ cs
y_res = cu(observation) .- m
S = NNlib.batched_mul(Hs, NNlib.batched_mul(Σs, NNlib.batched_transpose(Hs))) .+ Rs

ΣH_T = NNlib.batched_mul(Σs, NNlib.batched_transpose(Hs))

S_inv = CUDA.similar(S)
d_ipiv, _, d_S = CUDA.CUBLAS.getrf_strided_batched(S, true)
CUDA.CUBLAS.getri_strided_batched!(d_S, S_inv, d_ipiv)

diags = CuArray{eltype(S)}(undef, size(S, 1), size(S, 3))
for i in 1:size(S, 1)
diags[i, :] .= d_S[i, i, :]
end

log_dets = sum(log ∘ abs, diags; dims=1)

K = NNlib.batched_mul(ΣH_T, S_inv)

μ_filt = μs .+ NNlib.batched_vec(K, y_res)
Σ_filt = Σs .- NNlib.batched_mul(K, NNlib.batched_mul(Hs, Σs))

inv_term = NNlib.batched_vec(S_inv, y_res)
log_likes = -NNlib.batched_vec(reshape(y_res, 1, D, size(S, 3)), inv_term)
log_likes = (log_likes .- (log_dets .+ D * convert(eltype(log_likes), log(2π)))) ./ 2

# HACK: only errors seems to be from numerical stability so will just overwrite
log_likes[isnan.(log_likes)] .= -Inf

return BatchGaussianDistribution(μ_filt, Σ_filt), dropdims(log_likes; dims=1)
end

## KALMAN SMOOTHER #########################################################################

struct KalmanSmoother <: AbstractSmoother end
Expand Down
Loading
Loading