Skip to content

Commit

Permalink
MAKE WINDOW ADAPTATION TAKE INTEGRATOR AS ARGUMENT (#687)
Browse files Browse the repository at this point in the history
  • Loading branch information
reubenharry authored Jun 3, 2024
1 parent 360ac3b commit 3fbdac6
Showing 1 changed file with 4 additions and 2 deletions.
6 changes: 4 additions & 2 deletions blackjax/adaptation/window_adaptation.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@
import jax
import jax.numpy as jnp

import blackjax.mcmc as mcmc
from blackjax.adaptation.base import AdaptationResults, return_all_adapt_info
from blackjax.adaptation.mass_matrix import (
MassMatrixAdaptationState,
Expand Down Expand Up @@ -249,10 +250,11 @@ def window_adaptation(
target_acceptance_rate: float = 0.80,
progress_bar: bool = False,
adaptation_info_fn: Callable = return_all_adapt_info,
integrator=mcmc.integrators.velocity_verlet,
**extra_parameters,
) -> AdaptationAlgorithm:
"""Adapt the value of the inverse mass matrix and step size parameters of
algorithms in the HMC family. See Blackjax.hmc_family
algorithms in the HMC fmaily. See Blackjax.hmc_family
Algorithms in the HMC family on a euclidean manifold depend on the value of
at least two parameters: the step size, related to the trajectory
Expand Down Expand Up @@ -294,7 +296,7 @@ def window_adaptation(
"""

mcmc_kernel = algorithm.build_kernel()
mcmc_kernel = algorithm.build_kernel(integrator)

adapt_init, adapt_step, adapt_final = base(
is_mass_matrix_diagonal,
Expand Down

0 comments on commit 3fbdac6

Please sign in to comment.