Skip to content

Commit

Permalink
Add superclass for several bijections
Browse files Browse the repository at this point in the history
  • Loading branch information
dirmeier committed Oct 13, 2023
1 parent 8f61c0e commit 66ee284
Show file tree
Hide file tree
Showing 31 changed files with 758 additions and 391 deletions.
47 changes: 19 additions & 28 deletions docs/index.rst
Original file line number Diff line number Diff line change
Expand Up @@ -17,34 +17,25 @@ Surjectors makes use of
Example usage
-------------

You can, for instance, construct a simple neural process like this:

.. code-block:: python
from jax import random as jr
from ramsey import NP, MLP
from ramsey.data import sample_from_sine_function
def get_neural_process():
dim = 128
np = NP(
decoder=MLP([dim] * 3 + [2]),
latent_encoder=(
MLP([dim] * 3), MLP([dim, dim * 2])
)
)
return np
key = jr.PRNGKey(23)
data = sample_from_sine_function(key)
neural_process = get_neural_process()
params = neural_process.init(key, x_context=data.x, y_context=data.y, x_target=data.x)
The neural process takes a decoder and a set of two latent encoders as arguments. All of these are typically MLPs, but
Ramsey is flexible enough that you can change them, for instance, to CNNs or RNNs. Once the model is defined, you can initialize
its parameters just like in Flax.
You can, for instance, construct a simple normalizing flow like this:

>>> import distrax
>>> from jax import random as jr, numpy as jnp
>>> from surjectors import Slice, LULinear, Chain
>>> from surjectors import TransformedDistribution
>>>
>>> def decoder_fn(n_dim):
>>> def _fn(z):
>>> params = make_mlp([4, 4, n_dim * 2])(z)
>>> mu, log_scale = jnp.split(params, 2, -1)
>>> return distrax.Independent(distrax.Normal(mu, jnp.exp(log_scale)))
>>> return _fn
>>>
>>> base_distribution = distrax.Normal(jno.zeros(5), jnp.ones(1))
>>> flow = Chain([Slice(10, decoder_fn(10)), LULinear(5)])
>>> pushforward = TransformedDistribution(base_distribution, flow)

The flow is constructed using three objects: a base distribution, a transformation, and a transformed distribution.

Installation
------------
Expand Down
92 changes: 65 additions & 27 deletions docs/surjectors.rst
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,30 @@

.. currentmodule:: surjectors

Normalizing flows have from a computational perspective three components:

- A base distribution for which we use the probability distributions from `Distrax <https://github.com/google-deepmind/distrax>`_.
- A forward transformation $f$ whose Jacobian determinant can be evaluated efficiently. These are the bijectors and surjectors below.
- A transformed distribution that represents the pushforward from a base distribution to the distribution induced by the transformation.

Hence, every normalizing flow can be composed by defining these three components. See an example below.

>>> import distrax
>>> from jax import random as jr, numpy as jnp
>>> from surjectors import Slice, LULinear, Chain
>>> from surjectors import TransformedDistribution
>>>
>>> def decoder_fn(n_dim):
>>> def _fn(z):
>>> params = make_mlp([4, 4, n_dim * 2])(z)
>>> mu, log_scale = jnp.split(params, 2, -1)
>>> return distrax.Independent(distrax.Normal(mu, jnp.exp(log_scale)))
>>> return _fn
>>>
>>> base_distribution = distrax.Normal(jno.zeros(5), jnp.ones(1))
>>> flow = Chain([Slice(10, decoder_fn(10)), LULinear(5)])
>>> pushforward = TransformedDistribution(base_distribution, flow)


General
-------
Expand All @@ -15,7 +39,7 @@ TransformedDistribution
~~~~~~~~~~~~~~~~~~~~~~~

.. autoclass:: TransformedDistribution
:members: __init__
:members:

Chain
~~~~~
Expand All @@ -31,60 +55,74 @@ Bijective layers
MaskedCoupling
Permutation

MaskedAutoregressive
~~~~~~~~~~~~~~~~~~~~
Autoregressive bijections
~~~~~~~~~~~~~~~~~~~~~~~~~

.. autoclass:: MaskedAutoregressive
:members: __init__

MaskedCoupling
~~~~~~~~~~~~~~
Coupling bijections
~~~~~~~~~~~~~~~~~~~

