-
Notifications
You must be signed in to change notification settings - Fork 246
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Add an example for NeuTra HMC (#248)
* initialize the work on neutra hmc * sampling and plot neutra * add training plot * fix init step size * update plots * tunning numbers to get nicer plot? * add args to tune
- Loading branch information
1 parent
38df2f0
commit 781baf5
Showing
3 changed files
with
163 additions
and
9 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,143 @@ | ||
import argparse | ||
|
||
from matplotlib.gridspec import GridSpec | ||
import matplotlib.pyplot as plt | ||
import seaborn as sns | ||
|
||
from jax import jit, random, vmap | ||
from jax.config import config as jax_config | ||
from jax.experimental import optimizers | ||
import jax.numpy as np | ||
from jax.scipy.special import logsumexp | ||
from jax.tree_util import tree_map | ||
|
||
from numpyro.contrib.autoguide import AutoIAFNormal | ||
from numpyro.diagnostics import summary | ||
import numpyro.distributions as dist | ||
from numpyro.handlers import sample | ||
from numpyro.hmc_util import initialize_model | ||
from numpyro.mcmc import mcmc | ||
from numpyro.svi import elbo, svi | ||
from numpyro.util import fori_collect | ||
|
||
""" | ||
This example illustrates how to use a trained AutoIAFNormal autoguide to transform a posterior to a | ||
Gaussian-like one. The transform will be used to get better mixing rate for NUTS sampler. | ||
[1] Hoffman, M. et al. (2019), ["NeuTra-lizing Bad Geometry in Hamiltonian Monte Carlo Using Neural Transport"] | ||
(https://arxiv.org/abs/1903.03704). | ||
""" | ||
|
||
|
||
def dual_moon_pe(x): | ||
term1 = 0.5 * ((np.linalg.norm(x, axis=-1) - 2) / 0.4) ** 2 | ||
term2 = -0.5 * ((x[..., :1] + np.array([-2., 2.])) / 0.6) ** 2 | ||
return term1 - logsumexp(term2, axis=-1) | ||
|
||
|
||
def dual_moon_model(): | ||
x = sample('x', dist.Uniform(-10 * np.ones(2), 10 * np.ones(2))) | ||
pe = dual_moon_pe(x) | ||
sample('log_density', dist.Delta(log_density=-pe), obs=0.) | ||
|
||
|
||
def make_transformed_pe(potential_fn, transform, unpack_fn): | ||
def transformed_potential_fn(z): | ||
# NB: currently, intermediates for ComposeTransform is None, so this has no effect | ||
# see https://github.com/pyro-ppl/numpyro/issues/242 | ||
u, intermediates = transform.call_with_intermediates(z) | ||
logdet = transform.log_abs_det_jacobian(z, u, intermediates=intermediates) | ||
return potential_fn(unpack_fn(u)) + logdet | ||
|
||
return transformed_potential_fn | ||
|
||
|
||
def main(args): | ||
jax_config.update('jax_platform_name', args.device) | ||
|
||
print("Start vanilla HMC...") | ||
# TODO: set progbar=True when https://github.com/google/jax/issues/939 is resolved | ||
vanilla_samples = mcmc(args.num_warmup, args.num_samples, init_params=np.array([2., 0.]), | ||
potential_fn=dual_moon_pe, progbar=False) | ||
|
||
opt_init, opt_update, get_params = optimizers.adam(0.001) | ||
rng_guide, rng_init, rng_train = random.split(random.PRNGKey(1), 3) | ||
guide = AutoIAFNormal(rng_guide, dual_moon_model, get_params, hidden_dims=[args.num_hidden]) | ||
svi_init, svi_update, _ = svi(dual_moon_model, guide, elbo, opt_init, opt_update, get_params) | ||
opt_state, _ = svi_init(rng_init) | ||
|
||
def body_fn(val): | ||
i, loss, opt_state_, rng_ = val | ||
loss, opt_state_, rng_ = svi_update(i, rng_, opt_state_) | ||
return i + 1, loss, opt_state_, rng_ | ||
|
||
print("Start training guide...") | ||
losses, opt_states = fori_collect(0, args.num_iters, jit(body_fn), | ||
(0, 0., opt_state, rng_train), | ||
transform=lambda x: (x[1], x[2]), progbar=False) | ||
last_state = tree_map(lambda x: x[-1], opt_states) | ||
print("Finish training guide. Extract samples...") | ||
guide_samples = guide.sample_posterior(random.PRNGKey(0), last_state, | ||
sample_shape=(args.num_samples,)) | ||
|
||
transform = guide.get_transform(last_state) | ||
unpack_fn = lambda u: guide.unpack_latent(u, transform={}) # noqa: E731 | ||
|
||
_, potential_fn, constrain_fn = initialize_model(random.PRNGKey(0), dual_moon_model) | ||
transformed_potential_fn = make_transformed_pe(potential_fn, transform, unpack_fn) | ||
transformed_constrain_fn = lambda x: constrain_fn(unpack_fn(transform(x))) # noqa: E731 | ||
|
||
# TODO: expose latent_size in autoguide | ||
init_params = np.zeros(np.size(guide._init_latent)) | ||
print("\nStart NeuTra HMC...") | ||
zs = mcmc(args.num_warmup, args.num_samples, init_params, potential_fn=transformed_potential_fn) | ||
print("Transform samples into unwarped space...") | ||
samples = vmap(transformed_constrain_fn)(zs) | ||
summary(tree_map(lambda x: x[None, ...], samples)) | ||
|
||
# make plots | ||
x1 = np.linspace(-3, 3, 100) | ||
x2 = np.linspace(-3, 3, 100) | ||
X1, X2 = np.meshgrid(x1, x2) | ||
P = np.clip(np.exp(-dual_moon_pe(np.stack([X1, X2], axis=-1))), a_min=0.) | ||
|
||
fig = plt.figure(figsize=(12, 12), constrained_layout=True) | ||
gs = GridSpec(2, 2, figure=fig) | ||
ax1 = fig.add_subplot(gs[0, 0]) | ||
ax2 = fig.add_subplot(gs[0, 1]) | ||
ax3 = fig.add_subplot(gs[1, 0]) | ||
ax4 = fig.add_subplot(gs[1, 1]) | ||
|
||
ax1.plot(losses[1000:]) | ||
ax1.set_title('Autoguide training loss (skip the first 1000 steps)') | ||
|
||
ax2.contourf(X1, X2, P, cmap='OrRd') | ||
sns.kdeplot(guide_samples['x'][:, 0], guide_samples['x'][:, 1], ax=ax2) | ||
ax2.set(xlim=[-3, 3], ylim=[-3, 3], aspect='equal', | ||
xlabel='x0', ylabel='x1', title='Posterior using AutoIAFNormal guide') | ||
|
||
ax3.contourf(X1, X2, P, cmap='OrRd') | ||
sns.kdeplot(vanilla_samples[:, 0], vanilla_samples[:, 1], ax=ax3) | ||
ax3.plot(vanilla_samples[-50:, 0], vanilla_samples[-50:, 1], 'bo-', alpha=0.5) | ||
ax3.set(xlim=[-3, 3], ylim=[-3, 3], aspect='equal', | ||
xlabel='x0', ylabel='x1', title='Posterior using vanilla HMC sampler') | ||
|
||
ax4.contourf(X1, X2, P, cmap='OrRd') | ||
sns.kdeplot(samples['x'][:, 0], samples['x'][:, 1], ax=ax4) | ||
ax4.plot(samples['x'][-50:, 0], samples['x'][-50:, 1], 'bo-', alpha=0.5) | ||
ax4.set(xlim=[-3, 3], ylim=[-3, 3], aspect='equal', | ||
xlabel='x0', ylabel='x1', title='Posterior using NeuTra HMC sampler') | ||
|
||
plt.savefig("neutra.pdf") | ||
plt.close() | ||
|
||
|
||
if __name__ == "__main__": | ||
parser = argparse.ArgumentParser(description="NeuTra HMC") | ||
parser.add_argument('-n', '--num-samples', nargs='?', default=10000, type=int) | ||
parser.add_argument('--num-warmup', nargs='?', default=0, type=int) | ||
parser.add_argument('--num-hidden', nargs='?', default=15, type=int) | ||
parser.add_argument('--num-iters', nargs='?', default=10000, type=int) | ||
parser.add_argument('--device', default='cpu', type=str, help='use "cpu" or "gpu".') | ||
args = parser.parse_args() | ||
main(args) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters