Skip to content

Commit

Permalink
jax.tree_map -> jax.tree.map
Browse files Browse the repository at this point in the history
Also ignore warning from upstream dependency
  • Loading branch information
jessegrabowski authored and ricardoV94 committed Apr 16, 2024
1 parent dc05dcc commit 774ddff
Show file tree
Hide file tree
Showing 2 changed files with 8 additions and 5 deletions.
10 changes: 5 additions & 5 deletions pymc_experimental/tests/test_blackjax_smc.py
Original file line number Diff line number Diff line change
Expand Up @@ -133,7 +133,7 @@ def test_blackjax_particles_from_pymc_population_univariate():
model = fast_model()
population = {"x": np.array([2, 3, 4])}
blackjax_particles = blackjax_particles_from_pymc_population(model, population)
jax.tree_map(np.testing.assert_allclose, blackjax_particles, [np.array([[2], [3], [4]])])
jax.tree.map(np.testing.assert_allclose, blackjax_particles, [np.array([[2], [3], [4]])])


def test_blackjax_particles_from_pymc_population_multivariate():
Expand All @@ -144,7 +144,7 @@ def test_blackjax_particles_from_pymc_population_multivariate():

population = {"x": np.array([0.34614613, 1.09163261, -0.44526825]), "z": np.array([1, 2, 3])}
blackjax_particles = blackjax_particles_from_pymc_population(model, population)
jax.tree_map(
jax.tree.map(
np.testing.assert_allclose,
blackjax_particles,
[np.array([[0.34614613], [1.09163261], [-0.44526825]]), np.array([[1], [2], [3]])],
Expand All @@ -168,7 +168,7 @@ def test_blackjax_particles_from_pymc_population_multivariable():
population = {"x": np.array([[2, 3], [5, 6], [7, 9]]), "z": np.array([11, 12, 13])}
blackjax_particles = blackjax_particles_from_pymc_population(model, population)

jax.tree_map(
jax.tree.map(
np.testing.assert_allclose,
blackjax_particles,
[np.array([[2, 3], [5, 6], [7, 9]]), np.array([[11], [12], [13]])],
Expand Down Expand Up @@ -196,7 +196,7 @@ def test_get_jaxified_logprior():
"""
logprior = get_jaxified_logprior(fast_model())
for point in [-0.5, 0.0, 0.5]:
jax.tree_map(
jax.tree.map(
np.testing.assert_allclose,
jax.vmap(logprior)([np.array([point])]),
np.log(scipy.stats.norm(0, 1).pdf(point)),
Expand All @@ -212,7 +212,7 @@ def test_get_jaxified_loglikelihood():
"""
loglikelihood = get_jaxified_loglikelihood(fast_model())
for point in [-0.5, 0.0, 0.5]:
jax.tree_map(
jax.tree.map(
np.testing.assert_allclose,
jax.vmap(loglikelihood)([np.array([point])]),
np.log(scipy.stats.norm(point, 1).pdf(0)),
Expand Down
3 changes: 3 additions & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,9 @@ filterwarnings =[

# JAX issues an over-eager warning if os.fork() is called when the JAX module is loaded, even if JAX isn't being used
'ignore:os\.fork\(\) was called\.:RuntimeWarning',

# Warning coming from blackjax
'ignore:jax\.tree_map is deprecated:DeprecationWarning',
]

[tool.black]
Expand Down

0 comments on commit 774ddff

Please sign in to comment.