.. autoclass:: MaskedCoupling
:members: __init__

Permutation
~~~~~~~~~~~
Other bijections
~~~~~~~~~~~~~~~~

.. autoclass:: Permutation
:members: __init__

Inference surjection layers
---------------------------

.. autosummary::
AffineMaskedAutoregressiveInferenceFunnel
MaskedCouplingInferenceFunnel
AffineMaskedCouplingInferenceFunnel
LULinear
MLPFunnel
RationalQuadraticSplineMaskedCouplingInferenceFunnel
Slice

MaskedAutoregressiveInferenceFunnel
AffineMaskedAutoregressiveInferenceFunnel
RationalQuadraticSplineMaskedAutoregressiveInferenceFunnel

AffineMaskedAutoregressiveInferenceFunnel
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
LULinear
MLPInferenceFunnel
Slice

.. autoclass:: AffineMaskedAutoregressiveInferenceFunnel

Coupling inference surjections
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~

.. autoclass:: MaskedCouplingInferenceFunnel
:members: __init__

AffineMaskedCouplingInferenceFunnel
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~

.. autoclass:: AffineMaskedCouplingInferenceFunnel
:members: __init__

LULinear
~~~~~~~~
.. autoclass:: RationalQuadraticSplineMaskedCouplingInferenceFunnel
:members: __init__

.. autoclass:: LULinear
Autoregressive inference surjections
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~

MLPFunnel
~~~~~~~~~
.. autoclass:: MaskedAutoregressiveInferenceFunnel
:members: __init__

.. autoclass:: AffineMaskedAutoregressiveInferenceFunnel
:members: __init__

.. autoclass:: MLPFunnel
.. autoclass:: RationalQuadraticSplineMaskedAutoregressiveInferenceFunnel
:members: __init__

RationalQuadraticSplineMaskedCouplingInferenceFunnel
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
Other inference surjections
~~~~~~~~~~~~~~~~~~~~~~~~~~~

.. autoclass:: RationalQuadraticSplineMaskedCouplingInferenceFunnel
.. autoclass:: LULinear
:members: __init__

Slice
~~~~~
.. autoclass:: MLPInferenceFunnel
:members: __init__

.. autoclass:: Slice
:members: __init__
8 changes: 5 additions & 3 deletions examples/autoregressive_inference_surjection.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,14 +13,14 @@
AffineMaskedAutoregressiveInferenceFunnel,
Chain,
MaskedAutoregressive,
TransformedDistribution,
TransformedDistribution, Permutation,
)
from surjectors.conditioners import MADE, mlp_conditioner
from surjectors.nn import MADE, make_mlp
from surjectors.util import as_batch_iterator, unstack


def _decoder_fn(n_dim):
decoder_net = mlp_conditioner([4, 4, n_dim * 2])
decoder_net = make_mlp([4, 4, n_dim * 2])

def _fn(z):
params = decoder_net(z)
Expand Down Expand Up @@ -53,6 +53,8 @@ def _flow(**kwargs):
bijector_fn=_made_bijector_fn,
)
layers.append(layer)
# TODO(simon): needs to change order
# layers.append(Permutation(order, 1))
chain = Chain(layers)

