-
Notifications
You must be signed in to change notification settings - Fork 108
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Refactor dynamic_hmc out of hmc.py (#622)
* Refactor dynamic_hmc out of hmc.py * Fix formatting
- Loading branch information
1 parent
93af4cf
commit df5a966
Showing
5 changed files
with
207 additions
and
165 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,198 @@ | ||
# Copyright 2020- The Blackjax Authors. | ||
# | ||
# Licensed under the Apache License, Version 2.0 (the "License"); | ||
# you may not use this file except in compliance with the License. | ||
# You may obtain a copy of the License at | ||
# | ||
# http://www.apache.org/licenses/LICENSE-2.0 | ||
# | ||
# Unless required by applicable law or agreed to in writing, software | ||
# distributed under the License is distributed on an "AS IS" BASIS, | ||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
# See the License for the specific language governing permissions and | ||
# limitations under the License. | ||
"""Public API for the Dynamic HMC Kernel""" | ||
from typing import Callable, NamedTuple | ||
|
||
import jax | ||
import jax.numpy as jnp | ||
|
||
import blackjax.mcmc.integrators as integrators | ||
from blackjax.base import SamplingAlgorithm | ||
from blackjax.mcmc.hmc import HMCInfo, HMCState | ||
from blackjax.mcmc.hmc import build_kernel as build_static_hmc_kernel | ||
from blackjax.types import Array, ArrayLikeTree, ArrayTree, PRNGKey | ||
|
||
__all__ = [ | ||
"DynamicHMCState", | ||
"init", | ||
"build_kernel", | ||
"dynamic_hmc", | ||
"halton_sequence", | ||
] | ||
|
||
|
||
class DynamicHMCState(NamedTuple): | ||
"""State of the dynamic HMC algorithm. | ||
Adds a utility array for generating a pseudo or quasi-random sequence of | ||
number of integration steps. | ||
""" | ||
|
||
position: ArrayTree | ||
logdensity: float | ||
logdensity_grad: ArrayTree | ||
random_generator_arg: Array | ||
|
||
|
||
def init(position: ArrayLikeTree, logdensity_fn: Callable, random_generator_arg: Array): | ||
logdensity, logdensity_grad = jax.value_and_grad(logdensity_fn)(position) | ||
return DynamicHMCState(position, logdensity, logdensity_grad, random_generator_arg) | ||
|
||
|
||
def build_kernel( | ||
integrator: Callable = integrators.velocity_verlet, | ||
divergence_threshold: float = 1000, | ||
next_random_arg_fn: Callable = lambda key: jax.random.split(key)[1], | ||
integration_steps_fn: Callable = lambda key: jax.random.randint(key, (), 1, 10), | ||
): | ||
"""Build a Dynamic HMC kernel where the number of integration steps is chosen randomly. | ||
Parameters | ||
---------- | ||
integrator | ||
The symplectic integrator to use to integrate the Hamiltonian dynamics. | ||
divergence_threshold | ||
Value of the difference in energy above which we consider that the transition is divergent. | ||
next_random_arg_fn | ||
Function that generates the next `random_generator_arg` from its previous value. | ||
integration_steps_fn | ||
Function that generates the next pseudo or quasi-random number of integration steps in the | ||
sequence, given the current `random_generator_arg`. Needs to return an `int`. | ||
Returns | ||
------- | ||
A kernel that takes a rng_key and a Pytree that contains the current state | ||
of the chain and that returns a new state of the chain along with | ||
information about the transition. | ||
""" | ||
hmc_base = build_static_hmc_kernel(integrator, divergence_threshold) | ||
|
||
def kernel( | ||
rng_key: PRNGKey, | ||
state: DynamicHMCState, | ||
logdensity_fn: Callable, | ||
step_size: float, | ||
inverse_mass_matrix: Array, | ||
**integration_steps_kwargs, | ||
) -> tuple[DynamicHMCState, HMCInfo]: | ||
"""Generate a new sample with the HMC kernel.""" | ||
num_integration_steps = integration_steps_fn( | ||
state.random_generator_arg, **integration_steps_kwargs | ||
) | ||
hmc_state = HMCState(state.position, state.logdensity, state.logdensity_grad) | ||
hmc_proposal, info = hmc_base( | ||
rng_key, | ||
hmc_state, | ||
logdensity_fn, | ||
step_size, | ||
inverse_mass_matrix, | ||
num_integration_steps, | ||
) | ||
next_random_arg = next_random_arg_fn(state.random_generator_arg) | ||
return ( | ||
DynamicHMCState( | ||
hmc_proposal.position, | ||
hmc_proposal.logdensity, | ||
hmc_proposal.logdensity_grad, | ||
next_random_arg, | ||
), | ||
info, | ||
) | ||
|
||
return kernel | ||
|
||
|
||
class dynamic_hmc: | ||
"""Implements the (basic) user interface for the dynamic HMC kernel. | ||
Parameters | ||
---------- | ||
logdensity_fn | ||
The log-density function we wish to draw samples from. | ||
step_size | ||
The value to use for the step size in the symplectic integrator. | ||
inverse_mass_matrix | ||
The value to use for the inverse mass matrix when drawing a value for | ||
the momentum and computing the kinetic energy. | ||
divergence_threshold | ||
The absolute value of the difference in energy between two states above | ||
which we say that the transition is divergent. The default value is | ||
commonly found in other libraries, and yet is arbitrary. | ||
integrator | ||
(algorithm parameter) The symplectic integrator to use to integrate the trajectory. | ||
next_random_arg_fn | ||
Function that generates the next `random_generator_arg` from its previous value. | ||
integration_steps_fn | ||
Function that generates the next pseudo or quasi-random number of integration steps in the | ||
sequence, given the current `random_generator_arg`. | ||
Returns | ||
------- | ||
A ``SamplingAlgorithm``. | ||
""" | ||
|
||
init = staticmethod(init) | ||
build_kernel = staticmethod(build_kernel) | ||
|
||
def __new__( # type: ignore[misc] | ||
cls, | ||
logdensity_fn: Callable, | ||
step_size: float, | ||
inverse_mass_matrix: Array, | ||
*, | ||
divergence_threshold: int = 1000, | ||
integrator: Callable = integrators.velocity_verlet, | ||
next_random_arg_fn: Callable = lambda key: jax.random.split(key)[1], | ||
integration_steps_fn: Callable = lambda key: jax.random.randint(key, (), 1, 10), | ||
) -> SamplingAlgorithm: | ||
kernel = cls.build_kernel( | ||
integrator, divergence_threshold, next_random_arg_fn, integration_steps_fn | ||
) | ||
|
||
def init_fn(position: ArrayLikeTree, random_generator_arg: Array): | ||
return cls.init(position, logdensity_fn, random_generator_arg) | ||
|
||
def step_fn(rng_key: PRNGKey, state): | ||
return kernel( | ||
rng_key, | ||
state, | ||
logdensity_fn, | ||
step_size, | ||
inverse_mass_matrix, | ||
) | ||
|
||
return SamplingAlgorithm(init_fn, step_fn) # type: ignore[arg-type] | ||
|
||
|
||
def halton_sequence(i: Array, max_bits: int = 10) -> float: | ||
bit_masks = 2 ** jnp.arange(max_bits, dtype=i.dtype) | ||
return jnp.einsum("i,i->", jnp.mod((i + 1) // bit_masks, 2), 0.5 / bit_masks) | ||
|
||
|
||
def rescale(mu): | ||
# Returns s, such that `round(U(0, 1) * s + 0.5)` has expected value mu. | ||
k = jnp.floor(2 * mu - 1) | ||
x = k * (mu - 0.5 * (k + 1)) / (k + 1 - mu) | ||
return k + x | ||
|
||
|
||
def halton_trajectory_length( | ||
i: Array, trajectory_length_adjustment: float, max_bits: int = 10 | ||
) -> int: | ||
"""Generate a quasi-random number of integration steps.""" | ||
s = rescale(trajectory_length_adjustment) | ||
return jnp.asarray(jnp.rint(0.5 + halton_sequence(i, max_bits) * s), dtype=int) |
Oops, something went wrong.