Skip to content

Commit

Permalink
Add more documentation
Browse files Browse the repository at this point in the history
  • Loading branch information
dirmeier committed Oct 13, 2023
1 parent 8899d2b commit 8f61c0e
Show file tree
Hide file tree
Showing 39 changed files with 233 additions and 84 deletions.
7 changes: 6 additions & 1 deletion .pre-commit-config.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -72,7 +72,12 @@ repos:
args: ['--ignore=E501,E203,E302,E402,E731,W503']

- repo: https://github.com/jorisroovers/gitlint
rev: v0.18.0
rev: v0.19.1
hooks:
- id: gitlint
- id: gitlint-ci

- repo: https://github.com/pycqa/pydocstyle
rev: 6.1.1
hooks:
- id: pydocstyle
37 changes: 37 additions & 0 deletions docs/api_gen/surjectors.Chain.rst
Original file line number Diff line number Diff line change
@@ -0,0 +1,37 @@
surjectors.Chain
================

.. currentmodule:: surjectors

.. autoclass:: Chain


.. automethod:: __init__


.. rubric:: Methods

.. autosummary::

~Chain.__init__
~Chain.forward
~Chain.forward_and_likelihood_contribution
~Chain.inverse
~Chain.inverse_and_likelihood_contribution
~Chain.tree_flatten
~Chain.tree_unflatten





.. rubric:: Attributes

.. autosummary::

~Chain.decoder
~Chain.dtype
~Chain.encoder
~Chain.n_keep


26 changes: 26 additions & 0 deletions docs/api_gen/surjectors.TransformedDistribution.rst
Original file line number Diff line number Diff line change
@@ -0,0 +1,26 @@
surjectors.TransformedDistribution
==================================

.. currentmodule:: surjectors

.. autoclass:: TransformedDistribution


.. automethod:: __init__


.. rubric:: Methods

.. autosummary::

~TransformedDistribution.__init__
~TransformedDistribution.inverse_and_log_prob
~TransformedDistribution.log_prob
~TransformedDistribution.sample
~TransformedDistribution.sample_and_log_prob






3 changes: 2 additions & 1 deletion docs/conf.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,9 +15,10 @@
"sphinx.ext.napoleon",
"sphinx.ext.viewcode",
"sphinx_autodoc_typehints",
"sphinx_copybutton",
"sphinx_math_dollar",
"IPython.sphinxext.ipython_console_highlighting",
'sphinx_design'
'sphinx_design',
]


Expand Down
1 change: 1 addition & 0 deletions docs/index.rst
Original file line number Diff line number Diff line change
Expand Up @@ -104,4 +104,5 @@ Surjectors is licensed under the Apache 2.0 License.
:hidden:

surjectors
surjectors.nn
surjectors.util
18 changes: 18 additions & 0 deletions docs/surjectors.nn.rst
Original file line number Diff line number Diff line change
@@ -0,0 +1,18 @@
``surjectors.nn``
=================

.. currentmodule:: surjectors.nn

.. automodule:: surjectors.nn

.. autosummary::
MADE
make_mlp
make_transformer


.. autoclass:: MADE

.. autofunction:: make_mlp

.. autofunction:: make_transformer
4 changes: 2 additions & 2 deletions docs/surjectors.rst
Original file line number Diff line number Diff line change
Expand Up @@ -3,8 +3,6 @@

.. currentmodule:: surjectors

.. automodule:: surjectors


General
-------
Expand All @@ -17,11 +15,13 @@ TransformedDistribution
~~~~~~~~~~~~~~~~~~~~~~~

.. autoclass:: TransformedDistribution
:members: __init__

Chain
~~~~~

.. autoclass:: Chain
:members: __init__

Bijective layers
----------------
Expand Down
6 changes: 3 additions & 3 deletions examples/conditional_density_estimation.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,13 +8,13 @@
from matplotlib import pyplot as plt

from surjectors import (
Chain,
MaskedAutoregressive,
MaskedCoupling,
Permutation,
TransformedDistribution,
Chain
)
from surjectors.conditioners import MADE, mlp_conditioner
from surjectors.nn import make_mlp, MADE
from surjectors.util import (
as_batch_iterator,
make_alternating_binary_mask,
Expand All @@ -39,7 +39,7 @@ def _flow(method, **kwargs):
bijector=_bijector_fn,
conditioner=hk.Sequential(
[
mlp_conditioner([8, 8, dim * 2]),
make_mlp([8, 8, dim * 2]),
hk.Reshape((dim, dim)),
]
),
Expand Down
16 changes: 16 additions & 0 deletions surjectors/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,3 +27,19 @@
RationalQuadraticSplineMaskedCouplingInferenceFunnel,
)
from surjectors._src.surjectors.slice import Slice