base_distribution = distrax.Independent(
Expand Down
10 changes: 4 additions & 6 deletions examples/conditional_density_estimation.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,12 +37,10 @@ def _flow(method, **kwargs):
layer = MaskedCoupling(
mask=mask,
bijector=_bijector_fn,
conditioner=hk.Sequential(
[
make_mlp([8, 8, dim * 2]),
hk.Reshape((dim, dim)),
]
),
conditioner=hk.Sequential([
make_mlp([8, 8, dim * 2]),
hk.Reshape((dim, dim)),
]),
)
layers.append(layer)
else:
Expand Down
24 changes: 13 additions & 11 deletions examples/coupling_inference_surjection.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,15 +12,15 @@
from surjectors import (
AffineMaskedCouplingInferenceFunnel,
Chain,
MaskedAutoregressive,
MaskedCoupling,
TransformedDistribution,
)
from surjectors.conditioners import MADE, mlp_conditioner
from surjectors.util import as_batch_iterator, unstack
from surjectors.nn import make_mlp
from surjectors.util import as_batch_iterator, make_alternating_binary_mask


def _decoder_fn(n_dim):
decoder_net = mlp_conditioner([4, 4, n_dim * 2])
decoder_net = make_mlp([4, 4, n_dim * 2])

def _fn(z):
params = decoder_net(z)
Expand All @@ -30,9 +30,9 @@ def _fn(z):
return _fn


def _made_bijector_fn(params):
means, log_scales = unstack(params, -1)
return distrax.Inverse(distrax.ScalarAffine(means, jnp.exp(log_scales)))
def bijector_fn(params):
shift, log_scale = jnp.split(params, 2, axis=-1)
return distrax.ScalarAffine(shift, jnp.exp(log_scale))


def make_model(n_dimensions):
Expand All @@ -44,13 +44,15 @@ def _flow(**kwargs):
layer = AffineMaskedCouplingInferenceFunnel(
n_keep=int(n_dim / 2),
decoder=_decoder_fn(int(n_dim / 2)),
conditioner=mlp_conditioner([8, 8, n_dim * 2]),
conditioner=make_mlp([8, 8, n_dim * 2]),
)
n_dim = int(n_dim / 2)
else:
layer = MaskedAutoregressive(
conditioner=MADE(n_dim, [8, 8], 2),
bijector_fn=_made_bijector_fn,
mask = make_alternating_binary_mask(n_dim, i % 2 == 0)
layer = MaskedCoupling(
mask=mask,
bijector=bijector_fn,
conditioner=make_mlp([8, 8, n_dim * 2]),
)
layers.append(layer)
chain = Chain(layers)
Expand Down
28 changes: 22 additions & 6 deletions surjectors/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,35 +11,51 @@
from surjectors._src.distributions.transformed_distribution import (
TransformedDistribution,
)
from surjectors._src.surjectors.masked_autoregressive_inference_funnel import ( # noqa: E501
MaskedAutoregressiveInferenceFunnel,
)
from surjectors._src.surjectors.affine_masked_autoregressive_inference_funnel import ( # noqa: E501
AffineMaskedAutoregressiveInferenceFunnel,
)
from surjectors._src.surjectors.affine_masked_coupling_generative_funnel import ( # noqa: E501
AffineMaskedCouplingGenerativeFunnel,
)
from surjectors._src.surjectors.rq_masked_autoregressive_inference_funnel import ( # noqa: E501
RationalQuadraticSplineMaskedAutoregressiveInferenceFunnel,
)
from surjectors._src.surjectors.affine_masked_coupling_inference_funnel import (
AffineMaskedCouplingInferenceFunnel,
)
from surjectors._src.surjectors.masked_coupling_inference_funnel import (
MaskedCouplingInferenceFunnel,
)
from surjectors._src.surjectors.augment import Augment
from surjectors._src.surjectors.chain import Chain
from surjectors._src.surjectors.mlp import MLPFunnel
from surjectors._src.surjectors.mlp import MLPInferenceFunnel
from surjectors._src.surjectors.rq_masked_coupling_inference_funnel import (
RationalQuadraticSplineMaskedCouplingInferenceFunnel,
)
from surjectors._src.surjectors.slice import Slice

__all__ = [
"LULinear",

"MaskedAutoregressive",
"MaskedAutoregressiveInferenceFunnel",
"AffineMaskedAutoregressiveInferenceFunnel",
"RationalQuadraticSplineMaskedAutoregressiveInferenceFunnel",

"MaskedCoupling",
"MaskedCouplingInferenceFunnel",
"AffineMaskedCouplingInferenceFunnel",
"AffineMaskedCouplingGenerativeFunnel",
"RationalQuadraticSplineMaskedCouplingInferenceFunnel",

"Permutation",
"TransformedDistribution",
"AffineMaskedAutoregressiveInferenceFunnel",
"AffineMaskedCouplingGenerativeFunnel",
"AffineMaskedCouplingInferenceFunnel",

"Augment",
"Chain",
"MLPFunnel",
"MLPInferenceFunnel",
"Slice",
"RationalQuadraticSplineMaskedCouplingInferenceFunnel",
]
9 changes: 9 additions & 0 deletions surjectors/_src/_transform.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,9 @@
from abc import ABCMeta

from distrax._src.utils import jittable


class Transform(jittable.Jittable, metaclass=ABCMeta):
"""
Transformation of a random variable.
"""
Loading

0 comments on commit 66ee284

Please sign in to comment.