Skip to content

Commit

Permalink
Add Distrax classes and methods to surjectors (#34)
Browse files Browse the repository at this point in the history
  • Loading branch information
dirmeier authored Aug 17, 2024
1 parent 68c2f8a commit 376e59a
Show file tree
Hide file tree
Showing 7 changed files with 99 additions and 24 deletions.
41 changes: 41 additions & 0 deletions .github/workflows/examples.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,41 @@
name: examples

on:
push:
branches: [ main ]
pull_request:
branches: [ main ]

jobs:
precommit:
name: Pre-commit checks
runs-on: ubuntu-latest
steps:
- uses: actions/checkout@v3

examples:
runs-on: ubuntu-latest
needs:
- precommit
strategy:
matrix:
python-version: [3.11]
steps:
- uses: actions/checkout@v3
- name: Set up Python ${{ matrix.python-version }}
uses: actions/setup-python@v3
with:
python-version: ${{ matrix.python-version }}
- name: Install dependencies
run: |
pip install hatch matplotlib
- name: Build package
run: |
pip install jaxlib jax
pip install .
- name: Run tests
run: |
python examples/autoregressive_inference_surjection.py --n-iter 10
python examples/conditional_density_estimation.py --n-iter 10 --model coupling
python examples/conditional_density_estimation.py --n-iter 10 --model autoregressive
python examples/coupling_inference_surjection.py --n-iter 10
10 changes: 7 additions & 3 deletions examples/autoregressive_inference_surjection.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
import argparse
from collections import namedtuple

import distrax
Expand Down Expand Up @@ -99,14 +100,14 @@ def loss_fn(params):
return params, losses


def run():
def run(n_iter):
n, p = 1000, 20
rng_seq = hk.PRNGSequence(2)
y = jr.normal(next(rng_seq), shape=(n, p))
data = namedtuple("named_dataset", "y")(y)

model = make_model(p)
params, losses = train(rng_seq, data, model)
params, losses = train(rng_seq, data, model, n_iter)
plt.plot(losses)
plt.show()

Expand All @@ -115,4 +116,7 @@ def run():


if __name__ == "__main__":
run()
parser = argparse.ArgumentParser()
parser.add_argument("--n-iter", type=int, default=1_000)
args = parser.parse_args()
run(args.n_iter)
20 changes: 13 additions & 7 deletions examples/conditional_density_estimation.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,5 @@
import argparse

import distrax
import haiku as hk
import jax
Expand Down Expand Up @@ -49,9 +51,7 @@ def _flow(method, **kwargs):
layer = MaskedAutoregressive(
bijector_fn=_bijector_fn,
conditioner=MADE(
2,
[32, 32, 2 * 2],
2,
2, [32, 32], 2,
w_init=hk.initializers.TruncatedNormal(0.01),
b_init=jnp.zeros,
),
Expand Down Expand Up @@ -104,7 +104,7 @@ def loss_fn(params):
return params, losses


def run():
def run(n_iter, model):
n = 10000
thetas = distrax.Normal(jnp.zeros(2), jnp.full(2, 10)).sample(
seed=random.PRNGKey(0), sample_shape=(n,)
Expand All @@ -114,8 +114,8 @@ def run():
)
data = named_dataset(y, thetas)

model = make_model(2)
params, losses = train(hk.PRNGSequence(2), data, model)
model = make_model(2, model)
params, losses = train(hk.PRNGSequence(2), data, model, n_iter)
samples = model.apply(
params,
random.PRNGKey(2),
Expand All @@ -129,4 +129,10 @@ def run():


if __name__ == "__main__":
run()
parser = argparse.ArgumentParser()
parser.add_argument("--n-iter", type=int, default=1_000)
parser.add_argument("--model", type=str, default="coupling")
args = parser.parse_args()
run(args.n_iter, args.model)


12 changes: 9 additions & 3 deletions examples/coupling_inference_surjection.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
import argparse
from collections import namedtuple

import distrax
Expand Down Expand Up @@ -99,14 +100,14 @@ def loss_fn(params):
return params, losses


def run():
def run(n_iter):
n, p = 1000, 20
rng_seq = hk.PRNGSequence(2)
y = jr.normal(next(rng_seq), shape=(n, p))
data = namedtuple("named_dataset", "y")(y)

model = make_model(p)
params, losses = train(rng_seq, data, model)
params, losses = train(rng_seq, data, model, n_iter)
plt.plot(losses)
plt.show()

Expand All @@ -115,4 +116,9 @@ def run():


if __name__ == "__main__":
run()
parser = argparse.ArgumentParser()
parser.add_argument("--n-iter", type=int, default=1_000)
args = parser.parse_args()
run(args.n_iter)


5 changes: 4 additions & 1 deletion surjectors/__init__.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,8 @@
"""surjectors: Surjection layers for density estimation with normalizing flows."""

__version__ = "0.3.2"
__version__ = "0.3.3"

from distrax import ScalarAffine

from surjectors._src.bijectors.affine_masked_autoregressive import (
AffineMaskedAutoregressive,
Expand Down Expand Up @@ -60,4 +62,5 @@
"MLPInferenceFunnel",
"Slice",
# "Augment",
"ScalarAffine",
]
11 changes: 6 additions & 5 deletions surjectors/_src/bijectors/permutation.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,10 @@
import distrax
from jax import numpy as jnp

from surjectors._src.bijectors.bijector import Bijector


# pylint: disable=arguments-renamed
class Permutation(distrax.Bijector):
class Permutation(Bijector):
"""Permute the dimensions of a vector.
Args:
Expand All @@ -20,13 +21,13 @@ class Permutation(distrax.Bijector):
"""

def __init__(self, permutation, event_ndims_in: int):
super().__init__(event_ndims_in)
self.permutation = permutation
self.event_ndims_in = event_ndims_in

def _forward_and_likelihood_contribution(self, z):
def _forward_and_likelihood_contribution(self, z, **kwargs):
return z[..., self.permutation], jnp.full(jnp.shape(z)[:-1], 0.0)

def _inverse_and_likelihood_contribution(self, y):
def _inverse_and_likelihood_contribution(self, y, **kwargs):
size = self.permutation.size
permutation_inv = (
jnp.zeros(size, dtype=jnp.result_type(int))
Expand Down
24 changes: 19 additions & 5 deletions surjectors/_src/distributions/transformed_distribution.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,11 @@
from typing import Union

import chex
import distrax
import haiku as hk
from distrax import Distribution
from jax import Array
from tensorflow_probability.substrates.jax import distributions as tfd

from surjectors._src.surjectors.surjector import Surjector

Expand Down Expand Up @@ -31,7 +34,11 @@ class TransformedDistribution:
>>> )
"""

def __init__(self, base_distribution: Distribution, transform: Surjector):
def __init__(
self,
base_distribution: Union[Distribution, tfd.Distribution],
transform: Surjector,
):
self.base_distribution = base_distribution
self.transform = transform

Expand Down Expand Up @@ -118,10 +125,17 @@ def sample_and_log_prob(self, sample_shape=(), x: Array = None):
if x is not None:
chex.assert_equal(sample_shape[0], x.shape[0])

z, lp_z = self.base_distribution.sample_and_log_prob(
seed=hk.next_rng_key(),
sample_shape=sample_shape,
)
try:
z, lp_z = self.base_distribution.sample_and_log_prob(
seed=hk.next_rng_key(),
sample_shape=sample_shape,
)
except AttributeError:
z, lp_z = self.base_distribution.experimental_sample_and_log_prob(
seed=hk.next_rng_key(),
sample_shape=sample_shape,
)

y, fldj = self.transform.forward_and_likelihood_contribution(z, x=x)
lp = lp_z - fldj
return y, lp

0 comments on commit 376e59a

Please sign in to comment.