Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Added more examples #19

Merged
merged 3 commits into from
Jul 28, 2023
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Next Next commit
Added more examples
  • Loading branch information
dirmeier committed Jul 28, 2023
commit 393b13324880f7aeac2ab50d921b78c366b92ec5
6 changes: 4 additions & 2 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -10,15 +10,17 @@
Surjectors is a light-weight library of inference and generative surjection layers, i.e., layers that reduce or increase dimensionality, for density estimation using normalizing flows.
Surjectors builds on Distrax and Haiku and is fully compatible with both of them.

## Example usage
## Examples

TODO
You can find several self-contained examples on how to use the algorithms in `examples`.

## Installation

Make sure to have a working `JAX` installation. Depending whether you want to use CPU/GPU/TPU,
please follow [these instructions](https://github.com/google/jax#installation).

You can install

To install the latest GitHub <RELEASE>, just call the following on the command line:

```bash
Expand Down
116 changes: 116 additions & 0 deletions examples/autoregressive_inference_surjection.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,116 @@
from collections import namedtuple

import distrax
import haiku as hk
import jax
import numpy as np
import optax
from jax import numpy as jnp
from jax import random as jr
from matplotlib import pyplot as plt

from surjectors import (
AffineMaskedAutoregressiveInferenceFunnel,
Chain,
MaskedAutoregressive,
TransformedDistribution,
)
from surjectors.conditioners import MADE, mlp_conditioner
from surjectors.util import as_batch_iterator, unstack


def _decoder_fn(n_dim):
decoder_net = mlp_conditioner([4, 4, n_dim * 2])

def _fn(z):
params = decoder_net(z)
mu, log_scale = jnp.split(params, 2, -1)
return distrax.Independent(distrax.Normal(mu, jnp.exp(log_scale)))

return _fn


def _made_bijector_fn(params):
means, log_scales = unstack(params, -1)
return distrax.Inverse(distrax.ScalarAffine(means, jnp.exp(log_scales)))


def make_model(n_dimensions):
def _flow(**kwargs):
n_dim = n_dimensions
layers = []
for i in range(3):
if i != 1:
layer = AffineMaskedAutoregressiveInferenceFunnel(
n_keep=int(n_dim / 2),
decoder=_decoder_fn(int(n_dim / 2)),
conditioner=MADE(int(n_dim / 2), [8, 8], 2),
)
n_dim = int(n_dim / 2)
else:
layer = MaskedAutoregressive(
conditioner=MADE(n_dim, [8, 8], 2),
bijector_fn=_made_bijector_fn,
)
layers.append(layer)
chain = Chain(layers)

base_distribution = distrax.Independent(
distrax.Normal(jnp.zeros(n_dim), jnp.ones(n_dim)),
reinterpreted_batch_ndims=1,
)
td = TransformedDistribution(base_distribution, chain)
return td.log_prob(**kwargs)

td = hk.transform(_flow)
td = hk.without_apply_rng(td)
return td


def train(rng_seq, data, model, max_n_iter=1000):
train_iter = as_batch_iterator(next(rng_seq), data, 100, True)
params = model.init(next(rng_seq), **train_iter(0))

optimizer = optax.adam(1e-4)
state = optimizer.init(params)

@jax.jit
def step(params, state, **batch):
def loss_fn(params):
lp = model.apply(params, **batch)
return -jnp.sum(lp)

loss, grads = jax.value_and_grad(loss_fn)(params)
updates, new_state = optimizer.update(grads, state, params)
new_params = optax.apply_updates(params, updates)
return loss, new_params, new_state

losses = np.zeros(max_n_iter)
for i in range(max_n_iter):
train_loss = 0.0
for j in range(train_iter.num_batches):
batch = train_iter(j)
batch_loss, params, state = step(params, state, **batch)
train_loss += batch_loss
losses[i] = train_loss

return params, losses


def run():
n, p = 1000, 20
rng_seq = hk.PRNGSequence(2)
y = jr.normal(next(rng_seq), shape=(n, p))
data = namedtuple("named_dataset", "y")(y)

model = make_model(p)
params, losses = train(rng_seq, data, model)
plt.plot(losses)
plt.show()

y = jr.normal(next(rng_seq), shape=(10, p))
print(model.apply(params, **{"y": y}))


if __name__ == "__main__":
run()
Original file line number Diff line number Diff line change
@@ -1,46 +1,62 @@
from collections import namedtuple

import distrax
import haiku as hk
import jax
import numpy as np
import optax
from jax import numpy as jnp
from jax import random
from jax import random as jr
from matplotlib import pyplot as plt

from surjectors import Chain, MaskedCoupling, TransformedDistribution
from surjectors.conditioners import mlp_conditioner
from surjectors.util import (
as_batch_iterator,
make_alternating_binary_mask,
named_dataset,
from surjectors import (
AffineMaskedCouplingInferenceFunnel,
Chain,
MaskedAutoregressive,
TransformedDistribution,
)
from surjectors.conditioners import MADE, mlp_conditioner
from surjectors.util import as_batch_iterator, unstack


def _decoder_fn(n_dim):
decoder_net = mlp_conditioner([4, 4, n_dim * 2])

def _fn(z):
params = decoder_net(z)
mu, log_scale = jnp.split(params, 2, -1)
return distrax.Independent(distrax.Normal(mu, jnp.exp(log_scale)))

def simulator_fn(seed, theta):
p_noise = distrax.Normal(jnp.zeros_like(theta), 1.0)
noise = p_noise.sample(seed=seed)
return theta + 0.1 * noise
return _fn


def make_model(dim):
def _bijector_fn(params):
means, log_scales = jnp.split(params, 2, -1)
return distrax.ScalarAffine(means, jnp.exp(log_scales))
def _made_bijector_fn(params):
means, log_scales = unstack(params, -1)
return distrax.Inverse(distrax.ScalarAffine(means, jnp.exp(log_scales)))


def make_model(n_dimensions):
def _flow(**kwargs):
n_dim = n_dimensions
layers = []
for i in range(2):
mask = make_alternating_binary_mask(dim, i % 2 == 0)
layer = MaskedCoupling(
mask=mask,
bijector=_bijector_fn,
conditioner=mlp_conditioner([8, 8, dim * 2]),
)
for i in range(3):
if i != 1:
layer = AffineMaskedCouplingInferenceFunnel(
n_keep=int(n_dim / 2),
decoder=_decoder_fn(int(n_dim / 2)),
conditioner=mlp_conditioner([8, 8, n_dim * 2]),
)
n_dim = int(n_dim / 2)
else:
layer = MaskedAutoregressive(
conditioner=MADE(n_dim, [8, 8], 2),
bijector_fn=_made_bijector_fn,
)
layers.append(layer)
chain = Chain(layers)

base_distribution = distrax.Independent(
distrax.Normal(jnp.zeros(dim), jnp.ones(dim)),
distrax.Normal(jnp.zeros(n_dim), jnp.ones(n_dim)),
reinterpreted_batch_ndims=1,
)
td = TransformedDistribution(base_distribution, chain)
Expand Down Expand Up @@ -82,21 +98,18 @@ def loss_fn(params):


def run():
n = 1000
prior = distrax.Uniform(jnp.full(2, -2), jnp.full(2, 2))
theta = prior.sample(seed=random.PRNGKey(0), sample_shape=(n,))
likelihood = distrax.MultivariateNormalDiag(theta, jnp.ones_like(theta))
y = likelihood.sample(seed=random.PRNGKey(1))
data = named_dataset(y, theta)

model = make_model(2)
params, losses = train(hk.PRNGSequence(2), data, model)
n, p = 1000, 20
rng_seq = hk.PRNGSequence(2)
y = jr.normal(next(rng_seq), shape=(n, p))
data = namedtuple("named_dataset", "y")(y)

model = make_model(p)
params, losses = train(rng_seq, data, model)
plt.plot(losses)
plt.show()

theta = jnp.ones((5, 2))
data = jnp.repeat(jnp.arange(5), 2).reshape(-1, 2)
print(model.apply(params, **{"y": data, "x": theta}))
y = jr.normal(next(rng_seq), shape=(10, p))
print(model.apply(params, **{"y": y}))


if __name__ == "__main__":
Expand Down
5 changes: 4 additions & 1 deletion surjectors/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
surjectors: Surjection layers for density estimation with normalizing flows
"""

__version__ = "0.2.2"
__version__ = "0.2.3"

from surjectors.bijectors.lu_linear import LULinear
from surjectors.bijectors.masked_autoregressive import MaskedAutoregressive
Expand All @@ -11,6 +11,9 @@
from surjectors.distributions.transformed_distribution import (
TransformedDistribution,
)
from surjectors.surjectors.affine_masked_autoregressive_inference_funnel import ( # noqa: E501
AffineMaskedAutoregressiveInferenceFunnel,
)
from surjectors.surjectors.affine_masked_coupling_generative_funnel import (
AffineMaskedCouplingGenerativeFunnel,
)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@
)


def _conditional_fn(n_dim):
def _decoder_fn(n_dim):
decoder_net = mlp_conditioner([4, 4, n_dim * 2])

def _fn(z):
Expand All @@ -35,7 +35,7 @@ def _base_distribution_fn(n_latent):
def _get_funnel_surjector(n_latent, n_dimension):
return AffineMaskedAutoregressiveInferenceFunnel(
n_latent,
_conditional_fn(n_dimension - n_latent),
_decoder_fn(n_dimension - n_latent),
MADE(n_latent, [4, 4], 2),
)

Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,73 @@
# pylint: skip-file

import distrax
import haiku as hk
from jax import numpy as jnp
from jax import random

from surjectors import TransformedDistribution
from surjectors.conditioners.mlp import mlp_conditioner
from surjectors.surjectors.affine_masked_coupling_inference_funnel import ( # noqa: E501
AffineMaskedCouplingInferenceFunnel,
)


def _decoder_fn(n_dim):
decoder_net = mlp_conditioner([4, 4, n_dim * 2])

def _fn(z):
params = decoder_net(z)
mu, log_scale = jnp.split(params, 2, -1)
return distrax.Independent(distrax.Normal(mu, jnp.exp(log_scale)))

return _fn


def _base_distribution_fn(n_latent):
base_distribution = distrax.Independent(
distrax.Normal(jnp.zeros(n_latent), jnp.ones(n_latent)),
reinterpreted_batch_ndims=1,
)
return base_distribution


def _get_funnel_surjector(n_latent, n_dimension):
return AffineMaskedCouplingInferenceFunnel(
n_latent,
_decoder_fn(n_dimension - n_latent),
mlp_conditioner([4, 4, n_latent * 2]),
)


def make_surjector(n_dimension, n_latent):
def _transformation_fn(n_dimension):
funnel = _get_funnel_surjector(n_latent, n_dimension)
return funnel

def _flow(method, **kwargs):
td = TransformedDistribution(
_base_distribution_fn(n_latent), _transformation_fn(n_dimension)
)
return td(method, **kwargs)

td = hk.transform(_flow)
return td


def test_affine_masked_coupling_inference_funnel():
n_dimension, n_latent = 4, 2
y = random.normal(random.PRNGKey(1), shape=(10, n_dimension))

flow = make_surjector(n_dimension, n_latent)
params = flow.init(random.PRNGKey(0), method="log_prob", y=y)
_ = flow.apply(params, None, method="log_prob", y=y)


def test_conditional_affine_masked_coupling_inference_funnel():
n_dimension, n_latent = 4, 2
y = random.normal(random.PRNGKey(1), shape=(10, n_dimension))
x = random.normal(random.PRNGKey(1), shape=(10, 2))

flow = make_surjector(n_dimension, n_latent)
params = flow.init(random.PRNGKey(0), method="log_prob", y=y, x=x)
_ = flow.apply(params, None, method="log_prob", y=y, x=x)