__all__ = [
"LULinear",
"MaskedAutoregressive",
"MaskedCoupling",
"Permutation",
"TransformedDistribution",
"AffineMaskedAutoregressiveInferenceFunnel",
"AffineMaskedCouplingGenerativeFunnel",
"AffineMaskedCouplingInferenceFunnel",
"Augment",
"Chain",
"MLPFunnel",
"Slice",
"RationalQuadraticSplineMaskedCouplingInferenceFunnel",
]
2 changes: 1 addition & 1 deletion surjectors/_src/bijectors/_bijector.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
from chex import Array
from jax import numpy as jnp

from surjectors.surjectors.surjector import Surjector
from surjectors._src.surjectors.surjector import Surjector


# pylint: disable=too-many-arguments
Expand Down
2 changes: 1 addition & 1 deletion surjectors/_src/bijectors/lu_linear.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
import numpy as np
from jax import numpy as jnp

from surjectors.surjectors.surjector import Surjector
from surjectors._src.surjectors.surjector import Surjector


# pylint: disable=arguments-differ
Expand Down
6 changes: 3 additions & 3 deletions surjectors/_src/bijectors/masked_autoregressive.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,9 +3,9 @@
from distrax._src.utils import math
from jax import numpy as jnp

from surjectors.bijectors._bijector import _Bijector
from surjectors.conditioners import MADE
from surjectors.distributions.transformed_distribution import Array
from surjectors._src.bijectors._bijector import _Bijector
from surjectors._src.conditioners.nn.made import MADE
from surjectors._src.distributions.transformed_distribution import Array


# pylint: disable=too-many-arguments, arguments-renamed
Expand Down
4 changes: 2 additions & 2 deletions surjectors/_src/bijectors/masked_autoregressive_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,8 +8,8 @@
from jax import random

from surjectors import TransformedDistribution
from surjectors.bijectors.masked_autoregressive import MaskedAutoregressive
from surjectors.conditioners import MADE
from surjectors._src.bijectors.masked_autoregressive import MaskedAutoregressive
from surjectors._src.conditioners.nn.made import MADE
from surjectors.util import unstack


Expand Down
4 changes: 2 additions & 2 deletions surjectors/_src/bijectors/masked_coupling.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,8 +4,8 @@
from distrax._src.utils import math
from jax import numpy as jnp

from surjectors.distributions.transformed_distribution import Array
from surjectors.surjectors.surjector import Surjector
from surjectors._src.distributions.transformed_distribution import Array
from surjectors._src.surjectors.surjector import Surjector


# pylint: disable=too-many-arguments, arguments-renamed
Expand Down
6 changes: 3 additions & 3 deletions surjectors/_src/bijectors/masked_coupling_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@
from jax import random

import surjectors
from surjectors.conditioners.mlp import mlp_conditioner
from surjectors._src.conditioners.mlp import make_mlp
from surjectors.util import make_alternating_binary_mask


