SGLD in a numpyro model #360
-
Hi everyone, How can I implement SGLD with a model built in numpyro? I have following the tutorial, however I need to use grad_estimator, this function needs log_prior_fn and log_likelihood_fn but I just have the logprob_fn function created through the tutorial. |
Beta Was this translation helpful? Give feedback.
Replies: 5 comments 7 replies
-
Could you share a (possibly simplified) version of your model? |
Beta Was this translation helpful? Give feedback.
-
It is a bayesian MLP: class NN(nn.Module):
"""A simple FCNN model."""
@nn.compact
def __call__(self, x):
x = nn.Dense(features=50)(x)
x = nn.relu(x)
x = nn.Dense(features=50)(x)
x = nn.relu(x)
x = nn.Dense(features=50)(x)
x = nn.relu(x)
x = nn.Dense(features=10)(x)
return x
def model(x, y=None, prior=None):
N, in_feat = x.shape
net = NN()
bayesian_net = random_flax_module("net", net,
prior = prior,
input_shape=(N,in_feat))
with numpyro.plate('data', size=N):
f = numpyro.deterministic('f', value=bayesian_net(x)[:, 0])
numpyro.sample('y', fn=dists.CategoricalLogits(f), obs=y)
return f |
Beta Was this translation helpful? Give feedback.
-
You actually don’t need First let’s generate 100 data points to test the model: import jax
import jax.numpy as jnp
data_size = 100
rng_key = jax.random.PRNGKey(0)
X_key, y_key, rng_key = jax.random.split(rng_key, 3)
X = jax.random.uniform(X_key, shape=(data_size, 50))
y = jax.random.choice(y_key, 10, shape=(data_size,)) The model is the same as the one you provided, so there is no surprise there. Let us also compute an initial value for the parameters of the model (you can check that it’s a nested dictionary of array): from flax import linen as nn
class NN(nn.Module):
"""A simple FCNN model."""
@nn.compact
def __call__(self, x):
x = nn.Dense(features=50)(x)
x = nn.relu(x)
x = nn.Dense(features=50)(x)
x = nn.relu(x)
x = nn.Dense(features=50)(x)
x = nn.relu(x)
x = nn.Dense(features=10)(x)
return x
model = NN()
init_key, rng_key = jax.random.split(rng_key)
init_params = jax.jit(model.init)(rng_key, jnp.ones(50)) To use SgLD in Blackjax we need to define the log-prior as the log-likelihood separately, and provide the size of the dataset. At every step, the gradient estimator takes the current values of the parameters and a minibatch of data and returns the estimate The value of the logprior is simple to compute. Assuming we set a import distrax
def logprior_fn(params):
leaves, _ = jax.tree_util.tree_flatten(params)
flat_params = jnp.concatenate([jnp.ravel(a) for a in leaves])
return jnp.sum(distrax.Normal(0, 1).log_prob(flat_params))
print(logprior_fn(init_params))
# -7578.037 We assume that the logits computed by the neural networks are the parameters of a categorical distribution. Blackjax assumes that the loglikelihood function has the signature def loglikelihood_fn(params, X):
logits = model.apply(params, X)
return jnp.sum(distrax.Categorical(logits).log_prob(y))
print(loglikelihood_fn(init_params, X))
# -692.422 And we’re now ready to sample (don’t forget to import blackjax
from blackjax.sgmcmc.gradients import grad_estimator
grad_fn = grad_estimator(logprior_fn, loglikelihood_fn, data_size)
sgld = blackjax.sgld(grad_fn, 1e-3)
state = sgld.init(init_params, X)
sample_key, _ = jax.random.split(rng_key, 2)
new_state = jax.jit(sgld.step)(sample_key, state, X) import numpy as np
try:
np.testing.assert_equal(
np.asarray(state.position["params"]["Dense_0"]["kernel"]),
np.asarray(new_state.position["params"]["Dense_0"]["kernel"])
)
except AssertionError:
print("The state has moved!")
# The state has moved! |
Beta Was this translation helpful? Give feedback.
-
I think you should replace |
Beta Was this translation helpful? Give feedback.
-
On Wed, Oct 5, 2022 at 1:53 AM Rémi Louf ***@***.***> wrote:
You are right! Just to illustrate for those who read the thread:
logits = model.apply(init_params, X)print(logits.shape)# (100, 10)
logits_flat = jnp.ravel(model.apply(init_params, X))print(logits_flat.shape)# (1000,)
What surprises me is that distrax does not check shapes at runtime,
indeed both give (different) results:
Yes this seems like a distrax bug….
lp = jnp.sum(distrax.Categorical(logits).log_prob(y))print(lp)# -231.50316
lp_flat = jnp.sum(distrax.Categorical(logits_flat).log_prob(y))print(lp_flat)# -697.4608
—
Reply to this email directly, view it on GitHub
<#360 (reply in thread)>,
or unsubscribe
<https://github.com/notifications/unsubscribe-auth/ABDK6EEXTZP4QOMBVIICTS3WBU6W3ANCNFSM6AAAAAAQXXR77U>
.
You are receiving this because you commented.Message ID:
***@***.***>
--
Sent from Gmail Mobile
|
Beta Was this translation helpful? Give feedback.
You actually don’t need
Numpyro
at all here. The Neural Network can be implemented with Flax alone, and we will use Distrax for the distributions (jax.scipy.stats
doesn’t have an implementation for the categorical distribution).First let’s generate 100 data points to test the model:
The model is the same as the one you provided, so there is no surprise there. Let us also compute an initial value for the parameters of the model (you can check that …