Skip to content

Commit

Permalink
Add CDE example and refactor modules (#18)
Browse files Browse the repository at this point in the history
* Fix forward and likelihood contribution

* Fix forward surjection

* Add conditional density estimator example

* Refactor MADE modules
  • Loading branch information
dirmeier authored Apr 13, 2023
1 parent d28c6cf commit 3a1bbec
Show file tree
Hide file tree
Showing 15 changed files with 162 additions and 38 deletions.
2 changes: 1 addition & 1 deletion Makefile
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
PKG_VERSION=`hatch version`

tag:
git tag -a v${PKG_VERSION} -m ${PKG_VERSION}
git tag -a v${PKG_VERSION} -m v${PKG_VERSION}
git push --tag
132 changes: 132 additions & 0 deletions examples/conditional_density_estimation.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,132 @@
import distrax
import haiku as hk
import jax
import numpy as np
import optax
from jax import numpy as jnp
from jax import random
from matplotlib import pyplot as plt

from surjectors import (
Chain,
MaskedAutoregressive,
MaskedCoupling,
Permutation,
TransformedDistribution,
)
from surjectors.conditioners import MADE, mlp_conditioner
from surjectors.util import (
as_batch_iterator,
make_alternating_binary_mask,
named_dataset,
unstack,
)


def make_model(dim, model="coupling"):
def _bijector_fn(params):
means, log_scales = unstack(params, -1)
return distrax.ScalarAffine(means, jnp.exp(log_scales))

def _flow(method, **kwargs):
layers = []
order = jnp.arange(2)
for i in range(2):
if model == "coupling":
mask = make_alternating_binary_mask(2, i % 2 == 0)
layer = MaskedCoupling(
mask=mask,
bijector=_bijector_fn,
conditioner=hk.Sequential(
[
mlp_conditioner([8, 8, dim * 2]),
hk.Reshape((dim, dim)),
]
),
)
layers.append(layer)
else:
layer = MaskedAutoregressive(
bijector_fn=_bijector_fn,
conditioner=MADE(
2,
[32, 32, 2 * 2],
2,
w_init=hk.initializers.TruncatedNormal(0.01),
b_init=jnp.zeros,
),
)
order = order[::-1]
layers.append(layer)
layers.append(Permutation(order, 1))
if model != "coupling":
layers = layers[:-1]
chain = Chain(layers)

base_distribution = distrax.Independent(
distrax.Normal(jnp.zeros(dim), jnp.ones(dim)),
reinterpreted_batch_ndims=1,
)
td = TransformedDistribution(base_distribution, chain)
return td(method=method, **kwargs)

td = hk.transform(_flow)
return td


def train(rng_seq, data, model, max_n_iter=1000):
train_iter = as_batch_iterator(next(rng_seq), data, 100, True)
params = model.init(next(rng_seq), method="log_prob", **train_iter(0))

optimizer = optax.adam(1e-4)
state = optimizer.init(params)

@jax.jit
def step(params, state, **batch):
def loss_fn(params):
lp = model.apply(params, None, method="log_prob", **batch)
return -jnp.sum(lp)

loss, grads = jax.value_and_grad(loss_fn)(params)
updates, new_state = optimizer.update(grads, state, params)
new_params = optax.apply_updates(params, updates)
return loss, new_params, new_state

losses = np.zeros(max_n_iter)
for i in range(max_n_iter):
train_loss = 0.0
for j in range(train_iter.num_batches):
batch = train_iter(j)
batch_loss, params, state = step(params, state, **batch)
train_loss += batch_loss
losses[i] = train_loss

return params, losses


def run():
n = 10000
thetas = distrax.Normal(jnp.zeros(2), jnp.full(2, 10)).sample(
seed=random.PRNGKey(0), sample_shape=(n,)
)
y = 2 * thetas + distrax.Normal(jnp.zeros_like(thetas), 0.1).sample(
seed=random.PRNGKey(1)
)
data = named_dataset(y, thetas)

model = make_model(2)
params, losses = train(hk.PRNGSequence(2), data, model)
samples = model.apply(
params,
random.PRNGKey(2),
method="sample",
x=jnp.full_like(thetas, -2.0),
)

plt.hist(samples[:, 0])
plt.hist(samples[:, 1])
plt.show()


if __name__ == "__main__":
run()
24 changes: 2 additions & 22 deletions examples/surjector_example.py
Original file line number Diff line number Diff line change
@@ -1,19 +1,3 @@
"""
Bayesian Neural Network
=======================
This example implements the training and prediction of a
Bayesian Neural Network.
Predictions from a Haiku MLP fro the same data are shown
as a reference.
References
----------
[1] Blundell C., Cornebise J., Kavukcuoglu K., Wierstra D.
"Weight Uncertainty in Neural Networks".
ICML, 2015.
"""


import distrax
import haiku as hk
import jax
Expand All @@ -23,12 +7,8 @@
from jax import random
from matplotlib import pyplot as plt

from surjectors import (
Chain,
MaskedCoupling,
TransformedDistribution,
mlp_conditioner,
)
from surjectors import Chain, MaskedCoupling, TransformedDistribution
from surjectors.conditioners import mlp_conditioner
from surjectors.util import (
as_batch_iterator,
make_alternating_binary_mask,
Expand Down
5 changes: 3 additions & 2 deletions surjectors/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,11 +2,12 @@
surjectors: Surjection layers for density estimation with normalizing flows
"""

__version__ = "0.2.1"
__version__ = "0.2.2"

from surjectors.bijectors.lu_linear import LULinear
from surjectors.bijectors.masked_autoregressive import MaskedAutoregressive
from surjectors.bijectors.masked_coupling import MaskedCoupling
from surjectors.conditioners import mlp_conditioner, transformer_conditioner
from surjectors.bijectors.permutation import Permutation
from surjectors.distributions.transformed_distribution import (
TransformedDistribution,
)
Expand Down
2 changes: 1 addition & 1 deletion surjectors/bijectors/masked_autoregressive.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,8 +4,8 @@
from jax import numpy as jnp

from surjectors.bijectors._bijector import _Bijector
from surjectors.conditioners import MADE
from surjectors.distributions.transformed_distribution import Array
from surjectors.nn.made import MADE


# pylint: disable=too-many-arguments, arguments-renamed
Expand Down
2 changes: 1 addition & 1 deletion surjectors/bijectors/masked_autoregressive_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@

from surjectors import TransformedDistribution
from surjectors.bijectors.masked_autoregressive import MaskedAutoregressive
from surjectors.nn.made import MADE
from surjectors.conditioners import MADE
from surjectors.util import unstack


Expand Down
1 change: 1 addition & 0 deletions surjectors/conditioners/__init__.py
Original file line number Diff line number Diff line change
@@ -1,2 +1,3 @@
from .mlp import mlp_conditioner
from .nn.made import MADE
from .transformer import transformer_conditioner
5 changes: 4 additions & 1 deletion surjectors/conditioners/mlp.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,5 +19,8 @@ def mlp_conditioner(
"""

return hk.nets.MLP(
output_sizes=dims, w_init=w_init, b_init=b_init, activation=activation
output_sizes=dims,
w_init=w_init,
b_init=b_init,
activation=activation,
)
File renamed without changes.
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@
_make_dense_autoregressive_masks,
)

from surjectors.nn.masked_linear import MaskedLinear
from surjectors.conditioners.nn.masked_linear import MaskedLinear


# pylint: disable=too-many-arguments, arguments-renamed
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@
from jax import numpy as jnp
from jax import random

from surjectors.nn.made import MADE
from surjectors.conditioners import MADE
from surjectors.util import unstack


Expand Down
File renamed without changes.
4 changes: 3 additions & 1 deletion surjectors/distributions/transformed_distribution.py
Original file line number Diff line number Diff line change
Expand Up @@ -115,7 +115,9 @@ def sample_and_log_prob(self, sample_shape=(), x: Array = None):
transformation, the second one is its log probability
"""

if x is not None and len(sample_shape) > 0:
if x is not None and len(sample_shape) == 0:
sample_shape = (x.shape[0],)
if x is not None:
chex.assert_equal(sample_shape[0], x.shape[0])

z, lp_z = self.base_distribution.sample_and_log_prob(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -6,8 +6,8 @@
from jax import random

from surjectors import TransformedDistribution
from surjectors.conditioners import MADE
from surjectors.conditioners.mlp import mlp_conditioner
from surjectors.nn.made import MADE
from surjectors.surjectors.affine_masked_autoregressive_inference_funnel import ( # noqa: E501
AffineMaskedAutoregressiveInferenceFunnel,
)
Expand Down
17 changes: 11 additions & 6 deletions surjectors/surjectors/chain.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,21 +34,26 @@ def _inverse_and_log_contribution_dispatch(surjector, y, x):
return z, lc

def forward_and_likelihood_contribution(self, z, x=None, **kwargs):
y, log_det = self._surjectors[-1].forward_and_log_det(z, x)
y, log_det = self._forward_and_log_contribution_dispatch(
self._surjectors[-1], z, x
)
for surjector in reversed(self._surjectors[:-1]):
y, lc = self._forward_and_log_contribution_dispatch(surjector, y, x)
log_det += lc
return y, log_det

@staticmethod
def _forward_and_log_contribution_dispatch(surjector, y, x):
def _forward_and_log_contribution_dispatch(surjector, z, x):
if isinstance(surjector, Surjector):
fn = getattr(surjector, "forward_and_likelihood_contribution")
z, lc = fn(y, x)
if hasattr(surjector, "forward_and_likelihood_contribution"):
fn = getattr(surjector, "forward_and_likelihood_contribution")
else:
fn = getattr(surjector, "forward_and_log_det")
y, lc = fn(z, x)
else:
fn = getattr(surjector, "forward_and_log_det")
z, lc = fn(y)
return z, lc
y, lc = fn(z)
return y, lc

def forward(self, z, x=None):
y, _ = self.forward_and_likelihood_contribution(z, x)
Expand Down

0 comments on commit 3a1bbec

Please sign in to comment.