Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Feature/add more surjectors #1

Merged
merged 10 commits into from
Nov 28, 2022
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Next Next commit
Add more surjectors
  • Loading branch information
dirmeier committed Nov 4, 2022
commit ae0bb62c04dde0fe8ae25e64364162b2e4223e0c
20 changes: 9 additions & 11 deletions README.md
Original file line number Diff line number Diff line change
@@ -1,39 +1,37 @@
# surjectors

[![status](http://www.repostatus.org/badges/latest/concept.svg)](http://www.repostatus.org/#concept)
[![ci](https://github.com/dirmeier/surjectors/actions/workflows/ci.yaml/badge.svg)](https://github.com/dirmeier/Inference surjection layers/actions/workflows/ci.yaml)
[![ci](https://github.com/dirmeier/surjectors/actions/workflows/ci.yaml/badge.svg)](https://github.com/dirmeier/surjectors/actions/workflows/ci.yaml)
[![codecov](https://codecov.io/gh/dirmeier/surjectors/branch/main/graph/badge.svg)](https://codecov.io/gh/dirmeier/surjectors)
[![codacy]()]()
[![documentation](https://readthedocs.org/projects/surjectors/badge/?version=latest)](https://surjectors.readthedocs.io/en/latest/?badge=latest)
[![version](https://img.shields.io/pypi/v/surjectors.svg?colorB=black&style=flat)](https://pypi.org/project/surjectors/)

> Inference surjection layers
> Surjection layers for density estimation with normalizing flows

## About

TODO
Surjectors is a light-weight library of inference and generative surjection layers, i.e., layers that reduce dimensionality, for density estimation using normalizing flows.
Surjectors builds on Distrax and Haiku.

## Example usage

TODO

## Installation


To install the latest GitHub <RELEASE>, just call the following on the
command line:
To install the latest GitHub <RELEASE>, just call the following on the command line:

```bash
pip install git+https://github.com/dirmeier/Inference surjection layers@<RELEASE>
pip install git+https://github.com/dirmeier/surjectors@<RELEASE>
```

## Contributing

In order to contribute:

1) Fork and download the repository,
2) create a branch with the name of your new feature (something like `issue/fix-bug-related-to-something` or `feature/implement-new-bound`),
3) install `surjectors` and dev dependencies via `poetry install` (you might need to create a new `conda` or `venv` environment, to not break other dependencies),
1) Fork and download the forked repository,
2) create a branch with the name of your new feature (something like `issue/fix-bug-related-to-something` or `feature/implement-new-surjector`),
3) install `surjectors` and dev dependencies via `poetry install` (you might want to create a new `conda` or `venv` environment, to not break other dependencies),
4) develop code, commit changes and push it to your branch,
5) create a PR

File renamed without changes.
32 changes: 32 additions & 0 deletions examples/multivariate_gaussian.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,32 @@
prng = hk.PRNGSequence(jax.random.PRNGKey(42))
matrix = jax.random.uniform(next(prng), (4, 4))
bias = jax.random.normal(next(prng), (4,))
bijector = LowerUpperTriangularAffine(matrix, bias)

#
# def loss():
# x, lc = bijector.inverse_and_log_det(jnp.zeros(4) * 2.1)
# lp = distrax.Normal(jnp.zeros(4)).log_prob(x)
# return -jnp.sum(lp - lc)
#
# print(bijector.matrix)
#
# adam = optax.adam(0.003)
# g = jax.grad(loss)()
#
# print(g)
#

matrix = jax.random.uniform(next(prng), (4, 4))
bias = jax.random.normal(next(prng), (4,))
bijector = LowerUpperTriangularAffine(matrix, bias)

n = jnp.ones((4, 4)) * 3.1
n += jnp.triu(n) * 2

bijector = LowerUpperTriangularAffine(n, jnp.zeros(4))


print(bijector.forward(jnp.ones(4)))

print(n @jnp.ones(4) )
3 changes: 0 additions & 3 deletions examples/solar_dynamo.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,3 @@
jnp.array([549229066, 500358972], dtype=jnp.uint32), 100
)
pns[i] = pn


Distribution
Empty file.
Empty file.
Empty file.
Empty file.
2 changes: 2 additions & 0 deletions surjectors/distributions/conditional_distribution.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,2 @@
class ConditionalDistribution:
pass
File renamed without changes.
34 changes: 13 additions & 21 deletions surjectors/distributions/transformed_distribution.py
Original file line number Diff line number Diff line change
@@ -1,30 +1,22 @@
from typing import Tuple

import chex
import jax
import jax.numpy as jnp
from distrax._src.bijectors import bijector as bjct_base
from distrax._src.distributions import distribution as dist_base
from distrax._src.utils import conversion
from tensorflow_probability.substrates import jax as tfp

tfd = tfp.distributions

PRNGKey = dist_base.PRNGKey
Array = dist_base.Array
DistributionLike = dist_base.DistributionLike
BijectorLike = bjct_base.BijectorLike
from chex import PRNGKey
from distrax import Distribution
Array = chex.Array
from surjectors.surjectors.surjector import Surjector


class TransformedDistribution:
def __init__(self, base_distribution, surjector):
def __init__(self, base_distribution: Distribution, surjector: Surjector):
self.base_distribution = base_distribution
self.surjector = surjector

def log_prob(self, y: Array) -> Array:
x, ildj_y = self.surjector.inverse_and_log_det(y)
def log_prob(self, y: Array) -> jnp.ndarray:
x, lc = self.surjector.inverse_and_likelihood_contribution(y)
lp_x = self.base_distribution.log_prob(x)
lp_y = lp_x + ildj_y
return lp_y
lp = lp_x - lc
return lp

def sample(self, key: PRNGKey, sample_shape=(1,)):
z = self.base_distribution.sample(seed=key, sample_shape=sample_shape)
Expand All @@ -35,6 +27,6 @@ def sample_and_log_prob(self, key: PRNGKey, sample_shape=(1,)):
z, lp_z = self.base_distribution.sample_and_log_prob(
seed=key, sample_shape=sample_shape
)
y, fldj = jax.vmap(self.surjector.forward_and_log_det)(z)
lp_y = jax.vmap(jnp.subtract)(lp_z, fldj)
return y, lp_y
y, fldj = jax.vmap(self.surjector.forward_and_likelihood_contribution)(z)
lp = jax.vmap(jnp.subtract)(lp_z, fldj)
return y, lp
File renamed without changes.
24 changes: 24 additions & 0 deletions surjectors/surjectors/affine_coupling_funnel.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,24 @@
import chex
from jax import numpy as jnp

from surjectors.surjectors.surjector import Surjector


class AffineCouplingFunnel(Surjector):
def __init__(self, n_keep, decoder, transform, encoder, kind="inference_surjection"):
super().__init__(n_keep, decoder, encoder, kind)
self._transform = transform

def split_input(self, input):
split_proportions = (self.n_keep, input.shape[-1] - self.n_keep)
return jnp.split(input, split_proportions, axis=-1)

def inverse_and_likelihood_contribution(self, y):
y_plus, y_minus = self.split_input(y)
chex.assert_equal_shape([y_plus, y_minus])
z, jac_det = self._transform(y_plus, context=y_minus)
lp = self.decoder.log_prob(y_minus, context=z)
return z, lp + jac_det

def forward_and_likelihood_contribution(self, z):
raise NotImplementedError()
2 changes: 1 addition & 1 deletion surjectors/surjectors/chain.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@

class Chain(Surjector):
def __init__(self, surjectors):
super().__init__(None, None, None, "surjector")
self._surjectors = surjectors

def inverse_and_likelihood_contribution(self, y):
Expand All @@ -18,4 +19,3 @@ def forward_and_likelihood_contribution(self, z):
x, lc = _surjectors.forward_and_log_det(x)
log_det += lc
return y, log_det

22 changes: 0 additions & 22 deletions surjectors/surjectors/funnel.py

This file was deleted.

60 changes: 0 additions & 60 deletions surjectors/surjectors/linear.py

This file was deleted.

6 changes: 4 additions & 2 deletions surjectors/surjectors/lu_linear.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,8 +7,11 @@


class LULinear(Surjector):
def __init__(self, n_keep, dtype=jnp.float32):
def __init__(self, n_keep, with_bias=False, dtype=jnp.float32):
super().__init__(n_keep, None, None, "bijection", dtype)
if with_bias:
raise NotImplementedError()

n_triangular_entries = ((n_keep - 1) * n_keep) // 2

self._lower_indices = np.tril_indices(n_keep, k=-1)
Expand Down Expand Up @@ -51,4 +54,3 @@ def inverse_and_likelihood_contribution(self, y):

def forward_and_likelihood_contribution(self, z):
pass

30 changes: 30 additions & 0 deletions surjectors/surjectors/mlp.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,30 @@
import distrax
import haiku as hk
from jax import numpy as jnp

from surjectors.surjectors.affine_coupling_funnel import Funnel
from surjectors.surjectors.lu_linear import LULinear


class MLP(Funnel, hk.Module):
def __init__(self, n_keep, decoder, dtype=jnp.float32):
self._r = LULinear(n_keep, dtype, with_bias=False)
self._w_prime = hk.Linear(n_keep, with_bias=True)

self._decoder = decoder
super().__init__(n_keep, decoder)

def inverse_and_likelihood_contribution(self, y):
y_plus, y_minus = self.split_input(y)
z, jac_det = self._r.inverse_and_likelihood_contribution(y_plus)
z += self._w_prime(y_minus)
lp = self._decode(z).log_prob(y_minus)
return z, lp + jac_det

def _decode(self, array):
mu, log_scale = self._decoder(array)
distr = distrax.MultivariateNormalDiag(mu, jnp.exp(log_scale))
return distr

def forward_and_likelihood_contribution(self, z):
pass
Empty file.
17 changes: 13 additions & 4 deletions surjectors/surjectors/slice.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,18 @@
from jax import numpy as jnp

from surjectors.funnel import Funnel
from surjectors.surjectors.affine_coupling_funnel import Funnel


class Slice(Funnel):
def __init__(self, n_keep, kind="inference_surjection"):
# TODO: implement decoder and encoder
super().__init__(kind, decoder, encoder, n_keep)
def __init__(self, n_keep, decoder, encoder=None, kind="inference_surjection"):
super().__init__(n_keep, decoder, encoder, kind)

def inverse_and_likelihood_contribution(self, y):
z, y_minus = self.split_input(y)
lc = self.decoder.log_prob(y_minus, context=z)
return z, lc

def forward_and_likelihood_contribution(self, z):
y_minus = self.decoder.sample(context=z)
y = jnp.concatenate([z, y_minus], axis=-1)
return y
6 changes: 3 additions & 3 deletions surjectors/surjectors/surjector.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,10 @@
from abc import abstractmethod

from jax import numpy as jnp
from surjectors.surjectors.transform import Transform
from surjectors.surjectors._transform import Transform


_valid_kinds = ["inference_surjector", "generative_surjector", "bijector"]
_valid_kinds = ["inference_surjector", "generative_surjector", "bijector", "surjector"]


class Surjector(Transform):
Expand Down Expand Up @@ -49,4 +49,4 @@ def encoder(self):

@property
def dtype(self):
return self._dtype
return self._dtype