From eba931bd9e54793352c4cfca67dff49025d1ad09 Mon Sep 17 00:00:00 2001 From: Humphrey Yang Date: Wed, 17 Sep 2025 14:56:20 +1000 Subject: [PATCH 1/3] convert to JAX --- lectures/imp_sample.md | 284 ++++++++++++++++++++++++++--------------- 1 file changed, 182 insertions(+), 102 deletions(-) diff --git a/lectures/imp_sample.md b/lectures/imp_sample.md index 09ae67576..1ec0815fd 100644 --- a/lectures/imp_sample.md +++ b/lectures/imp_sample.md @@ -4,9 +4,9 @@ jupytext: extension: .md format_name: myst format_version: 0.13 - jupytext_version: 1.10.3 + jupytext_version: 1.16.6 kernelspec: - display_name: Python 3 + display_name: Python 3 (ipykernel) language: python name: python3 --- @@ -19,9 +19,9 @@ kernelspec: ## Overview -In {doc}`this lecture ` we described a peculiar property of a likelihood ratio process, namely, that its mean equals one for all $t \geq 0$ despite it's converging to zero almost surely. +In {doc}`this lecture ` we described a peculiar property of a likelihood ratio process, namely, that its mean equals one for all $t \geq 0$ despite its converging to zero almost surely. -While it is easy to verify that peculiar properly analytically (i.e., in population), it is challenging to use a computer simulation to verify it via an application of a law of large numbers that entails studying sample averages of repeated simulations. +While it is easy to verify that peculiar property analytically (i.e., in population), it is challenging to use a computer simulation to verify it via an application of a law of large numbers that entails studying sample averages of repeated simulations. To confront this challenge, this lecture puts __importance sampling__ to work to accelerate convergence of sample averages to population means. @@ -30,10 +30,13 @@ We use importance sampling to estimate the mean of a cumulative likelihood rati We start by importing some Python packages. ```{code-cell} ipython3 -import numpy as np -from numba import jit, vectorize, prange +import jax +import jax.numpy as jnp +import jax.random as jr import matplotlib.pyplot as plt -from math import gamma +from jax.scipy.special import gammaln +from typing import NamedTuple +from functools import partial ``` ## Mathematical expectation of likelihood ratio @@ -58,34 +61,46 @@ Our goal is to approximate the mathematical expectation $E \left[ L\left(\omega^ In {doc}`this lecture `, we showed that $E \left[ L\left(\omega^t\right) \right]$ equals $1$ for all $t$. -We want to check out how well this holds if we replace $E$ by with sample averages from simulations. +We want to check out how well this holds if we replace $E$ with sample averages from simulations. This turns out to be easier said than done because for Beta distributions assumed above, $L\left(\omega^t\right)$ has a very skewed distribution with a very long tail as $t \rightarrow \infty$. This property makes it difficult efficiently and accurately to estimate the mean by standard Monte Carlo simulation methods. -In this lecture we explore how a standard Monte Carlo method fails and how **importance sampling** -provides a more computationally efficient way to approximate the mean of the cumulative likelihood ratio. +In this lecture we explore how a standard Monte Carlo method fails. + +We also show how **importance sampling** provides a more computationally efficient way to approximate the mean of the cumulative likelihood ratio. We first take a look at the density functions `f` and `g` . ```{code-cell} ipython3 -# Parameters in the two beta distributions. -F_a, F_b = 1, 1 -G_a, G_b = 3, 1.2 - -@vectorize -def p(w, a, b): - r = gamma(a + b) / (gamma(a) * gamma(b)) - return r * w ** (a-1) * (1 - w) ** (b-1) - -# The two density functions. -f = jit(lambda w: p(w, F_a, F_b)) -g = jit(lambda w: p(w, G_a, G_b)) +# Parameters for the model +class ImpSampleParams(NamedTuple): + F_a: float = 1.0 # Beta parameters for f + F_b: float = 1.0 + G_a: float = 3.0 # Beta parameters for g + G_b: float = 1.2 + +params = ImpSampleParams() + +@jax.jit +def beta_pdf(w, a, b): + """Beta probability density function.""" + log_beta_const = gammaln(a) + gammaln(b) - gammaln(a + b) + log_pdf = (a - 1) * jnp.log(w) + (b - 1) * jnp.log(1 - w) - log_beta_const + return jnp.exp(log_pdf) + +@jax.jit +def f(w, params=params): + return beta_pdf(w, params.F_a, params.F_b) + +@jax.jit +def g(w, params=params): + return beta_pdf(w, params.G_a, params.G_b) ``` ```{code-cell} ipython3 -w_range = np.linspace(1e-2, 1-1e-5, 1000) +w_range = jnp.linspace(1e-2, 1-1e-5, 1000) plt.plot(w_range, g(w_range), label='g') plt.plot(w_range, f(w_range), label='f') @@ -98,7 +113,9 @@ plt.show() The likelihood ratio is `l(w)=f(w)/g(w)`. ```{code-cell} ipython3 -l = jit(lambda w: f(w) / g(w)) +@jax.jit +def l(w): + return f(w) / g(w) ``` ```{code-cell} ipython3 @@ -120,14 +137,11 @@ We illustrate this numerically below. We circumvent the issue by using a _change of distribution_ called **importance sampling**. -Instead of drawing from $g$ to generate data during the simulation, we use an alternative -distribution $h$ to generate draws of $\omega$. +Instead of drawing from $g$ to generate data during the simulation, we use an alternative distribution $h$ to generate draws of $\omega$. -The idea is to design $h$ so that it oversamples the region of $\Omega$ where -$\ell \left(\omega_t\right)$ has large values but low density under $g$. +The idea is to design $h$ so that it oversamples the region of $\Omega$ where $\ell \left(\omega_t\right)$ has large values but low density under $g$. -After we construct a sample in this way, we must then weight each realization by the likelihood ratio of $g$ and $h$ when we compute the empirical mean -of the likelihood ratio. +After we construct a sample in this way, we must then weight each realization by the likelihood ratio of $g$ and $h$ when we compute the empirical mean of the likelihood ratio. By doing this, we properly account for the fact that we are using $h$ and not $g$ to simulate data. @@ -162,15 +176,17 @@ Since we must use an $h$ that has larger mass in parts of the distribution to wh The plots compare $g$ and $h$. ```{code-cell} ipython3 -g_a, g_b = G_a, G_b +g_a, g_b = params.G_a, params.G_b h_a, h_b = 0.5, 0.5 + +key = jr.PRNGKey(0) ``` ```{code-cell} ipython3 -w_range = np.linspace(1e-5, 1-1e-5, 1000) +w_range = jnp.linspace(1e-5, 1-1e-5, 1000) plt.plot(w_range, g(w_range), label=f'g=Beta({g_a}, {g_b})') -plt.plot(w_range, p(w_range, 0.5, 0.5), label=f'h=Beta({h_a}, {h_b})') +plt.plot(w_range, beta_pdf(w_range, 0.5, 0.5), label=f'h=Beta({h_a}, {h_b})') plt.title('real data generating process $g$ and importance distribution $h$') plt.legend() plt.ylim([0., 3.]) @@ -196,26 +212,46 @@ Here $\frac{p\left(\omega_{i,t}^q\right)}{q\left(\omega_{i,t}^q\right)}$ is the Below we prepare a Python function for computing the importance sampling estimates given any beta distributions $p$, $q$. ```{code-cell} ipython3 -@jit(parallel=True) -def estimate(p_a, p_b, q_a, q_b, T=1, N=10000): - - μ_L = 0 - for i in prange(N): - - L = 1 - weight = 1 - for t in range(T): - w = np.random.beta(q_a, q_b) - l = f(w) / g(w) - - L *= l - weight *= p(w, p_a, p_b) / p(w, q_a, q_b) - - μ_L += L * weight - - μ_L /= N - - return μ_L +@jax.jit +def estimate_single_path(key, p_a, p_b, q_a, q_b, T): + """ + Estimation for a single sample path. + """ + + def loop_body(i, carry): + L, weight, key_state = carry + key_state, subkey = jr.split(key_state) + w = jr.beta(subkey, q_a, q_b) + + # Compute likelihood ratio using f/g functions + likelihood_ratio = f(w) / g(w) + L = L * likelihood_ratio + + # Importance sampling weight with beta_pdf + p_w = beta_pdf(w, p_a, p_b) + q_w = beta_pdf(w, q_a, q_b) + weight = weight * (p_w / q_w) + + return (L, weight, key_state) + + # Use fori_loop for dynamic T values + final_L, final_weight, _ = jax.lax.fori_loop( + 0, T, loop_body, (1.0, 1.0, key) + ) + return final_L * final_weight + +@partial(jax.jit, static_argnames=['N']) +def estimate(key, p_a, p_b, q_a, q_b, T=1, N=10000): + """Estimation of a batch of sample paths.""" + keys = jr.split(key, N) + + # Use vmap for vectorized computation + estimates = jax.vmap( + estimate_single_path, + in_axes=(0, *[None]*5) + )(keys, p_a, p_b, q_a, q_b, T) + + return jnp.mean(estimates) ``` Consider the case when $T=1$, which amounts to approximating $E_0\left[\ell\left(\omega\right)\right]$ @@ -223,13 +259,15 @@ Consider the case when $T=1$, which amounts to approximating $E_0\left[\ell\lef For the standard Monte Carlo estimate, we can set $p=g$ and $q=g$. ```{code-cell} ipython3 -estimate(g_a, g_b, g_a, g_b, T=1, N=10000) +key, subkey = jr.split(key) +estimate(subkey, g_a, g_b, g_a, g_b, T=1, N=10000) ``` For our importance sampling estimate, we set $q = h$. ```{code-cell} ipython3 -estimate(g_a, g_b, h_a, h_b, T=1, N=10000) +key, subkey = jr.split(key) +estimate(subkey, g_a, g_b, h_a, h_b, T=1, N=10000) ``` Evidently, even at $T=1$, our importance sampling estimate is closer to $1$ than is the Monte Carlo estimate. @@ -240,11 +278,13 @@ Setting $T=10$, we find that the Monte Carlo method severely underestimates the still produces an estimate close to its theoretical value of unity. ```{code-cell} ipython3 -estimate(g_a, g_b, g_a, g_b, T=10, N=10000) +key, subkey = jr.split(key) +estimate(subkey, g_a, g_b, g_a, g_b, T=10, N=10000) ``` ```{code-cell} ipython3 -estimate(g_a, g_b, h_a, h_b, T=10, N=10000) +key, subkey = jr.split(key) +estimate(subkey, g_a, g_b, h_a, h_b, T=10, N=10000) ``` The Monte Carlo method underestimates because the likelihood ratio $L(\omega^T) = \prod_{t=1}^T \frac{f(\omega_t)}{g(\omega_t)}$ has a highly skewed distribution under $g$. @@ -264,16 +304,22 @@ We next study the bias and efficiency of the Monte Carlo and importance sampling The code below produces distributions of estimates using both Monte Carlo and importance sampling methods. ```{code-cell} ipython3 -@jit(parallel=True) -def simulate(p_a, p_b, q_a, q_b, N_simu, T=1): - - μ_L_p = np.empty(N_simu) - μ_L_q = np.empty(N_simu) - - for i in prange(N_simu): - μ_L_p[i] = estimate(p_a, p_b, p_a, p_b, T=T) - μ_L_q[i] = estimate(p_a, p_b, q_a, q_b, T=T) - +@partial(jax.jit, static_argnames=['N_simu', 'N_samples']) +def simulate(key, p_a, p_b, q_a, q_b, N_simu, T=1, N_samples=1000): + """Simulation for both Monte Carlo and importance sampling.""" + keys = jr.split(key, 2 * N_simu) + keys_p = keys[:N_simu] + keys_q = keys[N_simu:] + + def run_monte_carlo(key_batch): + return estimate(key_batch, p_a, p_b, p_a, p_b, T, N_samples) + + def run_importance_sampling(key_batch): + return estimate(key_batch, p_a, p_b, q_a, q_b, T, N_samples) + + μ_L_p = jax.vmap(run_monte_carlo)(keys_p) + μ_L_q = jax.vmap(run_importance_sampling)(keys_q) + return μ_L_p, μ_L_q ``` @@ -283,17 +329,18 @@ We simulate $1000$ times for each method. ```{code-cell} ipython3 N_simu = 1000 -μ_L_p, μ_L_q = simulate(g_a, g_b, h_a, h_b, N_simu) +key, subkey = jr.split(key) +μ_L_p, μ_L_q = simulate(subkey, g_a, g_b, h_a, h_b, N_simu) ``` ```{code-cell} ipython3 # standard Monte Carlo (mean and std) -np.nanmean(μ_L_p), np.nanvar(μ_L_p) +jnp.nanmean(μ_L_p), jnp.nanvar(μ_L_p) ``` ```{code-cell} ipython3 # importance sampling (mean and std) -np.nanmean(μ_L_q), np.nanvar(μ_L_q) +jnp.nanmean(μ_L_q), jnp.nanvar(μ_L_q) ``` Although both methods tend to provide a mean estimate of ${E} \left[\ell\left(\omega\right)\right]$ close to $1$, the importance sampling estimates have smaller variance. @@ -301,17 +348,44 @@ Although both methods tend to provide a mean estimate of ${E} \left[\ell\left(\o Next, we present distributions of estimates for $\hat{E} \left[L\left(\omega^t\right)\right]$, in cases for $T=1, 5, 10, 20$. ```{code-cell} ipython3 -fig, axs = plt.subplots(2, 2, figsize=(14, 10)) +T_values = [1, 5, 10, 20] + +def simulate_multiple_T(key, p_a, p_b, q_a, q_b, N_simu, T_list, N_samples=1000): + """Simulation for multiple T values.""" + n_T = len(T_list) + keys = jr.split(key, n_T) + + results = [] + for i, T in enumerate(T_list): + result = simulate(keys[i], p_a, p_b, q_a, q_b, N_simu, T, N_samples) + results.append(result) + + # Stack results into arrays for consistency + μ_L_p_all = jnp.stack([r[0] for r in results]) + μ_L_q_all = jnp.stack([r[1] for r in results]) + + return μ_L_p_all, μ_L_q_all + +# Run all simulations at once +key, subkey = jr.split(key) +all_results = simulate_multiple_T(subkey, g_a, g_b, h_a, h_b, N_simu, T_values, N_samples=1000) + +# Extract results +μ_L_p_all, μ_L_q_all = all_results -μ_range = np.linspace(0, 2, 100) +fig, axs = plt.subplots(2, 2, figsize=(14, 10)) +μ_range = jnp.linspace(0, 2, 100) -for i, t in enumerate([1, 5, 10, 20]): +for i, t in enumerate(T_values): row = i // 2 col = i % 2 - - μ_L_p, μ_L_q = simulate(g_a, g_b, h_a, h_b, N_simu, T=t) - μ_hat_p, μ_hat_q = np.nanmean(μ_L_p), np.nanmean(μ_L_q) - σ_hat_p, σ_hat_q = np.nanvar(μ_L_p), np.nanvar(μ_L_q) + + # Get results for this T value + μ_L_p = μ_L_p_all[i] + μ_L_q = μ_L_q_all[i] + + μ_hat_p, μ_hat_q = jnp.nanmean(μ_L_p), jnp.nanmean(μ_L_q) + σ_hat_p, σ_hat_q = jnp.nanvar(μ_L_p), jnp.nanvar(μ_L_q) axs[row, col].set_xlabel('$μ_L$') axs[row, col].set_ylabel('frequency') @@ -322,7 +396,7 @@ for i, t in enumerate([1, 5, 10, 20]): for n, bins, μ_hat, σ_hat in [[n_p, bins_p, μ_hat_p, σ_hat_p], [n_q, bins_q, μ_hat_q, σ_hat_q]]: - idx = np.argmax(n) + idx = jnp.argmax(n) axs[row, col].text(bins[idx], n[idx], r'$\hat{μ}$='+f'{μ_hat:.4g}'+r', $\hat{σ}=$'+f'{σ_hat:.4g}') plt.show() @@ -351,12 +425,13 @@ $$ $$ ```{code-cell} ipython3 -μ_L_p, μ_L_q = simulate(g_a, g_b, F_a, F_b, N_simu) +key, subkey = jr.split(key) +μ_L_p, μ_L_q = simulate(subkey, g_a, g_b, params.F_a, params.F_b, N_simu) ``` ```{code-cell} ipython3 # importance sampling (mean and std) -np.nanmean(μ_L_q), np.nanvar(μ_L_q) +jnp.nanmean(μ_L_q), jnp.nanvar(μ_L_q) ``` We could also use other distributions as our importance distribution. @@ -369,12 +444,12 @@ b_list = [0.5, 1.2, 5.] ``` ```{code-cell} ipython3 -w_range = np.linspace(1e-5, 1-1e-5, 1000) +w_range = jnp.linspace(1e-5, 1-1e-5, 1000) plt.plot(w_range, g(w_range), label=f'g=Beta({g_a}, {g_b})') -plt.plot(w_range, p(w_range, a_list[0], b_list[0]), label=f'$h_1$=Beta({a_list[0]},{b_list[0]})') -plt.plot(w_range, p(w_range, a_list[1], b_list[1]), label=f'$h_2$=Beta({a_list[1]},{b_list[1]})') -plt.plot(w_range, p(w_range, a_list[2], b_list[2]), label=f'$h_3$=Beta({a_list[2]},{b_list[2]})') +plt.plot(w_range, beta_pdf(w_range, a_list[0], b_list[0]), label=f'$h_1$=Beta({a_list[0]},{b_list[0]})') +plt.plot(w_range, beta_pdf(w_range, a_list[1], b_list[1]), label=f'$h_2$=Beta({a_list[1]},{b_list[1]})') +plt.plot(w_range, beta_pdf(w_range, a_list[2], b_list[2]), label=f'$h_3$=Beta({a_list[2]},{b_list[2]})') plt.title('real data generating process $g$ and importance distribution $h$') plt.legend() plt.ylim([0., 3.]) @@ -404,16 +479,20 @@ We first simulate a plot the distribution of estimates for $\hat{E} \left[L\left h_a = a_list[1] h_b = b_list[1] -fig, axs = plt.subplots(1,2, figsize=(14, 10)) - -μ_range = np.linspace(0, 2, 100) - -for i, t in enumerate([1, 20]): +T_values_h2 = [1, 20] +key, subkey = jr.split(key) +all_results_h2 = simulate_multiple_T(subkey, g_a, g_b, h_a, h_b, N_simu, T_values_h2, N_samples=1000) +μ_L_p_all_h2, μ_L_q_all_h2 = all_results_h2 +fig, axs = plt.subplots(1,2, figsize=(14, 10)) +μ_range = jnp.linspace(0, 2, 100) - μ_L_p, μ_L_q = simulate(g_a, g_b, h_a, h_b, N_simu, T=t) - μ_hat_p, μ_hat_q = np.nanmean(μ_L_p), np.nanmean(μ_L_q) - σ_hat_p, σ_hat_q = np.nanvar(μ_L_p), np.nanvar(μ_L_q) +for i, t in enumerate(T_values_h2): + μ_L_p = μ_L_p_all_h2[i] + μ_L_q = μ_L_q_all_h2[i] + + μ_hat_p, μ_hat_q = jnp.nanmean(μ_L_p), jnp.nanmean(μ_L_q) + σ_hat_p, σ_hat_q = jnp.nanvar(μ_L_p), jnp.nanvar(μ_L_q) axs[i].set_xlabel('$μ_L$') axs[i].set_ylabel('frequency') @@ -424,7 +503,7 @@ for i, t in enumerate([1, 20]): for n, bins, μ_hat, σ_hat in [[n_p, bins_p, μ_hat_p, σ_hat_p], [n_q, bins_q, μ_hat_q, σ_hat_q]]: - idx = np.argmax(n) + idx = jnp.argmax(n) axs[i].text(bins[idx], n[idx], r'$\hat{μ}$='+f'{μ_hat:.4g}'+r', $\hat{σ}=$'+f'{σ_hat:.4g}') plt.show() @@ -438,16 +517,17 @@ Even at $T=20$, the mean is very close to $1$ and the variance is small. h_a = a_list[2] h_b = b_list[2] -fig, axs = plt.subplots(1,2, figsize=(14, 10)) - -μ_range = np.linspace(0, 2, 100) - -for i, t in enumerate([1, 20]): +T_list = [1, 20] +key, subkey = jr.split(key) +results = simulate_multiple_T(subkey, g_a, g_b, h_a, h_b, N_simu, T_list, N_samples=1000) +fig, axs = plt.subplots(1, 2, figsize=(14, 10)) +μ_range = jnp.linspace(0, 2, 100) - μ_L_p, μ_L_q = simulate(g_a, g_b, h_a, h_b, N_simu, T=t) - μ_hat_p, μ_hat_q = np.nanmean(μ_L_p), np.nanmean(μ_L_q) - σ_hat_p, σ_hat_q = np.nanvar(μ_L_p), np.nanvar(μ_L_q) +for i, t in enumerate(T_list): + μ_L_p, μ_L_q = results[i] + μ_hat_p, μ_hat_q = jnp.nanmean(μ_L_p), jnp.nanmean(μ_L_q) + σ_hat_p, σ_hat_q = jnp.nanvar(μ_L_p), jnp.nanvar(μ_L_q) axs[i].set_xlabel('$μ_L$') axs[i].set_ylabel('frequency') @@ -458,7 +538,7 @@ for i, t in enumerate([1, 20]): for n, bins, μ_hat, σ_hat in [[n_p, bins_p, μ_hat_p, σ_hat_p], [n_q, bins_q, μ_hat_q, σ_hat_q]]: - idx = np.argmax(n) + idx = jnp.argmax(n) axs[i].text(bins[idx], n[idx], r'$\hat{μ}$='+f'{μ_hat:.4g}'+r', $\hat{σ}=$'+f'{σ_hat:.4g}') plt.show() From 60826d6f605a7a286608bad3a6c8dafbec4ed8c0 Mon Sep 17 00:00:00 2001 From: Humphrey Yang Date: Fri, 17 Oct 2025 15:38:03 +1100 Subject: [PATCH 2/3] updates --- lectures/imp_sample.md | 19 +++++++++++++------ 1 file changed, 13 insertions(+), 6 deletions(-) diff --git a/lectures/imp_sample.md b/lectures/imp_sample.md index 1ec0815fd..ace15bc8d 100644 --- a/lectures/imp_sample.md +++ b/lectures/imp_sample.md @@ -4,7 +4,7 @@ jupytext: extension: .md format_name: myst format_version: 0.13 - jupytext_version: 1.16.6 + jupytext_version: 1.17.1 kernelspec: display_name: Python 3 (ipykernel) language: python @@ -19,7 +19,7 @@ kernelspec: ## Overview -In {doc}`this lecture ` we described a peculiar property of a likelihood ratio process, namely, that its mean equals one for all $t \geq 0$ despite its converging to zero almost surely. +In {doc}`likelihood_ratio_process` we described a peculiar property of a likelihood ratio process, namely, that its mean equals one for all $t \geq 0$ despite its converging to zero almost surely. While it is easy to verify that peculiar property analytically (i.e., in population), it is challenging to use a computer simulation to verify it via an application of a law of large numbers that entails studying sample averages of repeated simulations. @@ -37,11 +37,14 @@ import matplotlib.pyplot as plt from jax.scipy.special import gammaln from typing import NamedTuple from functools import partial + +# Set JAX to use 64-bit floats +jax.config.update("jax_enable_x64", True) ``` ## Mathematical expectation of likelihood ratio -In {doc}`this lecture `, we studied a likelihood ratio $\ell \left(\omega_t\right)$ +In {doc}`likelihood_ratio_process`, we studied a likelihood ratio $\ell \left(\omega_t\right)$ $$ \ell \left( \omega_t \right) = \frac{f\left(\omega_t\right)}{g\left(\omega_t\right)} @@ -59,7 +62,7 @@ $$ Our goal is to approximate the mathematical expectation $E \left[ L\left(\omega^t\right) \right]$ well. -In {doc}`this lecture `, we showed that $E \left[ L\left(\omega^t\right) \right]$ equals $1$ for all $t$. +In {doc}`likelihood_ratio_process`, we showed that $E \left[ L\left(\omega^t\right) \right]$ equals $1$ for all $t$. We want to check out how well this holds if we replace $E$ with sample averages from simulations. @@ -183,11 +186,16 @@ key = jr.PRNGKey(0) ``` ```{code-cell} ipython3 +--- +mystnb: + figure: + caption: 'Real data generating process $g$ and importance distribution $h$' + name: fig_imp_real +--- w_range = jnp.linspace(1e-5, 1-1e-5, 1000) plt.plot(w_range, g(w_range), label=f'g=Beta({g_a}, {g_b})') plt.plot(w_range, beta_pdf(w_range, 0.5, 0.5), label=f'h=Beta({h_a}, {h_b})') -plt.title('real data generating process $g$ and importance distribution $h$') plt.legend() plt.ylim([0., 3.]) plt.show() @@ -450,7 +458,6 @@ plt.plot(w_range, g(w_range), label=f'g=Beta({g_a}, {g_b})') plt.plot(w_range, beta_pdf(w_range, a_list[0], b_list[0]), label=f'$h_1$=Beta({a_list[0]},{b_list[0]})') plt.plot(w_range, beta_pdf(w_range, a_list[1], b_list[1]), label=f'$h_2$=Beta({a_list[1]},{b_list[1]})') plt.plot(w_range, beta_pdf(w_range, a_list[2], b_list[2]), label=f'$h_3$=Beta({a_list[2]},{b_list[2]})') -plt.title('real data generating process $g$ and importance distribution $h$') plt.legend() plt.ylim([0., 3.]) plt.show() From 58d1cd2237a47a9402a49ceb7cf1f6cfdb3b0f35 Mon Sep 17 00:00:00 2001 From: Humphrey Yang Date: Fri, 17 Oct 2025 16:04:00 +1100 Subject: [PATCH 3/3] update pep8 --- lectures/imp_sample.md | 169 +++++++++++++++++++++++++++-------------- 1 file changed, 113 insertions(+), 56 deletions(-) diff --git a/lectures/imp_sample.md b/lectures/imp_sample.md index ace15bc8d..c80e4fb61 100644 --- a/lectures/imp_sample.md +++ b/lectures/imp_sample.md @@ -81,7 +81,7 @@ We first take a look at the density functions `f` and `g` . class ImpSampleParams(NamedTuple): F_a: float = 1.0 # Beta parameters for f F_b: float = 1.0 - G_a: float = 3.0 # Beta parameters for g + G_a: float = 3.0 # Beta parameters for g G_b: float = 1.2 params = ImpSampleParams() @@ -89,8 +89,10 @@ params = ImpSampleParams() @jax.jit def beta_pdf(w, a, b): """Beta probability density function.""" - log_beta_const = gammaln(a) + gammaln(b) - gammaln(a + b) - log_pdf = (a - 1) * jnp.log(w) + (b - 1) * jnp.log(1 - w) - log_beta_const + log_beta_const = (gammaln(a) + gammaln(b) - + gammaln(a + b)) + log_pdf = ((a - 1) * jnp.log(w) + (b - 1) * + jnp.log(1 - w) - log_beta_const) return jnp.exp(log_pdf) @jax.jit @@ -194,8 +196,10 @@ mystnb: --- w_range = jnp.linspace(1e-5, 1-1e-5, 1000) -plt.plot(w_range, g(w_range), label=f'g=Beta({g_a}, {g_b})') -plt.plot(w_range, beta_pdf(w_range, 0.5, 0.5), label=f'h=Beta({h_a}, {h_b})') +plt.plot(w_range, g(w_range), + label=f'g=Beta({g_a}, {g_b})') +plt.plot(w_range, beta_pdf(w_range, 0.5, 0.5), + label=f'h=Beta({h_a}, {h_b})') plt.legend() plt.ylim([0., 3.]) plt.show() @@ -230,18 +234,18 @@ def estimate_single_path(key, p_a, p_b, q_a, q_b, T): L, weight, key_state = carry key_state, subkey = jr.split(key_state) w = jr.beta(subkey, q_a, q_b) - + # Compute likelihood ratio using f/g functions likelihood_ratio = f(w) / g(w) L = L * likelihood_ratio - + # Importance sampling weight with beta_pdf p_w = beta_pdf(w, p_a, p_b) q_w = beta_pdf(w, q_a, q_b) weight = weight * (p_w / q_w) - + return (L, weight, key_state) - + # Use fori_loop for dynamic T values final_L, final_weight, _ = jax.lax.fori_loop( 0, T, loop_body, (1.0, 1.0, key) @@ -252,13 +256,13 @@ def estimate_single_path(key, p_a, p_b, q_a, q_b, T): def estimate(key, p_a, p_b, q_a, q_b, T=1, N=10000): """Estimation of a batch of sample paths.""" keys = jr.split(key, N) - + # Use vmap for vectorized computation estimates = jax.vmap( - estimate_single_path, + estimate_single_path, in_axes=(0, *[None]*5) )(keys, p_a, p_b, q_a, q_b, T) - + return jnp.mean(estimates) ``` @@ -313,21 +317,24 @@ The code below produces distributions of estimates using both Monte Carlo and i ```{code-cell} ipython3 @partial(jax.jit, static_argnames=['N_simu', 'N_samples']) -def simulate(key, p_a, p_b, q_a, q_b, N_simu, T=1, N_samples=1000): +def simulate(key, p_a, p_b, q_a, q_b, N_simu, T=1, + N_samples=1000): """Simulation for both Monte Carlo and importance sampling.""" keys = jr.split(key, 2 * N_simu) keys_p = keys[:N_simu] keys_q = keys[N_simu:] - + def run_monte_carlo(key_batch): - return estimate(key_batch, p_a, p_b, p_a, p_b, T, N_samples) - + return estimate(key_batch, p_a, p_b, p_a, p_b, T, + N_samples) + def run_importance_sampling(key_batch): - return estimate(key_batch, p_a, p_b, q_a, q_b, T, N_samples) - + return estimate(key_batch, p_a, p_b, q_a, q_b, T, + N_samples) + μ_L_p = jax.vmap(run_monte_carlo)(keys_p) μ_L_q = jax.vmap(run_importance_sampling)(keys_q) - + return μ_L_p, μ_L_q ``` @@ -358,25 +365,31 @@ Next, we present distributions of estimates for $\hat{E} \left[L\left(\omega^t\r ```{code-cell} ipython3 T_values = [1, 5, 10, 20] -def simulate_multiple_T(key, p_a, p_b, q_a, q_b, N_simu, T_list, N_samples=1000): +def simulate_multiple_T(key, p_a, p_b, q_a, q_b, N_simu, + T_list, N_samples=1000): """Simulation for multiple T values.""" n_T = len(T_list) keys = jr.split(key, n_T) - + results = [] for i, T in enumerate(T_list): - result = simulate(keys[i], p_a, p_b, q_a, q_b, N_simu, T, N_samples) + result = simulate(keys[i], + p_a, p_b, q_a, q_b, N_simu, T, + N_samples) results.append(result) # Stack results into arrays for consistency μ_L_p_all = jnp.stack([r[0] for r in results]) μ_L_q_all = jnp.stack([r[1] for r in results]) - + return μ_L_p_all, μ_L_q_all # Run all simulations at once key, subkey = jr.split(key) -all_results = simulate_multiple_T(subkey, g_a, g_b, h_a, h_b, N_simu, T_values, N_samples=1000) +all_results = simulate_multiple_T(subkey, + g_a, g_b, h_a, h_b, + N_simu, T_values, + N_samples=1000) # Extract results μ_L_p_all, μ_L_q_all = all_results @@ -387,25 +400,36 @@ fig, axs = plt.subplots(2, 2, figsize=(14, 10)) for i, t in enumerate(T_values): row = i // 2 col = i % 2 - + # Get results for this T value μ_L_p = μ_L_p_all[i] μ_L_q = μ_L_q_all[i] - - μ_hat_p, μ_hat_q = jnp.nanmean(μ_L_p), jnp.nanmean(μ_L_q) - σ_hat_p, σ_hat_q = jnp.nanvar(μ_L_p), jnp.nanvar(μ_L_q) + + μ_hat_p = jnp.nanmean(μ_L_p) + μ_hat_q = jnp.nanmean(μ_L_q) + σ_hat_p = jnp.nanvar(μ_L_p) + σ_hat_q = jnp.nanvar(μ_L_q) axs[row, col].set_xlabel('$μ_L$') axs[row, col].set_ylabel('frequency') axs[row, col].set_title(f'$T$={t}') - n_p, bins_p, _ = axs[row, col].hist(μ_L_p, bins=μ_range, color='r', alpha=0.5, label='$g$ generating') - n_q, bins_q, _ = axs[row, col].hist(μ_L_q, bins=μ_range, color='b', alpha=0.5, label='$h$ generating') + n_p, bins_p, _ = axs[row, col].hist( + μ_L_p, bins=μ_range, + color='r', alpha=0.5, label='$g$ generating') + n_q, bins_q, _ = axs[row, col].hist( + μ_L_q, bins=μ_range, + color='b', alpha=0.5, label='$h$ generating') axs[row, col].legend(loc=4) - for n, bins, μ_hat, σ_hat in [[n_p, bins_p, μ_hat_p, σ_hat_p], - [n_q, bins_q, μ_hat_q, σ_hat_q]]: + for n, bins, μ_hat, σ_hat in [[n_p, bins_p, μ_hat_p, + σ_hat_p], + [n_q, bins_q, μ_hat_q, + σ_hat_q]]: idx = jnp.argmax(n) - axs[row, col].text(bins[idx], n[idx], r'$\hat{μ}$='+f'{μ_hat:.4g}'+r', $\hat{σ}=$'+f'{σ_hat:.4g}') + axs[row, col].text( + bins[idx], n[idx], + r'$\hat{μ}$=' + f'{μ_hat:.4g}' + + r', $\hat{σ}=$' + f'{σ_hat:.4g}') plt.show() ``` @@ -434,7 +458,8 @@ $$ ```{code-cell} ipython3 key, subkey = jr.split(key) -μ_L_p, μ_L_q = simulate(subkey, g_a, g_b, params.F_a, params.F_b, N_simu) +μ_L_p, μ_L_q = simulate(subkey, g_a, g_b, params.F_a, + params.F_b, N_simu) ``` ```{code-cell} ipython3 @@ -454,10 +479,14 @@ b_list = [0.5, 1.2, 5.] ```{code-cell} ipython3 w_range = jnp.linspace(1e-5, 1-1e-5, 1000) -plt.plot(w_range, g(w_range), label=f'g=Beta({g_a}, {g_b})') -plt.plot(w_range, beta_pdf(w_range, a_list[0], b_list[0]), label=f'$h_1$=Beta({a_list[0]},{b_list[0]})') -plt.plot(w_range, beta_pdf(w_range, a_list[1], b_list[1]), label=f'$h_2$=Beta({a_list[1]},{b_list[1]})') -plt.plot(w_range, beta_pdf(w_range, a_list[2], b_list[2]), label=f'$h_3$=Beta({a_list[2]},{b_list[2]})') +plt.plot(w_range, g(w_range), + label=f'g=Beta({g_a}, {g_b})') +plt.plot(w_range, beta_pdf(w_range, a_list[0], b_list[0]), + label=f'$h_1$=Beta({a_list[0]},{b_list[0]})') +plt.plot(w_range, beta_pdf(w_range, a_list[1], b_list[1]), + label=f'$h_2$=Beta({a_list[1]},{b_list[1]})') +plt.plot(w_range, beta_pdf(w_range, a_list[2], b_list[2]), + label=f'$h_3$=Beta({a_list[2]},{b_list[2]})') plt.legend() plt.ylim([0., 3.]) plt.show() @@ -488,30 +517,44 @@ h_b = b_list[1] T_values_h2 = [1, 20] key, subkey = jr.split(key) -all_results_h2 = simulate_multiple_T(subkey, g_a, g_b, h_a, h_b, N_simu, T_values_h2, N_samples=1000) +all_results_h2 = simulate_multiple_T(subkey, + g_a, g_b, h_a, h_b, + N_simu, T_values_h2, + N_samples=1000) μ_L_p_all_h2, μ_L_q_all_h2 = all_results_h2 -fig, axs = plt.subplots(1,2, figsize=(14, 10)) +fig, axs = plt.subplots(1, 2, figsize=(14, 10)) μ_range = jnp.linspace(0, 2, 100) for i, t in enumerate(T_values_h2): μ_L_p = μ_L_p_all_h2[i] μ_L_q = μ_L_q_all_h2[i] - - μ_hat_p, μ_hat_q = jnp.nanmean(μ_L_p), jnp.nanmean(μ_L_q) - σ_hat_p, σ_hat_q = jnp.nanvar(μ_L_p), jnp.nanvar(μ_L_q) + + μ_hat_p = jnp.nanmean(μ_L_p) + μ_hat_q = jnp.nanmean(μ_L_q) + σ_hat_p = jnp.nanvar(μ_L_p) + σ_hat_q = jnp.nanvar(μ_L_q) axs[i].set_xlabel('$μ_L$') axs[i].set_ylabel('frequency') axs[i].set_title(f'$T$={t}') - n_p, bins_p, _ = axs[i].hist(μ_L_p, bins=μ_range, color='r', alpha=0.5, label='$g$ generating') - n_q, bins_q, _ = axs[i].hist(μ_L_q, bins=μ_range, color='b', alpha=0.5, label='$h_2$ generating') + n_p, bins_p, _ = axs[i].hist( + μ_L_p, bins=μ_range, + color='r', alpha=0.5, label='$g$ generating') + n_q, bins_q, _ = axs[i].hist( + μ_L_q, bins=μ_range, + color='b', alpha=0.5, label='$h_2$ generating') axs[i].legend(loc=4) - for n, bins, μ_hat, σ_hat in [[n_p, bins_p, μ_hat_p, σ_hat_p], - [n_q, bins_q, μ_hat_q, σ_hat_q]]: + for n, bins, μ_hat, σ_hat in [[n_p, bins_p, μ_hat_p, + σ_hat_p], + [n_q, bins_q, μ_hat_q, + σ_hat_q]]: idx = jnp.argmax(n) - axs[i].text(bins[idx], n[idx], r'$\hat{μ}$='+f'{μ_hat:.4g}'+r', $\hat{σ}=$'+f'{σ_hat:.4g}') + axs[i].text( + bins[idx], n[idx], + r'$\hat{μ}$=' + f'{μ_hat:.4g}' + + r', $\hat{σ}=$' + f'{σ_hat:.4g}') plt.show() ``` @@ -526,27 +569,41 @@ h_b = b_list[2] T_list = [1, 20] key, subkey = jr.split(key) -results = simulate_multiple_T(subkey, g_a, g_b, h_a, h_b, N_simu, T_list, N_samples=1000) +results = simulate_multiple_T(subkey, + g_a, g_b, h_a, h_b, + N_simu, T_list, + N_samples=1000) fig, axs = plt.subplots(1, 2, figsize=(14, 10)) μ_range = jnp.linspace(0, 2, 100) for i, t in enumerate(T_list): μ_L_p, μ_L_q = results[i] - μ_hat_p, μ_hat_q = jnp.nanmean(μ_L_p), jnp.nanmean(μ_L_q) - σ_hat_p, σ_hat_q = jnp.nanvar(μ_L_p), jnp.nanvar(μ_L_q) + μ_hat_p = jnp.nanmean(μ_L_p) + μ_hat_q = jnp.nanmean(μ_L_q) + σ_hat_p = jnp.nanvar(μ_L_p) + σ_hat_q = jnp.nanvar(μ_L_q) axs[i].set_xlabel('$μ_L$') axs[i].set_ylabel('frequency') axs[i].set_title(f'$T$={t}') - n_p, bins_p, _ = axs[i].hist(μ_L_p, bins=μ_range, color='r', alpha=0.5, label='$g$ generating') - n_q, bins_q, _ = axs[i].hist(μ_L_q, bins=μ_range, color='b', alpha=0.5, label='$h_3$ generating') + n_p, bins_p, _ = axs[i].hist( + μ_L_p, bins=μ_range, + color='r', alpha=0.5, label='$g$ generating') + n_q, bins_q, _ = axs[i].hist( + μ_L_q, bins=μ_range, + color='b', alpha=0.5, label='$h_3$ generating') axs[i].legend(loc=4) - for n, bins, μ_hat, σ_hat in [[n_p, bins_p, μ_hat_p, σ_hat_p], - [n_q, bins_q, μ_hat_q, σ_hat_q]]: + for n, bins, μ_hat, σ_hat in [[n_p, bins_p, μ_hat_p, + σ_hat_p], + [n_q, bins_q, μ_hat_q, + σ_hat_q]]: idx = jnp.argmax(n) - axs[i].text(bins[idx], n[idx], r'$\hat{μ}$='+f'{μ_hat:.4g}'+r', $\hat{σ}=$'+f'{σ_hat:.4g}') + axs[i].text( + bins[idx], n[idx], + r'$\hat{μ}$=' + f'{μ_hat:.4g}' + + r', $\hat{σ}=$' + f'{σ_hat:.4g}') plt.show() ```