Skip to content

Commit

Permalink
Improve example (#9)
Browse files Browse the repository at this point in the history
  • Loading branch information
dirmeier authored Jan 23, 2023
1 parent 0fc53a1 commit aea66c6
Showing 1 changed file with 34 additions and 18 deletions.
52 changes: 34 additions & 18 deletions examples/surjector_example.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,19 @@
"""
Bayesian Neural Network
=======================
This example implements the training and prediction of a
Bayesian Neural Network.
Predictions from a Haiku MLP fro the same data are shown
as a reference.
References
----------
[1] Blundell C., Cornebise J., Kavukcuoglu K., Wierstra D.
"Weight Uncertainty in Neural Networks".
ICML, 2015.
"""


import distrax
import haiku as hk
import jax
Expand All @@ -19,15 +35,6 @@
named_dataset,
)

n = 1000

prior = distrax.Uniform(jnp.full(2, -2), jnp.full(2, 2))
theta = prior.sample(seed=random.PRNGKey(0), sample_shape=(n,))

likelihood = distrax.MultivariateNormalDiag(theta, jnp.ones_like(theta))
y = likelihood.sample(seed=random.PRNGKey(1))
data = named_dataset(y, theta)


def simulator_fn(seed, theta):
p_noise = distrax.Normal(jnp.zeros_like(theta), 1.0)
Expand Down Expand Up @@ -57,14 +64,14 @@ def _flow(method, **kwargs):
reinterpreted_batch_ndims=1,
)
td = TransformedDistribution(base_distribution, chain)
return td(method, **kwargs)
return td.log_prob(method, **kwargs)

td = hk.transform(_flow)
td = hk.without_apply_rng(td)
return td


def train(rng_seq, max_n_iter=1000):
def train(rng_seq, data, model, max_n_iter=1000):
train_iter = as_batch_iterator(next(rng_seq), data, 100, True)
params = model.init(next(rng_seq), method="log_prob", **train_iter(0))

Expand Down Expand Up @@ -94,14 +101,23 @@ def loss_fn(params):
return params, losses


model = make_model(2)
params, losses = train(hk.PRNGSequence(2))
def run():
n = 1000
prior = distrax.Uniform(jnp.full(2, -2), jnp.full(2, 2))
theta = prior.sample(seed=random.PRNGKey(0), sample_shape=(n,))
likelihood = distrax.MultivariateNormalDiag(theta, jnp.ones_like(theta))
y = likelihood.sample(seed=random.PRNGKey(1))
data = named_dataset(y, theta)

plt.plot(losses)
plt.show()
model = make_model(2)
params, losses = train(hk.PRNGSequence(2), data, model)
plt.plot(losses)
plt.show()

theta = jnp.ones((5, 2))
data = jnp.repeat(jnp.arange(5), 2).reshape(-1, 2)
print(model.apply(params, method="log_prob", **{"y": data, "x": theta}))

theta = jnp.ones((5, 2))
data = jnp.repeat(jnp.arange(5), 2).reshape(-1, 2)

print(model.apply(params, method="log_prob", **{"y": data, "x": theta}))
if __name__ == "__main__":
run()

0 comments on commit aea66c6

Please sign in to comment.