Skip to content

Commit

Permalink
Add an example for NeuTra HMC (#248)
Browse files Browse the repository at this point in the history
* initialize the work on neutra hmc

* sampling and plot neutra

* add training plot

* fix init step size

* update plots

* tunning numbers to get nicer plot?

* add args to tune
  • Loading branch information
fehiepsi authored and martinjankowiak committed Jul 19, 2019
1 parent 38df2f0 commit 781baf5
Show file tree
Hide file tree
Showing 3 changed files with 163 additions and 9 deletions.
143 changes: 143 additions & 0 deletions examples/neutra.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,143 @@
import argparse

from matplotlib.gridspec import GridSpec
import matplotlib.pyplot as plt
import seaborn as sns

from jax import jit, random, vmap
from jax.config import config as jax_config
from jax.experimental import optimizers
import jax.numpy as np
from jax.scipy.special import logsumexp
from jax.tree_util import tree_map

from numpyro.contrib.autoguide import AutoIAFNormal
from numpyro.diagnostics import summary
import numpyro.distributions as dist
from numpyro.handlers import sample
from numpyro.hmc_util import initialize_model
from numpyro.mcmc import mcmc
from numpyro.svi import elbo, svi
from numpyro.util import fori_collect

"""
This example illustrates how to use a trained AutoIAFNormal autoguide to transform a posterior to a
Gaussian-like one. The transform will be used to get better mixing rate for NUTS sampler.
[1] Hoffman, M. et al. (2019), ["NeuTra-lizing Bad Geometry in Hamiltonian Monte Carlo Using Neural Transport"]
(https://arxiv.org/abs/1903.03704).
"""


def dual_moon_pe(x):
term1 = 0.5 * ((np.linalg.norm(x, axis=-1) - 2) / 0.4) ** 2
term2 = -0.5 * ((x[..., :1] + np.array([-2., 2.])) / 0.6) ** 2
return term1 - logsumexp(term2, axis=-1)


def dual_moon_model():
x = sample('x', dist.Uniform(-10 * np.ones(2), 10 * np.ones(2)))
pe = dual_moon_pe(x)
sample('log_density', dist.Delta(log_density=-pe), obs=0.)


def make_transformed_pe(potential_fn, transform, unpack_fn):
def transformed_potential_fn(z):
# NB: currently, intermediates for ComposeTransform is None, so this has no effect
# see https://github.com/pyro-ppl/numpyro/issues/242
u, intermediates = transform.call_with_intermediates(z)
logdet = transform.log_abs_det_jacobian(z, u, intermediates=intermediates)
return potential_fn(unpack_fn(u)) + logdet

return transformed_potential_fn


def main(args):
jax_config.update('jax_platform_name', args.device)

print("Start vanilla HMC...")
# TODO: set progbar=True when https://github.com/google/jax/issues/939 is resolved
vanilla_samples = mcmc(args.num_warmup, args.num_samples, init_params=np.array([2., 0.]),
potential_fn=dual_moon_pe, progbar=False)

opt_init, opt_update, get_params = optimizers.adam(0.001)
rng_guide, rng_init, rng_train = random.split(random.PRNGKey(1), 3)
guide = AutoIAFNormal(rng_guide, dual_moon_model, get_params, hidden_dims=[args.num_hidden])
svi_init, svi_update, _ = svi(dual_moon_model, guide, elbo, opt_init, opt_update, get_params)
opt_state, _ = svi_init(rng_init)

def body_fn(val):
i, loss, opt_state_, rng_ = val
loss, opt_state_, rng_ = svi_update(i, rng_, opt_state_)
return i + 1, loss, opt_state_, rng_

print("Start training guide...")
losses, opt_states = fori_collect(0, args.num_iters, jit(body_fn),
(0, 0., opt_state, rng_train),
transform=lambda x: (x[1], x[2]), progbar=False)
last_state = tree_map(lambda x: x[-1], opt_states)
print("Finish training guide. Extract samples...")
guide_samples = guide.sample_posterior(random.PRNGKey(0), last_state,
sample_shape=(args.num_samples,))

transform = guide.get_transform(last_state)
unpack_fn = lambda u: guide.unpack_latent(u, transform={}) # noqa: E731

_, potential_fn, constrain_fn = initialize_model(random.PRNGKey(0), dual_moon_model)
transformed_potential_fn = make_transformed_pe(potential_fn, transform, unpack_fn)
transformed_constrain_fn = lambda x: constrain_fn(unpack_fn(transform(x))) # noqa: E731

# TODO: expose latent_size in autoguide
init_params = np.zeros(np.size(guide._init_latent))
print("\nStart NeuTra HMC...")
zs = mcmc(args.num_warmup, args.num_samples, init_params, potential_fn=transformed_potential_fn)
print("Transform samples into unwarped space...")
samples = vmap(transformed_constrain_fn)(zs)
summary(tree_map(lambda x: x[None, ...], samples))

# make plots
x1 = np.linspace(-3, 3, 100)
x2 = np.linspace(-3, 3, 100)
X1, X2 = np.meshgrid(x1, x2)
P = np.clip(np.exp(-dual_moon_pe(np.stack([X1, X2], axis=-1))), a_min=0.)

fig = plt.figure(figsize=(12, 12), constrained_layout=True)
gs = GridSpec(2, 2, figure=fig)
ax1 = fig.add_subplot(gs[0, 0])
ax2 = fig.add_subplot(gs[0, 1])
ax3 = fig.add_subplot(gs[1, 0])
ax4 = fig.add_subplot(gs[1, 1])

ax1.plot(losses[1000:])
ax1.set_title('Autoguide training loss (skip the first 1000 steps)')

ax2.contourf(X1, X2, P, cmap='OrRd')
sns.kdeplot(guide_samples['x'][:, 0], guide_samples['x'][:, 1], ax=ax2)
ax2.set(xlim=[-3, 3], ylim=[-3, 3], aspect='equal',
xlabel='x0', ylabel='x1', title='Posterior using AutoIAFNormal guide')

ax3.contourf(X1, X2, P, cmap='OrRd')
sns.kdeplot(vanilla_samples[:, 0], vanilla_samples[:, 1], ax=ax3)
ax3.plot(vanilla_samples[-50:, 0], vanilla_samples[-50:, 1], 'bo-', alpha=0.5)
ax3.set(xlim=[-3, 3], ylim=[-3, 3], aspect='equal',
xlabel='x0', ylabel='x1', title='Posterior using vanilla HMC sampler')

ax4.contourf(X1, X2, P, cmap='OrRd')
sns.kdeplot(samples['x'][:, 0], samples['x'][:, 1], ax=ax4)
ax4.plot(samples['x'][-50:, 0], samples['x'][-50:, 1], 'bo-', alpha=0.5)
ax4.set(xlim=[-3, 3], ylim=[-3, 3], aspect='equal',
xlabel='x0', ylabel='x1', title='Posterior using NeuTra HMC sampler')

plt.savefig("neutra.pdf")
plt.close()


if __name__ == "__main__":
parser = argparse.ArgumentParser(description="NeuTra HMC")
parser.add_argument('-n', '--num-samples', nargs='?', default=10000, type=int)
parser.add_argument('--num-warmup', nargs='?', default=0, type=int)
parser.add_argument('--num-hidden', nargs='?', default=15, type=int)
parser.add_argument('--num-iters', nargs='?', default=10000, type=int)
parser.add_argument('--device', default='cpu', type=str, help='use "cpu" or "gpu".')
args = parser.parse_args()
main(args)
27 changes: 19 additions & 8 deletions numpyro/contrib/autoguide/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@
from numpyro.contrib.nn.auto_reg_nn import AutoregressiveNN
import numpyro.distributions as dist
from numpyro.distributions import constraints
from numpyro.distributions.constraints import PermuteTransform, biject_to
from numpyro.distributions.constraints import AffineTransform, ComposeTransform, PermuteTransform, biject_to
from numpyro.distributions.flows import InverseAutoregressiveTransform
from numpyro.distributions.util import sum_rightmost
from numpyro.handlers import block, param, sample, seed, substitute, trace
Expand Down Expand Up @@ -157,12 +157,14 @@ def _setup_prototype(self, *args, **kwargs):
raise RuntimeError('{} found no latent variables; Use an empty guide instead'.format(type(self).__name__))
self._init_latent, self._unravel_fn = ravel_pytree(unconstrained_sites)

def _unpack_latent(self, latent_sample):
def unpack_latent(self, latent_sample, transform=None):
sample_shape = np.shape(latent_sample)[:-1]
latent_sample = np.reshape(latent_sample, (-1, np.shape(latent_sample)[-1]))
unpacked_samples = vmap(self._unravel_fn)(latent_sample)
unpacked_samples = tree_map(lambda x: np.reshape(x, sample_shape + np.shape(x)[1:]),
unpacked_samples)

transform = self._inv_transforms if transform is None else {}
return transform_fn(self._inv_transforms, unpacked_samples)

def __call__(self, *args, **kwargs):
Expand Down Expand Up @@ -193,6 +195,14 @@ def __call__(self, *args, **kwargs):

return result

def sample_posterior(self, rng, opt_state, *args, **kwargs):
sample_shape = kwargs.pop('sample_shape', ())
latent_size = np.size(self._init_latent)
transform = self.get_transform(opt_state)
posterior = dist.TransformedDistribution(_Normal(np.zeros(latent_size), 1.), transform)
latent_sample = posterior.sample(rng, sample_shape)
return self.unpack_latent(latent_sample)


class AutoDiagonalNormal(AutoContinuous):
"""
Expand All @@ -209,7 +219,11 @@ def sample_posterior(self, rng, opt_state, *args, **kwargs):
sample_shape = kwargs.pop('sample_shape', ())
loc, scale = self._loc_scale(opt_state)
latent_sample = dist.Normal(loc, scale).sample(rng, sample_shape)
return self._unpack_latent(latent_sample)
return self.unpack_latent(latent_sample)

def get_transform(self, opt_state):
loc, scale = self._loc_scale(opt_state)
return AffineTransform(loc, scale, domain=constraints.real_vector)

def _sample_latent(self, *args, **kwargs):
init_loc = self._init_latent
Expand Down Expand Up @@ -301,8 +315,7 @@ def __init__(self, rng, model, get_params_fn, prefix="auto", init_loc_fn=init_to
super(AutoIAFNormal, self).__init__(rng, model, get_params_fn, prefix=prefix,
init_loc_fn=init_loc_fn)

def sample_posterior(self, rng, opt_state, *args, **kwargs):
sample_shape = kwargs.pop('sample_shape', ())
def get_transform(self, opt_state):
params = self.get_params(opt_state)
latent_size = np.size(self._init_latent)
flows = []
Expand All @@ -311,9 +324,7 @@ def sample_posterior(self, rng, opt_state, *args, **kwargs):
if i > 0:
flows.append(PermuteTransform(np.arange(latent_size)[::-1]))
flows.append(InverseAutoregressiveTransform(self.arns[i], arn_params))
iaf_dist = dist.TransformedDistribution(_Normal(np.zeros(latent_size), 1.), flows)
latent_sample = iaf_dist.sample(rng, sample_shape)
return self._unpack_latent(latent_sample)
return ComposeTransform(flows)

def _sample_latent(self, *args, **kwargs):
latent_size = np.size(self._init_latent)
Expand Down
2 changes: 1 addition & 1 deletion numpyro/distributions/constraints.py
Original file line number Diff line number Diff line change
Expand Up @@ -268,7 +268,7 @@ def inv(self, y):
y = part.inv(y)
return y

def log_abs_det_jacobian(self, x, y):
def log_abs_det_jacobian(self, x, y, intermediates=None):
result = 0.
for part in self.parts[:-1]:
y_tmp = part(x)
Expand Down

0 comments on commit 781baf5

Please sign in to comment.