forked from jax-ml/jax
-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathadvi.py
139 lines (109 loc) · 4.78 KB
/
advi.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
# Copyright 2018 Google LLC
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Automatic differentiation variational inference in Numpy and JAX.
This demo fits a Gaussian approximation to an intractable, unnormalized
density, by differentiating through a Monte Carlo estimate of the
variational evidence lower bound (ELBO)."""
from functools import partial
import matplotlib.pyplot as plt
from jax.api import jit, grad, vmap
from jax import random
from jax.experimental import optimizers
import jax.numpy as jnp
import jax.scipy.stats.norm as norm
# ========= Functions to define the evidence lower bound. =========
def diag_gaussian_sample(rng, mean, log_std):
# Take a single sample from a diagonal multivariate Gaussian.
return mean + jnp.exp(log_std) * random.normal(rng, mean.shape)
def diag_gaussian_logpdf(x, mean, log_std):
# Evaluate a single point on a diagonal multivariate Gaussian.
return jnp.sum(vmap(norm.logpdf)(x, mean, jnp.exp(log_std)))
def elbo(logprob, rng, mean, log_std):
# Single-sample Monte Carlo estimate of the variational lower bound.
sample = diag_gaussian_sample(rng, mean, log_std)
return logprob(sample) - diag_gaussian_logpdf(sample, mean, log_std)
def batch_elbo(logprob, rng, params, num_samples):
# Average over a batch of random samples.
rngs = random.split(rng, num_samples)
vectorized_elbo = vmap(partial(elbo, logprob), in_axes=(0, None, None))
return jnp.mean(vectorized_elbo(rngs, *params))
# ========= Helper function for plotting. =========
@partial(jit, static_argnums=(0, 1, 2, 4))
def _mesh_eval(func, x_limits, y_limits, params, num_ticks):
# Evaluate func on a 2D grid defined by x_limits and y_limits.
x = jnp.linspace(*x_limits, num=num_ticks)
y = jnp.linspace(*y_limits, num=num_ticks)
X, Y = jnp.meshgrid(x, y)
xy_vec = jnp.stack([X.ravel(), Y.ravel()]).T
zs = vmap(func, in_axes=(0, None))(xy_vec, params)
return X, Y, zs.reshape(X.shape)
def mesh_eval(func, x_limits, y_limits, params, num_ticks=101):
return _mesh_eval(func, x_limits, y_limits, params, num_ticks)
# ========= Define an intractable unnormalized density =========
def funnel_log_density(params):
return norm.logpdf(params[0], 0, jnp.exp(params[1])) + \
norm.logpdf(params[1], 0, 1.35)
if __name__ == "__main__":
num_samples = 40
@jit
def objective(params, t):
rng = random.PRNGKey(t)
return -batch_elbo(funnel_log_density, rng, params, num_samples)
# Set up figure.
fig = plt.figure(figsize=(8,8), facecolor='white')
ax = fig.add_subplot(111, frameon=False)
plt.ion()
plt.show(block=False)
x_limits = [-2, 2]
y_limits = [-4, 2]
target_dist = lambda x, _: jnp.exp(funnel_log_density(x))
approx_dist = lambda x, params: jnp.exp(diag_gaussian_logpdf(x, *params))
def callback(params, t):
print("Iteration {} lower bound {}".format(t, objective(params, t)))
plt.cla()
X, Y, Z = mesh_eval(target_dist, x_limits, y_limits, 1)
ax.contour(X, Y, Z, cmap='summer')
X, Y, Z = mesh_eval(approx_dist, x_limits, y_limits, params)
ax.contour(X, Y, Z, cmap='winter')
ax.set_xlim(x_limits)
ax.set_ylim(y_limits)
ax.set_yticks([])
ax.set_xticks([])
# Plot random samples from variational distribution.
# Here we clone the rng used in computing the objective
# so that we can show exactly the same samples.
rngs = random.split(random.PRNGKey(t), num_samples)
samples = vmap(diag_gaussian_sample, in_axes=(0, None, None))(rngs, *params)
ax.plot(samples[:, 0], samples[:, 1], 'b.')
plt.draw()
plt.pause(1.0/60.0)
# Set up optimizer.
D = 2
init_mean = jnp.zeros(D)
init_std = jnp.zeros(D)
init_params = (init_mean, init_std)
opt_init, opt_update, get_params = optimizers.momentum(step_size=0.1, mass=0.9)
opt_state = opt_init(init_params)
@jit
def update(i, opt_state):
params = get_params(opt_state)
gradient = grad(objective)(params, i)
return opt_update(i, gradient, opt_state)
# Main loop.
print("Optimizing variational parameters...")
for t in range(100):
opt_state = update(t, opt_state)
params = get_params(opt_state)
callback(params, t)
plt.show(block=True)