diff --git a/README.md b/README.md index 1c7f069..bca5329 100644 --- a/README.md +++ b/README.md @@ -1,17 +1,18 @@ # surjectors [![status](http://www.repostatus.org/badges/latest/concept.svg)](http://www.repostatus.org/#concept) -[![ci](https://github.com/dirmeier/surjectors/actions/workflows/ci.yaml/badge.svg)](https://github.com/dirmeier/Inference surjection layers/actions/workflows/ci.yaml) +[![ci](https://github.com/dirmeier/surjectors/actions/workflows/ci.yaml/badge.svg)](https://github.com/dirmeier/surjectors/actions/workflows/ci.yaml) [![codecov](https://codecov.io/gh/dirmeier/surjectors/branch/main/graph/badge.svg)](https://codecov.io/gh/dirmeier/surjectors) [![codacy]()]() [![documentation](https://readthedocs.org/projects/surjectors/badge/?version=latest)](https://surjectors.readthedocs.io/en/latest/?badge=latest) [![version](https://img.shields.io/pypi/v/surjectors.svg?colorB=black&style=flat)](https://pypi.org/project/surjectors/) -> Inference surjection layers +> Surjection layers for density estimation with normalizing flows ## About -TODO +Surjectors is a light-weight library of inference and generative surjection layers, i.e., layers that reduce dimensionality, for density estimation using normalizing flows. +Surjectors builds on Distrax and Haiku. ## Example usage @@ -19,21 +20,18 @@ TODO ## Installation - -To install the latest GitHub , just call the following on the -command line: +To install the latest GitHub , just call the following on the command line: ```bash -pip install git+https://github.com/dirmeier/Inference surjection layers@ +pip install git+https://github.com/dirmeier/surjectors@ ``` ## Contributing In order to contribute: -1) Fork and download the repository, -2) create a branch with the name of your new feature (something like `issue/fix-bug-related-to-something` or `feature/implement-new-bound`), -3) install `surjectors` and dev dependencies via `poetry install` (you might need to create a new `conda` or `venv` environment, to not break other dependencies), +1) Fork and download the forked repository, +2) create a branch with the name of your new feature (something like `issue/fix-bug-related-to-something` or `feature/implement-new-surjector`), +3) install `surjectors` and dev dependencies via `poetry install` (you might want to create a new `conda` or `venv` environment, to not break other dependencies), 4) develop code, commit changes and push it to your branch, 5) create a PR - diff --git a/examples/solar_dynamo.py b/examples/solar_dynamo.py deleted file mode 100644 index 74ec438..0000000 --- a/examples/solar_dynamo.py +++ /dev/null @@ -1,20 +0,0 @@ -import numpy as np -import jax -from jax import random, numpy as jnp -import matplotlib.pyplot as plt - - -from surjectors.data import Simulator - -simulator = Simulator() - -n = 1000 -pns = [None] * n -for i in np.arange(n): - p0, alpha1, alpha2, epsilon_max, f, pn = simulator.sample( - jnp.array([549229066, 500358972], dtype=jnp.uint32), 100 - ) - pns[i] = pn - - -Distribution diff --git a/experiments/multivariate_gaussian_generative_surjection.py b/experiments/multivariate_gaussian_generative_surjection.py new file mode 100644 index 0000000..bdab5fd --- /dev/null +++ b/experiments/multivariate_gaussian_generative_surjection.py @@ -0,0 +1,222 @@ +import distrax +import haiku as hk +import jax +import matplotlib.pyplot as plt +import numpy as np +import optax +from jax import config +from jax import numpy as jnp +from jax import random + +from surjectors.bijectors.masked_coupling import MaskedCoupling +from surjectors.conditioners.mlp import mlp_conditioner +from surjectors.distributions.transformed_distribution import ( + TransformedDistribution, +) +from surjectors.surjectors.affine_masked_coupling_generative_funnel import \ + AffineMaskedCouplingGenerativeFunnel +from surjectors.surjectors.augment import Augment +from surjectors.surjectors.chain import Chain + +config.update("jax_enable_x64", True) + + +def _get_sampler_and_loadings(rng_key, batch_size, n_dimension): + pz_mean = jnp.array([-2.31, 0.421, 0.1, 3.21, -0.41, + -2.31, 0.421, 0.1, 3.21, -0.41]) + pz = distrax.MultivariateNormalDiag( + loc=pz_mean, scale_diag=jnp.ones_like(pz_mean) + ) + p_loadings = distrax.Normal(0.0, 10.0) + make_noise = distrax.Normal(0.0, 1) + + loadings_sample_key, rng_key = random.split(rng_key, 2) + loadings = p_loadings.sample( + seed=loadings_sample_key, sample_shape=(n_dimension, len(pz_mean)) + ) + + def _fn(rng_key): + z_sample_key, noise_sample_key = random.split(rng_key, 2) + z = pz.sample(seed=z_sample_key, sample_shape=(batch_size,)) + noise = make_noise.sample( + seed=noise_sample_key, sample_shape=(batch_size, n_dimension) + ) + + y = (loadings @ z.T).T + noise + #return z[:, :n_dimension], z , noise + return y, z, noise + + return _fn, loadings + + +def _encoder_fn(n_latent, n_dimension): + decoder_net = mlp_conditioner([32, 32, n_latent - n_dimension]) + + def _fn(z): + params = decoder_net(z) + mu, log_scale = jnp.split(params, 2, -1) + return distrax.Independent(distrax.Normal(mu, jnp.exp(log_scale))) + + return _fn + + +def _bijector_fn(params): + means, log_scales = jnp.split(params, 2, -1) + return distrax.ScalarAffine(means, jnp.exp(log_scales)) + + +def _base_distribution_fn(n_latent): + base_distribution = distrax.Independent( + distrax.Normal(jnp.zeros(n_latent), jnp.ones(n_latent)), + reinterpreted_batch_ndims=1, + ) + return base_distribution + + +def _get_slice_surjector(n_dimension, n_latent): + def _transformation_fn(): + layers = [] + + mask = jnp.arange(0, np.prod(n_dimension)) % 2 + mask = jnp.reshape(mask, n_dimension) + mask = mask.astype(bool) + for _ in range(2): + layer = MaskedCoupling( + mask=mask, + bijector=_bijector_fn, + conditioner=mlp_conditioner([32, 32, n_dimension]), + ) + layers.append(layer) + + layers.append( + Augment(n_dimension, _encoder_fn(n_latent, n_dimension)) + ) + + mask = jnp.arange(0, np.prod(n_latent)) % 2 + mask = jnp.reshape(mask, n_latent) + mask = mask.astype(bool) + for _ in range(2): + layer = MaskedCoupling( + mask=mask, + bijector=_bijector_fn, + conditioner=mlp_conditioner([32, 32, n_latent]), + ) + layers.append(layer) + mask = jnp.logical_not(mask) + + #return Augment(n_dimension, _encoder_fn()) + return Chain(layers) + + def _flow(method, **kwargs): + td = TransformedDistribution(_base_distribution_fn(n_latent), _transformation_fn()) + return td(method, **kwargs) + + td = hk.transform(_flow) + return td + + +def _get_funnel_surjector(n_dimension, n_latent): + def _transformation_fn(): + layers = [] + + mask = jnp.arange(0, np.prod(n_dimension)) % 2 + mask = jnp.reshape(mask, n_dimension) + mask = mask.astype(bool) + for _ in range(2): + layer = MaskedCoupling( + mask=mask, + bijector=_bijector_fn, + conditioner=mlp_conditioner([32, 32, n_dimension]), + ) + layers.append(layer) + + layers.append( + AffineMaskedCouplingGenerativeFunnel( + n_dimension, _encoder_fn(n_latent, n_dimension), mlp_conditioner(n_latent) + ) + ) + + mask = jnp.arange(0, np.prod(n_latent)) % 2 + mask = jnp.reshape(mask, n_latent) + mask = mask.astype(bool) + for _ in range(2): + layer = MaskedCoupling( + mask=mask, + bijector=_bijector_fn, + conditioner=mlp_conditioner([32, 32, n_latent]), + ) + layers.append(layer) + mask = jnp.logical_not(mask) + + return Chain(layers) + # return AffineMaskedCouplingGenerativeFunnel( + # n_dimension, _encoder_fn(), _conditioner(n_latent) + # ) + + def _flow(method, **kwargs): + td = TransformedDistribution(_base_distribution_fn(n_latent), _transformation_fn()) + return td(method, **kwargs) + + td = hk.transform(_flow) + return td + + +def train(key, surjector_fn, n_data, n_latent, batch_size, n_iter): + rng_seq = hk.PRNGSequence(0) + pyz, loadings = _get_sampler_and_loadings(next(rng_seq), batch_size, n_data) + flow = surjector_fn(n_data, n_latent) + + @jax.jit + def step(params, state, y_batch, noise_batch, rng): + def loss_fn(params): + lp = flow.apply( + params, rng, method="log_prob", y=y_batch, x=noise_batch + ) + return -jnp.sum(lp) + + loss, grads = jax.value_and_grad(loss_fn)(params) + updates, new_state = adam.update(grads, state, params) + new_params = optax.apply_updates(params, updates) + return loss, new_params, new_state + + y_init, _, noise_init = pyz(random.fold_in(next(rng_seq), 0)) + params = flow.init( + random.PRNGKey(key), + method="log_prob", + y=y_init, + x=noise_init + ) + adam = optax.adamw(0.001) + state = adam.init(params) + + losses = [0] * n_iter + for i in range(n_iter): + y_batch, _, noise_batch = pyz(next(rng_seq)) + loss, params, state = step(params, state, y_batch, noise_batch, next(rng_seq)) + losses[i] = loss + + losses = jnp.asarray(losses) + plt.plot(losses) + plt.show() + + y_batch, z_batch, noise_batch = pyz(next(rng_seq)) + y_pred = flow.apply( + params, next(rng_seq), method="sample", x=noise_batch, + ) + print(y_batch[:5, :]) + print(y_pred[:5, :]) + + +def run(): + train( + key=0, + surjector_fn=_get_funnel_surjector, + n_iter=2000, + batch_size=64, + n_data=5, + n_latent=10 + ) + + +if __name__ == "__main__": + run() diff --git a/experiments/multivariate_gaussian_inference_surjection.py b/experiments/multivariate_gaussian_inference_surjection.py new file mode 100644 index 0000000..8bfead4 --- /dev/null +++ b/experiments/multivariate_gaussian_inference_surjection.py @@ -0,0 +1,261 @@ +import distrax +import haiku as hk +import jax +import matplotlib.pyplot as plt +import numpy as np +import optax +from jax import config +from jax import numpy as jnp +from jax import random + +from surjectors.bijectors.masked_coupling import MaskedCoupling +from surjectors.conditioners.mlp import mlp_conditioner +from surjectors.distributions.transformed_distribution import ( + TransformedDistribution, +) +from surjectors.surjectors.affine_masked_coupling_inference_funnel import \ + AffineMaskedCouplingInferenceFunnel +from surjectors.surjectors.chain import Chain +from surjectors.surjectors.slice import Slice + +config.update("jax_enable_x64", True) + + +def _get_sampler(rng_key, batch_size, n_dimension, n_latent): + means_sample_key, rng_key = random.split(rng_key, 2) + pz_mean = distrax.Normal(0.0, 10.0).sample( + seed=means_sample_key, sample_shape=(n_latent) + ) + pz = distrax.MultivariateNormalDiag( + loc=pz_mean, scale_diag=jnp.ones_like(pz_mean) + ) + p_loadings = distrax.Normal(0.0, 10.0) + make_noise = distrax.Normal(0.0, 1) + + loadings_sample_key, rng_key = random.split(rng_key, 2) + loadings = p_loadings.sample( + seed=loadings_sample_key, sample_shape=(n_dimension, len(pz_mean)) + ) + + def _fn(rng_key): + z_sample_key, noise_sample_key = random.split(rng_key, 2) + z = pz.sample(seed=z_sample_key, sample_shape=(batch_size,)) + noise = make_noise.sample( + seed=noise_sample_key, sample_shape=(batch_size, n_dimension) + ) + + y = (loadings @ z.T).T + noise + # y = jnp.concatenate([z, z] ,axis=-1) + return y, z, noise + + return _fn, loadings + + +def _decoder_fn(n_dimension, n_latent): + decoder_net = mlp_conditioner([32, 32, n_dimension - n_latent]) + + def _fn(z): + params = decoder_net(z) + mu, log_scale = jnp.split(params, 2, -1) + return distrax.Independent(distrax.Normal(mu, jnp.exp(log_scale))) + + return _fn + + +def _bijector_fn(params): + means, log_scales = jnp.split(params, 2, -1) + return distrax.ScalarAffine(means, jnp.exp(log_scales)) + + +def _base_distribution_fn(n_latent): + base_distribution = distrax.Independent( + distrax.Normal(jnp.zeros(n_latent), jnp.ones(n_latent)), + reinterpreted_batch_ndims=1, + ) + return base_distribution + + +def _get_slice_surjector(n_dimension, n_latent): + def _transformation_fn(): + layers = [] + mask = jnp.arange(0, np.prod(n_dimension)) % 2 + mask = jnp.reshape(mask, n_dimension) + mask = mask.astype(bool) + + for _ in range(2): + layer = MaskedCoupling( + mask=mask, + bijector=_bijector_fn, + conditioner=mlp_conditioner([32, 32, n_dimension]) + ) + layers.append(layer) + + layers.append( + Slice(n_latent, _decoder_fn(n_dimension, n_latent)) + ) + + mask = jnp.arange(0, np.prod(n_latent)) % 2 + mask = jnp.reshape(mask, n_latent) + mask = mask.astype(bool) + for _ in range(2): + layer = MaskedCoupling( + mask=mask, + bijector=_bijector_fn, + conditioner=mlp_conditioner([32, 32, n_latent]), + ) + layers.append(layer) + mask = jnp.logical_not(mask) + + return Chain(layers) + + def _flow(method, **kwargs): + td = TransformedDistribution(_base_distribution_fn(n_latent), _transformation_fn()) + return td(method, **kwargs) + + td = hk.transform(_flow) + return td + + +def _get_funnel_surjector(n_dimension, n_latent): + def _transformation_fn(): + layers = [] + mask = jnp.arange(0, np.prod(n_dimension)) % 2 + mask = jnp.reshape(mask, n_dimension) + mask = mask.astype(bool) + + for _ in range(2): + layer = MaskedCoupling( + mask=mask, + bijector=_bijector_fn, + conditioner=mlp_conditioner([32, 32, n_dimension]), + ) + layers.append(layer) + mask = jnp.logical_not(mask) + + layers.append( + AffineMaskedCouplingInferenceFunnel( + n_latent, + _decoder_fn(n_dimension, n_latent), + mlp_conditioner([32, 32, n_dimension]) + ) + ) + + mask = jnp.arange(0, np.prod(n_latent)) % 2 + mask = jnp.reshape(mask, n_latent) + mask = mask.astype(bool) + for _ in range(2): + layer = MaskedCoupling( + mask=mask, + bijector=_bijector_fn, + conditioner=mlp_conditioner([32, 32, n_latent]), + ) + layers.append(layer) + mask = jnp.logical_not(mask) + + return Chain(layers) + + def _flow(method, **kwargs): + td = TransformedDistribution(_base_distribution_fn(n_latent), _transformation_fn()) + return td(method, **kwargs) + + td = hk.transform(_flow) + return td + + +def _get_bijector(n_dimension, n_latent): + def _transformation_fn(): + layers = [] + mask = jnp.arange(0, np.prod(n_dimension)) % 2 + mask = jnp.reshape(mask, n_dimension) + mask = mask.astype(bool) + + for _ in range(4): + layer = MaskedCoupling( + mask=mask, + bijector=_bijector_fn, + conditioner=mlp_conditioner([32, 32, n_dimension]), + ) + layers.append(layer) + mask = jnp.logical_not(mask) + + return Chain(layers) + + def _flow(method, **kwargs): + td = TransformedDistribution(_base_distribution_fn(n_dimension), _transformation_fn()) + return td(method, **kwargs) + + td = hk.transform(_flow) + return td + + +def train(rng_seq, sampler, surjector_fn, n_data, n_latent, batch_size, n_iter): + flow = surjector_fn(n_data, n_latent) + + @jax.jit + def step(params, state, y_batch, noise_batch, rng): + def loss_fn(params): + lp = flow.apply( + params, rng, method="log_prob", y=y_batch, x=noise_batch + ) + return -jnp.sum(lp) + + loss, grads = jax.value_and_grad(loss_fn)(params) + updates, new_state = adam.update(grads, state, params) + new_params = optax.apply_updates(params, updates) + return loss, new_params, new_state + + y_init, _, noise_init = sampler(next(rng_seq)) + params = flow.init( + next(rng_seq), + method="log_prob", + y=y_init, + x=noise_init + ) + adam = optax.adamw(0.001) + state = adam.init(params) + + losses = [0] * n_iter + for i in range(n_iter): + y_batch, _, noise_batch = sampler(next(rng_seq)) + loss, params, state = step(params, state, y_batch, noise_batch, next(rng_seq)) + losses[i] = loss + + losses = jnp.asarray(losses) + plt.plot(losses) + plt.show() + + return flow, params + + +def evaluate(rng_seq, params, model, sampler, batch_size, n_data): + y_batch, _, noise_batch = sampler(next(rng_seq)) + lp = model.apply(params, next(rng_seq), method="log_prob", y=y_batch, x=noise_batch) + print("\tPPLP: {:.3f}".format(jnp.mean(lp))) + + +def run(): + n_iter = 2000 + batch_size = 64 + n_data, n_latent = 100, 75 + sampler, _ = _get_sampler(random.PRNGKey(0), batch_size, n_data, n_latent) + for method, _fn in [ + ["Slice", _get_slice_surjector], + ["Funnel", _get_funnel_surjector], + ["Bijector", _get_bijector] + ]: + print(f"Doing {method}") + rng_seq = hk.PRNGSequence(0) + model, params = train( + rng_seq=rng_seq, + sampler=sampler, + surjector_fn=_fn, + n_iter=n_iter, + batch_size=batch_size, + n_data=n_data, + n_latent=n_latent + ) + evaluate(rng_seq, params, model, sampler, batch_size, n_data) + + +if __name__ == "__main__": + run() diff --git a/experiments/solar_dynamo_data.py b/experiments/solar_dynamo_data.py new file mode 100644 index 0000000..2db2088 --- /dev/null +++ b/experiments/solar_dynamo_data.py @@ -0,0 +1,66 @@ +from jax import random, lax +from jax.scipy.special import erf +import distrax + + +class SolarDynamoSimulator: + """Implements Eqn 2 and 3 of + FLUCTUATIONS IN BABCOCK-LEIGHTON DYNAMOS. II. REVISITING THE GNEVYSHEV-OHL RULE + https://iopscience.iop.org/article/10.1086/511177/pdf + """ + def __init__(self, **kwargs): + self.p0_mean = kwargs.get("p0_mean", 1.0) + self.p0_std = kwargs.get("p0_std", 1.0) + self.alpha1_min = kwargs.get("alpha1_min", 1.3) + self.alpha1_max = kwargs.get("alpha1_max", 1.5) + self.alpha2_max = kwargs.get("alpha2_max", 1.65) + self.epsilon_max = kwargs.get("epsilon_max", 0.5) + self.alpha1 = kwargs.get("alpha1", None) + self.alpha2 = kwargs.get("alpha2", None) + + def sample(self, key, batch_size, len_timeseries=1000): + p_key, alpha1_key, alpha2_key, epsilon_key, key = random.split(key, 5) + p0 = random.normal(p_key, shape=(batch_size,)) * self.p0_std + self.p0_mean + alpha1 = random.uniform( + alpha1_key, shape=(batch_size,), minval=self.alpha1_min, maxval=self.alpha1_max + ) + alpha2 = random.uniform( + alpha2_key, shape=(batch_size,), minval=alpha1, maxval=self.alpha2_max + ) + epsilon_max = random.uniform( + epsilon_key, shape=(batch_size,), minval=0, maxval=self.epsilon_max + ) + f, y, alpha, noise = self._sample_timeseries( + key, batch_size, p0, alpha1, alpha2, epsilon_max, len_timeseries + ) + + return (p0, alpha1, alpha2, epsilon_max, f), y, alpha, noise + + @staticmethod + def babcock_leighton_fn(p, b_1=0.6, w_1=0.2, b_2=1.0, w_2=0.8): + f = 0.5 * (1.0 + erf((p - b_1) / w_1)) * (1.0 - erf((p - b_2) / w_2)) + return f + + def babcock_leighton(self, p, alpha, epsilon): + p = alpha * self.babcock_leighton_fn(p) * p + epsilon + return p + + def _sample_timeseries( + self, key, batch_size, pn, alpha_min, alpha_max, epsilon_max, len_timeseries + ): + a = distrax.Uniform(alpha_min, alpha_max).sample( + seed=key, sample_shape=(len_timeseries,) + ) + noise = distrax.Uniform(0.0, epsilon_max).sample( + seed=key, sample_shape=(len_timeseries,) + ) + + def _fn(fs, arrays): + alpha, epsilon = arrays + f, pn = fs + f = self.babcock_leighton_fn(pn) + pn = self.babcock_leighton(pn, alpha, epsilon) + return (f, pn), (f, pn) + + _, pn = lax.scan(_fn, (pn, pn), (a, noise)) + return pn[0].T, pn[1].T, a.T, noise.T diff --git a/experiments/solar_dynamo_generative_surjection.py b/experiments/solar_dynamo_generative_surjection.py new file mode 100644 index 0000000..61e917b --- /dev/null +++ b/experiments/solar_dynamo_generative_surjection.py @@ -0,0 +1,209 @@ +import distrax +import haiku as hk +import jax +import matplotlib.pyplot as plt +import numpy as np +import optax +from jax import config +from jax import numpy as jnp +from jax import random + +from experiments.solar_dynamo_data import SolarDynamoSimulator +from surjectors.bijectors.masked_coupling import MaskedCoupling +from surjectors.conditioners.mlp import mlp_conditioner +from surjectors.conditioners.transformer import transformer_conditioner +from surjectors.distributions.transformed_distribution import ( + TransformedDistribution, +) +from surjectors.surjectors.affine_masked_coupling_generative_funnel import \ + AffineMaskedCouplingGenerativeFunnel +from surjectors.surjectors.augment import Augment +from surjectors.surjectors.chain import Chain + +config.update("jax_enable_x64", True) + + +def _get_sampler(): + simulator = SolarDynamoSimulator() + return simulator.sample + + +def _encoder_fn(n_latent, n_dimension): + decoder_net = mlp_conditioner([32, 32, n_latent - n_dimension]) + + def _fn(z): + params = decoder_net(z) + mu, log_scale = jnp.split(params, 2, -1) + return distrax.Independent(distrax.Normal(mu, jnp.exp(log_scale))) + + return _fn + + +def _bijector_fn(params): + means, log_scales = jnp.split(params, 2, -1) + return distrax.ScalarAffine(means, jnp.exp(log_scales)) + + +def _base_distribution_fn(n_latent): + base_distribution = distrax.Independent( + distrax.Normal(jnp.zeros(n_latent), jnp.ones(n_latent)), + reinterpreted_batch_ndims=1, + ) + return base_distribution + + +def _get_slice_surjector(n_dimension, n_latent): + def _transformation_fn(): + layers = [] + + mask = jnp.arange(0, np.prod(n_dimension)) % 2 + mask = jnp.reshape(mask, n_dimension) + mask = mask.astype(bool) + for _ in range(2): + layer = MaskedCoupling( + mask=mask, + bijector=_bijector_fn, + conditioner=mlp_conditioner([32, 32, n_dimension]), + ) + layers.append(layer) + + layers.append( + Augment(n_dimension, _encoder_fn(n_latent, n_dimension)) + ) + + mask = jnp.arange(n_latent) < n_latent - n_dimension + layers.append( + MaskedCoupling( + mask=mask.astype(jnp.bool_), + bijector=_bijector_fn, + conditioner=transformer_conditioner(n_latent) + ) + ) + mask = jnp.arange(n_latent) >= n_latent - n_dimension + layers.append( + MaskedCoupling( + mask=mask.astype(jnp.bool_), + bijector=_bijector_fn, + conditioner=mlp_conditioner([32, 32, n_latent]) + ) + ) + #return Augment(n_dimension, _encoder_fn()) + return Chain(layers) + + def _flow(method, **kwargs): + td = TransformedDistribution(_base_distribution_fn(n_latent), _transformation_fn()) + return td(method, **kwargs) + + td = hk.transform(_flow) + return td + + +def _get_funnel_surjector(n_dimension, n_latent): + def _transformation_fn(): + layers = [] + + mask = jnp.arange(0, np.prod(n_dimension)) % 2 + mask = jnp.reshape(mask, n_dimension) + mask = mask.astype(bool) + for _ in range(2): + layer = MaskedCoupling( + mask=mask, + bijector=_bijector_fn, + conditioner=mlp_conditioner(n_dimension), + ) + layers.append(layer) + + layers.append( + AffineMaskedCouplingGenerativeFunnel( + n_dimension, + _encoder_fn(n_latent, n_dimension), + mlp_conditioner(n_latent) + ) + ) + + mask = jnp.arange(0, np.prod(n_latent)) % 2 + mask = jnp.reshape(mask, n_latent) + mask = mask.astype(bool) + for _ in range(2): + layer = MaskedCoupling( + mask=mask, + bijector=_bijector_fn, + conditioner=transformer_conditioner(n_latent), + ) + layers.append(layer) + mask = jnp.logical_not(mask) + + return Chain(layers) + # return AffineMaskedCouplingGenerativeFunnel( + # n_dimension, _encoder_fn(), _conditioner(n_latent) + # ) + + def _flow(method, **kwargs): + td = TransformedDistribution( + _base_distribution_fn(n_latent), _transformation_fn() + ) + return td(method, **kwargs) + + td = hk.transform(_flow) + return td + + +def train(rng_seq, sampler, surjector_fn, n_data, n_latent, batch_size, n_iter): + flow = surjector_fn(n_data, n_latent) + + @jax.jit + def step(params, state, y_batch, noise_batch, rng): + def loss_fn(params): + lp = flow.apply( + params, rng, method="log_prob", y=y_batch + ) + return -jnp.sum(lp) + + loss, grads = jax.value_and_grad(loss_fn)(params) + updates, new_state = adam.update(grads, state, params) + new_params = optax.apply_updates(params, updates) + return loss, new_params, new_state + + _, y_init, _, noise_init = sampler(next(rng_seq), batch_size, n_data) + params = flow.init( + next(rng_seq), method="log_prob", y=y_init + ) + adam = optax.adamw(0.001) + state = adam.init(params) + + losses = [0] * n_iter + for i in range(n_iter): + _, y_batch, _, noise_batch = sampler(next(rng_seq), batch_size, n_data) + loss, params, state = step(params, state, y_batch, noise_batch, next(rng_seq)) + losses[i] = loss + + losses = jnp.asarray(losses) + plt.plot(losses) + plt.show() + return flow, params + + +def evaluate(rng_seq, params, model, sampler, batch_size, n_data): + _, y_batch, _, noise_batch = sampler(next(rng_seq), batch_size, n_data) + lp = model.apply(params, next(rng_seq), method="log_prob", y=y_batch) + print("PPLP: {:.3f}".format(lp / batch_size)) + + +def run(): + sampler = _get_sampler() + for _fn in [_get_slice_surjector, _get_funnel_surjector]: + rng_seq = hk.PRNGSequence(0) + model, params = train( + rng_seq=rng_seq, + sampler=sampler, + surjector_fn=_fn, + n_iter=2000, + batch_size=64, + n_data=100, + n_latent=110 + ) + evaluate(rng_seq, params, model, sampler, 64, 100) + + +if __name__ == "__main__": + run() diff --git a/experiments/solar_dynamo_inference_surjection.py b/experiments/solar_dynamo_inference_surjection.py new file mode 100644 index 0000000..83aad0b --- /dev/null +++ b/experiments/solar_dynamo_inference_surjection.py @@ -0,0 +1,233 @@ +import distrax +import haiku as hk +import jax +import matplotlib.pyplot as plt +import numpy as np +import optax +from jax import config +from jax import numpy as jnp +from jax import random + +from experiments.solar_dynamo_data import SolarDynamoSimulator +from surjectors.bijectors.masked_coupling import MaskedCoupling +from surjectors.conditioners.mlp import mlp_conditioner +from surjectors.distributions.transformed_distribution import ( + TransformedDistribution, +) +from surjectors.surjectors.affine_masked_coupling_inference_funnel \ + import AffineMaskedCouplingInferenceFunnel +from surjectors.surjectors.chain import Chain +from surjectors.surjectors.slice import Slice + +config.update("jax_enable_x64", True) + + +def _get_sampler(): + simulator = SolarDynamoSimulator() + return simulator.sample + + +def _decoder_fn(n_dimension, n_latent): + decoder_net = mlp_conditioner([32, 32, n_dimension - n_latent]) + + def _fn(z): + params = decoder_net(z) + mu, log_scale = jnp.split(params, 2, -1) + return distrax.Independent(distrax.Normal(mu, jnp.exp(log_scale))) + + return _fn + + +def _bijector_fn(params): + means, log_scales = jnp.split(params, 2, -1) + return distrax.ScalarAffine(means, jnp.exp(log_scales)) + + +def _base_distribution_fn(n_latent): + base_distribution = distrax.Independent( + distrax.Normal(jnp.zeros(n_latent), jnp.ones(n_latent)), + reinterpreted_batch_ndims=1, + ) + return base_distribution + + +def _get_slice_surjector(n_dimension, n_latent): + def _transformation_fn(): + layers = [] + + mask = jnp.arange(0, np.prod(n_dimension)) % 2 + mask = jnp.reshape(mask, n_dimension) + mask = mask.astype(bool) + for _ in range(2): + layer = MaskedCoupling( + mask=mask, + bijector=_bijector_fn, + conditioner=mlp_conditioner([32, 32, n_dimension]), + ) + layers.append(layer) + + layers.append( + Slice(n_latent, _decoder_fn(n_dimension, n_latent)) + ) + + mask = jnp.arange(0, np.prod(n_latent)) % 2 + mask = jnp.reshape(mask, n_latent) + mask = mask.astype(bool) + for _ in range(2): + layer = MaskedCoupling( + mask=mask, + bijector=_bijector_fn, + conditioner=mlp_conditioner([32, 32, n_latent]), + ) + layers.append(layer) + mask = jnp.logical_not(mask) + return Chain(layers) + + def _flow(method, **kwargs): + td = TransformedDistribution(_base_distribution_fn(n_latent), _transformation_fn()) + return td(method, **kwargs) + + td = hk.transform(_flow) + return td + + +def _get_funnel_surjector(n_dimension, n_latent): + def _transformation_fn(): + layers = [] + + mask = jnp.arange(0, np.prod(n_dimension)) % 2 + mask = jnp.reshape(mask, n_dimension) + mask = mask.astype(bool) + for _ in range(2): + layer = MaskedCoupling( + mask=mask, + bijector=_bijector_fn, + conditioner=mlp_conditioner([32, 32, n_dimension]), + ) + layers.append(layer) + + layers.append( + AffineMaskedCouplingInferenceFunnel( + n_latent, + _decoder_fn(n_dimension, n_latent), + mlp_conditioner([32, 32, n_dimension])) + ) + + mask = jnp.arange(0, np.prod(n_latent)) % 2 + mask = jnp.reshape(mask, n_latent) + mask = mask.astype(bool) + for _ in range(2): + layer = MaskedCoupling( + mask=mask, + bijector=_bijector_fn, + conditioner=mlp_conditioner([32, 32, n_latent]), + ) + layers.append(layer) + mask = jnp.logical_not(mask) + return Chain(layers) + + def _flow(method, **kwargs): + td = TransformedDistribution(_base_distribution_fn(n_latent), + _transformation_fn()) + return td(method, **kwargs) + + td = hk.transform(_flow) + return td + + +def _get_bijector(n_dimension, n_latent): + def _transformation_fn(): + layers = [] + mask = jnp.arange(0, np.prod(n_dimension)) % 2 + mask = jnp.reshape(mask, n_dimension) + mask = mask.astype(bool) + for _ in range(4): + layer = MaskedCoupling( + mask=mask, + bijector=_bijector_fn, + conditioner=mlp_conditioner([32, 32, n_dimension]), + ) + layers.append(layer) + return Chain(layers) + + def _flow(method, **kwargs): + td = TransformedDistribution( + _base_distribution_fn(n_dimension), + _transformation_fn() + ) + return td(method, **kwargs) + + td = hk.transform(_flow) + return td + + +def train(rng_seq, sampler, surjector_fn, n_data, n_latent, batch_size, n_iter): + flow = surjector_fn(n_data, n_latent) + + @jax.jit + def step(params, state, y_batch, noise_batch, rng): + def loss_fn(params): + lp = flow.apply( + params, rng, method="log_prob", y=y_batch, x=noise_batch + ) + return -jnp.sum(lp) + + loss, grads = jax.value_and_grad(loss_fn)(params) + updates, new_state = adam.update(grads, state, params) + new_params = optax.apply_updates(params, updates) + return loss, new_params, new_state + + _, y_init, _, noise_init = sampler(next(rng_seq), batch_size, n_data) + params = flow.init( + next(rng_seq), method="log_prob", y=y_init, x=noise_init + ) + adam = optax.adamw(0.001) + state = adam.init(params) + + losses = [0] * n_iter + for i in range(n_iter): + _, y_batch, _, noise_batch = sampler(next(rng_seq), batch_size, n_data) + loss, params, state = step( + params, state, y_batch, noise_batch, next(rng_seq) + ) + losses[i] = loss + + losses = jnp.asarray(losses) + plt.plot(losses) + plt.show() + return flow, params + + +def evaluate(rng_seq, params, model, sampler, batch_size, n_data): + _, y_batch, _, noise_batch = sampler(next(rng_seq), batch_size, n_data) + lp = model.apply(params, next(rng_seq), method="log_prob", y=y_batch, x=noise_batch) + print("\tPPLP: {:.3f}".format(jnp.mean(lp))) + + +def run(): + n_iter = 2000 + batch_size = 64 + n_data, n_latent = 100, 75 + sampler = _get_sampler() + for method, _fn in [ + ["Slice", _get_slice_surjector], + ["Funnel", _get_funnel_surjector], + ["Bijector", _get_bijector] + ]: + print(f"Doing {method}") + rng_seq = hk.PRNGSequence(0) + model, params = train( + rng_seq=rng_seq, + sampler=sampler, + surjector_fn=_fn, + n_iter=n_iter, + batch_size=batch_size, + n_data=n_data, + n_latent=n_latent + ) + evaluate(rng_seq, params, model, sampler, batch_size, n_data) + + +if __name__ == "__main__": + run() + diff --git a/surjectors/bijectors/__init__.py b/surjectors/bijectors/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/surjectors/surjectors/lu_linear.py b/surjectors/bijectors/lu_linear.py similarity index 93% rename from surjectors/surjectors/lu_linear.py rename to surjectors/bijectors/lu_linear.py index 816a4a8..e104930 100644 --- a/surjectors/surjectors/lu_linear.py +++ b/surjectors/bijectors/lu_linear.py @@ -7,8 +7,11 @@ class LULinear(Surjector): - def __init__(self, n_keep, dtype=jnp.float32): + def __init__(self, n_keep, with_bias=False, dtype=jnp.float32): super().__init__(n_keep, None, None, "bijection", dtype) + if with_bias: + raise NotImplementedError() + n_triangular_entries = ((n_keep - 1) * n_keep) // 2 self._lower_indices = np.tril_indices(n_keep, k=-1) @@ -51,4 +54,3 @@ def inverse_and_likelihood_contribution(self, y): def forward_and_likelihood_contribution(self, z): pass - diff --git a/surjectors/bijectors/masked_coupling.py b/surjectors/bijectors/masked_coupling.py new file mode 100644 index 0000000..962278b --- /dev/null +++ b/surjectors/bijectors/masked_coupling.py @@ -0,0 +1,45 @@ +from typing import Optional, Tuple + +import distrax +from distrax._src.utils import math +from jax import numpy as jnp + +from surjectors.distributions.transformed_distribution import Array + + +class MaskedCoupling(distrax.MaskedCoupling): + def __init__(self, mask: Array, conditioner, bijector, + event_ndims: Optional[int] = None, inner_event_ndims: int = 0): + super().__init__(mask, conditioner, bijector, event_ndims, inner_event_ndims) + + def forward_and_log_det(self, z: Array, x: Array = None) -> Tuple[Array, Array]: + self._check_forward_input_shape(z) + masked_z = jnp.where(self._event_mask, z, 0.) + if x is not None: + masked_z = jnp.concatenate([masked_z, x], axis=-1) + params = self._conditioner(masked_z) + y0, log_d = self._inner_bijector(params).forward_and_log_det(z) + y = jnp.where(self._event_mask, z, y0) + logdet = math.sum_last( + jnp.where(self._mask, 0., log_d), + self._event_ndims - self._inner_event_ndims + ) + return y, logdet + + def forward(self, z: Array, x: Array = None) ->Array: + y, log_det = self.forward_and_log_det(z, x) + return y + + def inverse_and_log_det(self, y: Array, x: Array = None) -> Tuple[Array, Array]: + self._check_inverse_input_shape(y) + masked_y = jnp.where(self._event_mask, y, 0.) + if x is not None: + masked_y = jnp.concatenate([masked_y, x], axis=-1) + params = self._conditioner(masked_y) + z0, log_d = self._inner_bijector(params).inverse_and_log_det(y) + z = jnp.where(self._event_mask, y, z0) + logdet = math.sum_last( + jnp.where(self._mask, 0., log_d), + self._event_ndims - self._inner_event_ndims + ) + return z, logdet diff --git a/surjectors/conditioners/__init__.py b/surjectors/conditioners/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/surjectors/conditioners/mlp.py b/surjectors/conditioners/mlp.py new file mode 100644 index 0000000..553dc65 --- /dev/null +++ b/surjectors/conditioners/mlp.py @@ -0,0 +1,19 @@ +import haiku as hk +import jax +from jax import numpy as jnp + + +def mlp_conditioner( + dims, + activation=jax.nn.gelu, + w_init=hk.initializers.TruncatedNormal(stddev=0.01), + b_init=jnp.zeros +): + dims[-1] = dims[-1] * 2 + + return hk.nets.MLP( + output_sizes=dims, + w_init=w_init, + b_init=b_init, + activation=activation + ) diff --git a/surjectors/conditioners/transformer.py b/surjectors/conditioners/transformer.py new file mode 100644 index 0000000..c1e79a3 --- /dev/null +++ b/surjectors/conditioners/transformer.py @@ -0,0 +1,72 @@ +import dataclasses +from typing import Callable +from typing import Optional + +import haiku as hk +import jax + + +@dataclasses.dataclass +class _EncoderLayer(hk.Module): + num_heads: int + num_layers: int + key_size: int + dropout_rate: float + widening_factor: int = 4 + initializer: Callable = hk.initializers.TruncatedNormal(stddev=0.01) + name: Optional[str] = None + + def __call__(self, inputs, *, is_training): + dropout_rate = self.dropout_rate if is_training else 0. + #causal_mask = np.tril(np.ones((1, 1, seq_len, seq_len))) + causal_mask=None + model_size=self.key_size * self.num_heads + + h = inputs + for _ in range(self.num_layers): + attn_block = hk.MultiHeadAttention( + num_heads=self.num_heads, + key_size=self.key_size, + model_size=model_size, + w_init=self.initializer, + ) + h_norm = hk.LayerNorm(axis=-1, create_scale=True, create_offset=True)(h) + h_attn = attn_block(h_norm, h_norm, h_norm, mask=causal_mask) + h_attn = hk.dropout(hk.next_rng_key(), dropout_rate, h_attn) + h = h + h_attn + + mlp = hk.nets.MLP( + [self.widening_factor * model_size, model_size], + w_init=self.initializer, + activation=jax.nn.gelu + ) + h_norm = hk.LayerNorm(axis=-1, create_scale=True, create_offset=True)(h) + h_dense = mlp(h_norm) + h_dense = hk.dropout(hk.next_rng_key(), dropout_rate, h_dense) + h = h + h_dense + + return hk.LayerNorm(axis=-1, create_scale=True, create_offset=True)(h) + + +@dataclasses.dataclass +class _AutoregressiveTransformerEncoder(hk.Module): + linear: hk.Linear + transformer: _EncoderLayer + output_size: int + name: Optional[str] = None + + def __call__(self, inputs, *, is_training=True): + h = self.linear(inputs) + h = self.transformer(h, is_training=is_training) + return hk.Linear(self.output_size)(h) + + +def transformer_conditioner( + output_size, num_heads=2, num_layers=2, key_size=32, dropout_rate=0.1, widening_factor=4 +): + linear = hk.Linear(key_size * num_heads) + encoder = _EncoderLayer( + num_heads, num_layers, key_size, dropout_rate, widening_factor + ) + transformer = _AutoregressiveTransformerEncoder(linear, encoder, output_size * 2) + return transformer diff --git a/surjectors/data.py b/surjectors/data.py deleted file mode 100644 index a14aa38..0000000 --- a/surjectors/data.py +++ /dev/null @@ -1,71 +0,0 @@ -from jax import random, lax -from jax.scipy.special import erf - - -class Simulator: - def __new__(cls, simulator="solar_dynamo", **kwargs): - if simulator == "solar_dynamo": - return SolarDynamoSimulator(**kwargs) - return StandardSimulator() - - -class StandardSimulator: - pass - - -class SolarDynamoSimulator: - def __init__(self, **kwargs): - self.p0_mean = kwargs.get("p0_mean", 1.0) - self.p0_std = kwargs.get("p0_std", 1.0) - self.alpha1_min = kwargs.get("alpha1_min", 1.3) - self.alpha1_max = kwargs.get("alpha1_max", 1.5) - self.alpha2_max = kwargs.get("alpha2_max", 1.65) - self.epsilon_max = kwargs.get("epsilon_max", 0.5) - self.alpha1 = kwargs.get("alpha1", None) - self.alpha2 = kwargs.get("alpha2", None) - - def sample(self, key, len_timeseries=1000): - p_key, alpha1_key, alpha2_key, epsilon_key, key = random.split(key, 5) - p0 = random.normal(p_key) * self.p0_std + self.p0_mean - alpha1 = random.uniform( - alpha1_key, minval=self.alpha1_min, maxval=self.alpha1_max - ) - alpha2 = random.uniform( - alpha2_key, minval=alpha1, maxval=self.alpha2_max - ) - epsilon_max = random.uniform( - epsilon_key, minval=0, maxval=self.epsilon_max - ) - batch = self._sample_timeseries( - key, p0, alpha1, alpha2, epsilon_max, len_timeseries - ) - - return p0, alpha1, alpha2, epsilon_max, batch[0], batch[1] - - def babcock_leighton_fn(self, p, b_1=0.6, w_1=0.2, b_2=1.0, w_2=0.8): - f = 0.5 * (1.0 + erf((p - b_1) / w_1)) * (1.0 - erf((p - b_2) / w_2)) - return f - - def babcock_leighton(self, p, alpha, epsilon): - p = alpha * self.babcock_leighton_fn(p) * p + epsilon - return p - - def _sample_timeseries( - self, key, pn, alpha_min, alpha_max, epsilon_max, len_timeseries - ): - a = random.uniform( - key, minval=alpha_min, maxval=alpha_max, shape=(len_timeseries,) - ) - e = random.uniform( - key, minval=0.0, maxval=epsilon_max, shape=(len_timeseries,) - ) - - def _fn(fs, arrays): - alpha, epsilon = arrays - f, pn = fs - f = self.babcock_leighton_fn(pn) - pn = self.babcock_leighton(pn, alpha, epsilon) - return (f, pn), (f, pn) - - _, pn = lax.scan(_fn, (pn, pn), (a, e)) - return pn diff --git a/surjectors/distributions/conditional_distribution.py b/surjectors/distributions/conditional_distribution.py new file mode 100644 index 0000000..ab030b3 --- /dev/null +++ b/surjectors/distributions/conditional_distribution.py @@ -0,0 +1,2 @@ +class ConditionalDistribution: + pass diff --git a/surjectors/distribution.py b/surjectors/distributions/distribution.py similarity index 100% rename from surjectors/distribution.py rename to surjectors/distributions/distribution.py diff --git a/surjectors/distributions/transformed_distribution.py b/surjectors/distributions/transformed_distribution.py index a9765dc..b2d8483 100644 --- a/surjectors/distributions/transformed_distribution.py +++ b/surjectors/distributions/transformed_distribution.py @@ -1,40 +1,48 @@ from typing import Tuple +import chex +import haiku as hk import jax import jax.numpy as jnp -from distrax._src.bijectors import bijector as bjct_base -from distrax._src.distributions import distribution as dist_base -from distrax._src.utils import conversion -from tensorflow_probability.substrates import jax as tfp +from distrax import Distribution -tfd = tfp.distributions - -PRNGKey = dist_base.PRNGKey -Array = dist_base.Array -DistributionLike = dist_base.DistributionLike -BijectorLike = bjct_base.BijectorLike +Array = chex.Array +from surjectors.surjectors.surjector import Surjector class TransformedDistribution: - def __init__(self, base_distribution, surjector): + def __init__(self, base_distribution: Distribution, surjector: Surjector): self.base_distribution = base_distribution self.surjector = surjector - def log_prob(self, y: Array) -> Array: - x, ildj_y = self.surjector.inverse_and_log_det(y) - lp_x = self.base_distribution.log_prob(x) - lp_y = lp_x + ildj_y - return lp_y + def __call__(self, method, **kwargs): + return getattr(self, method)(**kwargs) + + def log_prob(self, y: Array, x: Array = None) -> Array: + _, lp = self.inverse_and_log_prob(y, x) + return lp - def sample(self, key: PRNGKey, sample_shape=(1,)): - z = self.base_distribution.sample(seed=key, sample_shape=sample_shape) - y = jax.vmap(self.surjector.inverse)(z) + def inverse_and_log_prob(self, y: Array, x: Array=None) -> Tuple[Array, Array]: + x, lc = self.surjector.inverse_and_likelihood_contribution(y, x=x) + lp_x = self.base_distribution.log_prob(x) + lp = lp_x + lc + return x, lp + + def sample(self, sample_shape=(), x: Array = None): + if x is not None and len(sample_shape): + chex.assert_equal(sample_shape[0], x.shape[0]) + elif x is not None: + sample_shape = (x.shape[0],) + z = self.base_distribution.sample(seed=hk.next_rng_key(), sample_shape=sample_shape) + y = jax.vmap(self.surjector.forward)(z, x) return y - def sample_and_log_prob(self, key: PRNGKey, sample_shape=(1,)): + def sample_and_log_prob(self, sample_shape=(1,), x: Array = None): z, lp_z = self.base_distribution.sample_and_log_prob( - seed=key, sample_shape=sample_shape + seed=hk.next_rng_key(), sample_shape=sample_shape, x=x + ) + y, fldj = jax.vmap(self.surjector.forward_and_likelihood_contribution)( + z, x=x ) - y, fldj = jax.vmap(self.surjector.forward_and_log_det)(z) - lp_y = jax.vmap(jnp.subtract)(lp_z, fldj) - return y, lp_y + lp = jax.vmap(jnp.subtract)(lp_z, fldj) + return y, lp diff --git a/surjectors/surjectors/transform.py b/surjectors/surjectors/_transform.py similarity index 96% rename from surjectors/surjectors/transform.py rename to surjectors/surjectors/_transform.py index 03d47fe..d0861a0 100644 --- a/surjectors/surjectors/transform.py +++ b/surjectors/surjectors/_transform.py @@ -1,6 +1,5 @@ from abc import ABCMeta, abstractmethod -import distrax from distrax._src.utils import jittable diff --git a/surjectors/surjectors/affine_masked_coupling_generative_funnel.py b/surjectors/surjectors/affine_masked_coupling_generative_funnel.py new file mode 100644 index 0000000..f80ffa6 --- /dev/null +++ b/surjectors/surjectors/affine_masked_coupling_generative_funnel.py @@ -0,0 +1,50 @@ +import distrax +from chex import Array +from jax import numpy as jnp +import haiku as hk +from surjectors.bijectors.masked_coupling import MaskedCoupling +from surjectors.surjectors.funnel import Funnel + + +class AffineMaskedCouplingGenerativeFunnel(Funnel): + def __init__(self, n_keep, encoder, conditioner): + super().__init__(n_keep, None, conditioner, encoder, "generative_surjector") + + def _mask(self, array): + mask = jnp.arange(array.shape[-1]) >= self.n_keep + mask = mask.astype(jnp.bool_) + return mask + + def _inner_bijector(self, mask): + def _bijector_fn(params: Array): + shift, log_scale = jnp.split(params, 2, axis=-1) + return distrax.ScalarAffine(shift, jnp.exp(log_scale)) + + return MaskedCoupling( + mask, self._conditioner, _bijector_fn + ) + + def inverse_and_likelihood_contribution(self, y, x=None): + y_condition = y + # TODO + if x is not None: + y_condition = jnp.concatenate([y, x], axis=-1) + z_minus, lc = self.encoder(y_condition).sample_and_log_prob(seed=hk.next_rng_key()) + input = jnp.concatenate([y, z_minus], axis=-1) + # TODO: remote the conditioning here? + z, jac_det = self._inner_bijector(self._mask(input)).inverse_and_log_det(input) + return z, -lc + jac_det + + def forward_and_likelihood_contribution(self, z, x=None): + # TODO: remote the conditioning here? + faux, jac_det = self._inner_bijector(self._mask(z)).inverse_and_log_det(z) + y = faux[..., :self.n_keep] + y_condition = y + if x is not None: + y_condition = jnp.concatenate([y_condition, x], axis=-1) + lc = self.encoder(y_condition).log_prob(faux[..., self.n_keep:]) + return y, -lc + jac_det + + def forward(self, z, x=None): + y, _ = self.forward_and_likelihood_contribution(z, x) + return y diff --git a/surjectors/surjectors/affine_masked_coupling_inference_funnel.py b/surjectors/surjectors/affine_masked_coupling_inference_funnel.py new file mode 100644 index 0000000..e8d3804 --- /dev/null +++ b/surjectors/surjectors/affine_masked_coupling_inference_funnel.py @@ -0,0 +1,49 @@ +import distrax +from chex import Array +from jax import numpy as jnp +import haiku as hk +from surjectors.bijectors.masked_coupling import MaskedCoupling +from surjectors.surjectors.funnel import Funnel + + +class AffineMaskedCouplingInferenceFunnel(Funnel): + def __init__(self, n_keep, decoder, conditioner): + super().__init__(n_keep, decoder, conditioner, None, "inference_surjector") + + def _mask(self, array): + mask = jnp.arange(array.shape[-1]) >= self.n_keep + mask = mask.astype(jnp.bool_) + return mask + + def _inner_bijector(self, mask): + def _bijector_fn(params: Array): + shift, log_scale = jnp.split(params, 2, axis=-1) + return distrax.ScalarAffine(shift, jnp.exp(log_scale)) + + return MaskedCoupling( + mask, self._conditioner, _bijector_fn + ) + + def inverse_and_likelihood_contribution(self, y, x=None): + # TODO: remote the conditioning here? + faux, jac_det = self._inner_bijector(self._mask(y)).inverse_and_log_det(y, x) + z_condition = z = faux[:, :self.n_keep] + if x is not None: + z_condition = jnp.concatenate([z, x], axis=-1) + lc = self.decoder(z_condition).log_prob(y[:, self.n_keep:]) + return z, lc + jac_det + + def forward_and_likelihood_contribution(self, z, x=None): + z_condition = z + if x is not None: + z_condition = jnp.concatenate([z, x], axis=-1) + y_minus, jac_det = self.decoder(z_condition).sample_and_log_prob(seed=hk.next_rng_key()) + # TODO need to sort the indexes correctly (?) + z_tilde = jnp.concatenate([z, y_minus], axis=-1) + # TODO: remote the conditioning here? + y, lc = self._inner_bijector(self._mask(z_tilde)).forward_and_log_det(z_tilde, x) + return y, lc + jac_det + + def forward(self, z, x=None): + y, _ = self.forward_and_likelihood_contribution(z, x) + return y \ No newline at end of file diff --git a/surjectors/surjectors/augment.py b/surjectors/surjectors/augment.py new file mode 100644 index 0000000..44e15bd --- /dev/null +++ b/surjectors/surjectors/augment.py @@ -0,0 +1,34 @@ +import distrax +from chex import Array +from jax import numpy as jnp +import haiku as hk +from surjectors.surjectors.funnel import Funnel + + +class Augment(Funnel): + def __init__(self, n_keep, encoder): + super().__init__(n_keep, None, None, encoder, "generative_surjector") + + def split_input(self, input): + spl = jnp.split(input, [self.n_keep], axis=-1) + return spl + + def inverse_and_likelihood_contribution(self, y, x: Array = None): + z_plus = y_condition = y + if x is not None: + y_condition = jnp.concatenate([y_condition, x], axis=-1) + z_minus, lc = self.encoder(y_condition).sample_and_log_prob(seed=hk.next_rng_key()) + z = jnp.concatenate([z_plus, z_minus], axis=-1) + return z, -lc + + def forward_and_likelihood_contribution(self, z, x=None): + z_plus, z_minus = self.split_input(z) + y_condition = y = z_plus + if x is not None: + y_condition = jnp.concatenate([y_condition, x], axis=-1) + lc = self.encoder(y_condition).log_prob(z_minus) + return y, -lc + + def forward(self, z, x=None): + y, _ = self.forward_and_likelihood_contribution(z, x) + return y diff --git a/surjectors/surjectors/chain.py b/surjectors/surjectors/chain.py index 24dd6f7..ab79199 100644 --- a/surjectors/surjectors/chain.py +++ b/surjectors/surjectors/chain.py @@ -3,19 +3,43 @@ class Chain(Surjector): def __init__(self, surjectors): + super().__init__(None, None, None, "surjector") self._surjectors = surjectors - def inverse_and_likelihood_contribution(self, y): - z, log_det = self._surjectors[0].forward_and_log_det(y) + def inverse_and_likelihood_contribution(self, y, x=None): + z, lcs = self._inverse_and_log_contribution_dispatch( + self._surjectors[0], y, x + ) for surjector in self._surjectors[1:]: - x, lc = surjector.inverse_and_likelihood_contribution(z) - log_det += lc - return z, log_det + z, lc = self._inverse_and_log_contribution_dispatch(surjector, z, x) + lcs += lc + return z, lcs + + @staticmethod + def _inverse_and_log_contribution_dispatch(surjector, y, x): + if isinstance(surjector, Surjector): + fn = getattr(surjector, "inverse_and_likelihood_contribution") + else: + fn = getattr(surjector, "inverse_and_log_det") + z, lc = fn(y, x) + return z, lc - def forward_and_likelihood_contribution(self, z): - y, log_det = self._surjectors[-1].forward_and_log_det(z) - for _surjectors in reversed(self._surjectors[:-1]): - x, lc = _surjectors.forward_and_log_det(x) + def forward_and_likelihood_contribution(self, z, x=None): + y, log_det = self._surjectors[-1].forward_and_log_det(z, x) + for surjector in reversed(self._surjectors[:-1]): + y, lc = self._forward_and_log_contribution_dispatch(surjector, y, x) log_det += lc return y, log_det + @staticmethod + def _forward_and_log_contribution_dispatch(surjector, y, x): + if isinstance(surjector, Surjector): + fn = getattr(surjector, "forward_and_likelihood_contribution") + else: + fn = getattr(surjector, "forward_and_log_det") + z, lc = fn(y, x) + return z, lc + + def forward(self, z, x=None): + y, _ = self.forward_and_likelihood_contribution(z, x) + return y diff --git a/surjectors/surjectors/funnel.py b/surjectors/surjectors/funnel.py index c53df17..64f9fc5 100644 --- a/surjectors/surjectors/funnel.py +++ b/surjectors/surjectors/funnel.py @@ -1,22 +1,9 @@ -from jax import numpy as jnp +from abc import ABC from surjectors.surjectors.surjector import Surjector -class Funnel(Surjector): - def __init__(self, n_keep, decoder, encoder=None, kind="inference_surjection"): +class Funnel(Surjector, ABC): + def __init__(self, n_keep, decoder, conditioner, encoder, kind): super().__init__(n_keep, decoder, encoder, kind) - - def split_input(self, input): - split_proportions = (self.n_keep, input.shape[-1] - self.n_keep) - return jnp.split(input, split_proportions, axis=-1) - - def inverse_and_likelihood_contribution(self, y): - z, y_minus = self.split_input(y) - lc = self.decoder.log_prob(y_minus, context=z) - return z, lc - - def forward_and_likelihood_contribution(self, z): - y_minus = self.decoder.sample(context=z) - y = jnp.concatenate([z, y_minus], axis=-1) - return y + self._conditioner = conditioner diff --git a/surjectors/surjectors/linear.py b/surjectors/surjectors/linear.py deleted file mode 100644 index a8fbcfb..0000000 --- a/surjectors/surjectors/linear.py +++ /dev/null @@ -1,60 +0,0 @@ -import optax -from distrax import LowerUpperTriangularAffine -from jax import numpy as jnp -import jax -import haiku as hk -import distrax - -from surjectors.surjectors.funnel import Funnel -from surjectors.surjectors.lu_linear import LULinear - - -class MLP(Funnel, hk.Module): - def __init__(self, n_keep, decoder, dtype=jnp.float32): - self._r = LULinear(n_keep, dtype) - self._w_prime = hk.Linear(n_keep, with_bias=True) - - self._decoder = decoder # TODO: should be a conditional gaussian - super().__init__(n_keep, decoder) - - def inverse_and_likelihood_contribution(self, y): - x_plus, x_minus = y[:, :self.n_keep], y[:, self.n_keep:] - z, lc = self._r.inverse_and_likelihood_contribution(x_plus) + self._w_prime(x_minus) - lp = self._decoder.log_prob(x_minus, context=z) - return z, lp + lc - - def forward_and_likelihood_contribution(self, z): - pass - -prng = hk.PRNGSequence(jax.random.PRNGKey(42)) -matrix = jax.random.uniform(next(prng), (4, 4)) -bias = jax.random.normal(next(prng), (4,)) -bijector = LowerUpperTriangularAffine(matrix, bias) - -# -# def loss(): -# x, lc = bijector.inverse_and_log_det(jnp.zeros(4) * 2.1) -# lp = distrax.Normal(jnp.zeros(4)).log_prob(x) -# return -jnp.sum(lp - lc) -# -# print(bijector.matrix) -# -# adam = optax.adam(0.003) -# g = jax.grad(loss)() -# -# print(g) -# - -matrix = jax.random.uniform(next(prng), (4, 4)) -bias = jax.random.normal(next(prng), (4,)) -bijector = LowerUpperTriangularAffine(matrix, bias) - -n = jnp.ones((4, 4)) * 3.1 -n += jnp.triu(n) * 2 - -bijector = LowerUpperTriangularAffine(n, jnp.zeros(4)) - - -print(bijector.forward(jnp.ones(4))) - -print(n @jnp.ones(4) ) \ No newline at end of file diff --git a/surjectors/surjectors/mlp.py b/surjectors/surjectors/mlp.py new file mode 100644 index 0000000..d5188c9 --- /dev/null +++ b/surjectors/surjectors/mlp.py @@ -0,0 +1,30 @@ +import distrax +import haiku as hk +from jax import numpy as jnp + +from surjectors.surjectors.affine_masked_coupling_inference_funnel import Funnel +from surjectors.surjectors.lu_linear import LULinear + + +class MLP(Funnel, hk.Module): + def __init__(self, n_keep, decoder, dtype=jnp.float32): + self._r = LULinear(n_keep, dtype, with_bias=False) + self._w_prime = hk.Linear(n_keep, with_bias=True) + + self._decoder = decoder + super().__init__(n_keep, decoder) + + def inverse_and_likelihood_contribution(self, y): + y_plus, y_minus = self.split_input(y) + z, jac_det = self._r.inverse_and_likelihood_contribution(y_plus) + z += self._w_prime(y_minus) + lp = self._decode(z).log_prob(y_minus) + return z, lp + jac_det + + def _decode(self, array): + mu, log_scale = self._decoder(array) + distr = distrax.MultivariateNormalDiag(mu, jnp.exp(log_scale)) + return distr + + def forward_and_likelihood_contribution(self, z): + pass diff --git a/surjectors/surjectors/rq_coupling_funnel.py b/surjectors/surjectors/rq_coupling_funnel.py new file mode 100644 index 0000000..22b31fe --- /dev/null +++ b/surjectors/surjectors/rq_coupling_funnel.py @@ -0,0 +1,45 @@ +import distrax +from chex import Array +from distrax import MaskedCoupling +from jax import numpy as jnp + +from surjectors.surjectors.surjector import Surjector + + +class NSFCouplingFunnel(Surjector): + def __init__( + self, + n_keep, + decoder, + conditioner, + range_min, + range_max, + kind="inference_surjection", + ): + super().__init__(n_keep, decoder, None, kind) + self._conditioner = conditioner + self._range_min = range_min + self._range_max = range_max + + def _mask(self, array): + mask = jnp.arange(array.shape[-1]) >= self.n_keep + mask = mask.astype(jnp.bool_) + return mask + + def _inner_bijector(self, mask): + def _bijector_fn(params: Array): + return distrax.RationalQuadraticSpline( + params, range_min=self._range_min, range_max=self._range_max + ) + + return MaskedCoupling(mask, self._conditioner, _bijector_fn) + + def inverse_and_likelihood_contribution(self, y): + mask = self._mask(y) + faux, jac_det = self._inner_bijector(mask).inverse_and_log_det(y) + z = faux[:, : self.n_keep] + lp = self.decoder.log_prob(faux[:, self.n_keep :], context=z) + return z, lp + jac_det + + def forward_and_likelihood_contribution(self, z): + raise NotImplementedError() diff --git a/surjectors/surjectors/slice.py b/surjectors/surjectors/slice.py index 311d3fe..8e74675 100644 --- a/surjectors/surjectors/slice.py +++ b/surjectors/surjectors/slice.py @@ -1,9 +1,37 @@ +import distrax +from chex import Array from jax import numpy as jnp - -from surjectors.funnel import Funnel +import haiku as hk +from surjectors.surjectors.funnel import Funnel class Slice(Funnel): - def __init__(self, n_keep, kind="inference_surjection"): - # TODO: implement decoder and encoder - super().__init__(kind, decoder, encoder, n_keep) \ No newline at end of file + def __init__(self, n_keep, decoder): + super().__init__(n_keep, decoder, None, None, "inference_surjector") + + def split_input(self, input): + spl = jnp.split(input, [self.n_keep], axis=-1) + return spl + + def inverse_and_likelihood_contribution(self, y, x: Array=None): + z, y_minus = self.split_input(y) + z_condition = z + if x is not None: + z_condition = jnp.concatenate([z, x], axis=-1) + lc = self.decoder(z_condition).log_prob(y_minus) + return z, lc + + def forward_and_likelihood_contribution(self, z, x=None): + z_condition = z + if x is not None: + z_condition = jnp.concatenate([z, x], axis=-1) + y_minus, lc = self.decoder(z_condition).sample_and_log_prob( + seed=hk.next_rng_key() + ) + y = jnp.concatenate([z, y_minus], axis=-1) + + return y, lc + + def forward(self, z, x=None): + y, _ = self.forward_and_likelihood_contribution(z, x) + return y diff --git a/surjectors/surjectors/surjector.py b/surjectors/surjectors/surjector.py index a4e562f..ee4eee3 100644 --- a/surjectors/surjectors/surjector.py +++ b/surjectors/surjectors/surjector.py @@ -1,20 +1,27 @@ -from abc import abstractmethod +from abc import abstractmethod, ABC from jax import numpy as jnp -from surjectors.surjectors.transform import Transform +from surjectors.surjectors._transform import Transform -_valid_kinds = ["inference_surjector", "generative_surjector", "bijector"] +_valid_kinds = [ + "inference_surjector", + "generative_surjector", + "bijector", + "surjector", +] -class Surjector(Transform): +class Surjector(Transform, ABC): """ Surjector base class """ + def __init__(self, n_keep, decoder, encoder, kind, dtype=jnp.float32): if kind not in _valid_kinds: raise ValueError( - "'kind' argument needs to be either of: " "/".join(_valid_kinds) + "'kind' argument needs to be either of: " + + "/".join(_valid_kinds) ) if kind == _valid_kinds[1] and encoder is None: raise ValueError( @@ -49,4 +56,4 @@ def encoder(self): @property def dtype(self): - return self._dtype \ No newline at end of file + return self._dtype