From 393b13324880f7aeac2ab50d921b78c366b92ec5 Mon Sep 17 00:00:00 2001 From: Simon Dirmeier Date: Fri, 28 Jul 2023 15:11:24 +0200 Subject: [PATCH 1/3] Added more examples --- README.md | 6 +- .../autoregressive_inference_surjection.py | 116 ++++++++++++++++++ ...le.py => coupling_inference_surjection.py} | 83 +++++++------ surjectors/__init__.py | 5 +- ...ed_autoregressive_inference_funnel_test.py | 4 +- ...e_masked_coupling_inference_funnel_test.py | 73 +++++++++++ 6 files changed, 247 insertions(+), 40 deletions(-) create mode 100644 examples/autoregressive_inference_surjection.py rename examples/{surjector_example.py => coupling_inference_surjection.py} (50%) create mode 100644 surjectors/surjectors/affine_masked_coupling_inference_funnel_test.py diff --git a/README.md b/README.md index b8e3819..042023d 100644 --- a/README.md +++ b/README.md @@ -10,15 +10,17 @@ Surjectors is a light-weight library of inference and generative surjection layers, i.e., layers that reduce or increase dimensionality, for density estimation using normalizing flows. Surjectors builds on Distrax and Haiku and is fully compatible with both of them. -## Example usage +## Examples -TODO +You can find several self-contained examples on how to use the algorithms in `examples`. ## Installation Make sure to have a working `JAX` installation. Depending whether you want to use CPU/GPU/TPU, please follow [these instructions](https://github.com/google/jax#installation). +You can install + To install the latest GitHub , just call the following on the command line: ```bash diff --git a/examples/autoregressive_inference_surjection.py b/examples/autoregressive_inference_surjection.py new file mode 100644 index 0000000..cd2adbf --- /dev/null +++ b/examples/autoregressive_inference_surjection.py @@ -0,0 +1,116 @@ +from collections import namedtuple + +import distrax +import haiku as hk +import jax +import numpy as np +import optax +from jax import numpy as jnp +from jax import random as jr +from matplotlib import pyplot as plt + +from surjectors import ( + AffineMaskedAutoregressiveInferenceFunnel, + Chain, + MaskedAutoregressive, + TransformedDistribution, +) +from surjectors.conditioners import MADE, mlp_conditioner +from surjectors.util import as_batch_iterator, unstack + + +def _decoder_fn(n_dim): + decoder_net = mlp_conditioner([4, 4, n_dim * 2]) + + def _fn(z): + params = decoder_net(z) + mu, log_scale = jnp.split(params, 2, -1) + return distrax.Independent(distrax.Normal(mu, jnp.exp(log_scale))) + + return _fn + + +def _made_bijector_fn(params): + means, log_scales = unstack(params, -1) + return distrax.Inverse(distrax.ScalarAffine(means, jnp.exp(log_scales))) + + +def make_model(n_dimensions): + def _flow(**kwargs): + n_dim = n_dimensions + layers = [] + for i in range(3): + if i != 1: + layer = AffineMaskedAutoregressiveInferenceFunnel( + n_keep=int(n_dim / 2), + decoder=_decoder_fn(int(n_dim / 2)), + conditioner=MADE(int(n_dim / 2), [8, 8], 2), + ) + n_dim = int(n_dim / 2) + else: + layer = MaskedAutoregressive( + conditioner=MADE(n_dim, [8, 8], 2), + bijector_fn=_made_bijector_fn, + ) + layers.append(layer) + chain = Chain(layers) + + base_distribution = distrax.Independent( + distrax.Normal(jnp.zeros(n_dim), jnp.ones(n_dim)), + reinterpreted_batch_ndims=1, + ) + td = TransformedDistribution(base_distribution, chain) + return td.log_prob(**kwargs) + + td = hk.transform(_flow) + td = hk.without_apply_rng(td) + return td + + +def train(rng_seq, data, model, max_n_iter=1000): + train_iter = as_batch_iterator(next(rng_seq), data, 100, True) + params = model.init(next(rng_seq), **train_iter(0)) + + optimizer = optax.adam(1e-4) + state = optimizer.init(params) + + @jax.jit + def step(params, state, **batch): + def loss_fn(params): + lp = model.apply(params, **batch) + return -jnp.sum(lp) + + loss, grads = jax.value_and_grad(loss_fn)(params) + updates, new_state = optimizer.update(grads, state, params) + new_params = optax.apply_updates(params, updates) + return loss, new_params, new_state + + losses = np.zeros(max_n_iter) + for i in range(max_n_iter): + train_loss = 0.0 + for j in range(train_iter.num_batches): + batch = train_iter(j) + batch_loss, params, state = step(params, state, **batch) + train_loss += batch_loss + losses[i] = train_loss + + return params, losses + + +def run(): + n, p = 1000, 20 + rng_seq = hk.PRNGSequence(2) + y = jr.normal(next(rng_seq), shape=(n, p)) + data = namedtuple("named_dataset", "y")(y) + + model = make_model(p) + params, losses = train(rng_seq, data, model) + plt.plot(losses) + plt.show() + + y = jr.normal(next(rng_seq), shape=(10, p)) + print(model.apply(params, **{"y": y})) + + +if __name__ == "__main__": + run() diff --git a/examples/surjector_example.py b/examples/coupling_inference_surjection.py similarity index 50% rename from examples/surjector_example.py rename to examples/coupling_inference_surjection.py index 61859a4..35a43b3 100644 --- a/examples/surjector_example.py +++ b/examples/coupling_inference_surjection.py @@ -1,46 +1,62 @@ +from collections import namedtuple + import distrax import haiku as hk import jax import numpy as np import optax from jax import numpy as jnp -from jax import random +from jax import random as jr from matplotlib import pyplot as plt -from surjectors import Chain, MaskedCoupling, TransformedDistribution -from surjectors.conditioners import mlp_conditioner -from surjectors.util import ( - as_batch_iterator, - make_alternating_binary_mask, - named_dataset, +from surjectors import ( + AffineMaskedCouplingInferenceFunnel, + Chain, + MaskedAutoregressive, + TransformedDistribution, ) +from surjectors.conditioners import MADE, mlp_conditioner +from surjectors.util import as_batch_iterator, unstack + + +def _decoder_fn(n_dim): + decoder_net = mlp_conditioner([4, 4, n_dim * 2]) + def _fn(z): + params = decoder_net(z) + mu, log_scale = jnp.split(params, 2, -1) + return distrax.Independent(distrax.Normal(mu, jnp.exp(log_scale))) -def simulator_fn(seed, theta): - p_noise = distrax.Normal(jnp.zeros_like(theta), 1.0) - noise = p_noise.sample(seed=seed) - return theta + 0.1 * noise + return _fn -def make_model(dim): - def _bijector_fn(params): - means, log_scales = jnp.split(params, 2, -1) - return distrax.ScalarAffine(means, jnp.exp(log_scales)) +def _made_bijector_fn(params): + means, log_scales = unstack(params, -1) + return distrax.Inverse(distrax.ScalarAffine(means, jnp.exp(log_scales))) + +def make_model(n_dimensions): def _flow(**kwargs): + n_dim = n_dimensions layers = [] - for i in range(2): - mask = make_alternating_binary_mask(dim, i % 2 == 0) - layer = MaskedCoupling( - mask=mask, - bijector=_bijector_fn, - conditioner=mlp_conditioner([8, 8, dim * 2]), - ) + for i in range(3): + if i != 1: + layer = AffineMaskedCouplingInferenceFunnel( + n_keep=int(n_dim / 2), + decoder=_decoder_fn(int(n_dim / 2)), + conditioner=mlp_conditioner([8, 8, n_dim * 2]), + ) + n_dim = int(n_dim / 2) + else: + layer = MaskedAutoregressive( + conditioner=MADE(n_dim, [8, 8], 2), + bijector_fn=_made_bijector_fn, + ) layers.append(layer) chain = Chain(layers) base_distribution = distrax.Independent( - distrax.Normal(jnp.zeros(dim), jnp.ones(dim)), + distrax.Normal(jnp.zeros(n_dim), jnp.ones(n_dim)), reinterpreted_batch_ndims=1, ) td = TransformedDistribution(base_distribution, chain) @@ -82,21 +98,18 @@ def loss_fn(params): def run(): - n = 1000 - prior = distrax.Uniform(jnp.full(2, -2), jnp.full(2, 2)) - theta = prior.sample(seed=random.PRNGKey(0), sample_shape=(n,)) - likelihood = distrax.MultivariateNormalDiag(theta, jnp.ones_like(theta)) - y = likelihood.sample(seed=random.PRNGKey(1)) - data = named_dataset(y, theta) - - model = make_model(2) - params, losses = train(hk.PRNGSequence(2), data, model) + n, p = 1000, 20 + rng_seq = hk.PRNGSequence(2) + y = jr.normal(next(rng_seq), shape=(n, p)) + data = namedtuple("named_dataset", "y")(y) + + model = make_model(p) + params, losses = train(rng_seq, data, model) plt.plot(losses) plt.show() - theta = jnp.ones((5, 2)) - data = jnp.repeat(jnp.arange(5), 2).reshape(-1, 2) - print(model.apply(params, **{"y": data, "x": theta})) + y = jr.normal(next(rng_seq), shape=(10, p)) + print(model.apply(params, **{"y": y})) if __name__ == "__main__": diff --git a/surjectors/__init__.py b/surjectors/__init__.py index 00f31c3..ae7dc56 100644 --- a/surjectors/__init__.py +++ b/surjectors/__init__.py @@ -2,7 +2,7 @@ surjectors: Surjection layers for density estimation with normalizing flows """ -__version__ = "0.2.2" +__version__ = "0.2.3" from surjectors.bijectors.lu_linear import LULinear from surjectors.bijectors.masked_autoregressive import MaskedAutoregressive @@ -11,6 +11,9 @@ from surjectors.distributions.transformed_distribution import ( TransformedDistribution, ) +from surjectors.surjectors.affine_masked_autoregressive_inference_funnel import ( # noqa: E501 + AffineMaskedAutoregressiveInferenceFunnel, +) from surjectors.surjectors.affine_masked_coupling_generative_funnel import ( AffineMaskedCouplingGenerativeFunnel, ) diff --git a/surjectors/surjectors/affine_masked_autoregressive_inference_funnel_test.py b/surjectors/surjectors/affine_masked_autoregressive_inference_funnel_test.py index 86ba5e6..11662cb 100644 --- a/surjectors/surjectors/affine_masked_autoregressive_inference_funnel_test.py +++ b/surjectors/surjectors/affine_masked_autoregressive_inference_funnel_test.py @@ -13,7 +13,7 @@ ) -def _conditional_fn(n_dim): +def _decoder_fn(n_dim): decoder_net = mlp_conditioner([4, 4, n_dim * 2]) def _fn(z): @@ -35,7 +35,7 @@ def _base_distribution_fn(n_latent): def _get_funnel_surjector(n_latent, n_dimension): return AffineMaskedAutoregressiveInferenceFunnel( n_latent, - _conditional_fn(n_dimension - n_latent), + _decoder_fn(n_dimension - n_latent), MADE(n_latent, [4, 4], 2), ) diff --git a/surjectors/surjectors/affine_masked_coupling_inference_funnel_test.py b/surjectors/surjectors/affine_masked_coupling_inference_funnel_test.py new file mode 100644 index 0000000..10c725a --- /dev/null +++ b/surjectors/surjectors/affine_masked_coupling_inference_funnel_test.py @@ -0,0 +1,73 @@ +# pylint: skip-file + +import distrax +import haiku as hk +from jax import numpy as jnp +from jax import random + +from surjectors import TransformedDistribution +from surjectors.conditioners.mlp import mlp_conditioner +from surjectors.surjectors.affine_masked_coupling_inference_funnel import ( # noqa: E501 + AffineMaskedCouplingInferenceFunnel, +) + + +def _decoder_fn(n_dim): + decoder_net = mlp_conditioner([4, 4, n_dim * 2]) + + def _fn(z): + params = decoder_net(z) + mu, log_scale = jnp.split(params, 2, -1) + return distrax.Independent(distrax.Normal(mu, jnp.exp(log_scale))) + + return _fn + + +def _base_distribution_fn(n_latent): + base_distribution = distrax.Independent( + distrax.Normal(jnp.zeros(n_latent), jnp.ones(n_latent)), + reinterpreted_batch_ndims=1, + ) + return base_distribution + + +def _get_funnel_surjector(n_latent, n_dimension): + return AffineMaskedCouplingInferenceFunnel( + n_latent, + _decoder_fn(n_dimension - n_latent), + mlp_conditioner([4, 4, n_latent * 2]), + ) + + +def make_surjector(n_dimension, n_latent): + def _transformation_fn(n_dimension): + funnel = _get_funnel_surjector(n_latent, n_dimension) + return funnel + + def _flow(method, **kwargs): + td = TransformedDistribution( + _base_distribution_fn(n_latent), _transformation_fn(n_dimension) + ) + return td(method, **kwargs) + + td = hk.transform(_flow) + return td + + +def test_affine_masked_coupling_inference_funnel(): + n_dimension, n_latent = 4, 2 + y = random.normal(random.PRNGKey(1), shape=(10, n_dimension)) + + flow = make_surjector(n_dimension, n_latent) + params = flow.init(random.PRNGKey(0), method="log_prob", y=y) + _ = flow.apply(params, None, method="log_prob", y=y) + + +def test_conditional_affine_masked_coupling_inference_funnel(): + n_dimension, n_latent = 4, 2 + y = random.normal(random.PRNGKey(1), shape=(10, n_dimension)) + x = random.normal(random.PRNGKey(1), shape=(10, 2)) + + flow = make_surjector(n_dimension, n_latent) + params = flow.init(random.PRNGKey(0), method="log_prob", y=y, x=x) + _ = flow.apply(params, None, method="log_prob", y=y, x=x) From 5590d0789189be1c5701035d95723cfd503999c8 Mon Sep 17 00:00:00 2001 From: Simon Dirmeier Date: Fri, 28 Jul 2023 15:23:20 +0200 Subject: [PATCH 2/3] Added more examples --- .../surjectors/affine_masked_coupling_inference_funnel_test.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/surjectors/surjectors/affine_masked_coupling_inference_funnel_test.py b/surjectors/surjectors/affine_masked_coupling_inference_funnel_test.py index 10c725a..a2856ad 100644 --- a/surjectors/surjectors/affine_masked_coupling_inference_funnel_test.py +++ b/surjectors/surjectors/affine_masked_coupling_inference_funnel_test.py @@ -35,7 +35,7 @@ def _get_funnel_surjector(n_latent, n_dimension): return AffineMaskedCouplingInferenceFunnel( n_latent, _decoder_fn(n_dimension - n_latent), - mlp_conditioner([4, 4, n_latent * 2]), + mlp_conditioner([4, 4, n_dimension * 2]), ) From eb68d9df20b1a5998033836294da28215a87664e Mon Sep 17 00:00:00 2001 From: Simon Dirmeier Date: Fri, 28 Jul 2023 15:27:14 +0200 Subject: [PATCH 3/3] Added more examples --- surjectors/conditioners/nn/masked_linear.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/surjectors/conditioners/nn/masked_linear.py b/surjectors/conditioners/nn/masked_linear.py index 965bcb2..8e60311 100644 --- a/surjectors/conditioners/nn/masked_linear.py +++ b/surjectors/conditioners/nn/masked_linear.py @@ -7,7 +7,7 @@ from jax import numpy as jnp -# pylint: disable=too-many-arguments +# pylint: disable=too-many-arguments,too-few-public-methods class MaskedLinear(hk.Linear): """ Linear layer that masks some weights out