From ae0bb62c04dde0fe8ae25e64364162b2e4223e0c Mon Sep 17 00:00:00 2001 From: Simon Dirmeier Date: Fri, 4 Nov 2022 23:21:23 +0100 Subject: [PATCH 01/10] Add more surjectors --- README.md | 20 +++---- {surjectors => examples}/data.py | 0 examples/multivariate_gaussian.py | 32 ++++++++++ examples/solar_dynamo.py | 3 - surjectors/bijectors/__init__.py | 0 surjectors/bijectors/coupling.py | 0 surjectors/conditioners/__init__.py | 0 surjectors/conditioners/transformer.py | 0 .../distributions/conditional_distribution.py | 2 + .../{ => distributions}/distribution.py | 0 .../distributions/transformed_distribution.py | 34 ++++------- .../{transform.py => _transform.py} | 0 .../surjectors/affine_coupling_funnel.py | 24 ++++++++ surjectors/surjectors/chain.py | 2 +- surjectors/surjectors/funnel.py | 22 ------- surjectors/surjectors/linear.py | 60 ------------------- surjectors/surjectors/lu_linear.py | 6 +- surjectors/surjectors/mlp.py | 30 ++++++++++ surjectors/surjectors/rq_coupling_funnel.py | 0 surjectors/surjectors/slice.py | 17 ++++-- surjectors/surjectors/surjector.py | 6 +- 21 files changed, 131 insertions(+), 127 deletions(-) rename {surjectors => examples}/data.py (100%) create mode 100644 examples/multivariate_gaussian.py create mode 100644 surjectors/bijectors/__init__.py create mode 100644 surjectors/bijectors/coupling.py create mode 100644 surjectors/conditioners/__init__.py create mode 100644 surjectors/conditioners/transformer.py create mode 100644 surjectors/distributions/conditional_distribution.py rename surjectors/{ => distributions}/distribution.py (100%) rename surjectors/surjectors/{transform.py => _transform.py} (100%) create mode 100644 surjectors/surjectors/affine_coupling_funnel.py delete mode 100644 surjectors/surjectors/funnel.py delete mode 100644 surjectors/surjectors/linear.py create mode 100644 surjectors/surjectors/mlp.py create mode 100644 surjectors/surjectors/rq_coupling_funnel.py diff --git a/README.md b/README.md index 1c7f069..bca5329 100644 --- a/README.md +++ b/README.md @@ -1,17 +1,18 @@ # surjectors [![status](http://www.repostatus.org/badges/latest/concept.svg)](http://www.repostatus.org/#concept) -[![ci](https://github.com/dirmeier/surjectors/actions/workflows/ci.yaml/badge.svg)](https://github.com/dirmeier/Inference surjection layers/actions/workflows/ci.yaml) +[![ci](https://github.com/dirmeier/surjectors/actions/workflows/ci.yaml/badge.svg)](https://github.com/dirmeier/surjectors/actions/workflows/ci.yaml) [![codecov](https://codecov.io/gh/dirmeier/surjectors/branch/main/graph/badge.svg)](https://codecov.io/gh/dirmeier/surjectors) [![codacy]()]() [![documentation](https://readthedocs.org/projects/surjectors/badge/?version=latest)](https://surjectors.readthedocs.io/en/latest/?badge=latest) [![version](https://img.shields.io/pypi/v/surjectors.svg?colorB=black&style=flat)](https://pypi.org/project/surjectors/) -> Inference surjection layers +> Surjection layers for density estimation with normalizing flows ## About -TODO +Surjectors is a light-weight library of inference and generative surjection layers, i.e., layers that reduce dimensionality, for density estimation using normalizing flows. +Surjectors builds on Distrax and Haiku. ## Example usage @@ -19,21 +20,18 @@ TODO ## Installation - -To install the latest GitHub , just call the following on the -command line: +To install the latest GitHub , just call the following on the command line: ```bash -pip install git+https://github.com/dirmeier/Inference surjection layers@ +pip install git+https://github.com/dirmeier/surjectors@ ``` ## Contributing In order to contribute: -1) Fork and download the repository, -2) create a branch with the name of your new feature (something like `issue/fix-bug-related-to-something` or `feature/implement-new-bound`), -3) install `surjectors` and dev dependencies via `poetry install` (you might need to create a new `conda` or `venv` environment, to not break other dependencies), +1) Fork and download the forked repository, +2) create a branch with the name of your new feature (something like `issue/fix-bug-related-to-something` or `feature/implement-new-surjector`), +3) install `surjectors` and dev dependencies via `poetry install` (you might want to create a new `conda` or `venv` environment, to not break other dependencies), 4) develop code, commit changes and push it to your branch, 5) create a PR - diff --git a/surjectors/data.py b/examples/data.py similarity index 100% rename from surjectors/data.py rename to examples/data.py diff --git a/examples/multivariate_gaussian.py b/examples/multivariate_gaussian.py new file mode 100644 index 0000000..25d9ad2 --- /dev/null +++ b/examples/multivariate_gaussian.py @@ -0,0 +1,32 @@ +prng = hk.PRNGSequence(jax.random.PRNGKey(42)) +matrix = jax.random.uniform(next(prng), (4, 4)) +bias = jax.random.normal(next(prng), (4,)) +bijector = LowerUpperTriangularAffine(matrix, bias) + +# +# def loss(): +# x, lc = bijector.inverse_and_log_det(jnp.zeros(4) * 2.1) +# lp = distrax.Normal(jnp.zeros(4)).log_prob(x) +# return -jnp.sum(lp - lc) +# +# print(bijector.matrix) +# +# adam = optax.adam(0.003) +# g = jax.grad(loss)() +# +# print(g) +# + +matrix = jax.random.uniform(next(prng), (4, 4)) +bias = jax.random.normal(next(prng), (4,)) +bijector = LowerUpperTriangularAffine(matrix, bias) + +n = jnp.ones((4, 4)) * 3.1 +n += jnp.triu(n) * 2 + +bijector = LowerUpperTriangularAffine(n, jnp.zeros(4)) + + +print(bijector.forward(jnp.ones(4))) + +print(n @jnp.ones(4) ) diff --git a/examples/solar_dynamo.py b/examples/solar_dynamo.py index 74ec438..a89a9ab 100644 --- a/examples/solar_dynamo.py +++ b/examples/solar_dynamo.py @@ -15,6 +15,3 @@ jnp.array([549229066, 500358972], dtype=jnp.uint32), 100 ) pns[i] = pn - - -Distribution diff --git a/surjectors/bijectors/__init__.py b/surjectors/bijectors/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/surjectors/bijectors/coupling.py b/surjectors/bijectors/coupling.py new file mode 100644 index 0000000..e69de29 diff --git a/surjectors/conditioners/__init__.py b/surjectors/conditioners/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/surjectors/conditioners/transformer.py b/surjectors/conditioners/transformer.py new file mode 100644 index 0000000..e69de29 diff --git a/surjectors/distributions/conditional_distribution.py b/surjectors/distributions/conditional_distribution.py new file mode 100644 index 0000000..ab030b3 --- /dev/null +++ b/surjectors/distributions/conditional_distribution.py @@ -0,0 +1,2 @@ +class ConditionalDistribution: + pass diff --git a/surjectors/distribution.py b/surjectors/distributions/distribution.py similarity index 100% rename from surjectors/distribution.py rename to surjectors/distributions/distribution.py diff --git a/surjectors/distributions/transformed_distribution.py b/surjectors/distributions/transformed_distribution.py index a9765dc..ee53ecd 100644 --- a/surjectors/distributions/transformed_distribution.py +++ b/surjectors/distributions/transformed_distribution.py @@ -1,30 +1,22 @@ -from typing import Tuple - +import chex import jax import jax.numpy as jnp -from distrax._src.bijectors import bijector as bjct_base -from distrax._src.distributions import distribution as dist_base -from distrax._src.utils import conversion -from tensorflow_probability.substrates import jax as tfp - -tfd = tfp.distributions - -PRNGKey = dist_base.PRNGKey -Array = dist_base.Array -DistributionLike = dist_base.DistributionLike -BijectorLike = bjct_base.BijectorLike +from chex import PRNGKey +from distrax import Distribution +Array = chex.Array +from surjectors.surjectors.surjector import Surjector class TransformedDistribution: - def __init__(self, base_distribution, surjector): + def __init__(self, base_distribution: Distribution, surjector: Surjector): self.base_distribution = base_distribution self.surjector = surjector - def log_prob(self, y: Array) -> Array: - x, ildj_y = self.surjector.inverse_and_log_det(y) + def log_prob(self, y: Array) -> jnp.ndarray: + x, lc = self.surjector.inverse_and_likelihood_contribution(y) lp_x = self.base_distribution.log_prob(x) - lp_y = lp_x + ildj_y - return lp_y + lp = lp_x - lc + return lp def sample(self, key: PRNGKey, sample_shape=(1,)): z = self.base_distribution.sample(seed=key, sample_shape=sample_shape) @@ -35,6 +27,6 @@ def sample_and_log_prob(self, key: PRNGKey, sample_shape=(1,)): z, lp_z = self.base_distribution.sample_and_log_prob( seed=key, sample_shape=sample_shape ) - y, fldj = jax.vmap(self.surjector.forward_and_log_det)(z) - lp_y = jax.vmap(jnp.subtract)(lp_z, fldj) - return y, lp_y + y, fldj = jax.vmap(self.surjector.forward_and_likelihood_contribution)(z) + lp = jax.vmap(jnp.subtract)(lp_z, fldj) + return y, lp diff --git a/surjectors/surjectors/transform.py b/surjectors/surjectors/_transform.py similarity index 100% rename from surjectors/surjectors/transform.py rename to surjectors/surjectors/_transform.py diff --git a/surjectors/surjectors/affine_coupling_funnel.py b/surjectors/surjectors/affine_coupling_funnel.py new file mode 100644 index 0000000..cbb7aa4 --- /dev/null +++ b/surjectors/surjectors/affine_coupling_funnel.py @@ -0,0 +1,24 @@ +import chex +from jax import numpy as jnp + +from surjectors.surjectors.surjector import Surjector + + +class AffineCouplingFunnel(Surjector): + def __init__(self, n_keep, decoder, transform, encoder, kind="inference_surjection"): + super().__init__(n_keep, decoder, encoder, kind) + self._transform = transform + + def split_input(self, input): + split_proportions = (self.n_keep, input.shape[-1] - self.n_keep) + return jnp.split(input, split_proportions, axis=-1) + + def inverse_and_likelihood_contribution(self, y): + y_plus, y_minus = self.split_input(y) + chex.assert_equal_shape([y_plus, y_minus]) + z, jac_det = self._transform(y_plus, context=y_minus) + lp = self.decoder.log_prob(y_minus, context=z) + return z, lp + jac_det + + def forward_and_likelihood_contribution(self, z): + raise NotImplementedError() diff --git a/surjectors/surjectors/chain.py b/surjectors/surjectors/chain.py index 24dd6f7..c1f8a3a 100644 --- a/surjectors/surjectors/chain.py +++ b/surjectors/surjectors/chain.py @@ -3,6 +3,7 @@ class Chain(Surjector): def __init__(self, surjectors): + super().__init__(None, None, None, "surjector") self._surjectors = surjectors def inverse_and_likelihood_contribution(self, y): @@ -18,4 +19,3 @@ def forward_and_likelihood_contribution(self, z): x, lc = _surjectors.forward_and_log_det(x) log_det += lc return y, log_det - diff --git a/surjectors/surjectors/funnel.py b/surjectors/surjectors/funnel.py deleted file mode 100644 index c53df17..0000000 --- a/surjectors/surjectors/funnel.py +++ /dev/null @@ -1,22 +0,0 @@ -from jax import numpy as jnp - -from surjectors.surjectors.surjector import Surjector - - -class Funnel(Surjector): - def __init__(self, n_keep, decoder, encoder=None, kind="inference_surjection"): - super().__init__(n_keep, decoder, encoder, kind) - - def split_input(self, input): - split_proportions = (self.n_keep, input.shape[-1] - self.n_keep) - return jnp.split(input, split_proportions, axis=-1) - - def inverse_and_likelihood_contribution(self, y): - z, y_minus = self.split_input(y) - lc = self.decoder.log_prob(y_minus, context=z) - return z, lc - - def forward_and_likelihood_contribution(self, z): - y_minus = self.decoder.sample(context=z) - y = jnp.concatenate([z, y_minus], axis=-1) - return y diff --git a/surjectors/surjectors/linear.py b/surjectors/surjectors/linear.py deleted file mode 100644 index a8fbcfb..0000000 --- a/surjectors/surjectors/linear.py +++ /dev/null @@ -1,60 +0,0 @@ -import optax -from distrax import LowerUpperTriangularAffine -from jax import numpy as jnp -import jax -import haiku as hk -import distrax - -from surjectors.surjectors.funnel import Funnel -from surjectors.surjectors.lu_linear import LULinear - - -class MLP(Funnel, hk.Module): - def __init__(self, n_keep, decoder, dtype=jnp.float32): - self._r = LULinear(n_keep, dtype) - self._w_prime = hk.Linear(n_keep, with_bias=True) - - self._decoder = decoder # TODO: should be a conditional gaussian - super().__init__(n_keep, decoder) - - def inverse_and_likelihood_contribution(self, y): - x_plus, x_minus = y[:, :self.n_keep], y[:, self.n_keep:] - z, lc = self._r.inverse_and_likelihood_contribution(x_plus) + self._w_prime(x_minus) - lp = self._decoder.log_prob(x_minus, context=z) - return z, lp + lc - - def forward_and_likelihood_contribution(self, z): - pass - -prng = hk.PRNGSequence(jax.random.PRNGKey(42)) -matrix = jax.random.uniform(next(prng), (4, 4)) -bias = jax.random.normal(next(prng), (4,)) -bijector = LowerUpperTriangularAffine(matrix, bias) - -# -# def loss(): -# x, lc = bijector.inverse_and_log_det(jnp.zeros(4) * 2.1) -# lp = distrax.Normal(jnp.zeros(4)).log_prob(x) -# return -jnp.sum(lp - lc) -# -# print(bijector.matrix) -# -# adam = optax.adam(0.003) -# g = jax.grad(loss)() -# -# print(g) -# - -matrix = jax.random.uniform(next(prng), (4, 4)) -bias = jax.random.normal(next(prng), (4,)) -bijector = LowerUpperTriangularAffine(matrix, bias) - -n = jnp.ones((4, 4)) * 3.1 -n += jnp.triu(n) * 2 - -bijector = LowerUpperTriangularAffine(n, jnp.zeros(4)) - - -print(bijector.forward(jnp.ones(4))) - -print(n @jnp.ones(4) ) \ No newline at end of file diff --git a/surjectors/surjectors/lu_linear.py b/surjectors/surjectors/lu_linear.py index 816a4a8..e104930 100644 --- a/surjectors/surjectors/lu_linear.py +++ b/surjectors/surjectors/lu_linear.py @@ -7,8 +7,11 @@ class LULinear(Surjector): - def __init__(self, n_keep, dtype=jnp.float32): + def __init__(self, n_keep, with_bias=False, dtype=jnp.float32): super().__init__(n_keep, None, None, "bijection", dtype) + if with_bias: + raise NotImplementedError() + n_triangular_entries = ((n_keep - 1) * n_keep) // 2 self._lower_indices = np.tril_indices(n_keep, k=-1) @@ -51,4 +54,3 @@ def inverse_and_likelihood_contribution(self, y): def forward_and_likelihood_contribution(self, z): pass - diff --git a/surjectors/surjectors/mlp.py b/surjectors/surjectors/mlp.py new file mode 100644 index 0000000..8087b62 --- /dev/null +++ b/surjectors/surjectors/mlp.py @@ -0,0 +1,30 @@ +import distrax +import haiku as hk +from jax import numpy as jnp + +from surjectors.surjectors.affine_coupling_funnel import Funnel +from surjectors.surjectors.lu_linear import LULinear + + +class MLP(Funnel, hk.Module): + def __init__(self, n_keep, decoder, dtype=jnp.float32): + self._r = LULinear(n_keep, dtype, with_bias=False) + self._w_prime = hk.Linear(n_keep, with_bias=True) + + self._decoder = decoder + super().__init__(n_keep, decoder) + + def inverse_and_likelihood_contribution(self, y): + y_plus, y_minus = self.split_input(y) + z, jac_det = self._r.inverse_and_likelihood_contribution(y_plus) + z += self._w_prime(y_minus) + lp = self._decode(z).log_prob(y_minus) + return z, lp + jac_det + + def _decode(self, array): + mu, log_scale = self._decoder(array) + distr = distrax.MultivariateNormalDiag(mu, jnp.exp(log_scale)) + return distr + + def forward_and_likelihood_contribution(self, z): + pass diff --git a/surjectors/surjectors/rq_coupling_funnel.py b/surjectors/surjectors/rq_coupling_funnel.py new file mode 100644 index 0000000..e69de29 diff --git a/surjectors/surjectors/slice.py b/surjectors/surjectors/slice.py index 311d3fe..c814547 100644 --- a/surjectors/surjectors/slice.py +++ b/surjectors/surjectors/slice.py @@ -1,9 +1,18 @@ from jax import numpy as jnp -from surjectors.funnel import Funnel +from surjectors.surjectors.affine_coupling_funnel import Funnel class Slice(Funnel): - def __init__(self, n_keep, kind="inference_surjection"): - # TODO: implement decoder and encoder - super().__init__(kind, decoder, encoder, n_keep) \ No newline at end of file + def __init__(self, n_keep, decoder, encoder=None, kind="inference_surjection"): + super().__init__(n_keep, decoder, encoder, kind) + + def inverse_and_likelihood_contribution(self, y): + z, y_minus = self.split_input(y) + lc = self.decoder.log_prob(y_minus, context=z) + return z, lc + + def forward_and_likelihood_contribution(self, z): + y_minus = self.decoder.sample(context=z) + y = jnp.concatenate([z, y_minus], axis=-1) + return y diff --git a/surjectors/surjectors/surjector.py b/surjectors/surjectors/surjector.py index a4e562f..c424b29 100644 --- a/surjectors/surjectors/surjector.py +++ b/surjectors/surjectors/surjector.py @@ -1,10 +1,10 @@ from abc import abstractmethod from jax import numpy as jnp -from surjectors.surjectors.transform import Transform +from surjectors.surjectors._transform import Transform -_valid_kinds = ["inference_surjector", "generative_surjector", "bijector"] +_valid_kinds = ["inference_surjector", "generative_surjector", "bijector", "surjector"] class Surjector(Transform): @@ -49,4 +49,4 @@ def encoder(self): @property def dtype(self): - return self._dtype \ No newline at end of file + return self._dtype From 0ae779bbd187aecc1c250e78380f20a197ca561b Mon Sep 17 00:00:00 2001 From: Simon Dirmeier Date: Sat, 5 Nov 2022 23:09:59 +0100 Subject: [PATCH 02/10] Implement affine coupling funnel --- .../surjectors/affine_coupling_funnel.py | 35 +++++++++++++------ surjectors/surjectors/slice.py | 4 +++ 2 files changed, 29 insertions(+), 10 deletions(-) diff --git a/surjectors/surjectors/affine_coupling_funnel.py b/surjectors/surjectors/affine_coupling_funnel.py index cbb7aa4..0416b5d 100644 --- a/surjectors/surjectors/affine_coupling_funnel.py +++ b/surjectors/surjectors/affine_coupling_funnel.py @@ -1,24 +1,39 @@ import chex +import distrax +import numpy as np +from chex import Array +from distrax import ScalarAffine, MaskedCoupling from jax import numpy as jnp from surjectors.surjectors.surjector import Surjector class AffineCouplingFunnel(Surjector): - def __init__(self, n_keep, decoder, transform, encoder, kind="inference_surjection"): - super().__init__(n_keep, decoder, encoder, kind) - self._transform = transform + def __init__(self, n_keep, decoder, conditioner, kind="inference_surjection"): + super().__init__(n_keep, decoder, None, kind) + self._conditioner = conditioner - def split_input(self, input): - split_proportions = (self.n_keep, input.shape[-1] - self.n_keep) - return jnp.split(input, split_proportions, axis=-1) + def _mask(self, array): + mask = jnp.arange(array.shape[-1]) >= self.n_keep + mask = mask.astype(jnp.bool_) + return mask + + def _inner_bijector(self, mask): + def _bijector_fn(params: Array): + shift, log_scale = jnp.split(params, 2, axis=-1) + return distrax.ScalarAffine(shift, jnp.exp(log_scale)) + + return MaskedCoupling( + mask, self._conditioner, _bijector_fn + ) def inverse_and_likelihood_contribution(self, y): - y_plus, y_minus = self.split_input(y) - chex.assert_equal_shape([y_plus, y_minus]) - z, jac_det = self._transform(y_plus, context=y_minus) - lp = self.decoder.log_prob(y_minus, context=z) + mask = self._mask(y) + faux, jac_det = self._inner_bijector(mask).inverse_and_log_det(y) + z = faux[:, :self.n_keep] + lp = self.decoder.log_prob(faux[:, self.n_keep:], context=z) return z, lp + jac_det def forward_and_likelihood_contribution(self, z): raise NotImplementedError() + diff --git a/surjectors/surjectors/slice.py b/surjectors/surjectors/slice.py index c814547..81dec11 100644 --- a/surjectors/surjectors/slice.py +++ b/surjectors/surjectors/slice.py @@ -7,6 +7,10 @@ class Slice(Funnel): def __init__(self, n_keep, decoder, encoder=None, kind="inference_surjection"): super().__init__(n_keep, decoder, encoder, kind) + def split_input(self, input): + split_proportions = (self.n_keep, input.shape[-1] - self.n_keep) + return jnp.split(input, split_proportions, axis=-1) + def inverse_and_likelihood_contribution(self, y): z, y_minus = self.split_input(y) lc = self.decoder.log_prob(y_minus, context=z) From c2b99a32f896bf2619b3628869e1ed8b6e5f96c0 Mon Sep 17 00:00:00 2001 From: Simon Dirmeier Date: Fri, 11 Nov 2022 18:31:05 +0100 Subject: [PATCH 03/10] More cleanups --- examples/multivariate_gaussian.py | 218 +++++++++++++++--- examples/solar_dynamo.py | 10 +- examples/{data.py => solar_dynamo_data.py} | 16 +- .../distributions/transformed_distribution.py | 20 +- .../surjectors/affine_coupling_funnel.py | 21 +- surjectors/surjectors/chain.py | 19 +- surjectors/surjectors/funnel.py | 7 + surjectors/surjectors/rq_coupling_funnel.py | 45 ++++ surjectors/surjectors/slice.py | 14 +- surjectors/surjectors/surjector.py | 11 +- 10 files changed, 301 insertions(+), 80 deletions(-) rename examples/{data.py => solar_dynamo_data.py} (85%) create mode 100644 surjectors/surjectors/funnel.py diff --git a/examples/multivariate_gaussian.py b/examples/multivariate_gaussian.py index 25d9ad2..dc18da1 100644 --- a/examples/multivariate_gaussian.py +++ b/examples/multivariate_gaussian.py @@ -1,32 +1,186 @@ -prng = hk.PRNGSequence(jax.random.PRNGKey(42)) -matrix = jax.random.uniform(next(prng), (4, 4)) -bias = jax.random.normal(next(prng), (4,)) -bijector = LowerUpperTriangularAffine(matrix, bias) - -# -# def loss(): -# x, lc = bijector.inverse_and_log_det(jnp.zeros(4) * 2.1) -# lp = distrax.Normal(jnp.zeros(4)).log_prob(x) -# return -jnp.sum(lp - lc) -# -# print(bijector.matrix) -# -# adam = optax.adam(0.003) -# g = jax.grad(loss)() -# -# print(g) -# - -matrix = jax.random.uniform(next(prng), (4, 4)) -bias = jax.random.normal(next(prng), (4,)) -bijector = LowerUpperTriangularAffine(matrix, bias) - -n = jnp.ones((4, 4)) * 3.1 -n += jnp.triu(n) * 2 - -bijector = LowerUpperTriangularAffine(n, jnp.zeros(4)) - - -print(bijector.forward(jnp.ones(4))) - -print(n @jnp.ones(4) ) +import distrax +import haiku as hk +import jax +import numpy as np +import optax +import matplotlib.pyplot as plt + +from jax import numpy as jnp +from jax import random + +from surjectors.surjectors.chain import Chain +from surjectors.surjectors.slice import Slice +from surjectors.distributions.transformed_distribution import ( + TransformedDistribution, +) + +from jax import config + +config.update("jax_enable_x64", True) + + +def _get_sampler_and_loadings(rng_key, batch_size, n_dimension): + pz_mean = jnp.array([-2.31, 0.421, 0.1, 3.21, -0.41]) + pz = distrax.MultivariateNormalDiag( + loc=pz_mean, scale_diag=jnp.ones_like(pz_mean) + ) + p_loadings = distrax.Normal(0.0, 10.0) + make_noise = distrax.Normal(0.0, 1) + + loadings_sample_key, rng_key = random.split(rng_key, 2) + loadings = p_loadings.sample( + seed=loadings_sample_key, sample_shape=(n_dimension, len(pz_mean)) + ) + + def _fn(rng_key): + z_sample_key, noise_sample_key = random.split(rng_key, 2) + z = pz.sample(seed=z_sample_key, sample_shape=(batch_size,)) + noise = +make_noise.sample( + seed=noise_sample_key, sample_shape=(batch_size, n_dimension) + ) + + # y = (loadings @ z.T).T + noise + y = jnp.concatenate([z, -z], axis=-1) + return y, z + + return _fn, loadings + + +def _get_surjector(n_dimension, n_latent): + def _bijector_conditioner(dim): + return hk.Sequential( + [ + hk.Linear( + 32, + w_init=hk.initializers.TruncatedNormal(stddev=0.01), + b_init=jnp.zeros, + ), + jax.nn.gelu, + hk.Linear( + 32, + w_init=hk.initializers.TruncatedNormal(stddev=0.01), + b_init=jnp.zeros, + ), + jax.nn.gelu, + hk.Linear(dim * 2), + ] + ) + + def _surjector_conditioner(): + return hk.Sequential( + [ + hk.Linear( + 16, + w_init=hk.initializers.TruncatedNormal(stddev=0.01), + b_init=jnp.zeros, + ), + jax.nn.gelu, + hk.Linear( + 16, + w_init=hk.initializers.TruncatedNormal(stddev=0.01), + b_init=jnp.zeros, + ), + jax.nn.gelu, + hk.Linear((n_dimension - n_latent) * 2), + ] + ) + + def _decoder_fn(): + decoder_net = _surjector_conditioner() + + def _fn(z): + params = decoder_net(z) + mu, log_scale = jnp.split(params, 2, -1) + return distrax.Independent(distrax.Normal(mu, jnp.exp(log_scale))) + + return _fn + + def _bijector_fn(params): + means, log_scales = jnp.split(params, 2, -1) + return distrax.ScalarAffine(means, jnp.exp(log_scales)) + + def _transformation_fn(): + layers = [] + mask = jnp.arange(0, np.prod(n_dimension)) % 2 + mask = jnp.reshape(mask, n_dimension) + mask = mask.astype(bool) + + for _ in range(2): + layer = distrax.MaskedCoupling( + mask=mask, + bijector=_bijector_fn, + conditioner=_bijector_conditioner(n_dimension), + ) + layers.append(layer) + + layers.append(Slice(n_latent, _decoder_fn())) + + mask = jnp.arange(0, np.prod(n_latent)) % 2 + mask = jnp.reshape(mask, n_latent) + mask = mask.astype(bool) + + for _ in range(2): + layer = distrax.MaskedCoupling( + mask=mask, + bijector=_bijector_fn, + conditioner=_bijector_conditioner(n_latent), + ) + layers.append(layer) + mask = jnp.logical_not(mask) + # return Chain(layers) + return Slice(n_latent, _decoder_fn()) + + def _base_fn(): + base_distribution = distrax.Independent( + distrax.Normal(jnp.zeros(n_latent), jnp.ones(n_latent)), + reinterpreted_batch_ndims=1, + ) + return base_distribution + + def _flow(method, **kwargs): + td = TransformedDistribution(_base_fn(), _transformation_fn()) + return td(method, **kwargs) + + td = hk.transform(_flow) + return td + + +def run(key=0, n_iter=1000, batch_size=64, n_data=10, n_latent=5): + rng_seq = hk.PRNGSequence(0) + pyz, loadings = _get_sampler_and_loadings(next(rng_seq), batch_size, n_data) + flow = _get_surjector(n_data, n_latent) + + @jax.jit + def step(params, state, y_batch, rng): + def loss_fn(params): + lp = flow.apply(params, rng, method="log_prob", y=y_batch) + return -jnp.sum(lp) + + loss, grads = jax.value_and_grad(loss_fn)(params) + updates, new_state = adam.update(grads, state, params) + new_params = optax.apply_updates(params, updates) + return loss, new_params, new_state + + y_init, _ = pyz(random.fold_in(next(rng_seq), 0)) + params = flow.init(random.PRNGKey(key), method="log_prob", y=y_init) + adam = optax.adamw(0.001) + state = adam.init(params) + + losses = [0] * n_iter + for i in range(n_iter): + y_batch, _ = pyz(next(rng_seq)) + loss, params, state = step(params, state, y_batch, next(rng_seq)) + losses[i] = loss + + losses = jnp.asarray(losses) + plt.plot(losses) + plt.show() + + y_batch, z_batch = pyz(next(rng_seq)) + y_pred = flow.apply(params, next(rng_seq), method="sample", y=y_batch) + print(y_batch[:5, :]) + print(y_pred[:5, :]) + + +if __name__ == "__main__": + run() diff --git a/examples/solar_dynamo.py b/examples/solar_dynamo.py index a89a9ab..2778a2b 100644 --- a/examples/solar_dynamo.py +++ b/examples/solar_dynamo.py @@ -3,14 +3,12 @@ from jax import random, numpy as jnp import matplotlib.pyplot as plt +from examples.solar_dynamo_data import SolarDynamoSimulator -from surjectors.data import Simulator +simulator = SolarDynamoSimulator() -simulator = Simulator() - -n = 1000 -pns = [None] * n -for i in np.arange(n): +n_iter = 1000 +for i in np.arange(n_iter): p0, alpha1, alpha2, epsilon_max, f, pn = simulator.sample( jnp.array([549229066, 500358972], dtype=jnp.uint32), 100 ) diff --git a/examples/data.py b/examples/solar_dynamo_data.py similarity index 85% rename from examples/data.py rename to examples/solar_dynamo_data.py index a14aa38..59f333b 100644 --- a/examples/data.py +++ b/examples/solar_dynamo_data.py @@ -2,17 +2,6 @@ from jax.scipy.special import erf -class Simulator: - def __new__(cls, simulator="solar_dynamo", **kwargs): - if simulator == "solar_dynamo": - return SolarDynamoSimulator(**kwargs) - return StandardSimulator() - - -class StandardSimulator: - pass - - class SolarDynamoSimulator: def __init__(self, **kwargs): self.p0_mean = kwargs.get("p0_mean", 1.0) @@ -24,7 +13,7 @@ def __init__(self, **kwargs): self.alpha1 = kwargs.get("alpha1", None) self.alpha2 = kwargs.get("alpha2", None) - def sample(self, key, len_timeseries=1000): + def sample(self, key, batclen_timeseries=1000): p_key, alpha1_key, alpha2_key, epsilon_key, key = random.split(key, 5) p0 = random.normal(p_key) * self.p0_std + self.p0_mean alpha1 = random.uniform( @@ -42,7 +31,8 @@ def sample(self, key, len_timeseries=1000): return p0, alpha1, alpha2, epsilon_max, batch[0], batch[1] - def babcock_leighton_fn(self, p, b_1=0.6, w_1=0.2, b_2=1.0, w_2=0.8): + @staticmethod + def babcock_leighton_fn(p, b_1=0.6, w_1=0.2, b_2=1.0, w_2=0.8): f = 0.5 * (1.0 + erf((p - b_1) / w_1)) * (1.0 - erf((p - b_2) / w_2)) return f diff --git a/surjectors/distributions/transformed_distribution.py b/surjectors/distributions/transformed_distribution.py index ee53ecd..ccaa750 100644 --- a/surjectors/distributions/transformed_distribution.py +++ b/surjectors/distributions/transformed_distribution.py @@ -1,8 +1,11 @@ +from typing import Tuple + import chex import jax import jax.numpy as jnp from chex import PRNGKey from distrax import Distribution + Array = chex.Array from surjectors.surjectors.surjector import Surjector @@ -12,11 +15,18 @@ def __init__(self, base_distribution: Distribution, surjector: Surjector): self.base_distribution = base_distribution self.surjector = surjector - def log_prob(self, y: Array) -> jnp.ndarray: + def __call__(self, method, **kwargs): + return getattr(self, method)(**kwargs) + + def log_prob(self, y: Array) -> Array: + _, lp = self.inverse_and_log_prob(y) + return lp + + def inverse_and_log_prob(self, y: Array) -> Tuple[Array, Array]: x, lc = self.surjector.inverse_and_likelihood_contribution(y) lp_x = self.base_distribution.log_prob(x) - lp = lp_x - lc - return lp + lp = lp_x + lc + return x, lp def sample(self, key: PRNGKey, sample_shape=(1,)): z = self.base_distribution.sample(seed=key, sample_shape=sample_shape) @@ -27,6 +37,8 @@ def sample_and_log_prob(self, key: PRNGKey, sample_shape=(1,)): z, lp_z = self.base_distribution.sample_and_log_prob( seed=key, sample_shape=sample_shape ) - y, fldj = jax.vmap(self.surjector.forward_and_likelihood_contribution)(z) + y, fldj = jax.vmap(self.surjector.forward_and_likelihood_contribution)( + z + ) lp = jax.vmap(jnp.subtract)(lp_z, fldj) return y, lp diff --git a/surjectors/surjectors/affine_coupling_funnel.py b/surjectors/surjectors/affine_coupling_funnel.py index 0416b5d..dad1a4b 100644 --- a/surjectors/surjectors/affine_coupling_funnel.py +++ b/surjectors/surjectors/affine_coupling_funnel.py @@ -1,16 +1,14 @@ -import chex import distrax -import numpy as np from chex import Array -from distrax import ScalarAffine, MaskedCoupling +from distrax import MaskedCoupling from jax import numpy as jnp -from surjectors.surjectors.surjector import Surjector +from surjectors.surjectors.funnel import Funnel -class AffineCouplingFunnel(Surjector): - def __init__(self, n_keep, decoder, conditioner, kind="inference_surjection"): - super().__init__(n_keep, decoder, None, kind) +class AffineCouplingFunnel(Funnel): + def __init__(self, n_keep, decoder, conditioner): + super().__init__(n_keep, decoder, None, "inference_surjection") self._conditioner = conditioner def _mask(self, array): @@ -23,17 +21,14 @@ def _bijector_fn(params: Array): shift, log_scale = jnp.split(params, 2, axis=-1) return distrax.ScalarAffine(shift, jnp.exp(log_scale)) - return MaskedCoupling( - mask, self._conditioner, _bijector_fn - ) + return MaskedCoupling(mask, self._conditioner, _bijector_fn) def inverse_and_likelihood_contribution(self, y): mask = self._mask(y) faux, jac_det = self._inner_bijector(mask).inverse_and_log_det(y) - z = faux[:, :self.n_keep] - lp = self.decoder.log_prob(faux[:, self.n_keep:], context=z) + z = faux[:, : self.n_keep] + lp = self.decoder.log_prob(faux[:, self.n_keep :], context=z) return z, lp + jac_det def forward_and_likelihood_contribution(self, z): raise NotImplementedError() - diff --git a/surjectors/surjectors/chain.py b/surjectors/surjectors/chain.py index c1f8a3a..172fde5 100644 --- a/surjectors/surjectors/chain.py +++ b/surjectors/surjectors/chain.py @@ -7,11 +7,22 @@ def __init__(self, surjectors): self._surjectors = surjectors def inverse_and_likelihood_contribution(self, y): - z, log_det = self._surjectors[0].forward_and_log_det(y) + z, lcs = self._inverse_and_log_contribution_dispatch( + self._surjectors[0], y + ) for surjector in self._surjectors[1:]: - x, lc = surjector.inverse_and_likelihood_contribution(z) - log_det += lc - return z, log_det + z, lc = self._inverse_and_log_contribution_dispatch(surjector, z) + lcs += lc + return z, lcs + + @staticmethod + def _inverse_and_log_contribution_dispatch(surjector, y): + if isinstance(surjector, Surjector): + fn = getattr(surjector, "inverse_and_likelihood_contribution") + else: + fn = getattr(surjector, "inverse_and_log_det") + z, lc = fn(y) + return z, lc def forward_and_likelihood_contribution(self, z): y, log_det = self._surjectors[-1].forward_and_log_det(z) diff --git a/surjectors/surjectors/funnel.py b/surjectors/surjectors/funnel.py new file mode 100644 index 0000000..ce5437c --- /dev/null +++ b/surjectors/surjectors/funnel.py @@ -0,0 +1,7 @@ +from surjectors.surjectors.surjector import Surjector + + +class Funnel(Surjector): + def __init__(self, n_keep, decoder, conditioner, encoder, kind): + super().__init__(n_keep, decoder, encoder, kind) + self._conditioner = conditioner diff --git a/surjectors/surjectors/rq_coupling_funnel.py b/surjectors/surjectors/rq_coupling_funnel.py index e69de29..22b31fe 100644 --- a/surjectors/surjectors/rq_coupling_funnel.py +++ b/surjectors/surjectors/rq_coupling_funnel.py @@ -0,0 +1,45 @@ +import distrax +from chex import Array +from distrax import MaskedCoupling +from jax import numpy as jnp + +from surjectors.surjectors.surjector import Surjector + + +class NSFCouplingFunnel(Surjector): + def __init__( + self, + n_keep, + decoder, + conditioner, + range_min, + range_max, + kind="inference_surjection", + ): + super().__init__(n_keep, decoder, None, kind) + self._conditioner = conditioner + self._range_min = range_min + self._range_max = range_max + + def _mask(self, array): + mask = jnp.arange(array.shape[-1]) >= self.n_keep + mask = mask.astype(jnp.bool_) + return mask + + def _inner_bijector(self, mask): + def _bijector_fn(params: Array): + return distrax.RationalQuadraticSpline( + params, range_min=self._range_min, range_max=self._range_max + ) + + return MaskedCoupling(mask, self._conditioner, _bijector_fn) + + def inverse_and_likelihood_contribution(self, y): + mask = self._mask(y) + faux, jac_det = self._inner_bijector(mask).inverse_and_log_det(y) + z = faux[:, : self.n_keep] + lp = self.decoder.log_prob(faux[:, self.n_keep :], context=z) + return z, lp + jac_det + + def forward_and_likelihood_contribution(self, z): + raise NotImplementedError() diff --git a/surjectors/surjectors/slice.py b/surjectors/surjectors/slice.py index 81dec11..d308dbd 100644 --- a/surjectors/surjectors/slice.py +++ b/surjectors/surjectors/slice.py @@ -1,19 +1,21 @@ from jax import numpy as jnp -from surjectors.surjectors.affine_coupling_funnel import Funnel +from surjectors.surjectors.funnel import Funnel class Slice(Funnel): - def __init__(self, n_keep, decoder, encoder=None, kind="inference_surjection"): - super().__init__(n_keep, decoder, encoder, kind) + def __init__( + self, n_keep, decoder, encoder=None, kind="inference_surjector" + ): + super().__init__(n_keep, decoder, encoder, None, kind) def split_input(self, input): - split_proportions = (self.n_keep, input.shape[-1] - self.n_keep) - return jnp.split(input, split_proportions, axis=-1) + spl = jnp.split(input, [self.n_keep], axis=-1) + return spl def inverse_and_likelihood_contribution(self, y): z, y_minus = self.split_input(y) - lc = self.decoder.log_prob(y_minus, context=z) + lc = self.decoder(z).log_prob(y_minus) return z, lc def forward_and_likelihood_contribution(self, z): diff --git a/surjectors/surjectors/surjector.py b/surjectors/surjectors/surjector.py index c424b29..00ccf19 100644 --- a/surjectors/surjectors/surjector.py +++ b/surjectors/surjectors/surjector.py @@ -4,17 +4,24 @@ from surjectors.surjectors._transform import Transform -_valid_kinds = ["inference_surjector", "generative_surjector", "bijector", "surjector"] +_valid_kinds = [ + "inference_surjector", + "generative_surjector", + "bijector", + "surjector", +] class Surjector(Transform): """ Surjector base class """ + def __init__(self, n_keep, decoder, encoder, kind, dtype=jnp.float32): if kind not in _valid_kinds: raise ValueError( - "'kind' argument needs to be either of: " "/".join(_valid_kinds) + "'kind' argument needs to be either of: " + + "/".join(_valid_kinds) ) if kind == _valid_kinds[1] and encoder is None: raise ValueError( From ff415921d2eb9279d70d94ff8a7d3c215475ea04 Mon Sep 17 00:00:00 2001 From: Simon Dirmeier Date: Mon, 14 Nov 2022 17:02:13 +0100 Subject: [PATCH 04/10] More additions --- examples/multivariate_gaussian.py | 54 ++++++++++++------- surjectors/bijectors/coupling.py | 0 surjectors/bijectors/masked_coupling.py | 46 ++++++++++++++++ .../distributions/transformed_distribution.py | 25 +++++---- surjectors/surjectors/chain.py | 31 +++++++---- surjectors/surjectors/slice.py | 22 ++++++-- 6 files changed, 135 insertions(+), 43 deletions(-) delete mode 100644 surjectors/bijectors/coupling.py create mode 100644 surjectors/bijectors/masked_coupling.py diff --git a/examples/multivariate_gaussian.py b/examples/multivariate_gaussian.py index dc18da1..75bb918 100644 --- a/examples/multivariate_gaussian.py +++ b/examples/multivariate_gaussian.py @@ -8,6 +8,7 @@ from jax import numpy as jnp from jax import random +from surjectors.bijectors.masked_coupling import MaskedCoupling from surjectors.surjectors.chain import Chain from surjectors.surjectors.slice import Slice from surjectors.distributions.transformed_distribution import ( @@ -35,18 +36,18 @@ def _get_sampler_and_loadings(rng_key, batch_size, n_dimension): def _fn(rng_key): z_sample_key, noise_sample_key = random.split(rng_key, 2) z = pz.sample(seed=z_sample_key, sample_shape=(batch_size,)) - noise = +make_noise.sample( + noise = make_noise.sample( seed=noise_sample_key, sample_shape=(batch_size, n_dimension) ) - # y = (loadings @ z.T).T + noise - y = jnp.concatenate([z, -z], axis=-1) - return y, z + y = (loadings @ z.T).T + noise + # y = jnp.concatenate([z, z] ,axis=-1) + return y, z, noise return _fn, loadings -def _get_surjector(n_dimension, n_latent): +def _get_slice_surjector(n_dimension, n_latent): def _bijector_conditioner(dim): return hk.Sequential( [ @@ -106,7 +107,7 @@ def _transformation_fn(): mask = mask.astype(bool) for _ in range(2): - layer = distrax.MaskedCoupling( + layer = MaskedCoupling( mask=mask, bijector=_bijector_fn, conditioner=_bijector_conditioner(n_dimension), @@ -120,15 +121,15 @@ def _transformation_fn(): mask = mask.astype(bool) for _ in range(2): - layer = distrax.MaskedCoupling( + layer = MaskedCoupling( mask=mask, bijector=_bijector_fn, conditioner=_bijector_conditioner(n_latent), ) layers.append(layer) mask = jnp.logical_not(mask) - # return Chain(layers) - return Slice(n_latent, _decoder_fn()) + #return Slice(n_latent, _decoder_fn()) + return Chain(layers) def _base_fn(): base_distribution = distrax.Independent( @@ -145,15 +146,17 @@ def _flow(method, **kwargs): return td -def run(key=0, n_iter=1000, batch_size=64, n_data=10, n_latent=5): +def train(key, surjector_fn, n_data, n_latent, batch_size, n_iter): rng_seq = hk.PRNGSequence(0) pyz, loadings = _get_sampler_and_loadings(next(rng_seq), batch_size, n_data) - flow = _get_surjector(n_data, n_latent) + flow = surjector_fn(n_data, n_latent) @jax.jit - def step(params, state, y_batch, rng): + def step(params, state, y_batch, noise_batch, rng): def loss_fn(params): - lp = flow.apply(params, rng, method="log_prob", y=y_batch) + lp = flow.apply( + params, rng, method="log_prob", y=y_batch, x=noise_batch + ) return -jnp.sum(lp) loss, grads = jax.value_and_grad(loss_fn)(params) @@ -161,26 +164,39 @@ def loss_fn(params): new_params = optax.apply_updates(params, updates) return loss, new_params, new_state - y_init, _ = pyz(random.fold_in(next(rng_seq), 0)) - params = flow.init(random.PRNGKey(key), method="log_prob", y=y_init) + y_init, _, noise_init = pyz(random.fold_in(next(rng_seq), 0)) + params = flow.init( + random.PRNGKey(key), + method="log_prob", + y=y_init, + x=noise_init + ) adam = optax.adamw(0.001) state = adam.init(params) losses = [0] * n_iter for i in range(n_iter): - y_batch, _ = pyz(next(rng_seq)) - loss, params, state = step(params, state, y_batch, next(rng_seq)) + y_batch, _, noise_batch = pyz(next(rng_seq)) + loss, params, state = step(params, state, y_batch, noise_batch, + next(rng_seq)) losses[i] = loss losses = jnp.asarray(losses) plt.plot(losses) plt.show() - y_batch, z_batch = pyz(next(rng_seq)) - y_pred = flow.apply(params, next(rng_seq), method="sample", y=y_batch) + y_batch, z_batch, noise_batch = pyz(next(rng_seq)) + y_pred = flow.apply( + params, next(rng_seq), method="sample", x=noise_batch + ) print(y_batch[:5, :]) print(y_pred[:5, :]) +def run(): + key = random.PRNGKey(0) + train(key, _get_slice_surjector, n_data, n_latent, batch_size, n_iter) + + if __name__ == "__main__": run() diff --git a/surjectors/bijectors/coupling.py b/surjectors/bijectors/coupling.py deleted file mode 100644 index e69de29..0000000 diff --git a/surjectors/bijectors/masked_coupling.py b/surjectors/bijectors/masked_coupling.py new file mode 100644 index 0000000..50d2705 --- /dev/null +++ b/surjectors/bijectors/masked_coupling.py @@ -0,0 +1,46 @@ +from typing import Optional, Tuple + +import distrax +from distrax._src.utils import math +from jax import numpy as jnp + +from surjectors.distributions.transformed_distribution import Array + + +class MaskedCoupling(distrax.MaskedCoupling): + def __init__(self, mask: Array, conditioner, bijector, + event_ndims: Optional[int] = None, inner_event_ndims: int = 0): + super().__init__(mask, conditioner, bijector, event_ndims, + inner_event_ndims) + + def forward_and_log_det(self, z: Array, x: Array = None) -> Tuple[Array, Array]: + self._check_forward_input_shape(z) + masked_z = jnp.where(self._event_mask, z, 0.) + if x is not None: + masked_z = jnp.concatenate([masked_z, x], axis=-1) + params = self._conditioner(masked_z) + y0, log_d = self._inner_bijector(params).forward_and_log_det(z) + y = jnp.where(self._event_mask, z, y0) + logdet = math.sum_last( + jnp.where(self._mask, 0., log_d), + self._event_ndims - self._inner_event_ndims + ) + return y, logdet + + def forward(self, z: Array, x: Array = None) ->Array: + y, log_det = self.forward_and_log_det(z, x) + return y + + def inverse_and_log_det(self, y: Array, x: Array = None) -> Tuple[Array, Array]: + self._check_inverse_input_shape(y) + masked_y = jnp.where(self._event_mask, y, 0.) + if x is not None: + masked_y = jnp.concatenate([masked_y, x], axis=-1) + params = self._conditioner(masked_y) + z0, log_d = self._inner_bijector(params).inverse_and_log_det(y) + z = jnp.where(self._event_mask, y, z0) + logdet = math.sum_last( + jnp.where(self._mask, 0., log_d), + self._event_ndims - self._inner_event_ndims + ) + return z, logdet diff --git a/surjectors/distributions/transformed_distribution.py b/surjectors/distributions/transformed_distribution.py index ccaa750..07c190c 100644 --- a/surjectors/distributions/transformed_distribution.py +++ b/surjectors/distributions/transformed_distribution.py @@ -5,6 +5,7 @@ import jax.numpy as jnp from chex import PRNGKey from distrax import Distribution +import haiku as hk Array = chex.Array from surjectors.surjectors.surjector import Surjector @@ -18,27 +19,31 @@ def __init__(self, base_distribution: Distribution, surjector: Surjector): def __call__(self, method, **kwargs): return getattr(self, method)(**kwargs) - def log_prob(self, y: Array) -> Array: - _, lp = self.inverse_and_log_prob(y) + def log_prob(self, y: Array, x: Array = None) -> Array: + _, lp = self.inverse_and_log_prob(y, x) return lp - def inverse_and_log_prob(self, y: Array) -> Tuple[Array, Array]: - x, lc = self.surjector.inverse_and_likelihood_contribution(y) + def inverse_and_log_prob(self, y: Array, x: Array=None) -> Tuple[Array, Array]: + x, lc = self.surjector.inverse_and_likelihood_contribution(y, x=x) lp_x = self.base_distribution.log_prob(x) lp = lp_x + lc return x, lp - def sample(self, key: PRNGKey, sample_shape=(1,)): - z = self.base_distribution.sample(seed=key, sample_shape=sample_shape) - y = jax.vmap(self.surjector.inverse)(z) + def sample(self, sample_shape=(), x: Array = None): + if x is not None and len(sample_shape): + chex.assert_equal(sample_shape[0], x.shape[0]) + elif x is not None: + sample_shape = (x.shape[0],) + z = self.base_distribution.sample(seed=hk.next_rng_key(), sample_shape=sample_shape) + y = jax.vmap(self.surjector.forward)(z, x) return y - def sample_and_log_prob(self, key: PRNGKey, sample_shape=(1,)): + def sample_and_log_prob(self, sample_shape=(1,), x: Array = None): z, lp_z = self.base_distribution.sample_and_log_prob( - seed=key, sample_shape=sample_shape + seed=hk.next_rng_key(), sample_shape=sample_shape, x=x ) y, fldj = jax.vmap(self.surjector.forward_and_likelihood_contribution)( - z + z, x=x ) lp = jax.vmap(jnp.subtract)(lp_z, fldj) return y, lp diff --git a/surjectors/surjectors/chain.py b/surjectors/surjectors/chain.py index 172fde5..ab79199 100644 --- a/surjectors/surjectors/chain.py +++ b/surjectors/surjectors/chain.py @@ -6,27 +6,40 @@ def __init__(self, surjectors): super().__init__(None, None, None, "surjector") self._surjectors = surjectors - def inverse_and_likelihood_contribution(self, y): + def inverse_and_likelihood_contribution(self, y, x=None): z, lcs = self._inverse_and_log_contribution_dispatch( - self._surjectors[0], y + self._surjectors[0], y, x ) for surjector in self._surjectors[1:]: - z, lc = self._inverse_and_log_contribution_dispatch(surjector, z) + z, lc = self._inverse_and_log_contribution_dispatch(surjector, z, x) lcs += lc return z, lcs @staticmethod - def _inverse_and_log_contribution_dispatch(surjector, y): + def _inverse_and_log_contribution_dispatch(surjector, y, x): if isinstance(surjector, Surjector): fn = getattr(surjector, "inverse_and_likelihood_contribution") else: fn = getattr(surjector, "inverse_and_log_det") - z, lc = fn(y) + z, lc = fn(y, x) return z, lc - def forward_and_likelihood_contribution(self, z): - y, log_det = self._surjectors[-1].forward_and_log_det(z) - for _surjectors in reversed(self._surjectors[:-1]): - x, lc = _surjectors.forward_and_log_det(x) + def forward_and_likelihood_contribution(self, z, x=None): + y, log_det = self._surjectors[-1].forward_and_log_det(z, x) + for surjector in reversed(self._surjectors[:-1]): + y, lc = self._forward_and_log_contribution_dispatch(surjector, y, x) log_det += lc return y, log_det + + @staticmethod + def _forward_and_log_contribution_dispatch(surjector, y, x): + if isinstance(surjector, Surjector): + fn = getattr(surjector, "forward_and_likelihood_contribution") + else: + fn = getattr(surjector, "forward_and_log_det") + z, lc = fn(y, x) + return z, lc + + def forward(self, z, x=None): + y, _ = self.forward_and_likelihood_contribution(z, x) + return y diff --git a/surjectors/surjectors/slice.py b/surjectors/surjectors/slice.py index d308dbd..f8ba507 100644 --- a/surjectors/surjectors/slice.py +++ b/surjectors/surjectors/slice.py @@ -1,5 +1,5 @@ from jax import numpy as jnp - +import haiku as hk from surjectors.surjectors.funnel import Funnel @@ -13,12 +13,24 @@ def split_input(self, input): spl = jnp.split(input, [self.n_keep], axis=-1) return spl - def inverse_and_likelihood_contribution(self, y): + def inverse_and_likelihood_contribution(self, y, x = None): z, y_minus = self.split_input(y) - lc = self.decoder(z).log_prob(y_minus) + z_condition = z + if x is not None: + z_condition = jnp.concatenate([z, x], axis=-1) + lc = self.decoder(z_condition).log_prob(y_minus) return z, lc - def forward_and_likelihood_contribution(self, z): - y_minus = self.decoder.sample(context=z) + def forward_and_likelihood_contribution(self, z, x=None): + z_condition = z + if x is not None: + z_condition = jnp.concatenate([z, x], axis=-1) + y_minus, lc = self.decoder(z_condition).sample_and_log_prob( + seed=hk.next_rng_key() + ) y = jnp.concatenate([z, y_minus], axis=-1) + return y, lc + + def forward(self, z, x=None): + y, _ = self.forward_and_likelihood_contribution(z, x) return y From 4308b3729e0e9e4e760ff97b3e0ed71e9420970b Mon Sep 17 00:00:00 2001 From: Simon Dirmeier Date: Mon, 14 Nov 2022 22:57:07 +0100 Subject: [PATCH 05/10] Fix coupling funnel --- examples/multivariate_gaussian.py | 113 +++++++++++++++++- .../{surjectors => bijectors}/lu_linear.py | 0 surjectors/surjectors/_transform.py | 1 - .../surjectors/affine_coupling_funnel.py | 34 ------ .../affine_masked_coupling_funnel.py | 48 ++++++++ surjectors/surjectors/funnel.py | 4 +- surjectors/surjectors/mlp.py | 2 +- surjectors/surjectors/slice.py | 1 + surjectors/surjectors/surjector.py | 4 +- 9 files changed, 165 insertions(+), 42 deletions(-) rename surjectors/{surjectors => bijectors}/lu_linear.py (100%) delete mode 100644 surjectors/surjectors/affine_coupling_funnel.py create mode 100644 surjectors/surjectors/affine_masked_coupling_funnel.py diff --git a/examples/multivariate_gaussian.py b/examples/multivariate_gaussian.py index 75bb918..b4b93d3 100644 --- a/examples/multivariate_gaussian.py +++ b/examples/multivariate_gaussian.py @@ -9,7 +9,9 @@ from jax import random from surjectors.bijectors.masked_coupling import MaskedCoupling +from surjectors.surjectors.affine_masked_coupling_funnel import AffineCouplingFunnel from surjectors.surjectors.chain import Chain +from surjectors.surjectors.funnel import Funnel from surjectors.surjectors.slice import Slice from surjectors.distributions.transformed_distribution import ( TransformedDistribution, @@ -128,7 +130,7 @@ def _transformation_fn(): ) layers.append(layer) mask = jnp.logical_not(mask) - #return Slice(n_latent, _decoder_fn()) + # return Slice(n_latent, _decoder_fn()) return Chain(layers) def _base_fn(): @@ -146,6 +148,105 @@ def _flow(method, **kwargs): return td +def _get_funnel_surjector(n_dimension, n_latent): + def _bijector_conditioner(dim): + return hk.Sequential( + [ + hk.Linear( + 32, + w_init=hk.initializers.TruncatedNormal(stddev=0.01), + b_init=jnp.zeros, + ), + jax.nn.gelu, + hk.Linear( + 32, + w_init=hk.initializers.TruncatedNormal(stddev=0.01), + b_init=jnp.zeros, + ), + jax.nn.gelu, + hk.Linear(dim * 2), + ] + ) + + def _surjector_conditioner(): + return hk.Sequential( + [ + hk.Linear( + 16, + w_init=hk.initializers.TruncatedNormal(stddev=0.01), + b_init=jnp.zeros, + ), + jax.nn.gelu, + hk.Linear( + 16, + w_init=hk.initializers.TruncatedNormal(stddev=0.01), + b_init=jnp.zeros, + ), + jax.nn.gelu, + hk.Linear((n_dimension - n_latent) * 2), + ] + ) + + def _decoder_fn(): + decoder_net = _surjector_conditioner() + + def _fn(z): + params = decoder_net(z) + mu, log_scale = jnp.split(params, 2, -1) + return distrax.Independent(distrax.Normal(mu, jnp.exp(log_scale))) + + return _fn + + def _bijector_fn(params): + means, log_scales = jnp.split(params, 2, -1) + return distrax.ScalarAffine(means, jnp.exp(log_scales)) + + def _transformation_fn(): + layers = [] + mask = jnp.arange(0, np.prod(n_dimension)) % 2 + mask = jnp.reshape(mask, n_dimension) + mask = mask.astype(bool) + + for _ in range(2): + layer = MaskedCoupling( + mask=mask, + bijector=_bijector_fn, + conditioner=_bijector_conditioner(n_dimension), + ) + layers.append(layer) + + layers.append(AffineCouplingFunnel(n_latent, _decoder_fn(), _bijector_conditioner(n_dimension))) + + mask = jnp.arange(0, np.prod(n_latent)) % 2 + mask = jnp.reshape(mask, n_latent) + mask = mask.astype(bool) + + for _ in range(2): + layer = MaskedCoupling( + mask=mask, + bijector=_bijector_fn, + conditioner=_bijector_conditioner(n_latent), + ) + layers.append(layer) + mask = jnp.logical_not(mask) + #return Chain(layers) + return AffineCouplingFunnel(n_latent, _decoder_fn(), _bijector_conditioner(n_dimension)) + + def _base_fn(): + base_distribution = distrax.Independent( + distrax.Normal(jnp.zeros(n_latent), jnp.ones(n_latent)), + reinterpreted_batch_ndims=1, + ) + return base_distribution + + def _flow(method, **kwargs): + td = TransformedDistribution(_base_fn(), _transformation_fn()) + return td(method, **kwargs) + + td = hk.transform(_flow) + return td + + def train(key, surjector_fn, n_data, n_latent, batch_size, n_iter): rng_seq = hk.PRNGSequence(0) pyz, loadings = _get_sampler_and_loadings(next(rng_seq), batch_size, n_data) @@ -194,8 +295,14 @@ def loss_fn(params): def run(): - key = random.PRNGKey(0) - train(key, _get_slice_surjector, n_data, n_latent, batch_size, n_iter) + train( + key=0, + surjector_fn=_get_funnel_surjector, + n_iter=2000, + batch_size=64, + n_data=10, + n_latent=5 + ) if __name__ == "__main__": diff --git a/surjectors/surjectors/lu_linear.py b/surjectors/bijectors/lu_linear.py similarity index 100% rename from surjectors/surjectors/lu_linear.py rename to surjectors/bijectors/lu_linear.py diff --git a/surjectors/surjectors/_transform.py b/surjectors/surjectors/_transform.py index 03d47fe..d0861a0 100644 --- a/surjectors/surjectors/_transform.py +++ b/surjectors/surjectors/_transform.py @@ -1,6 +1,5 @@ from abc import ABCMeta, abstractmethod -import distrax from distrax._src.utils import jittable diff --git a/surjectors/surjectors/affine_coupling_funnel.py b/surjectors/surjectors/affine_coupling_funnel.py deleted file mode 100644 index dad1a4b..0000000 --- a/surjectors/surjectors/affine_coupling_funnel.py +++ /dev/null @@ -1,34 +0,0 @@ -import distrax -from chex import Array -from distrax import MaskedCoupling -from jax import numpy as jnp - -from surjectors.surjectors.funnel import Funnel - - -class AffineCouplingFunnel(Funnel): - def __init__(self, n_keep, decoder, conditioner): - super().__init__(n_keep, decoder, None, "inference_surjection") - self._conditioner = conditioner - - def _mask(self, array): - mask = jnp.arange(array.shape[-1]) >= self.n_keep - mask = mask.astype(jnp.bool_) - return mask - - def _inner_bijector(self, mask): - def _bijector_fn(params: Array): - shift, log_scale = jnp.split(params, 2, axis=-1) - return distrax.ScalarAffine(shift, jnp.exp(log_scale)) - - return MaskedCoupling(mask, self._conditioner, _bijector_fn) - - def inverse_and_likelihood_contribution(self, y): - mask = self._mask(y) - faux, jac_det = self._inner_bijector(mask).inverse_and_log_det(y) - z = faux[:, : self.n_keep] - lp = self.decoder.log_prob(faux[:, self.n_keep :], context=z) - return z, lp + jac_det - - def forward_and_likelihood_contribution(self, z): - raise NotImplementedError() diff --git a/surjectors/surjectors/affine_masked_coupling_funnel.py b/surjectors/surjectors/affine_masked_coupling_funnel.py new file mode 100644 index 0000000..6b0bc26 --- /dev/null +++ b/surjectors/surjectors/affine_masked_coupling_funnel.py @@ -0,0 +1,48 @@ +import distrax +from chex import Array +from jax import numpy as jnp +import haiku as hk +from surjectors.bijectors.masked_coupling import MaskedCoupling +from surjectors.surjectors.funnel import Funnel + + +class AffineCouplingFunnel(Funnel): + def __init__(self, n_keep, decoder, conditioner): + super().__init__(n_keep, decoder, conditioner, None, "inference_surjector") + + def _mask(self, array): + mask = jnp.arange(array.shape[-1]) >= self.n_keep + mask = mask.astype(jnp.bool_) + return mask + + def _inner_bijector(self, mask): + def _bijector_fn(params: Array): + shift, log_scale = jnp.split(params, 2, axis=-1) + return distrax.ScalarAffine(shift, jnp.exp(log_scale)) + + return MaskedCoupling( + mask, self._conditioner, _bijector_fn + ) + + def inverse_and_likelihood_contribution(self, y, x=None): + faux, jac_det = self._inner_bijector(self._mask(y)).inverse_and_log_det(y, x) + z = faux[:, :self.n_keep] + z_condition = z + if x is not None: + z_condition = jnp.concatenate([z, x], axis=-1) + lc = self.decoder(z_condition).log_prob(y[:, self.n_keep:]) + return z, lc + jac_det + + def forward_and_likelihood_contribution(self, z, x=None): + z_condition = z + if x is not None: + z_condition = jnp.concatenate([z, x], axis=-1) + y_minus, jac_det = self.decoder(z_condition).sample_and_log_prob(seed=hk.next_rng_key()) + # TODO need to sort the indexes correctly (?) + z_tilde = jnp.concatenate([z, y_minus], axis=-1) + y, lc = self._inner_bijector(self._mask(z_tilde)).forward_and_log_det(z_tilde, x) + return y, lc + jac_det + + def forward(self, z, x=None): + y, _ = self.forward_and_likelihood_contribution(z, x) + return y \ No newline at end of file diff --git a/surjectors/surjectors/funnel.py b/surjectors/surjectors/funnel.py index ce5437c..64f9fc5 100644 --- a/surjectors/surjectors/funnel.py +++ b/surjectors/surjectors/funnel.py @@ -1,7 +1,9 @@ +from abc import ABC + from surjectors.surjectors.surjector import Surjector -class Funnel(Surjector): +class Funnel(Surjector, ABC): def __init__(self, n_keep, decoder, conditioner, encoder, kind): super().__init__(n_keep, decoder, encoder, kind) self._conditioner = conditioner diff --git a/surjectors/surjectors/mlp.py b/surjectors/surjectors/mlp.py index 8087b62..658ade4 100644 --- a/surjectors/surjectors/mlp.py +++ b/surjectors/surjectors/mlp.py @@ -2,7 +2,7 @@ import haiku as hk from jax import numpy as jnp -from surjectors.surjectors.affine_coupling_funnel import Funnel +from surjectors.surjectors.affine_masked_coupling_funnel import Funnel from surjectors.surjectors.lu_linear import LULinear diff --git a/surjectors/surjectors/slice.py b/surjectors/surjectors/slice.py index f8ba507..056f448 100644 --- a/surjectors/surjectors/slice.py +++ b/surjectors/surjectors/slice.py @@ -34,3 +34,4 @@ def forward_and_likelihood_contribution(self, z, x=None): def forward(self, z, x=None): y, _ = self.forward_and_likelihood_contribution(z, x) return y + diff --git a/surjectors/surjectors/surjector.py b/surjectors/surjectors/surjector.py index 00ccf19..ee4eee3 100644 --- a/surjectors/surjectors/surjector.py +++ b/surjectors/surjectors/surjector.py @@ -1,4 +1,4 @@ -from abc import abstractmethod +from abc import abstractmethod, ABC from jax import numpy as jnp from surjectors.surjectors._transform import Transform @@ -12,7 +12,7 @@ ] -class Surjector(Transform): +class Surjector(Transform, ABC): """ Surjector base class """ From 8a3bf45fe64292acd2c82272eec513ef4f32d675 Mon Sep 17 00:00:00 2001 From: Simon Dirmeier Date: Tue, 15 Nov 2022 22:40:17 +0100 Subject: [PATCH 06/10] Add affine surjectors --- ...ivariate_gaussian_generative_surjection.py | 275 ++++++++++++++++++ ...ivariate_gaussian_inference_surjection.py} | 22 +- .../distributions/transformed_distribution.py | 3 +- ...ffine_masked_coupling_generative_funnel.py | 50 ++++ ...ffine_masked_coupling_inference_funnel.py} | 7 +- surjectors/surjectors/augment.py | 34 +++ surjectors/surjectors/mlp.py | 2 +- surjectors/surjectors/slice.py | 11 +- 8 files changed, 382 insertions(+), 22 deletions(-) create mode 100644 examples/multivariate_gaussian_generative_surjection.py rename examples/{multivariate_gaussian.py => multivariate_gaussian_inference_surjection.py} (95%) create mode 100644 surjectors/surjectors/affine_masked_coupling_generative_funnel.py rename surjectors/surjectors/{affine_masked_coupling_funnel.py => affine_masked_coupling_inference_funnel.py} (90%) create mode 100644 surjectors/surjectors/augment.py diff --git a/examples/multivariate_gaussian_generative_surjection.py b/examples/multivariate_gaussian_generative_surjection.py new file mode 100644 index 0000000..9c0b2a8 --- /dev/null +++ b/examples/multivariate_gaussian_generative_surjection.py @@ -0,0 +1,275 @@ +import distrax +import haiku as hk +import jax +import matplotlib.pyplot as plt +import numpy as np +import optax +from jax import config +from jax import numpy as jnp +from jax import random + +from surjectors.bijectors.masked_coupling import MaskedCoupling +from surjectors.distributions.transformed_distribution import ( + TransformedDistribution, +) +from surjectors.surjectors.affine_masked_coupling_generative_funnel import \ + AffineMaskedCouplingGenerativeFunnel +from surjectors.surjectors.augment import Augment +from surjectors.surjectors.chain import Chain + +config.update("jax_enable_x64", True) + + +def _get_sampler_and_loadings(rng_key, batch_size, n_dimension): + pz_mean = jnp.array([-2.31, 0.421, 0.1, 3.21, -0.41, -2.31, 0.421, 0.1, 3.21, -0.41]) + pz = distrax.MultivariateNormalDiag( + loc=pz_mean, scale_diag=jnp.ones_like(pz_mean) + ) + p_loadings = distrax.Normal(0.0, 10.0) + make_noise = distrax.Normal(0.0, 1) + + loadings_sample_key, rng_key = random.split(rng_key, 2) + loadings = p_loadings.sample( + seed=loadings_sample_key, sample_shape=(n_dimension, len(pz_mean)) + ) + + def _fn(rng_key): + z_sample_key, noise_sample_key = random.split(rng_key, 2) + z = pz.sample(seed=z_sample_key, sample_shape=(batch_size,)) + noise = make_noise.sample( + seed=noise_sample_key, sample_shape=(batch_size, n_dimension) + ) + + y = (loadings @ z.T).T + noise + return y, z, noise + return z[:, :n_dimension], z , noise + + return _fn, loadings + + +def _get_slice_surjector(n_dimension, n_latent): + def _conditioner(dim): + return hk.Sequential( + [ + hk.Linear( + 32, + w_init=hk.initializers.TruncatedNormal(stddev=0.01), + b_init=jnp.zeros, + ), + jax.nn.gelu, + hk.Linear( + 32, + w_init=hk.initializers.TruncatedNormal(stddev=0.01), + b_init=jnp.zeros, + ), + jax.nn.gelu, + hk.Linear(dim * 2), + ] + ) + + def _encoder_fn(): + decoder_net = _conditioner((n_latent - n_dimension)) + + def _fn(z): + params = decoder_net(z) + mu, log_scale = jnp.split(params, 2, -1) + return distrax.Independent(distrax.Normal(mu, jnp.exp(log_scale))) + + return _fn + + def _bijector_fn(params): + means, log_scales = jnp.split(params, 2, -1) + return distrax.ScalarAffine(means, jnp.exp(log_scales)) + + def _transformation_fn(): + layers = [] + + mask = jnp.arange(0, np.prod(n_dimension)) % 2 + mask = jnp.reshape(mask, n_dimension) + mask = mask.astype(bool) + for _ in range(2): + layer = MaskedCoupling( + mask=mask, + bijector=_bijector_fn, + conditioner=_conditioner(n_dimension), + ) + layers.append(layer) + + layers.append( + Augment(n_dimension, _encoder_fn()) + ) + + mask = jnp.arange(0, np.prod(n_latent)) % 2 + mask = jnp.reshape(mask, n_latent) + mask = mask.astype(bool) + for _ in range(2): + layer = MaskedCoupling( + mask=mask, + bijector=_bijector_fn, + conditioner=_conditioner(n_latent), + ) + layers.append(layer) + mask = jnp.logical_not(mask) + #return Augment(n_dimension, _encoder_fn()) + return Chain(layers) + + def _base_fn(): + base_distribution = distrax.Independent( + distrax.Normal(jnp.zeros(n_latent), jnp.ones(n_latent)), + reinterpreted_batch_ndims=1, + ) + return base_distribution + + def _flow(method, **kwargs): + td = TransformedDistribution(_base_fn(), _transformation_fn()) + return td(method, **kwargs) + + td = hk.transform(_flow) + return td + + +def _get_funnel_surjector(n_dimension, n_latent): + def _conditioner(dim): + return hk.Sequential( + [ + hk.Linear( + 32, + w_init=hk.initializers.TruncatedNormal(stddev=0.01), + b_init=jnp.zeros, + ), + jax.nn.gelu, + hk.Linear( + 32, + w_init=hk.initializers.TruncatedNormal(stddev=0.01), + b_init=jnp.zeros, + ), + jax.nn.gelu, + hk.Linear(dim * 2), + ] + ) + + def _encoder_fn(): + decoder_net = _conditioner((n_latent - n_dimension)) + + def _fn(z): + params = decoder_net(z) + mu, log_scale = jnp.split(params, 2, -1) + return distrax.Independent(distrax.Normal(mu, jnp.exp(log_scale))) + + return _fn + + def _bijector_fn(params): + means, log_scales = jnp.split(params, 2, -1) + return distrax.ScalarAffine(means, jnp.exp(log_scales)) + + def _transformation_fn(): + layers = [] + + mask = jnp.arange(0, np.prod(n_dimension)) % 2 + mask = jnp.reshape(mask, n_dimension) + mask = mask.astype(bool) + for _ in range(2): + layer = MaskedCoupling( + mask=mask, + bijector=_bijector_fn, + conditioner=_conditioner(n_dimension), + ) + layers.append(layer) + + layers.append( + AffineMaskedCouplingGenerativeFunnel( + n_dimension, _encoder_fn(), _conditioner(n_latent) + ) + ) + + mask = jnp.arange(0, np.prod(n_latent)) % 2 + mask = jnp.reshape(mask, n_latent) + mask = mask.astype(bool) + for _ in range(2): + layer = MaskedCoupling( + mask=mask, + bijector=_bijector_fn, + conditioner=_conditioner(n_latent), + ) + layers.append(layer) + mask = jnp.logical_not(mask) + + return Chain(layers) + # return AffineMaskedCouplingGenerativeFunnel( + # n_dimension, _encoder_fn(), _conditioner(n_latent) + # ) + + def _base_fn(): + base_distribution = distrax.Independent( + distrax.Normal(jnp.zeros(n_latent), jnp.ones(n_latent)), + reinterpreted_batch_ndims=1, + ) + return base_distribution + + def _flow(method, **kwargs): + td = TransformedDistribution(_base_fn(), _transformation_fn()) + return td(method, **kwargs) + + td = hk.transform(_flow) + return td + + +def train(key, surjector_fn, n_data, n_latent, batch_size, n_iter): + rng_seq = hk.PRNGSequence(0) + pyz, loadings = _get_sampler_and_loadings(next(rng_seq), 2*batch_size, n_data) + flow = surjector_fn(n_data, n_latent) + + @jax.jit + def step(params, state, y_batch, noise_batch, rng): + def loss_fn(params): + lp = flow.apply( + params, rng, method="log_prob", y=y_batch, x=noise_batch + ) + return -jnp.sum(lp) + + loss, grads = jax.value_and_grad(loss_fn)(params) + updates, new_state = adam.update(grads, state, params) + new_params = optax.apply_updates(params, updates) + return loss, new_params, new_state + + y_init, _, noise_init = pyz(random.fold_in(next(rng_seq), 0)) + params = flow.init( + random.PRNGKey(key), + method="log_prob", + y=y_init, + x=noise_init + ) + adam = optax.adamw(0.001) + state = adam.init(params) + + losses = [0] * n_iter + for i in range(n_iter): + y_batch, _, noise_batch = pyz(next(rng_seq)) + loss, params, state = step(params, state, y_batch, noise_batch, next(rng_seq)) + losses[i] = loss + + losses = jnp.asarray(losses) + plt.plot(losses) + plt.show() + + y_batch, z_batch, noise_batch = pyz(next(rng_seq)) + y_pred = flow.apply( + params, next(rng_seq), method="sample", x=noise_batch, + ) + print(y_batch[:5, :]) + print(y_pred[:5, :]) + + +def run(): + train( + key=0, + surjector_fn=_get_funnel_surjector, + n_iter=2000, + batch_size=64, + n_data=5, + n_latent=10 + ) + + +if __name__ == "__main__": + run() diff --git a/examples/multivariate_gaussian.py b/examples/multivariate_gaussian_inference_surjection.py similarity index 95% rename from examples/multivariate_gaussian.py rename to examples/multivariate_gaussian_inference_surjection.py index b4b93d3..58077f3 100644 --- a/examples/multivariate_gaussian.py +++ b/examples/multivariate_gaussian_inference_surjection.py @@ -1,23 +1,21 @@ import distrax import haiku as hk import jax +import matplotlib.pyplot as plt import numpy as np import optax -import matplotlib.pyplot as plt - +from jax import config from jax import numpy as jnp from jax import random from surjectors.bijectors.masked_coupling import MaskedCoupling -from surjectors.surjectors.affine_masked_coupling_funnel import AffineCouplingFunnel -from surjectors.surjectors.chain import Chain -from surjectors.surjectors.funnel import Funnel -from surjectors.surjectors.slice import Slice from surjectors.distributions.transformed_distribution import ( TransformedDistribution, ) - -from jax import config +from surjectors.surjectors.affine_masked_coupling_inference_funnel import \ + AffineMaskedCouplingInferenceFunnel +from surjectors.surjectors.chain import Chain +from surjectors.surjectors.slice import Slice config.update("jax_enable_x64", True) @@ -215,7 +213,9 @@ def _transformation_fn(): ) layers.append(layer) - layers.append(AffineCouplingFunnel(n_latent, _decoder_fn(), _bijector_conditioner(n_dimension))) + layers.append( + AffineMaskedCouplingInferenceFunnel(n_latent, _decoder_fn(), _bijector_conditioner(n_dimension)) + ) mask = jnp.arange(0, np.prod(n_latent)) % 2 mask = jnp.reshape(mask, n_latent) @@ -229,8 +229,8 @@ def _transformation_fn(): ) layers.append(layer) mask = jnp.logical_not(mask) - #return Chain(layers) - return AffineCouplingFunnel(n_latent, _decoder_fn(), _bijector_conditioner(n_dimension)) + return Chain(layers) + #return AffineCouplingFunnel(n_latent, _decoder_fn(), _bijector_conditioner(n_dimension)) def _base_fn(): base_distribution = distrax.Independent( diff --git a/surjectors/distributions/transformed_distribution.py b/surjectors/distributions/transformed_distribution.py index 07c190c..b2d8483 100644 --- a/surjectors/distributions/transformed_distribution.py +++ b/surjectors/distributions/transformed_distribution.py @@ -1,11 +1,10 @@ from typing import Tuple import chex +import haiku as hk import jax import jax.numpy as jnp -from chex import PRNGKey from distrax import Distribution -import haiku as hk Array = chex.Array from surjectors.surjectors.surjector import Surjector diff --git a/surjectors/surjectors/affine_masked_coupling_generative_funnel.py b/surjectors/surjectors/affine_masked_coupling_generative_funnel.py new file mode 100644 index 0000000..f80ffa6 --- /dev/null +++ b/surjectors/surjectors/affine_masked_coupling_generative_funnel.py @@ -0,0 +1,50 @@ +import distrax +from chex import Array +from jax import numpy as jnp +import haiku as hk +from surjectors.bijectors.masked_coupling import MaskedCoupling +from surjectors.surjectors.funnel import Funnel + + +class AffineMaskedCouplingGenerativeFunnel(Funnel): + def __init__(self, n_keep, encoder, conditioner): + super().__init__(n_keep, None, conditioner, encoder, "generative_surjector") + + def _mask(self, array): + mask = jnp.arange(array.shape[-1]) >= self.n_keep + mask = mask.astype(jnp.bool_) + return mask + + def _inner_bijector(self, mask): + def _bijector_fn(params: Array): + shift, log_scale = jnp.split(params, 2, axis=-1) + return distrax.ScalarAffine(shift, jnp.exp(log_scale)) + + return MaskedCoupling( + mask, self._conditioner, _bijector_fn + ) + + def inverse_and_likelihood_contribution(self, y, x=None): + y_condition = y + # TODO + if x is not None: + y_condition = jnp.concatenate([y, x], axis=-1) + z_minus, lc = self.encoder(y_condition).sample_and_log_prob(seed=hk.next_rng_key()) + input = jnp.concatenate([y, z_minus], axis=-1) + # TODO: remote the conditioning here? + z, jac_det = self._inner_bijector(self._mask(input)).inverse_and_log_det(input) + return z, -lc + jac_det + + def forward_and_likelihood_contribution(self, z, x=None): + # TODO: remote the conditioning here? + faux, jac_det = self._inner_bijector(self._mask(z)).inverse_and_log_det(z) + y = faux[..., :self.n_keep] + y_condition = y + if x is not None: + y_condition = jnp.concatenate([y_condition, x], axis=-1) + lc = self.encoder(y_condition).log_prob(faux[..., self.n_keep:]) + return y, -lc + jac_det + + def forward(self, z, x=None): + y, _ = self.forward_and_likelihood_contribution(z, x) + return y diff --git a/surjectors/surjectors/affine_masked_coupling_funnel.py b/surjectors/surjectors/affine_masked_coupling_inference_funnel.py similarity index 90% rename from surjectors/surjectors/affine_masked_coupling_funnel.py rename to surjectors/surjectors/affine_masked_coupling_inference_funnel.py index 6b0bc26..e8d3804 100644 --- a/surjectors/surjectors/affine_masked_coupling_funnel.py +++ b/surjectors/surjectors/affine_masked_coupling_inference_funnel.py @@ -6,7 +6,7 @@ from surjectors.surjectors.funnel import Funnel -class AffineCouplingFunnel(Funnel): +class AffineMaskedCouplingInferenceFunnel(Funnel): def __init__(self, n_keep, decoder, conditioner): super().__init__(n_keep, decoder, conditioner, None, "inference_surjector") @@ -25,9 +25,9 @@ def _bijector_fn(params: Array): ) def inverse_and_likelihood_contribution(self, y, x=None): + # TODO: remote the conditioning here? faux, jac_det = self._inner_bijector(self._mask(y)).inverse_and_log_det(y, x) - z = faux[:, :self.n_keep] - z_condition = z + z_condition = z = faux[:, :self.n_keep] if x is not None: z_condition = jnp.concatenate([z, x], axis=-1) lc = self.decoder(z_condition).log_prob(y[:, self.n_keep:]) @@ -40,6 +40,7 @@ def forward_and_likelihood_contribution(self, z, x=None): y_minus, jac_det = self.decoder(z_condition).sample_and_log_prob(seed=hk.next_rng_key()) # TODO need to sort the indexes correctly (?) z_tilde = jnp.concatenate([z, y_minus], axis=-1) + # TODO: remote the conditioning here? y, lc = self._inner_bijector(self._mask(z_tilde)).forward_and_log_det(z_tilde, x) return y, lc + jac_det diff --git a/surjectors/surjectors/augment.py b/surjectors/surjectors/augment.py new file mode 100644 index 0000000..44e15bd --- /dev/null +++ b/surjectors/surjectors/augment.py @@ -0,0 +1,34 @@ +import distrax +from chex import Array +from jax import numpy as jnp +import haiku as hk +from surjectors.surjectors.funnel import Funnel + + +class Augment(Funnel): + def __init__(self, n_keep, encoder): + super().__init__(n_keep, None, None, encoder, "generative_surjector") + + def split_input(self, input): + spl = jnp.split(input, [self.n_keep], axis=-1) + return spl + + def inverse_and_likelihood_contribution(self, y, x: Array = None): + z_plus = y_condition = y + if x is not None: + y_condition = jnp.concatenate([y_condition, x], axis=-1) + z_minus, lc = self.encoder(y_condition).sample_and_log_prob(seed=hk.next_rng_key()) + z = jnp.concatenate([z_plus, z_minus], axis=-1) + return z, -lc + + def forward_and_likelihood_contribution(self, z, x=None): + z_plus, z_minus = self.split_input(z) + y_condition = y = z_plus + if x is not None: + y_condition = jnp.concatenate([y_condition, x], axis=-1) + lc = self.encoder(y_condition).log_prob(z_minus) + return y, -lc + + def forward(self, z, x=None): + y, _ = self.forward_and_likelihood_contribution(z, x) + return y diff --git a/surjectors/surjectors/mlp.py b/surjectors/surjectors/mlp.py index 658ade4..d5188c9 100644 --- a/surjectors/surjectors/mlp.py +++ b/surjectors/surjectors/mlp.py @@ -2,7 +2,7 @@ import haiku as hk from jax import numpy as jnp -from surjectors.surjectors.affine_masked_coupling_funnel import Funnel +from surjectors.surjectors.affine_masked_coupling_inference_funnel import Funnel from surjectors.surjectors.lu_linear import LULinear diff --git a/surjectors/surjectors/slice.py b/surjectors/surjectors/slice.py index 056f448..7b6b899 100644 --- a/surjectors/surjectors/slice.py +++ b/surjectors/surjectors/slice.py @@ -1,19 +1,19 @@ +import distrax +from chex import Array from jax import numpy as jnp import haiku as hk from surjectors.surjectors.funnel import Funnel class Slice(Funnel): - def __init__( - self, n_keep, decoder, encoder=None, kind="inference_surjector" - ): - super().__init__(n_keep, decoder, encoder, None, kind) + def __init__(self, n_keep, decoder): + super().__init__(n_keep, decoder, None, None, "inference_surkector") def split_input(self, input): spl = jnp.split(input, [self.n_keep], axis=-1) return spl - def inverse_and_likelihood_contribution(self, y, x = None): + def inverse_and_likelihood_contribution(self, y, x: Array=None): z, y_minus = self.split_input(y) z_condition = z if x is not None: @@ -29,6 +29,7 @@ def forward_and_likelihood_contribution(self, z, x=None): seed=hk.next_rng_key() ) y = jnp.concatenate([z, y_minus], axis=-1) + return y, lc def forward(self, z, x=None): From cf487e9b2cf6ac2d9d7ab6d122dd4e94a405e6d7 Mon Sep 17 00:00:00 2001 From: Simon Dirmeier Date: Wed, 16 Nov 2022 15:17:47 +0100 Subject: [PATCH 07/10] Reformat some of the things --- examples/solar_dynamo.py | 15 -- ...ivariate_gaussian_generative_surjection.py | 222 ++++++++++++++++++ ...tivariate_gaussian_inference_surjection.py | 216 +++++++++++++++++ .../solar_dynamo_data.py | 26 +- .../solar_dynamo_generative_surjection.py | 56 +---- .../solar_dynamo_inference_surjection.py | 104 ++++---- surjectors/conditioners/mlp_conditioner.py | 19 ++ 7 files changed, 531 insertions(+), 127 deletions(-) delete mode 100644 examples/solar_dynamo.py create mode 100644 experiments/multivariate_gaussian_generative_surjection.py create mode 100644 experiments/multivariate_gaussian_inference_surjection.py rename {examples => experiments}/solar_dynamo_data.py (63%) rename examples/multivariate_gaussian_inference_surjection.py => experiments/solar_dynamo_generative_surjection.py (83%) rename examples/multivariate_gaussian_generative_surjection.py => experiments/solar_dynamo_inference_surjection.py (75%) create mode 100644 surjectors/conditioners/mlp_conditioner.py diff --git a/examples/solar_dynamo.py b/examples/solar_dynamo.py deleted file mode 100644 index 2778a2b..0000000 --- a/examples/solar_dynamo.py +++ /dev/null @@ -1,15 +0,0 @@ -import numpy as np -import jax -from jax import random, numpy as jnp -import matplotlib.pyplot as plt - -from examples.solar_dynamo_data import SolarDynamoSimulator - -simulator = SolarDynamoSimulator() - -n_iter = 1000 -for i in np.arange(n_iter): - p0, alpha1, alpha2, epsilon_max, f, pn = simulator.sample( - jnp.array([549229066, 500358972], dtype=jnp.uint32), 100 - ) - pns[i] = pn diff --git a/experiments/multivariate_gaussian_generative_surjection.py b/experiments/multivariate_gaussian_generative_surjection.py new file mode 100644 index 0000000..48ec957 --- /dev/null +++ b/experiments/multivariate_gaussian_generative_surjection.py @@ -0,0 +1,222 @@ +import distrax +import haiku as hk +import jax +import matplotlib.pyplot as plt +import numpy as np +import optax +from jax import config +from jax import numpy as jnp +from jax import random + +from surjectors.bijectors.masked_coupling import MaskedCoupling +from surjectors.conditioners.mlp_conditioner import mlp_conditioner +from surjectors.distributions.transformed_distribution import ( + TransformedDistribution, +) +from surjectors.surjectors.affine_masked_coupling_generative_funnel import \ + AffineMaskedCouplingGenerativeFunnel +from surjectors.surjectors.augment import Augment +from surjectors.surjectors.chain import Chain + +config.update("jax_enable_x64", True) + + +def _get_sampler_and_loadings(rng_key, batch_size, n_dimension): + pz_mean = jnp.array([-2.31, 0.421, 0.1, 3.21, -0.41, + -2.31, 0.421, 0.1, 3.21, -0.41]) + pz = distrax.MultivariateNormalDiag( + loc=pz_mean, scale_diag=jnp.ones_like(pz_mean) + ) + p_loadings = distrax.Normal(0.0, 10.0) + make_noise = distrax.Normal(0.0, 1) + + loadings_sample_key, rng_key = random.split(rng_key, 2) + loadings = p_loadings.sample( + seed=loadings_sample_key, sample_shape=(n_dimension, len(pz_mean)) + ) + + def _fn(rng_key): + z_sample_key, noise_sample_key = random.split(rng_key, 2) + z = pz.sample(seed=z_sample_key, sample_shape=(batch_size,)) + noise = make_noise.sample( + seed=noise_sample_key, sample_shape=(batch_size, n_dimension) + ) + + y = (loadings @ z.T).T + noise + #return z[:, :n_dimension], z , noise + return y, z, noise + + return _fn, loadings + + +def _encoder_fn(n_latent, n_dimension): + decoder_net = mlp_conditioner([32, 32, n_latent - n_dimension]) + + def _fn(z): + params = decoder_net(z) + mu, log_scale = jnp.split(params, 2, -1) + return distrax.Independent(distrax.Normal(mu, jnp.exp(log_scale))) + + return _fn + + +def _bijector_fn(params): + means, log_scales = jnp.split(params, 2, -1) + return distrax.ScalarAffine(means, jnp.exp(log_scales)) + + +def _base_distribution_fn(n_latent): + base_distribution = distrax.Independent( + distrax.Normal(jnp.zeros(n_latent), jnp.ones(n_latent)), + reinterpreted_batch_ndims=1, + ) + return base_distribution + + +def _get_slice_surjector(n_dimension, n_latent): + def _transformation_fn(): + layers = [] + + mask = jnp.arange(0, np.prod(n_dimension)) % 2 + mask = jnp.reshape(mask, n_dimension) + mask = mask.astype(bool) + for _ in range(2): + layer = MaskedCoupling( + mask=mask, + bijector=_bijector_fn, + conditioner=mlp_conditioner([32, 32, n_dimension]), + ) + layers.append(layer) + + layers.append( + Augment(n_dimension, _encoder_fn(n_latent, n_dimension)) + ) + + mask = jnp.arange(0, np.prod(n_latent)) % 2 + mask = jnp.reshape(mask, n_latent) + mask = mask.astype(bool) + for _ in range(2): + layer = MaskedCoupling( + mask=mask, + bijector=_bijector_fn, + conditioner=mlp_conditioner([32, 32, n_latent]), + ) + layers.append(layer) + mask = jnp.logical_not(mask) + + #return Augment(n_dimension, _encoder_fn()) + return Chain(layers) + + def _flow(method, **kwargs): + td = TransformedDistribution(_base_distribution_fn(n_latent), _transformation_fn()) + return td(method, **kwargs) + + td = hk.transform(_flow) + return td + + +def _get_funnel_surjector(n_dimension, n_latent): + def _transformation_fn(): + layers = [] + + mask = jnp.arange(0, np.prod(n_dimension)) % 2 + mask = jnp.reshape(mask, n_dimension) + mask = mask.astype(bool) + for _ in range(2): + layer = MaskedCoupling( + mask=mask, + bijector=_bijector_fn, + conditioner=mlp_conditioner([32, 32, n_dimension]), + ) + layers.append(layer) + + layers.append( + AffineMaskedCouplingGenerativeFunnel( + n_dimension, _encoder_fn(n_latent, n_dimension), mlp_conditioner(n_latent) + ) + ) + + mask = jnp.arange(0, np.prod(n_latent)) % 2 + mask = jnp.reshape(mask, n_latent) + mask = mask.astype(bool) + for _ in range(2): + layer = MaskedCoupling( + mask=mask, + bijector=_bijector_fn, + conditioner=mlp_conditioner([32, 32, n_latent]), + ) + layers.append(layer) + mask = jnp.logical_not(mask) + + return Chain(layers) + # return AffineMaskedCouplingGenerativeFunnel( + # n_dimension, _encoder_fn(), _conditioner(n_latent) + # ) + + def _flow(method, **kwargs): + td = TransformedDistribution(_base_distribution_fn(n_latent), _transformation_fn()) + return td(method, **kwargs) + + td = hk.transform(_flow) + return td + + +def train(key, surjector_fn, n_data, n_latent, batch_size, n_iter): + rng_seq = hk.PRNGSequence(0) + pyz, loadings = _get_sampler_and_loadings(next(rng_seq), 2*batch_size, n_data) + flow = surjector_fn(n_data, n_latent) + + @jax.jit + def step(params, state, y_batch, noise_batch, rng): + def loss_fn(params): + lp = flow.apply( + params, rng, method="log_prob", y=y_batch, x=noise_batch + ) + return -jnp.sum(lp) + + loss, grads = jax.value_and_grad(loss_fn)(params) + updates, new_state = adam.update(grads, state, params) + new_params = optax.apply_updates(params, updates) + return loss, new_params, new_state + + y_init, _, noise_init = pyz(random.fold_in(next(rng_seq), 0)) + params = flow.init( + random.PRNGKey(key), + method="log_prob", + y=y_init, + x=noise_init + ) + adam = optax.adamw(0.001) + state = adam.init(params) + + losses = [0] * n_iter + for i in range(n_iter): + y_batch, _, noise_batch = pyz(next(rng_seq)) + loss, params, state = step(params, state, y_batch, noise_batch, next(rng_seq)) + losses[i] = loss + + losses = jnp.asarray(losses) + plt.plot(losses) + plt.show() + + y_batch, z_batch, noise_batch = pyz(next(rng_seq)) + y_pred = flow.apply( + params, next(rng_seq), method="sample", x=noise_batch, + ) + print(y_batch[:5, :]) + print(y_pred[:5, :]) + + +def run(): + train( + key=0, + surjector_fn=_get_funnel_surjector, + n_iter=2000, + batch_size=64, + n_data=5, + n_latent=10 + ) + + +if __name__ == "__main__": + run() diff --git a/experiments/multivariate_gaussian_inference_surjection.py b/experiments/multivariate_gaussian_inference_surjection.py new file mode 100644 index 0000000..1352637 --- /dev/null +++ b/experiments/multivariate_gaussian_inference_surjection.py @@ -0,0 +1,216 @@ +import distrax +import haiku as hk +import jax +import matplotlib.pyplot as plt +import numpy as np +import optax +from jax import config +from jax import numpy as jnp +from jax import random + +from surjectors.bijectors.masked_coupling import MaskedCoupling +from surjectors.conditioners.mlp_conditioner import mlp_conditioner +from surjectors.distributions.transformed_distribution import ( + TransformedDistribution, +) +from surjectors.surjectors.affine_masked_coupling_inference_funnel import \ + AffineMaskedCouplingInferenceFunnel +from surjectors.surjectors.chain import Chain +from surjectors.surjectors.slice import Slice + +config.update("jax_enable_x64", True) + + +def _get_sampler_and_loadings(rng_key, batch_size, n_dimension): + pz_mean = jnp.array([-2.31, 0.421, 0.1, 3.21, -0.41]) + pz = distrax.MultivariateNormalDiag( + loc=pz_mean, scale_diag=jnp.ones_like(pz_mean) + ) + p_loadings = distrax.Normal(0.0, 10.0) + make_noise = distrax.Normal(0.0, 1) + + loadings_sample_key, rng_key = random.split(rng_key, 2) + loadings = p_loadings.sample( + seed=loadings_sample_key, sample_shape=(n_dimension, len(pz_mean)) + ) + + def _fn(rng_key): + z_sample_key, noise_sample_key = random.split(rng_key, 2) + z = pz.sample(seed=z_sample_key, sample_shape=(batch_size,)) + noise = make_noise.sample( + seed=noise_sample_key, sample_shape=(batch_size, n_dimension) + ) + + y = (loadings @ z.T).T + noise + # y = jnp.concatenate([z, z] ,axis=-1) + return y, z, noise + + return _fn, loadings + + +def _decoder_fn(n_dimension, n_latent): + decoder_net = mlp_conditioner([32, 32, n_dimension - n_latent]) + + def _fn(z): + params = decoder_net(z) + mu, log_scale = jnp.split(params, 2, -1) + return distrax.Independent(distrax.Normal(mu, jnp.exp(log_scale))) + + return _fn + + +def _bijector_fn(params): + means, log_scales = jnp.split(params, 2, -1) + return distrax.ScalarAffine(means, jnp.exp(log_scales)) + + +def _base_distribution_fn(n_latent): + base_distribution = distrax.Independent( + distrax.Normal(jnp.zeros(n_latent), jnp.ones(n_latent)), + reinterpreted_batch_ndims=1, + ) + return base_distribution + + +def _get_slice_surjector(n_dimension, n_latent): + def _transformation_fn(): + layers = [] + mask = jnp.arange(0, np.prod(n_dimension)) % 2 + mask = jnp.reshape(mask, n_dimension) + mask = mask.astype(bool) + + for _ in range(2): + layer = MaskedCoupling( + mask=mask, + bijector=_bijector_fn, + conditioner=mlp_conditioner([32, 32, n_dimension]) + ) + layers.append(layer) + + layers.append(Slice(n_latent, _decoder_fn(n_dimension, n_latent))) + + mask = jnp.arange(0, np.prod(n_latent)) % 2 + mask = jnp.reshape(mask, n_latent) + mask = mask.astype(bool) + + for _ in range(2): + layer = MaskedCoupling( + mask=mask, + bijector=_bijector_fn, + conditioner=mlp_conditioner([32, 32, n_latent]), + ) + layers.append(layer) + mask = jnp.logical_not(mask) + # return Slice(n_latent, _decoder_fn()) + return Chain(layers) + + def _flow(method, **kwargs): + td = TransformedDistribution(_base_distribution_fn(n_latent), _transformation_fn()) + return td(method, **kwargs) + + td = hk.transform(_flow) + return td + + +def _get_funnel_surjector(n_dimension, n_latent): + def _transformation_fn(): + layers = [] + mask = jnp.arange(0, np.prod(n_dimension)) % 2 + mask = jnp.reshape(mask, n_dimension) + mask = mask.astype(bool) + + for _ in range(2): + layer = MaskedCoupling( + mask=mask, + bijector=_bijector_fn, + conditioner=mlp_conditioner([32, 32, n_dimension]), + ) + layers.append(layer) + + layers.append( + AffineMaskedCouplingInferenceFunnel(n_latent, _decoder_fn(n_dimension, n_latent), mlp_conditioner([32, 32, n_dimension])) + ) + + mask = jnp.arange(0, np.prod(n_latent)) % 2 + mask = jnp.reshape(mask, n_latent) + mask = mask.astype(bool) + + for _ in range(2): + layer = MaskedCoupling( + mask=mask, + bijector=_bijector_fn, + conditioner=mlp_conditioner([32, 32, n_latent]), + ) + layers.append(layer) + mask = jnp.logical_not(mask) + #return AffineCouplingFunnel(n_latent, _decoder_fn(), _bijector_conditioner(n_dimension)) + return Chain(layers) + + def _flow(method, **kwargs): + td = TransformedDistribution(_base_distribution_fn(n_latent), _transformation_fn()) + return td(method, **kwargs) + + td = hk.transform(_flow) + return td + + +def train(key, surjector_fn, n_data, n_latent, batch_size, n_iter): + rng_seq = hk.PRNGSequence(0) + pyz, loadings = _get_sampler_and_loadings(next(rng_seq), batch_size, n_data) + flow = surjector_fn(n_data, n_latent) + + @jax.jit + def step(params, state, y_batch, noise_batch, rng): + def loss_fn(params): + lp = flow.apply( + params, rng, method="log_prob", y=y_batch, x=noise_batch + ) + return -jnp.sum(lp) + + loss, grads = jax.value_and_grad(loss_fn)(params) + updates, new_state = adam.update(grads, state, params) + new_params = optax.apply_updates(params, updates) + return loss, new_params, new_state + + y_init, _, noise_init = pyz(random.fold_in(next(rng_seq), 0)) + params = flow.init( + random.PRNGKey(key), + method="log_prob", + y=y_init, + x=noise_init + ) + adam = optax.adamw(0.001) + state = adam.init(params) + + losses = [0] * n_iter + for i in range(n_iter): + y_batch, _, noise_batch = pyz(next(rng_seq)) + loss, params, state = step(params, state, y_batch, noise_batch, + next(rng_seq)) + losses[i] = loss + + losses = jnp.asarray(losses) + plt.plot(losses) + plt.show() + + y_batch, z_batch, noise_batch = pyz(next(rng_seq)) + y_pred = flow.apply( + params, next(rng_seq), method="sample", x=noise_batch + ) + print(y_batch[:5, :]) + print(y_pred[:5, :]) + + +def run(): + train( + key=0, + surjector_fn=_get_funnel_surjector, + n_iter=2000, + batch_size=64, + n_data=10, + n_latent=5 + ) + + +if __name__ == "__main__": + run() diff --git a/examples/solar_dynamo_data.py b/experiments/solar_dynamo_data.py similarity index 63% rename from examples/solar_dynamo_data.py rename to experiments/solar_dynamo_data.py index 59f333b..b857b68 100644 --- a/examples/solar_dynamo_data.py +++ b/experiments/solar_dynamo_data.py @@ -1,6 +1,6 @@ from jax import random, lax from jax.scipy.special import erf - +import distrax class SolarDynamoSimulator: def __init__(self, **kwargs): @@ -13,23 +13,23 @@ def __init__(self, **kwargs): self.alpha1 = kwargs.get("alpha1", None) self.alpha2 = kwargs.get("alpha2", None) - def sample(self, key, batclen_timeseries=1000): + def sample(self, key, batch_size, len_timeseries=1000): p_key, alpha1_key, alpha2_key, epsilon_key, key = random.split(key, 5) - p0 = random.normal(p_key) * self.p0_std + self.p0_mean + p0 = random.normal(p_key, shape=(batch_size,)) * self.p0_std + self.p0_mean alpha1 = random.uniform( - alpha1_key, minval=self.alpha1_min, maxval=self.alpha1_max + alpha1_key, shape=(batch_size,), minval=self.alpha1_min, maxval=self.alpha1_max ) alpha2 = random.uniform( - alpha2_key, minval=alpha1, maxval=self.alpha2_max + alpha2_key, shape=(batch_size,), minval=alpha1, maxval=self.alpha2_max ) epsilon_max = random.uniform( - epsilon_key, minval=0, maxval=self.epsilon_max + epsilon_key, shape=(batch_size,), minval=0, maxval=self.epsilon_max ) batch = self._sample_timeseries( - key, p0, alpha1, alpha2, epsilon_max, len_timeseries + key, batch_size, p0, alpha1, alpha2, epsilon_max, len_timeseries ) - return p0, alpha1, alpha2, epsilon_max, batch[0], batch[1] + return p0, alpha1, alpha2, epsilon_max, batch[0].T, batch[1].T @staticmethod def babcock_leighton_fn(p, b_1=0.6, w_1=0.2, b_2=1.0, w_2=0.8): @@ -41,13 +41,13 @@ def babcock_leighton(self, p, alpha, epsilon): return p def _sample_timeseries( - self, key, pn, alpha_min, alpha_max, epsilon_max, len_timeseries + self, key, batch_size, pn, alpha_min, alpha_max, epsilon_max, len_timeseries ): - a = random.uniform( - key, minval=alpha_min, maxval=alpha_max, shape=(len_timeseries,) + a = distrax.Uniform(alpha_min, alpha_max).sample( + seed=key, sample_shape=(len_timeseries,) ) - e = random.uniform( - key, minval=0.0, maxval=epsilon_max, shape=(len_timeseries,) + e = distrax.Uniform(0.0, epsilon_max).sample( + seed=key, sample_shape=(len_timeseries,) ) def _fn(fs, arrays): diff --git a/examples/multivariate_gaussian_inference_surjection.py b/experiments/solar_dynamo_generative_surjection.py similarity index 83% rename from examples/multivariate_gaussian_inference_surjection.py rename to experiments/solar_dynamo_generative_surjection.py index 58077f3..cefa9b8 100644 --- a/examples/multivariate_gaussian_inference_surjection.py +++ b/experiments/solar_dynamo_generative_surjection.py @@ -8,12 +8,13 @@ from jax import numpy as jnp from jax import random +from experiments.solar_dynamo_data import SolarDynamoSimulator from surjectors.bijectors.masked_coupling import MaskedCoupling from surjectors.distributions.transformed_distribution import ( TransformedDistribution, ) -from surjectors.surjectors.affine_masked_coupling_inference_funnel import \ - AffineMaskedCouplingInferenceFunnel +from surjectors.surjectors.affine_masked_coupling_inference_funnel \ + import AffineMaskedCouplingInferenceFunnel from surjectors.surjectors.chain import Chain from surjectors.surjectors.slice import Slice @@ -21,34 +22,15 @@ def _get_sampler_and_loadings(rng_key, batch_size, n_dimension): - pz_mean = jnp.array([-2.31, 0.421, 0.1, 3.21, -0.41]) - pz = distrax.MultivariateNormalDiag( - loc=pz_mean, scale_diag=jnp.ones_like(pz_mean) + simulator = SolarDynamoSimulator() + p0, alpha1, alpha2, epsilon_max, f, pn = simulator.sample( + jnp.array([549229066, 500358972], dtype=jnp.uint32), 64 ) - p_loadings = distrax.Normal(0.0, 10.0) - make_noise = distrax.Normal(0.0, 1) - - loadings_sample_key, rng_key = random.split(rng_key, 2) - loadings = p_loadings.sample( - seed=loadings_sample_key, sample_shape=(n_dimension, len(pz_mean)) - ) - - def _fn(rng_key): - z_sample_key, noise_sample_key = random.split(rng_key, 2) - z = pz.sample(seed=z_sample_key, sample_shape=(batch_size,)) - noise = make_noise.sample( - seed=noise_sample_key, sample_shape=(batch_size, n_dimension) - ) - - y = (loadings @ z.T).T + noise - # y = jnp.concatenate([z, z] ,axis=-1) - return y, z, noise - - return _fn, loadings + return simulator.sample def _get_slice_surjector(n_dimension, n_latent): - def _bijector_conditioner(dim): + def _conditioner(dim): return hk.Sequential( [ hk.Linear( @@ -67,27 +49,8 @@ def _bijector_conditioner(dim): ] ) - def _surjector_conditioner(): - return hk.Sequential( - [ - hk.Linear( - 16, - w_init=hk.initializers.TruncatedNormal(stddev=0.01), - b_init=jnp.zeros, - ), - jax.nn.gelu, - hk.Linear( - 16, - w_init=hk.initializers.TruncatedNormal(stddev=0.01), - b_init=jnp.zeros, - ), - jax.nn.gelu, - hk.Linear((n_dimension - n_latent) * 2), - ] - ) - def _decoder_fn(): - decoder_net = _surjector_conditioner() + decoder_net = _conditioner() def _fn(z): params = decoder_net(z) @@ -307,3 +270,4 @@ def run(): if __name__ == "__main__": run() + diff --git a/examples/multivariate_gaussian_generative_surjection.py b/experiments/solar_dynamo_inference_surjection.py similarity index 75% rename from examples/multivariate_gaussian_generative_surjection.py rename to experiments/solar_dynamo_inference_surjection.py index 9c0b2a8..cefa9b8 100644 --- a/examples/multivariate_gaussian_generative_surjection.py +++ b/experiments/solar_dynamo_inference_surjection.py @@ -8,43 +8,25 @@ from jax import numpy as jnp from jax import random +from experiments.solar_dynamo_data import SolarDynamoSimulator from surjectors.bijectors.masked_coupling import MaskedCoupling from surjectors.distributions.transformed_distribution import ( TransformedDistribution, ) -from surjectors.surjectors.affine_masked_coupling_generative_funnel import \ - AffineMaskedCouplingGenerativeFunnel -from surjectors.surjectors.augment import Augment +from surjectors.surjectors.affine_masked_coupling_inference_funnel \ + import AffineMaskedCouplingInferenceFunnel from surjectors.surjectors.chain import Chain +from surjectors.surjectors.slice import Slice config.update("jax_enable_x64", True) def _get_sampler_and_loadings(rng_key, batch_size, n_dimension): - pz_mean = jnp.array([-2.31, 0.421, 0.1, 3.21, -0.41, -2.31, 0.421, 0.1, 3.21, -0.41]) - pz = distrax.MultivariateNormalDiag( - loc=pz_mean, scale_diag=jnp.ones_like(pz_mean) + simulator = SolarDynamoSimulator() + p0, alpha1, alpha2, epsilon_max, f, pn = simulator.sample( + jnp.array([549229066, 500358972], dtype=jnp.uint32), 64 ) - p_loadings = distrax.Normal(0.0, 10.0) - make_noise = distrax.Normal(0.0, 1) - - loadings_sample_key, rng_key = random.split(rng_key, 2) - loadings = p_loadings.sample( - seed=loadings_sample_key, sample_shape=(n_dimension, len(pz_mean)) - ) - - def _fn(rng_key): - z_sample_key, noise_sample_key = random.split(rng_key, 2) - z = pz.sample(seed=z_sample_key, sample_shape=(batch_size,)) - noise = make_noise.sample( - seed=noise_sample_key, sample_shape=(batch_size, n_dimension) - ) - - y = (loadings @ z.T).T + noise - return y, z, noise - return z[:, :n_dimension], z , noise - - return _fn, loadings + return simulator.sample def _get_slice_surjector(n_dimension, n_latent): @@ -67,8 +49,8 @@ def _conditioner(dim): ] ) - def _encoder_fn(): - decoder_net = _conditioner((n_latent - n_dimension)) + def _decoder_fn(): + decoder_net = _conditioner() def _fn(z): params = decoder_net(z) @@ -83,34 +65,33 @@ def _bijector_fn(params): def _transformation_fn(): layers = [] - mask = jnp.arange(0, np.prod(n_dimension)) % 2 mask = jnp.reshape(mask, n_dimension) mask = mask.astype(bool) + for _ in range(2): layer = MaskedCoupling( mask=mask, bijector=_bijector_fn, - conditioner=_conditioner(n_dimension), + conditioner=_bijector_conditioner(n_dimension), ) layers.append(layer) - layers.append( - Augment(n_dimension, _encoder_fn()) - ) + layers.append(Slice(n_latent, _decoder_fn())) mask = jnp.arange(0, np.prod(n_latent)) % 2 mask = jnp.reshape(mask, n_latent) mask = mask.astype(bool) + for _ in range(2): layer = MaskedCoupling( mask=mask, bijector=_bijector_fn, - conditioner=_conditioner(n_latent), + conditioner=_bijector_conditioner(n_latent), ) layers.append(layer) mask = jnp.logical_not(mask) - #return Augment(n_dimension, _encoder_fn()) + # return Slice(n_latent, _decoder_fn()) return Chain(layers) def _base_fn(): @@ -129,7 +110,7 @@ def _flow(method, **kwargs): def _get_funnel_surjector(n_dimension, n_latent): - def _conditioner(dim): + def _bijector_conditioner(dim): return hk.Sequential( [ hk.Linear( @@ -148,8 +129,27 @@ def _conditioner(dim): ] ) - def _encoder_fn(): - decoder_net = _conditioner((n_latent - n_dimension)) + def _surjector_conditioner(): + return hk.Sequential( + [ + hk.Linear( + 16, + w_init=hk.initializers.TruncatedNormal(stddev=0.01), + b_init=jnp.zeros, + ), + jax.nn.gelu, + hk.Linear( + 16, + w_init=hk.initializers.TruncatedNormal(stddev=0.01), + b_init=jnp.zeros, + ), + jax.nn.gelu, + hk.Linear((n_dimension - n_latent) * 2), + ] + ) + + def _decoder_fn(): + decoder_net = _surjector_conditioner() def _fn(z): params = decoder_net(z) @@ -164,40 +164,36 @@ def _bijector_fn(params): def _transformation_fn(): layers = [] - mask = jnp.arange(0, np.prod(n_dimension)) % 2 mask = jnp.reshape(mask, n_dimension) mask = mask.astype(bool) + for _ in range(2): layer = MaskedCoupling( mask=mask, bijector=_bijector_fn, - conditioner=_conditioner(n_dimension), + conditioner=_bijector_conditioner(n_dimension), ) layers.append(layer) layers.append( - AffineMaskedCouplingGenerativeFunnel( - n_dimension, _encoder_fn(), _conditioner(n_latent) - ) + AffineMaskedCouplingInferenceFunnel(n_latent, _decoder_fn(), _bijector_conditioner(n_dimension)) ) mask = jnp.arange(0, np.prod(n_latent)) % 2 mask = jnp.reshape(mask, n_latent) mask = mask.astype(bool) + for _ in range(2): layer = MaskedCoupling( mask=mask, bijector=_bijector_fn, - conditioner=_conditioner(n_latent), + conditioner=_bijector_conditioner(n_latent), ) layers.append(layer) mask = jnp.logical_not(mask) - return Chain(layers) - # return AffineMaskedCouplingGenerativeFunnel( - # n_dimension, _encoder_fn(), _conditioner(n_latent) - # ) + #return AffineCouplingFunnel(n_latent, _decoder_fn(), _bijector_conditioner(n_dimension)) def _base_fn(): base_distribution = distrax.Independent( @@ -216,7 +212,7 @@ def _flow(method, **kwargs): def train(key, surjector_fn, n_data, n_latent, batch_size, n_iter): rng_seq = hk.PRNGSequence(0) - pyz, loadings = _get_sampler_and_loadings(next(rng_seq), 2*batch_size, n_data) + pyz, loadings = _get_sampler_and_loadings(next(rng_seq), batch_size, n_data) flow = surjector_fn(n_data, n_latent) @jax.jit @@ -245,7 +241,8 @@ def loss_fn(params): losses = [0] * n_iter for i in range(n_iter): y_batch, _, noise_batch = pyz(next(rng_seq)) - loss, params, state = step(params, state, y_batch, noise_batch, next(rng_seq)) + loss, params, state = step(params, state, y_batch, noise_batch, + next(rng_seq)) losses[i] = loss losses = jnp.asarray(losses) @@ -254,7 +251,7 @@ def loss_fn(params): y_batch, z_batch, noise_batch = pyz(next(rng_seq)) y_pred = flow.apply( - params, next(rng_seq), method="sample", x=noise_batch, + params, next(rng_seq), method="sample", x=noise_batch ) print(y_batch[:5, :]) print(y_pred[:5, :]) @@ -266,10 +263,11 @@ def run(): surjector_fn=_get_funnel_surjector, n_iter=2000, batch_size=64, - n_data=5, - n_latent=10 + n_data=10, + n_latent=5 ) if __name__ == "__main__": run() + diff --git a/surjectors/conditioners/mlp_conditioner.py b/surjectors/conditioners/mlp_conditioner.py new file mode 100644 index 0000000..553dc65 --- /dev/null +++ b/surjectors/conditioners/mlp_conditioner.py @@ -0,0 +1,19 @@ +import haiku as hk +import jax +from jax import numpy as jnp + + +def mlp_conditioner( + dims, + activation=jax.nn.gelu, + w_init=hk.initializers.TruncatedNormal(stddev=0.01), + b_init=jnp.zeros +): + dims[-1] = dims[-1] * 2 + + return hk.nets.MLP( + output_sizes=dims, + w_init=w_init, + b_init=b_init, + activation=activation + ) From 283937777ceedd856bc984e6df4b74ee985f700e Mon Sep 17 00:00:00 2001 From: Simon Dirmeier Date: Wed, 16 Nov 2022 21:29:58 +0100 Subject: [PATCH 08/10] Add surjectors with transformers --- ...ivariate_gaussian_generative_surjection.py | 4 +- ...tivariate_gaussian_inference_surjection.py | 11 +- experiments/solar_dynamo_data.py | 15 +- .../solar_dynamo_generative_surjection.py | 213 ++++++------------ .../solar_dynamo_inference_surjection.py | 170 ++++---------- surjectors/bijectors/masked_coupling.py | 3 +- .../{mlp_conditioner.py => mlp.py} | 0 surjectors/conditioners/transformer.py | 73 ++++++ 8 files changed, 211 insertions(+), 278 deletions(-) rename surjectors/conditioners/{mlp_conditioner.py => mlp.py} (100%) diff --git a/experiments/multivariate_gaussian_generative_surjection.py b/experiments/multivariate_gaussian_generative_surjection.py index 48ec957..bdab5fd 100644 --- a/experiments/multivariate_gaussian_generative_surjection.py +++ b/experiments/multivariate_gaussian_generative_surjection.py @@ -9,7 +9,7 @@ from jax import random from surjectors.bijectors.masked_coupling import MaskedCoupling -from surjectors.conditioners.mlp_conditioner import mlp_conditioner +from surjectors.conditioners.mlp import mlp_conditioner from surjectors.distributions.transformed_distribution import ( TransformedDistribution, ) @@ -163,7 +163,7 @@ def _flow(method, **kwargs): def train(key, surjector_fn, n_data, n_latent, batch_size, n_iter): rng_seq = hk.PRNGSequence(0) - pyz, loadings = _get_sampler_and_loadings(next(rng_seq), 2*batch_size, n_data) + pyz, loadings = _get_sampler_and_loadings(next(rng_seq), batch_size, n_data) flow = surjector_fn(n_data, n_latent) @jax.jit diff --git a/experiments/multivariate_gaussian_inference_surjection.py b/experiments/multivariate_gaussian_inference_surjection.py index 1352637..25e4a04 100644 --- a/experiments/multivariate_gaussian_inference_surjection.py +++ b/experiments/multivariate_gaussian_inference_surjection.py @@ -9,7 +9,7 @@ from jax import random from surjectors.bijectors.masked_coupling import MaskedCoupling -from surjectors.conditioners.mlp_conditioner import mlp_conditioner +from surjectors.conditioners.mlp import mlp_conditioner from surjectors.distributions.transformed_distribution import ( TransformedDistribution, ) @@ -128,7 +128,11 @@ def _transformation_fn(): layers.append(layer) layers.append( - AffineMaskedCouplingInferenceFunnel(n_latent, _decoder_fn(n_dimension, n_latent), mlp_conditioner([32, 32, n_dimension])) + AffineMaskedCouplingInferenceFunnel( + n_latent, + _decoder_fn(n_dimension, n_latent), + mlp_conditioner([32, 32, n_dimension]) + ) ) mask = jnp.arange(0, np.prod(n_latent)) % 2 @@ -185,8 +189,7 @@ def loss_fn(params): losses = [0] * n_iter for i in range(n_iter): y_batch, _, noise_batch = pyz(next(rng_seq)) - loss, params, state = step(params, state, y_batch, noise_batch, - next(rng_seq)) + loss, params, state = step(params, state, y_batch, noise_batch, next(rng_seq)) losses[i] = loss losses = jnp.asarray(losses) diff --git a/experiments/solar_dynamo_data.py b/experiments/solar_dynamo_data.py index b857b68..2db2088 100644 --- a/experiments/solar_dynamo_data.py +++ b/experiments/solar_dynamo_data.py @@ -2,7 +2,12 @@ from jax.scipy.special import erf import distrax + class SolarDynamoSimulator: + """Implements Eqn 2 and 3 of + FLUCTUATIONS IN BABCOCK-LEIGHTON DYNAMOS. II. REVISITING THE GNEVYSHEV-OHL RULE + https://iopscience.iop.org/article/10.1086/511177/pdf + """ def __init__(self, **kwargs): self.p0_mean = kwargs.get("p0_mean", 1.0) self.p0_std = kwargs.get("p0_std", 1.0) @@ -25,11 +30,11 @@ def sample(self, key, batch_size, len_timeseries=1000): epsilon_max = random.uniform( epsilon_key, shape=(batch_size,), minval=0, maxval=self.epsilon_max ) - batch = self._sample_timeseries( + f, y, alpha, noise = self._sample_timeseries( key, batch_size, p0, alpha1, alpha2, epsilon_max, len_timeseries ) - return p0, alpha1, alpha2, epsilon_max, batch[0].T, batch[1].T + return (p0, alpha1, alpha2, epsilon_max, f), y, alpha, noise @staticmethod def babcock_leighton_fn(p, b_1=0.6, w_1=0.2, b_2=1.0, w_2=0.8): @@ -46,7 +51,7 @@ def _sample_timeseries( a = distrax.Uniform(alpha_min, alpha_max).sample( seed=key, sample_shape=(len_timeseries,) ) - e = distrax.Uniform(0.0, epsilon_max).sample( + noise = distrax.Uniform(0.0, epsilon_max).sample( seed=key, sample_shape=(len_timeseries,) ) @@ -57,5 +62,5 @@ def _fn(fs, arrays): pn = self.babcock_leighton(pn, alpha, epsilon) return (f, pn), (f, pn) - _, pn = lax.scan(_fn, (pn, pn), (a, e)) - return pn + _, pn = lax.scan(_fn, (pn, pn), (a, noise)) + return pn[0].T, pn[1].T, a.T, noise.T diff --git a/experiments/solar_dynamo_generative_surjection.py b/experiments/solar_dynamo_generative_surjection.py index cefa9b8..8aa162b 100644 --- a/experiments/solar_dynamo_generative_surjection.py +++ b/experiments/solar_dynamo_generative_surjection.py @@ -10,99 +10,88 @@ from experiments.solar_dynamo_data import SolarDynamoSimulator from surjectors.bijectors.masked_coupling import MaskedCoupling +from surjectors.conditioners.mlp import mlp_conditioner +from surjectors.conditioners.transformer import transformer_conditioner from surjectors.distributions.transformed_distribution import ( TransformedDistribution, ) -from surjectors.surjectors.affine_masked_coupling_inference_funnel \ - import AffineMaskedCouplingInferenceFunnel +from surjectors.surjectors.affine_masked_coupling_generative_funnel import \ + AffineMaskedCouplingGenerativeFunnel +from surjectors.surjectors.augment import Augment from surjectors.surjectors.chain import Chain -from surjectors.surjectors.slice import Slice config.update("jax_enable_x64", True) -def _get_sampler_and_loadings(rng_key, batch_size, n_dimension): +def _get_sampler(): simulator = SolarDynamoSimulator() - p0, alpha1, alpha2, epsilon_max, f, pn = simulator.sample( - jnp.array([549229066, 500358972], dtype=jnp.uint32), 64 - ) return simulator.sample -def _get_slice_surjector(n_dimension, n_latent): - def _conditioner(dim): - return hk.Sequential( - [ - hk.Linear( - 32, - w_init=hk.initializers.TruncatedNormal(stddev=0.01), - b_init=jnp.zeros, - ), - jax.nn.gelu, - hk.Linear( - 32, - w_init=hk.initializers.TruncatedNormal(stddev=0.01), - b_init=jnp.zeros, - ), - jax.nn.gelu, - hk.Linear(dim * 2), - ] - ) +def _encoder_fn(n_latent, n_dimension): + decoder_net = mlp_conditioner([32, 32, n_latent - n_dimension]) + + def _fn(z): + params = decoder_net(z) + mu, log_scale = jnp.split(params, 2, -1) + return distrax.Independent(distrax.Normal(mu, jnp.exp(log_scale))) + + return _fn + - def _decoder_fn(): - decoder_net = _conditioner() +def _bijector_fn(params): + means, log_scales = jnp.split(params, 2, -1) + return distrax.ScalarAffine(means, jnp.exp(log_scales)) - def _fn(z): - params = decoder_net(z) - mu, log_scale = jnp.split(params, 2, -1) - return distrax.Independent(distrax.Normal(mu, jnp.exp(log_scale))) - return _fn +def _base_distribution_fn(n_latent): + base_distribution = distrax.Independent( + distrax.Normal(jnp.zeros(n_latent), jnp.ones(n_latent)), + reinterpreted_batch_ndims=1, + ) + return base_distribution - def _bijector_fn(params): - means, log_scales = jnp.split(params, 2, -1) - return distrax.ScalarAffine(means, jnp.exp(log_scales)) +def _get_slice_surjector(n_dimension, n_latent): def _transformation_fn(): layers = [] + mask = jnp.arange(0, np.prod(n_dimension)) % 2 mask = jnp.reshape(mask, n_dimension) mask = mask.astype(bool) - for _ in range(2): layer = MaskedCoupling( mask=mask, bijector=_bijector_fn, - conditioner=_bijector_conditioner(n_dimension), + conditioner=mlp_conditioner([32, 32, n_dimension]), ) layers.append(layer) - layers.append(Slice(n_latent, _decoder_fn())) - - mask = jnp.arange(0, np.prod(n_latent)) % 2 - mask = jnp.reshape(mask, n_latent) - mask = mask.astype(bool) + layers.append( + Augment(n_dimension, _encoder_fn(n_latent, n_dimension)) + ) - for _ in range(2): - layer = MaskedCoupling( - mask=mask, + mask = jnp.arange(n_latent) < n_latent - n_dimension + layers.append( + MaskedCoupling( + mask=mask.astype(jnp.bool_), bijector=_bijector_fn, - conditioner=_bijector_conditioner(n_latent), + conditioner=transformer_conditioner(n_latent) + ) + ) + mask = jnp.arange(n_latent) >= n_latent - n_dimension + layers.append( + MaskedCoupling( + mask=mask.astype(jnp.bool_), + bijector=_bijector_fn, + conditioner=mlp_conditioner([32, 32, n_latent]) ) - layers.append(layer) - mask = jnp.logical_not(mask) - # return Slice(n_latent, _decoder_fn()) - return Chain(layers) - - def _base_fn(): - base_distribution = distrax.Independent( - distrax.Normal(jnp.zeros(n_latent), jnp.ones(n_latent)), - reinterpreted_batch_ndims=1, ) - return base_distribution + #return Augment(n_dimension, _encoder_fn()) + return Chain(layers) def _flow(method, **kwargs): - td = TransformedDistribution(_base_fn(), _transformation_fn()) + td = TransformedDistribution(_base_distribution_fn(n_latent), _transformation_fn()) return td(method, **kwargs) td = hk.transform(_flow) @@ -110,100 +99,47 @@ def _flow(method, **kwargs): def _get_funnel_surjector(n_dimension, n_latent): - def _bijector_conditioner(dim): - return hk.Sequential( - [ - hk.Linear( - 32, - w_init=hk.initializers.TruncatedNormal(stddev=0.01), - b_init=jnp.zeros, - ), - jax.nn.gelu, - hk.Linear( - 32, - w_init=hk.initializers.TruncatedNormal(stddev=0.01), - b_init=jnp.zeros, - ), - jax.nn.gelu, - hk.Linear(dim * 2), - ] - ) - - def _surjector_conditioner(): - return hk.Sequential( - [ - hk.Linear( - 16, - w_init=hk.initializers.TruncatedNormal(stddev=0.01), - b_init=jnp.zeros, - ), - jax.nn.gelu, - hk.Linear( - 16, - w_init=hk.initializers.TruncatedNormal(stddev=0.01), - b_init=jnp.zeros, - ), - jax.nn.gelu, - hk.Linear((n_dimension - n_latent) * 2), - ] - ) - - def _decoder_fn(): - decoder_net = _surjector_conditioner() - - def _fn(z): - params = decoder_net(z) - mu, log_scale = jnp.split(params, 2, -1) - return distrax.Independent(distrax.Normal(mu, jnp.exp(log_scale))) - - return _fn - - def _bijector_fn(params): - means, log_scales = jnp.split(params, 2, -1) - return distrax.ScalarAffine(means, jnp.exp(log_scales)) - def _transformation_fn(): layers = [] + mask = jnp.arange(0, np.prod(n_dimension)) % 2 mask = jnp.reshape(mask, n_dimension) mask = mask.astype(bool) - for _ in range(2): layer = MaskedCoupling( mask=mask, bijector=_bijector_fn, - conditioner=_bijector_conditioner(n_dimension), + conditioner=mlp_conditioner(n_dimension), ) layers.append(layer) layers.append( - AffineMaskedCouplingInferenceFunnel(n_latent, _decoder_fn(), _bijector_conditioner(n_dimension)) + AffineMaskedCouplingGenerativeFunnel( + n_dimension, _encoder_fn(n_latent, n_dimension), mlp_conditioner(n_latent) + ) ) mask = jnp.arange(0, np.prod(n_latent)) % 2 mask = jnp.reshape(mask, n_latent) mask = mask.astype(bool) - for _ in range(2): layer = MaskedCoupling( mask=mask, bijector=_bijector_fn, - conditioner=_bijector_conditioner(n_latent), + conditioner=transformer_conditioner(n_latent), ) layers.append(layer) mask = jnp.logical_not(mask) - return Chain(layers) - #return AffineCouplingFunnel(n_latent, _decoder_fn(), _bijector_conditioner(n_dimension)) - def _base_fn(): - base_distribution = distrax.Independent( - distrax.Normal(jnp.zeros(n_latent), jnp.ones(n_latent)), - reinterpreted_batch_ndims=1, - ) - return base_distribution + return Chain(layers) + # return AffineMaskedCouplingGenerativeFunnel( + # n_dimension, _encoder_fn(), _conditioner(n_latent) + # ) def _flow(method, **kwargs): - td = TransformedDistribution(_base_fn(), _transformation_fn()) + td = TransformedDistribution( + _base_distribution_fn(n_latent), _transformation_fn() + ) return td(method, **kwargs) td = hk.transform(_flow) @@ -212,14 +148,14 @@ def _flow(method, **kwargs): def train(key, surjector_fn, n_data, n_latent, batch_size, n_iter): rng_seq = hk.PRNGSequence(0) - pyz, loadings = _get_sampler_and_loadings(next(rng_seq), batch_size, n_data) + sampler = _get_sampler() flow = surjector_fn(n_data, n_latent) @jax.jit def step(params, state, y_batch, noise_batch, rng): def loss_fn(params): lp = flow.apply( - params, rng, method="log_prob", y=y_batch, x=noise_batch + params, rng, method="log_prob", y=y_batch ) return -jnp.sum(lp) @@ -228,31 +164,25 @@ def loss_fn(params): new_params = optax.apply_updates(params, updates) return loss, new_params, new_state - y_init, _, noise_init = pyz(random.fold_in(next(rng_seq), 0)) + _, y_init, _, noise_init = sampler(next(rng_seq), batch_size, n_data) params = flow.init( - random.PRNGKey(key), - method="log_prob", - y=y_init, - x=noise_init + next(rng_seq), method="log_prob", y=y_init ) adam = optax.adamw(0.001) state = adam.init(params) losses = [0] * n_iter for i in range(n_iter): - y_batch, _, noise_batch = pyz(next(rng_seq)) - loss, params, state = step(params, state, y_batch, noise_batch, - next(rng_seq)) + _, y_batch, _, noise_batch = sampler(next(rng_seq), batch_size, n_data) + loss, params, state = step(params, state, y_batch, noise_batch, next(rng_seq)) losses[i] = loss losses = jnp.asarray(losses) plt.plot(losses) plt.show() - y_batch, z_batch, noise_batch = pyz(next(rng_seq)) - y_pred = flow.apply( - params, next(rng_seq), method="sample", x=noise_batch - ) + _, y_batch, _, noise_batch = sampler(next(rng_seq), batch_size, n_data) + y_pred = flow.apply(params, next(rng_seq), method="sample") print(y_batch[:5, :]) print(y_pred[:5, :]) @@ -260,14 +190,13 @@ def loss_fn(params): def run(): train( key=0, - surjector_fn=_get_funnel_surjector, + surjector_fn=_get_slice_surjector, n_iter=2000, batch_size=64, - n_data=10, - n_latent=5 + n_data=100, + n_latent=110 ) if __name__ == "__main__": run() - diff --git a/experiments/solar_dynamo_inference_surjection.py b/experiments/solar_dynamo_inference_surjection.py index cefa9b8..72f3fa4 100644 --- a/experiments/solar_dynamo_inference_surjection.py +++ b/experiments/solar_dynamo_inference_surjection.py @@ -10,6 +10,7 @@ from experiments.solar_dynamo_data import SolarDynamoSimulator from surjectors.bijectors.masked_coupling import MaskedCoupling +from surjectors.conditioners.mlp import mlp_conditioner from surjectors.distributions.transformed_distribution import ( TransformedDistribution, ) @@ -21,48 +22,36 @@ config.update("jax_enable_x64", True) -def _get_sampler_and_loadings(rng_key, batch_size, n_dimension): +def _get_sampler(): simulator = SolarDynamoSimulator() - p0, alpha1, alpha2, epsilon_max, f, pn = simulator.sample( - jnp.array([549229066, 500358972], dtype=jnp.uint32), 64 - ) return simulator.sample -def _get_slice_surjector(n_dimension, n_latent): - def _conditioner(dim): - return hk.Sequential( - [ - hk.Linear( - 32, - w_init=hk.initializers.TruncatedNormal(stddev=0.01), - b_init=jnp.zeros, - ), - jax.nn.gelu, - hk.Linear( - 32, - w_init=hk.initializers.TruncatedNormal(stddev=0.01), - b_init=jnp.zeros, - ), - jax.nn.gelu, - hk.Linear(dim * 2), - ] - ) +def _decoder_fn(n_dimension, n_latent): + decoder_net = mlp_conditioner([32, 32, n_dimension - n_latent]) - def _decoder_fn(): - decoder_net = _conditioner() + def _fn(z): + params = decoder_net(z) + mu, log_scale = jnp.split(params, 2, -1) + return distrax.Independent(distrax.Normal(mu, jnp.exp(log_scale))) - def _fn(z): - params = decoder_net(z) - mu, log_scale = jnp.split(params, 2, -1) - return distrax.Independent(distrax.Normal(mu, jnp.exp(log_scale))) + return _fn - return _fn - def _bijector_fn(params): - means, log_scales = jnp.split(params, 2, -1) - return distrax.ScalarAffine(means, jnp.exp(log_scales)) +def _bijector_fn(params): + means, log_scales = jnp.split(params, 2, -1) + return distrax.ScalarAffine(means, jnp.exp(log_scales)) + + +def _base_distribution_fn(n_latent): + base_distribution = distrax.Independent( + distrax.Normal(jnp.zeros(n_latent), jnp.ones(n_latent)), + reinterpreted_batch_ndims=1, + ) + return base_distribution + +def _get_slice_surjector(n_dimension, n_latent): def _transformation_fn(): layers = [] mask = jnp.arange(0, np.prod(n_dimension)) % 2 @@ -73,11 +62,11 @@ def _transformation_fn(): layer = MaskedCoupling( mask=mask, bijector=_bijector_fn, - conditioner=_bijector_conditioner(n_dimension), + conditioner=mlp_conditioner([32, 32, n_dimension]), ) layers.append(layer) - layers.append(Slice(n_latent, _decoder_fn())) + layers.append(Slice(n_latent, _decoder_fn(n_dimension, n_latent))) mask = jnp.arange(0, np.prod(n_latent)) % 2 mask = jnp.reshape(mask, n_latent) @@ -87,22 +76,14 @@ def _transformation_fn(): layer = MaskedCoupling( mask=mask, bijector=_bijector_fn, - conditioner=_bijector_conditioner(n_latent), + conditioner=mlp_conditioner([32, 32, n_latent]), ) layers.append(layer) mask = jnp.logical_not(mask) - # return Slice(n_latent, _decoder_fn()) return Chain(layers) - def _base_fn(): - base_distribution = distrax.Independent( - distrax.Normal(jnp.zeros(n_latent), jnp.ones(n_latent)), - reinterpreted_batch_ndims=1, - ) - return base_distribution - def _flow(method, **kwargs): - td = TransformedDistribution(_base_fn(), _transformation_fn()) + td = TransformedDistribution(_base_distribution_fn(n_latent), _transformation_fn()) return td(method, **kwargs) td = hk.transform(_flow) @@ -110,58 +91,6 @@ def _flow(method, **kwargs): def _get_funnel_surjector(n_dimension, n_latent): - def _bijector_conditioner(dim): - return hk.Sequential( - [ - hk.Linear( - 32, - w_init=hk.initializers.TruncatedNormal(stddev=0.01), - b_init=jnp.zeros, - ), - jax.nn.gelu, - hk.Linear( - 32, - w_init=hk.initializers.TruncatedNormal(stddev=0.01), - b_init=jnp.zeros, - ), - jax.nn.gelu, - hk.Linear(dim * 2), - ] - ) - - def _surjector_conditioner(): - return hk.Sequential( - [ - hk.Linear( - 16, - w_init=hk.initializers.TruncatedNormal(stddev=0.01), - b_init=jnp.zeros, - ), - jax.nn.gelu, - hk.Linear( - 16, - w_init=hk.initializers.TruncatedNormal(stddev=0.01), - b_init=jnp.zeros, - ), - jax.nn.gelu, - hk.Linear((n_dimension - n_latent) * 2), - ] - ) - - def _decoder_fn(): - decoder_net = _surjector_conditioner() - - def _fn(z): - params = decoder_net(z) - mu, log_scale = jnp.split(params, 2, -1) - return distrax.Independent(distrax.Normal(mu, jnp.exp(log_scale))) - - return _fn - - def _bijector_fn(params): - means, log_scales = jnp.split(params, 2, -1) - return distrax.ScalarAffine(means, jnp.exp(log_scales)) - def _transformation_fn(): layers = [] mask = jnp.arange(0, np.prod(n_dimension)) % 2 @@ -172,12 +101,15 @@ def _transformation_fn(): layer = MaskedCoupling( mask=mask, bijector=_bijector_fn, - conditioner=_bijector_conditioner(n_dimension), + conditioner=mlp_conditioner([32, 32, n_dimension]), ) layers.append(layer) layers.append( - AffineMaskedCouplingInferenceFunnel(n_latent, _decoder_fn(), _bijector_conditioner(n_dimension)) + AffineMaskedCouplingInferenceFunnel( + n_latent, + _decoder_fn(n_dimension, n_latent), + mlp_conditioner([32, 32, n_dimension])) ) mask = jnp.arange(0, np.prod(n_latent)) % 2 @@ -188,22 +120,15 @@ def _transformation_fn(): layer = MaskedCoupling( mask=mask, bijector=_bijector_fn, - conditioner=_bijector_conditioner(n_latent), + conditioner=mlp_conditioner([32, 32, n_latent]), ) layers.append(layer) mask = jnp.logical_not(mask) return Chain(layers) - #return AffineCouplingFunnel(n_latent, _decoder_fn(), _bijector_conditioner(n_dimension)) - - def _base_fn(): - base_distribution = distrax.Independent( - distrax.Normal(jnp.zeros(n_latent), jnp.ones(n_latent)), - reinterpreted_batch_ndims=1, - ) - return base_distribution def _flow(method, **kwargs): - td = TransformedDistribution(_base_fn(), _transformation_fn()) + td = TransformedDistribution(_base_distribution_fn(n_latent), + _transformation_fn()) return td(method, **kwargs) td = hk.transform(_flow) @@ -211,8 +136,8 @@ def _flow(method, **kwargs): def train(key, surjector_fn, n_data, n_latent, batch_size, n_iter): - rng_seq = hk.PRNGSequence(0) - pyz, loadings = _get_sampler_and_loadings(next(rng_seq), batch_size, n_data) + rng_seq = hk.PRNGSequence(key) + sampler = _get_sampler() flow = surjector_fn(n_data, n_latent) @jax.jit @@ -228,28 +153,27 @@ def loss_fn(params): new_params = optax.apply_updates(params, updates) return loss, new_params, new_state - y_init, _, noise_init = pyz(random.fold_in(next(rng_seq), 0)) + _, y_init, _, noise_init = sampler(next(rng_seq), batch_size) + params = flow.init( - random.PRNGKey(key), - method="log_prob", - y=y_init, - x=noise_init + next(rng_seq), method="log_prob", y=y_init, x=noise_init ) adam = optax.adamw(0.001) state = adam.init(params) losses = [0] * n_iter for i in range(n_iter): - y_batch, _, noise_batch = pyz(next(rng_seq)) - loss, params, state = step(params, state, y_batch, noise_batch, - next(rng_seq)) + _, y_batch, _, noise_batch = sampler(next(rng_seq), batch_size) + loss, params, state = step( + params, state, y_batch, noise_batch, next(rng_seq) + ) losses[i] = loss losses = jnp.asarray(losses) plt.plot(losses) plt.show() - y_batch, z_batch, noise_batch = pyz(next(rng_seq)) + _, y_batch, _, noise_batch = sampler(next(rng_seq), batch_size) y_pred = flow.apply( params, next(rng_seq), method="sample", x=noise_batch ) @@ -260,11 +184,11 @@ def loss_fn(params): def run(): train( key=0, - surjector_fn=_get_funnel_surjector, + surjector_fn=_get_slice_surjector, n_iter=2000, batch_size=64, - n_data=10, - n_latent=5 + n_data=100, + n_latent=10 ) diff --git a/surjectors/bijectors/masked_coupling.py b/surjectors/bijectors/masked_coupling.py index 50d2705..962278b 100644 --- a/surjectors/bijectors/masked_coupling.py +++ b/surjectors/bijectors/masked_coupling.py @@ -10,8 +10,7 @@ class MaskedCoupling(distrax.MaskedCoupling): def __init__(self, mask: Array, conditioner, bijector, event_ndims: Optional[int] = None, inner_event_ndims: int = 0): - super().__init__(mask, conditioner, bijector, event_ndims, - inner_event_ndims) + super().__init__(mask, conditioner, bijector, event_ndims, inner_event_ndims) def forward_and_log_det(self, z: Array, x: Array = None) -> Tuple[Array, Array]: self._check_forward_input_shape(z) diff --git a/surjectors/conditioners/mlp_conditioner.py b/surjectors/conditioners/mlp.py similarity index 100% rename from surjectors/conditioners/mlp_conditioner.py rename to surjectors/conditioners/mlp.py diff --git a/surjectors/conditioners/transformer.py b/surjectors/conditioners/transformer.py index e69de29..af2cf37 100644 --- a/surjectors/conditioners/transformer.py +++ b/surjectors/conditioners/transformer.py @@ -0,0 +1,73 @@ +import dataclasses +from typing import Callable +from typing import Optional + +import haiku as hk +import jax +import numpy as np + + +@dataclasses.dataclass +class _EncoderLayer(hk.Module): + num_heads: int + num_layers: int + key_size: int + dropout_rate: float + widening_factor: int = 4 + initializer: Callable = hk.initializers.TruncatedNormal(stddev=0.01) + name: Optional[str] = None + + def __call__(self, inputs, *, is_training): + dropout_rate = self.dropout_rate if is_training else 0. + #causal_mask = np.tril(np.ones((1, 1, seq_len, seq_len))) + causal_mask=None + model_size=self.key_size * self.num_heads + + h = inputs + for _ in range(self.num_layers): + attn_block = hk.MultiHeadAttention( + num_heads=self.num_heads, + key_size=self.key_size, + model_size=model_size, + w_init=self.initializer, + ) + h_norm = hk.LayerNorm(axis=-1, create_scale=True, create_offset=True)(h) + h_attn = attn_block(h_norm, h_norm, h_norm, mask=causal_mask) + h_attn = hk.dropout(hk.next_rng_key(), dropout_rate, h_attn) + h = h + h_attn + + mlp = hk.nets.MLP( + [self.widening_factor * model_size, model_size], + w_init=self.initializer, + activation=jax.nn.gelu + ) + h_norm = hk.LayerNorm(axis=-1, create_scale=True, create_offset=True)(h) + h_dense = mlp(h_norm) + h_dense = hk.dropout(hk.next_rng_key(), dropout_rate, h_dense) + h = h + h_dense + + return hk.LayerNorm(axis=-1, create_scale=True, create_offset=True)(h) + + +@dataclasses.dataclass +class _AutoregressiveTransformerEncoder(hk.Module): + linear: hk.Linear + transformer: _EncoderLayer + output_size: int + name: Optional[str] = None + + def __call__(self, inputs, *, is_training=True): + h = self.linear(inputs) + h = self.transformer(h, is_training=is_training) + return hk.Linear(self.output_size)(h) + + +def transformer_conditioner( + output_size, num_heads=2, num_layers=2, key_size=32, dropout_rate=0.1, widening_factor=4 +): + linear = hk.Linear(key_size * num_heads) + encoder = _EncoderLayer( + num_heads, num_layers, key_size, dropout_rate, widening_factor + ) + transformer = _AutoregressiveTransformerEncoder(linear, encoder, output_size * 2) + return transformer From f29038f19e812337b8065b7606497a670038029b Mon Sep 17 00:00:00 2001 From: Simon Dirmeier Date: Thu, 17 Nov 2022 17:23:03 +0100 Subject: [PATCH 09/10] Minor adds --- ...tivariate_gaussian_inference_surjection.py | 96 +++++++++++++------ .../solar_dynamo_generative_surjection.py | 33 ++++--- .../solar_dynamo_inference_surjection.py | 86 ++++++++++++----- surjectors/conditioners/transformer.py | 1 - surjectors/surjectors/slice.py | 2 +- 5 files changed, 150 insertions(+), 68 deletions(-) diff --git a/experiments/multivariate_gaussian_inference_surjection.py b/experiments/multivariate_gaussian_inference_surjection.py index 25e4a04..8bfead4 100644 --- a/experiments/multivariate_gaussian_inference_surjection.py +++ b/experiments/multivariate_gaussian_inference_surjection.py @@ -21,8 +21,11 @@ config.update("jax_enable_x64", True) -def _get_sampler_and_loadings(rng_key, batch_size, n_dimension): - pz_mean = jnp.array([-2.31, 0.421, 0.1, 3.21, -0.41]) +def _get_sampler(rng_key, batch_size, n_dimension, n_latent): + means_sample_key, rng_key = random.split(rng_key, 2) + pz_mean = distrax.Normal(0.0, 10.0).sample( + seed=means_sample_key, sample_shape=(n_latent) + ) pz = distrax.MultivariateNormalDiag( loc=pz_mean, scale_diag=jnp.ones_like(pz_mean) ) @@ -87,12 +90,13 @@ def _transformation_fn(): ) layers.append(layer) - layers.append(Slice(n_latent, _decoder_fn(n_dimension, n_latent))) + layers.append( + Slice(n_latent, _decoder_fn(n_dimension, n_latent)) + ) mask = jnp.arange(0, np.prod(n_latent)) % 2 mask = jnp.reshape(mask, n_latent) mask = mask.astype(bool) - for _ in range(2): layer = MaskedCoupling( mask=mask, @@ -101,7 +105,7 @@ def _transformation_fn(): ) layers.append(layer) mask = jnp.logical_not(mask) - # return Slice(n_latent, _decoder_fn()) + return Chain(layers) def _flow(method, **kwargs): @@ -126,6 +130,7 @@ def _transformation_fn(): conditioner=mlp_conditioner([32, 32, n_dimension]), ) layers.append(layer) + mask = jnp.logical_not(mask) layers.append( AffineMaskedCouplingInferenceFunnel( @@ -138,7 +143,6 @@ def _transformation_fn(): mask = jnp.arange(0, np.prod(n_latent)) % 2 mask = jnp.reshape(mask, n_latent) mask = mask.astype(bool) - for _ in range(2): layer = MaskedCoupling( mask=mask, @@ -147,7 +151,7 @@ def _transformation_fn(): ) layers.append(layer) mask = jnp.logical_not(mask) - #return AffineCouplingFunnel(n_latent, _decoder_fn(), _bijector_conditioner(n_dimension)) + return Chain(layers) def _flow(method, **kwargs): @@ -158,9 +162,33 @@ def _flow(method, **kwargs): return td -def train(key, surjector_fn, n_data, n_latent, batch_size, n_iter): - rng_seq = hk.PRNGSequence(0) - pyz, loadings = _get_sampler_and_loadings(next(rng_seq), batch_size, n_data) +def _get_bijector(n_dimension, n_latent): + def _transformation_fn(): + layers = [] + mask = jnp.arange(0, np.prod(n_dimension)) % 2 + mask = jnp.reshape(mask, n_dimension) + mask = mask.astype(bool) + + for _ in range(4): + layer = MaskedCoupling( + mask=mask, + bijector=_bijector_fn, + conditioner=mlp_conditioner([32, 32, n_dimension]), + ) + layers.append(layer) + mask = jnp.logical_not(mask) + + return Chain(layers) + + def _flow(method, **kwargs): + td = TransformedDistribution(_base_distribution_fn(n_dimension), _transformation_fn()) + return td(method, **kwargs) + + td = hk.transform(_flow) + return td + + +def train(rng_seq, sampler, surjector_fn, n_data, n_latent, batch_size, n_iter): flow = surjector_fn(n_data, n_latent) @jax.jit @@ -176,9 +204,9 @@ def loss_fn(params): new_params = optax.apply_updates(params, updates) return loss, new_params, new_state - y_init, _, noise_init = pyz(random.fold_in(next(rng_seq), 0)) + y_init, _, noise_init = sampler(next(rng_seq)) params = flow.init( - random.PRNGKey(key), + next(rng_seq), method="log_prob", y=y_init, x=noise_init @@ -188,7 +216,7 @@ def loss_fn(params): losses = [0] * n_iter for i in range(n_iter): - y_batch, _, noise_batch = pyz(next(rng_seq)) + y_batch, _, noise_batch = sampler(next(rng_seq)) loss, params, state = step(params, state, y_batch, noise_batch, next(rng_seq)) losses[i] = loss @@ -196,23 +224,37 @@ def loss_fn(params): plt.plot(losses) plt.show() - y_batch, z_batch, noise_batch = pyz(next(rng_seq)) - y_pred = flow.apply( - params, next(rng_seq), method="sample", x=noise_batch - ) - print(y_batch[:5, :]) - print(y_pred[:5, :]) + return flow, params + + +def evaluate(rng_seq, params, model, sampler, batch_size, n_data): + y_batch, _, noise_batch = sampler(next(rng_seq)) + lp = model.apply(params, next(rng_seq), method="log_prob", y=y_batch, x=noise_batch) + print("\tPPLP: {:.3f}".format(jnp.mean(lp))) def run(): - train( - key=0, - surjector_fn=_get_funnel_surjector, - n_iter=2000, - batch_size=64, - n_data=10, - n_latent=5 - ) + n_iter = 2000 + batch_size = 64 + n_data, n_latent = 100, 75 + sampler, _ = _get_sampler(random.PRNGKey(0), batch_size, n_data, n_latent) + for method, _fn in [ + ["Slice", _get_slice_surjector], + ["Funnel", _get_funnel_surjector], + ["Bijector", _get_bijector] + ]: + print(f"Doing {method}") + rng_seq = hk.PRNGSequence(0) + model, params = train( + rng_seq=rng_seq, + sampler=sampler, + surjector_fn=_fn, + n_iter=n_iter, + batch_size=batch_size, + n_data=n_data, + n_latent=n_latent + ) + evaluate(rng_seq, params, model, sampler, batch_size, n_data) if __name__ == "__main__": diff --git a/experiments/solar_dynamo_generative_surjection.py b/experiments/solar_dynamo_generative_surjection.py index 8aa162b..24144cd 100644 --- a/experiments/solar_dynamo_generative_surjection.py +++ b/experiments/solar_dynamo_generative_surjection.py @@ -146,9 +146,7 @@ def _flow(method, **kwargs): return td -def train(key, surjector_fn, n_data, n_latent, batch_size, n_iter): - rng_seq = hk.PRNGSequence(0) - sampler = _get_sampler() +def train(rng_seq, sampler, surjector_fn, n_data, n_latent, batch_size, n_iter): flow = surjector_fn(n_data, n_latent) @jax.jit @@ -180,22 +178,29 @@ def loss_fn(params): losses = jnp.asarray(losses) plt.plot(losses) plt.show() + return flow, params + +def evaluate(rng_seq, params, model, sampler, batch_size, n_data): _, y_batch, _, noise_batch = sampler(next(rng_seq), batch_size, n_data) - y_pred = flow.apply(params, next(rng_seq), method="sample") - print(y_batch[:5, :]) - print(y_pred[:5, :]) + lp = model.apply(params, next(rng_seq), method="log_prob", y=y_batch) + print("PPLP: {:.3f}".format(lp / batch_size)) def run(): - train( - key=0, - surjector_fn=_get_slice_surjector, - n_iter=2000, - batch_size=64, - n_data=100, - n_latent=110 - ) + sampler = _get_sampler() + for _fn in [_get_slice_surjector, _get_funnel_surjector]: + rng_seq = hk.PRNGSequence(0) + model, params = train( + rng_seq=rng_seq, + sampler=sampler, + surjector_fn=_fn, + n_iter=2000, + batch_size=64, + n_data=100, + n_latent=110 + ) + evaluate(rng_seq, params, model, sampler, 64, 100) if __name__ == "__main__": diff --git a/experiments/solar_dynamo_inference_surjection.py b/experiments/solar_dynamo_inference_surjection.py index 72f3fa4..83aad0b 100644 --- a/experiments/solar_dynamo_inference_surjection.py +++ b/experiments/solar_dynamo_inference_surjection.py @@ -54,10 +54,10 @@ def _base_distribution_fn(n_latent): def _get_slice_surjector(n_dimension, n_latent): def _transformation_fn(): layers = [] + mask = jnp.arange(0, np.prod(n_dimension)) % 2 mask = jnp.reshape(mask, n_dimension) mask = mask.astype(bool) - for _ in range(2): layer = MaskedCoupling( mask=mask, @@ -66,12 +66,13 @@ def _transformation_fn(): ) layers.append(layer) - layers.append(Slice(n_latent, _decoder_fn(n_dimension, n_latent))) + layers.append( + Slice(n_latent, _decoder_fn(n_dimension, n_latent)) + ) mask = jnp.arange(0, np.prod(n_latent)) % 2 mask = jnp.reshape(mask, n_latent) mask = mask.astype(bool) - for _ in range(2): layer = MaskedCoupling( mask=mask, @@ -93,10 +94,10 @@ def _flow(method, **kwargs): def _get_funnel_surjector(n_dimension, n_latent): def _transformation_fn(): layers = [] + mask = jnp.arange(0, np.prod(n_dimension)) % 2 mask = jnp.reshape(mask, n_dimension) mask = mask.astype(bool) - for _ in range(2): layer = MaskedCoupling( mask=mask, @@ -115,7 +116,6 @@ def _transformation_fn(): mask = jnp.arange(0, np.prod(n_latent)) % 2 mask = jnp.reshape(mask, n_latent) mask = mask.astype(bool) - for _ in range(2): layer = MaskedCoupling( mask=mask, @@ -135,9 +135,33 @@ def _flow(method, **kwargs): return td -def train(key, surjector_fn, n_data, n_latent, batch_size, n_iter): - rng_seq = hk.PRNGSequence(key) - sampler = _get_sampler() +def _get_bijector(n_dimension, n_latent): + def _transformation_fn(): + layers = [] + mask = jnp.arange(0, np.prod(n_dimension)) % 2 + mask = jnp.reshape(mask, n_dimension) + mask = mask.astype(bool) + for _ in range(4): + layer = MaskedCoupling( + mask=mask, + bijector=_bijector_fn, + conditioner=mlp_conditioner([32, 32, n_dimension]), + ) + layers.append(layer) + return Chain(layers) + + def _flow(method, **kwargs): + td = TransformedDistribution( + _base_distribution_fn(n_dimension), + _transformation_fn() + ) + return td(method, **kwargs) + + td = hk.transform(_flow) + return td + + +def train(rng_seq, sampler, surjector_fn, n_data, n_latent, batch_size, n_iter): flow = surjector_fn(n_data, n_latent) @jax.jit @@ -153,8 +177,7 @@ def loss_fn(params): new_params = optax.apply_updates(params, updates) return loss, new_params, new_state - _, y_init, _, noise_init = sampler(next(rng_seq), batch_size) - + _, y_init, _, noise_init = sampler(next(rng_seq), batch_size, n_data) params = flow.init( next(rng_seq), method="log_prob", y=y_init, x=noise_init ) @@ -163,7 +186,7 @@ def loss_fn(params): losses = [0] * n_iter for i in range(n_iter): - _, y_batch, _, noise_batch = sampler(next(rng_seq), batch_size) + _, y_batch, _, noise_batch = sampler(next(rng_seq), batch_size, n_data) loss, params, state = step( params, state, y_batch, noise_batch, next(rng_seq) ) @@ -172,24 +195,37 @@ def loss_fn(params): losses = jnp.asarray(losses) plt.plot(losses) plt.show() + return flow, params - _, y_batch, _, noise_batch = sampler(next(rng_seq), batch_size) - y_pred = flow.apply( - params, next(rng_seq), method="sample", x=noise_batch - ) - print(y_batch[:5, :]) - print(y_pred[:5, :]) + +def evaluate(rng_seq, params, model, sampler, batch_size, n_data): + _, y_batch, _, noise_batch = sampler(next(rng_seq), batch_size, n_data) + lp = model.apply(params, next(rng_seq), method="log_prob", y=y_batch, x=noise_batch) + print("\tPPLP: {:.3f}".format(jnp.mean(lp))) def run(): - train( - key=0, - surjector_fn=_get_slice_surjector, - n_iter=2000, - batch_size=64, - n_data=100, - n_latent=10 - ) + n_iter = 2000 + batch_size = 64 + n_data, n_latent = 100, 75 + sampler = _get_sampler() + for method, _fn in [ + ["Slice", _get_slice_surjector], + ["Funnel", _get_funnel_surjector], + ["Bijector", _get_bijector] + ]: + print(f"Doing {method}") + rng_seq = hk.PRNGSequence(0) + model, params = train( + rng_seq=rng_seq, + sampler=sampler, + surjector_fn=_fn, + n_iter=n_iter, + batch_size=batch_size, + n_data=n_data, + n_latent=n_latent + ) + evaluate(rng_seq, params, model, sampler, batch_size, n_data) if __name__ == "__main__": diff --git a/surjectors/conditioners/transformer.py b/surjectors/conditioners/transformer.py index af2cf37..c1e79a3 100644 --- a/surjectors/conditioners/transformer.py +++ b/surjectors/conditioners/transformer.py @@ -4,7 +4,6 @@ import haiku as hk import jax -import numpy as np @dataclasses.dataclass diff --git a/surjectors/surjectors/slice.py b/surjectors/surjectors/slice.py index 7b6b899..455d4ad 100644 --- a/surjectors/surjectors/slice.py +++ b/surjectors/surjectors/slice.py @@ -7,7 +7,7 @@ class Slice(Funnel): def __init__(self, n_keep, decoder): - super().__init__(n_keep, decoder, None, None, "inference_surkector") + super().__init__(n_keep, decoder, None, None, "inference_surjector") def split_input(self, input): spl = jnp.split(input, [self.n_keep], axis=-1) From a76b22c1e799fbe4185eefb03935abac6da32b20 Mon Sep 17 00:00:00 2001 From: Simon Dirmeier Date: Mon, 28 Nov 2022 09:19:13 +0100 Subject: [PATCH 10/10] Some experiments --- experiments/solar_dynamo_generative_surjection.py | 4 +++- surjectors/surjectors/slice.py | 1 - 2 files changed, 3 insertions(+), 2 deletions(-) diff --git a/experiments/solar_dynamo_generative_surjection.py b/experiments/solar_dynamo_generative_surjection.py index 24144cd..61e917b 100644 --- a/experiments/solar_dynamo_generative_surjection.py +++ b/experiments/solar_dynamo_generative_surjection.py @@ -115,7 +115,9 @@ def _transformation_fn(): layers.append( AffineMaskedCouplingGenerativeFunnel( - n_dimension, _encoder_fn(n_latent, n_dimension), mlp_conditioner(n_latent) + n_dimension, + _encoder_fn(n_latent, n_dimension), + mlp_conditioner(n_latent) ) ) diff --git a/surjectors/surjectors/slice.py b/surjectors/surjectors/slice.py index 455d4ad..8e74675 100644 --- a/surjectors/surjectors/slice.py +++ b/surjectors/surjectors/slice.py @@ -35,4 +35,3 @@ def forward_and_likelihood_contribution(self, z, x=None): def forward(self, z, x=None): y, _ = self.forward_and_likelihood_contribution(z, x) return y -