Skip to content

Commit

Permalink
Refactor dynamic_hmc out of hmc.py (#622)
Browse files Browse the repository at this point in the history
* Refactor dynamic_hmc out of hmc.py

* Fix formatting
  • Loading branch information
junpenglao committed Mar 12, 2024
1 parent 93af4cf commit df5a966
Show file tree
Hide file tree
Showing 5 changed files with 207 additions and 165 deletions.
3 changes: 2 additions & 1 deletion blackjax/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,9 +8,10 @@
from .diagnostics import effective_sample_size as ess
from .diagnostics import potential_scale_reduction as rhat
from .mcmc.barker import barker_proposal
from .mcmc.dynamic_hmc import dynamic_hmc
from .mcmc.elliptical_slice import elliptical_slice
from .mcmc.ghmc import ghmc
from .mcmc.hmc import dynamic_hmc, hmc
from .mcmc.hmc import hmc
from .mcmc.mala import mala
from .mcmc.marginal_latent_gaussian import mgrad_gaussian
from .mcmc.mclmc import mclmc
Expand Down
13 changes: 4 additions & 9 deletions blackjax/adaptation/chees_adaptation.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@
import numpy as np
import optax

import blackjax.mcmc.hmc as hmc
import blackjax.mcmc.dynamic_hmc as dynamic_hmc
import blackjax.optimizers.dual_averaging as dual_averaging
from blackjax.adaptation.base import AdaptationInfo, AdaptationResults
from blackjax.base import AdaptationAlgorithm
Expand Down Expand Up @@ -370,7 +370,7 @@ def run(
next_random_arg_fn = lambda key: jax.random.split(key)[1]
init_random_arg = key_init
else:
jitter_gn = lambda i: _halton_sequence(
jitter_gn = lambda i: dynamic_hmc.halton_sequence(
i, np.ceil(np.log2(num_steps + max_sampling_steps))
) * jitter_amount + (1.0 - jitter_amount)
next_random_arg_fn = lambda i: i + 1
Expand All @@ -382,7 +382,7 @@ def integration_steps_fn(random_generator_arg, trajectory_length_adjusted):
dtype=int,
)

step_fn = hmc.build_dynamic_kernel(
step_fn = dynamic_hmc.build_kernel(
next_random_arg_fn=next_random_arg_fn,
integration_steps_fn=integration_steps_fn,
)
Expand Down Expand Up @@ -420,7 +420,7 @@ def one_step(carry, rng_key):
)

batch_init = jax.vmap(
lambda p: hmc.init_dynamic(p, logdensity_fn, init_random_arg)
lambda p: dynamic_hmc.init(p, logdensity_fn, init_random_arg)
)
init_states = batch_init(positions)
init_adaptation_state = init(init_random_arg, step_size)
Expand All @@ -446,8 +446,3 @@ def one_step(carry, rng_key):
return AdaptationResults(last_states, parameters), info

return AdaptationAlgorithm(run) # type: ignore[arg-type]


def _halton_sequence(i, max_bits=10):
bit_masks = 2 ** jnp.arange(max_bits, dtype=i.dtype)
return jnp.einsum("i,i->", jnp.mod((i + 1) // bit_masks, 2), 0.5 / bit_masks)
198 changes: 198 additions & 0 deletions blackjax/mcmc/dynamic_hmc.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,198 @@
# Copyright 2020- The Blackjax Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Public API for the Dynamic HMC Kernel"""
from typing import Callable, NamedTuple

import jax
import jax.numpy as jnp

import blackjax.mcmc.integrators as integrators
from blackjax.base import SamplingAlgorithm
from blackjax.mcmc.hmc import HMCInfo, HMCState
from blackjax.mcmc.hmc import build_kernel as build_static_hmc_kernel
from blackjax.types import Array, ArrayLikeTree, ArrayTree, PRNGKey

__all__ = [
"DynamicHMCState",
"init",
"build_kernel",
"dynamic_hmc",
"halton_sequence",
]


class DynamicHMCState(NamedTuple):
"""State of the dynamic HMC algorithm.
Adds a utility array for generating a pseudo or quasi-random sequence of
number of integration steps.
"""

position: ArrayTree
logdensity: float
logdensity_grad: ArrayTree
random_generator_arg: Array


def init(position: ArrayLikeTree, logdensity_fn: Callable, random_generator_arg: Array):
logdensity, logdensity_grad = jax.value_and_grad(logdensity_fn)(position)
return DynamicHMCState(position, logdensity, logdensity_grad, random_generator_arg)


def build_kernel(
integrator: Callable = integrators.velocity_verlet,
divergence_threshold: float = 1000,
next_random_arg_fn: Callable = lambda key: jax.random.split(key)[1],
integration_steps_fn: Callable = lambda key: jax.random.randint(key, (), 1, 10),
):
"""Build a Dynamic HMC kernel where the number of integration steps is chosen randomly.
Parameters
----------
integrator
The symplectic integrator to use to integrate the Hamiltonian dynamics.
divergence_threshold
Value of the difference in energy above which we consider that the transition is divergent.
next_random_arg_fn
Function that generates the next `random_generator_arg` from its previous value.
integration_steps_fn
Function that generates the next pseudo or quasi-random number of integration steps in the
sequence, given the current `random_generator_arg`. Needs to return an `int`.
Returns
-------
A kernel that takes a rng_key and a Pytree that contains the current state
of the chain and that returns a new state of the chain along with
information about the transition.
"""
hmc_base = build_static_hmc_kernel(integrator, divergence_threshold)

def kernel(
rng_key: PRNGKey,
state: DynamicHMCState,
logdensity_fn: Callable,
step_size: float,
inverse_mass_matrix: Array,
**integration_steps_kwargs,
) -> tuple[DynamicHMCState, HMCInfo]:
"""Generate a new sample with the HMC kernel."""
num_integration_steps = integration_steps_fn(
state.random_generator_arg, **integration_steps_kwargs
)
hmc_state = HMCState(state.position, state.logdensity, state.logdensity_grad)
hmc_proposal, info = hmc_base(
rng_key,
hmc_state,
logdensity_fn,
step_size,
inverse_mass_matrix,
num_integration_steps,
)
next_random_arg = next_random_arg_fn(state.random_generator_arg)
return (
DynamicHMCState(
hmc_proposal.position,
hmc_proposal.logdensity,
hmc_proposal.logdensity_grad,
next_random_arg,
),
info,
)

return kernel


class dynamic_hmc:
"""Implements the (basic) user interface for the dynamic HMC kernel.
Parameters
----------
logdensity_fn
The log-density function we wish to draw samples from.
step_size
The value to use for the step size in the symplectic integrator.
inverse_mass_matrix
The value to use for the inverse mass matrix when drawing a value for
the momentum and computing the kinetic energy.
divergence_threshold
The absolute value of the difference in energy between two states above
which we say that the transition is divergent. The default value is
commonly found in other libraries, and yet is arbitrary.
integrator
(algorithm parameter) The symplectic integrator to use to integrate the trajectory.
next_random_arg_fn
Function that generates the next `random_generator_arg` from its previous value.
integration_steps_fn
Function that generates the next pseudo or quasi-random number of integration steps in the
sequence, given the current `random_generator_arg`.
Returns
-------
A ``SamplingAlgorithm``.
"""

init = staticmethod(init)
build_kernel = staticmethod(build_kernel)

def __new__( # type: ignore[misc]
cls,
logdensity_fn: Callable,
step_size: float,
inverse_mass_matrix: Array,
*,
divergence_threshold: int = 1000,
integrator: Callable = integrators.velocity_verlet,
next_random_arg_fn: Callable = lambda key: jax.random.split(key)[1],
integration_steps_fn: Callable = lambda key: jax.random.randint(key, (), 1, 10),
) -> SamplingAlgorithm:
kernel = cls.build_kernel(
integrator, divergence_threshold, next_random_arg_fn, integration_steps_fn
)

def init_fn(position: ArrayLikeTree, random_generator_arg: Array):
return cls.init(position, logdensity_fn, random_generator_arg)

def step_fn(rng_key: PRNGKey, state):
return kernel(
rng_key,
state,
logdensity_fn,
step_size,
inverse_mass_matrix,
)

return SamplingAlgorithm(init_fn, step_fn) # type: ignore[arg-type]


def halton_sequence(i: Array, max_bits: int = 10) -> float:
bit_masks = 2 ** jnp.arange(max_bits, dtype=i.dtype)
return jnp.einsum("i,i->", jnp.mod((i + 1) // bit_masks, 2), 0.5 / bit_masks)


def rescale(mu):
# Returns s, such that `round(U(0, 1) * s + 0.5)` has expected value mu.
k = jnp.floor(2 * mu - 1)
x = k * (mu - 0.5 * (k + 1)) / (k + 1 - mu)
return k + x


def halton_trajectory_length(
i: Array, trajectory_length_adjustment: float, max_bits: int = 10
) -> int:
"""Generate a quasi-random number of integration steps."""
s = rescale(trajectory_length_adjustment)
return jnp.asarray(jnp.rint(0.5 + halton_sequence(i, max_bits) * s), dtype=int)
Loading

0 comments on commit df5a966

Please sign in to comment.