diff --git a/blackjax/mcmc/mclmc.py b/blackjax/mcmc/mclmc.py index b65e5b6e4..33f6060c2 100644 --- a/blackjax/mcmc/mclmc.py +++ b/blackjax/mcmc/mclmc.py @@ -22,7 +22,7 @@ from blackjax.base import SamplingAlgorithm from blackjax.mcmc.integrators import IntegratorState, noneuclidean_mclachlan from blackjax.types import ArrayLike, PRNGKey -from blackjax.util import generate_unit_vector, pytree_size +from blackjax.util import generate_unit_vector __all__ = ["MCLMCInfo", "init", "build_kernel", "mclmc"] @@ -44,12 +44,12 @@ class MCLMCInfo(NamedTuple): energy_change: float -def init(x_initial: ArrayLike, logdensity_fn, rng_key): - l, g = jax.value_and_grad(logdensity_fn)(x_initial) +def init(position: ArrayLike, logdensity_fn, rng_key): + l, g = jax.value_and_grad(logdensity_fn)(position) return IntegratorState( - position=x_initial, - momentum=generate_unit_vector(rng_key, x_initial), + position=position, + momentum=generate_unit_vector(rng_key, position), logdensity=l, logdensity_grad=g, ) @@ -83,8 +83,6 @@ def kernel( state, step_size ) - dim = pytree_size(position) - # Langevin-like noise momentum, dim = partially_refresh_momentum( momentum=momentum, rng_key=rng_key, L=L, step_size=step_size @@ -95,6 +93,7 @@ def kernel( ), MCLMCInfo( logdensity=logdensity, energy_change=kinetic_change - logdensity + state.logdensity, + # TODO: Potential bug here, see #625 kinetic_change=kinetic_change * (dim - 1), ) diff --git a/blackjax/util.py b/blackjax/util.py index 9fb461c8d..6f1d49072 100644 --- a/blackjax/util.py +++ b/blackjax/util.py @@ -1,6 +1,6 @@ """Utility functions for BlackJax.""" from functools import partial -from typing import Union +from typing import Callable, Union import jax.numpy as jnp from jax import jit, lax @@ -8,7 +8,7 @@ from jax.random import normal, split from jax.tree_util import tree_leaves -from blackjax.base import Info, State +from blackjax.base import Info, SamplingAlgorithm, State, VIAlgorithm from blackjax.progress_bar import progress_bar_scan from blackjax.types import Array, ArrayLikeTree, ArrayTree, PRNGKey @@ -141,35 +141,39 @@ def index_pytree(input_pytree: ArrayLikeTree) -> ArrayTree: def run_inference_algorithm( - rng_key, - initial_state_or_position, - inference_algorithm, - num_steps, + rng_key: PRNGKey, + initial_state_or_position: ArrayLikeTree, + inference_algorithm: Union[SamplingAlgorithm, VIAlgorithm], + num_steps: int, progress_bar: bool = False, - transform=lambda x: x, + transform: Callable = lambda x: x, ) -> tuple[State, State, Info]: """Wrapper to run an inference algorithm. Parameters ---------- - rng_key : PRNGKey + rng_key The random state used by JAX's random numbers generator. - initial_state_or_position: ArrayLikeTree + initial_state_or_position The initial state OR the initial position of the inference algorithm. If an initial position is passed in, the function will automatically convert it into an initial state. - inference_algorithm : Union[SamplingAlgorithm, VIAlgorithm] + inference_algorithm One of blackjax's sampling algorithms or variational inference algorithms. - num_steps : int - Number of learning steps. - transform: - a transformation of the sequence of states to be returned. By default, the states are returned as is. + num_steps + Number of MCMC steps. + progress_bar + Whether to display a progress bar. + transform + A transformation of the trace of states to be returned. This is useful for + computing determinstic variables, or returning a subset of the states. + By default, the states are returned as is. Returns ------- Tuple[State, State, Info] 1. The final state of the inference algorithm. - 2. The history of states of the inference algorithm. - 3. The history of the info of the inference algorithm. + 2. The trace of states of the inference algorithm (contains the MCMC samples). + 3. The trace of the info of the inference algorithm for diagnostics. """ try: initial_state = inference_algorithm.init(initial_state_or_position) diff --git a/tests/mcmc/test_sampling.py b/tests/mcmc/test_sampling.py index 2a9fd07c5..879eba550 100644 --- a/tests/mcmc/test_sampling.py +++ b/tests/mcmc/test_sampling.py @@ -78,7 +78,7 @@ def run_mclmc(self, logdensity_fn, num_steps, initial_position, key): init_key, tune_key, run_key = jax.random.split(key, 3) initial_state = blackjax.mcmc.mclmc.init( - x_initial=initial_position, logdensity_fn=logdensity_fn, rng_key=init_key + position=initial_position, logdensity_fn=logdensity_fn, rng_key=init_key ) kernel = blackjax.mcmc.mclmc.build_kernel(