Expand Down Expand Up @@ -76,7 +76,7 @@ def _transformation_fn(n_dimension):
layer = flow_ctor(
mask=mask,
bijector=_bijector_fn,
conditioner=mlp_conditioner(
conditioner=make_mlp(
[8, n_dim * 2],
w_init=hk.initializers.TruncatedNormal(stddev=1.0),
b_init=jnp.ones,
Expand All @@ -102,7 +102,7 @@ def _flow(**kwargs):
layer = surjectors.MaskedCoupling(
mask=mask,
bijector=_bijector_fn,
conditioner=mlp_conditioner([8, 8, n_dim * 2]),
conditioner=make_mlp([8, 8, n_dim * 2]),
)
layers.append(layer)
chain = surjectors.Chain(layers)
Expand Down
3 changes: 0 additions & 3 deletions surjectors/_src/conditioners/__init__.py
Original file line number Diff line number Diff line change
@@ -1,3 +0,0 @@
from .mlp import mlp_conditioner
from .nn.made import MADE
from .transformer import transformer_conditioner
2 changes: 1 addition & 1 deletion surjectors/_src/conditioners/mlp.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
from jax import numpy as jnp


def mlp_conditioner(
def make_mlp(
dims,
activation=jax.nn.gelu,
w_init=hk.initializers.TruncatedNormal(stddev=0.01),
Expand Down
2 changes: 1 addition & 1 deletion surjectors/_src/conditioners/nn/made.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@
_make_dense_autoregressive_masks,
)

from surjectors.conditioners.nn.masked_linear import MaskedLinear
from surjectors._src.conditioners.nn.masked_linear import MaskedLinear


# pylint: disable=too-many-arguments, arguments-renamed
Expand Down
2 changes: 1 addition & 1 deletion surjectors/_src/conditioners/nn/made_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@
from jax import numpy as jnp
from jax import random

from surjectors.conditioners import MADE
from surjectors._src.conditioners.nn.made import MADE
from surjectors.util import unstack


Expand Down
2 changes: 1 addition & 1 deletion surjectors/_src/conditioners/transformer.py
Original file line number Diff line number Diff line change
Expand Up @@ -65,7 +65,7 @@ def __call__(self, inputs, *, is_training=True):


# pylint: disable=too-many-arguments
def transformer_conditioner(
def make_transformer(
output_size,
num_heads=2,
num_layers=2,
Expand Down
40 changes: 31 additions & 9 deletions surjectors/_src/distributions/transformed_distribution.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,21 +3,43 @@
import chex
import distrax
import haiku as hk
from chex import Array
from jax import Array
from distrax import Distribution

from surjectors.surjectors.surjector import Surjector
from surjectors._src.surjectors._transform import Transform


class TransformedDistribution:
"""
Distribution of a random variable transformed by a surjective or
bijectiive function
bijective function.
Can be used to define a pushforward measure.
Examples:
>>> import distrax
>>> from jax import numpy as jnp
>>> from surjectors import Slice, Chain, TransformedDistribution
>>> a = Slice(10)
>>> b = Slice(5)
>>> ab = Chain([a, b])
>>> TransformedDistribution(
>>> distrax.Normal(jnp.zeros(5), jnp.ones(5)),
>>> Chain([a, b])
>>> )
"""

def __init__(self, base_distribution: Distribution, surjector: Surjector):
def __init__(self, base_distribution: Distribution, transform: Transform):
"""
Constructs a TransformedDistribution.
Args:
base_distribution: a distribution object
transform: some transformation
"""

self.base_distribution = base_distribution
self.surjector = surjector
self.transform = transform

def __call__(self, method, **kwargs):
return getattr(self, method)(**kwargs)
Expand Down Expand Up @@ -67,10 +89,10 @@ def inverse_and_log_prob(
chex.assert_equal_rank([y, x])
chex.assert_axis_dimension(y, 0, x.shape[0])

if isinstance(self.surjector, distrax.Bijector):
z, lc = self.surjector.inverse_and_log_det(y)
if isinstance(self.transform, distrax.Bijector):
z, lc = self.transform.inverse_and_log_det(y)
else:
z, lc = self.surjector.inverse_and_likelihood_contribution(y, x=x)
z, lc = self.transform.inverse_and_likelihood_contribution(y, x=x)
lp_z = self.base_distribution.log_prob(z)
lp = lp_z + lc
return z, lp
Expand Down Expand Up @@ -124,6 +146,6 @@ def sample_and_log_prob(self, sample_shape=(), x: Array = None):
seed=hk.next_rng_key(),
sample_shape=sample_shape,
)
y, fldj = self.surjector.forward_and_likelihood_contribution(z, x=x)
y, fldj = self.transform.forward_and_likelihood_contribution(z, x=x)
lp = lp_z - fldj
return y, lp
Original file line number Diff line number Diff line change
Expand Up @@ -3,8 +3,8 @@
from chex import Array
from jax import numpy as jnp

from surjectors.bijectors.masked_autoregressive import MaskedAutoregressive
from surjectors.surjectors.funnel import Funnel
from surjectors._src.bijectors.masked_autoregressive import MaskedAutoregressive
from surjectors._src.surjectors.funnel import Funnel
from surjectors.util import unstack


Expand Down
Loading

0 comments on commit 8f61c0e

Please sign in to comment.