diff --git a/pymc_experimental/tests/model/transforms/test_autoreparam.py b/pymc_experimental/tests/model/transforms/test_autoreparam.py index 9749894e..b2ea245a 100644 --- a/pymc_experimental/tests/model/transforms/test_autoreparam.py +++ b/pymc_experimental/tests/model/transforms/test_autoreparam.py @@ -70,9 +70,9 @@ def test_multilevel(): # multilevel modelling a = pm.Normal("a") s = pm.HalfNormal("s") - a_g = pm.Normal("a_g", a, s, dims="level") + a_g = pm.Normal("a_g", a, s, shape=(2,), dims="level") s_g = pm.HalfNormal("s_g") - a_ig = pm.Normal("a_ig", a_g, s_g, dims=("county", "level")) + a_ig = pm.Normal("a_ig", a_g, s_g, shape=(2, 2), dims=("county", "level")) model_r, vip = vip_reparametrize(model, ["a_g", "a_ig"]) assert "a_g" in vip.get_lambda()