Skip to content

Instantly share code, notes, and snippets.

@antotocar34
Created January 15, 2023 15:44
Show Gist options
  • Save antotocar34/ebfc1a8a53ca17af2383c4af2f0c2e52 to your computer and use it in GitHub Desktop.
Save antotocar34/ebfc1a8a53ca17af2383c4af2f0c2e52 to your computer and use it in GitHub Desktop.
SMC bug
import jax
import jax.numpy as jnp
import matplotlib.pyplot as plt
import numpy as np
from jax.scipy.stats import multivariate_normal
jax.config.update("jax_platform_name", "cpu")
import blackjax
import blackjax.smc.resampling as resampling
key = jax.random.PRNGKey(42)
def V(x):
return 5 * jnp.square(jnp.sum(x**2) - 1)
def prior_log_prob(x):
d = x.shape[0]
return multivariate_normal.logpdf(x, jnp.zeros((d,)), jnp.eye(d))
def inference_loop(rng_key, mcmc_kernel, initial_state, num_samples):
@jax.jit
def one_step(state, k):
state, _ = mcmc_kernel(k, state)
return state, state
keys = jax.random.split(rng_key, num_samples)
_, states = jax.lax.scan(one_step, initial_state, keys)
return states
def full_logdensity(x):
return -V(x) + prior_log_prob(x)
inv_mass_matrix = jnp.eye(1)
n_samples = 10_000
hmc_parameters = dict(
step_size=1e-4, inverse_mass_matrix=inv_mass_matrix, num_integration_steps=50
)
hmc = blackjax.hmc(full_logdensity, **hmc_parameters)
hmc_state = hmc.init(jnp.ones((1,)))
hmc_samples = inference_loop(key, hmc.step, hmc_state, n_samples)
nuts_parameters = dict(step_size=1e-4, inverse_mass_matrix=inv_mass_matrix)
nuts = blackjax.nuts(full_logdensity, **nuts_parameters)
nuts_state = nuts.init(jnp.ones((1,)))
nuts_samples = inference_loop(key, nuts.step, nuts_state, n_samples)
def smc_inference_loop(rng_key, smc_kernel, initial_state):
"""Run the temepered SMC algorithm.
We run the adaptive algorithm until the tempering parameter lambda reaches the value
lambda=1.
"""
def cond(carry):
i, state, _k = carry
return state.lmbda < 1
def one_step(carry):
i, state, k = carry
k, subk = jax.random.split(k, 2)
state, _ = smc_kernel(subk, state)
return i + 1, state, k
n_iter, final_state, _ = jax.lax.while_loop(
cond, one_step, (0, initial_state, rng_key)
)
return n_iter, final_state
loglikelihood = lambda x: -V(x)
hmc_parameters = dict(
step_size=1e-4, inverse_mass_matrix=inv_mass_matrix, num_integration_steps=1
)
tempered = blackjax.adaptive_tempered_smc(
prior_log_prob,
loglikelihood,
blackjax.hmc,
hmc_parameters,
resampling.systematic,
0.5,
mcmc_iter=1,
)
initial_smc_state = jax.random.multivariate_normal(
jax.random.PRNGKey(0), jnp.zeros([1]), jnp.eye(1), (10,)
)
initial_smc_state_dict = [{"mean1":v[0] , "mean2": v[1]} for v in initial_smc_state]
initial_smc_state = tempered.init(initial_smc_state_dict)
n_iter, smc_samples = smc_inference_loop(key, tempered.step, initial_smc_state)
print("Number of steps in the adaptive algorithm: ", n_iter.item())
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment