Skip to content

SGLD in a numpyro model #360

Answered by rlouf
Davidmolin4 asked this question in Q&A
Sep 28, 2022 · 5 comments · 7 replies
Discussion options

You must be logged in to vote

You actually don’t need Numpyro at all here. The Neural Network can be implemented with Flax alone, and we will use Distrax for the distributions (jax.scipy.stats doesn’t have an implementation for the categorical distribution).

First let’s generate 100 data points to test the model:

import jax
import jax.numpy as jnp

data_size = 100

rng_key = jax.random.PRNGKey(0)
X_key, y_key, rng_key = jax.random.split(rng_key, 3)

X = jax.random.uniform(X_key, shape=(data_size, 50))
y = jax.random.choice(y_key, 10, shape=(data_size,))

The model is the same as the one you provided, so there is no surprise there. Let us also compute an initial value for the parameters of the model (you can check that …

Replies: 5 comments 7 replies

Comment options

You must be logged in to vote
0 replies
Comment options

You must be logged in to vote
2 replies
@rlouf
Comment options

@Davidmolin4
Comment options

Comment options

You must be logged in to vote
4 replies
@Davidmolin4
Comment options

@rlouf
Comment options

@Davidmolin4
Comment options

@rlouf
Comment options

Answer selected by rlouf
Comment options

You must be logged in to vote
1 reply
@rlouf
Comment options

rlouf Oct 5, 2022
Maintainer

Comment options

You must be logged in to vote
0 replies
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Category
Q&A
Labels
None yet
3 participants