Skip to content

Commit

Permalink
Variational inference/Bayesian neural network: fixed data dims in ann…
Browse files Browse the repository at this point in the history
…_input/ann_output (pymc-devs#506)

* fixed data dims

* pre-commit

Co-authored-by: Oriol (ZBook) <oriol.abril.pla@gmail.com>
  • Loading branch information
earlbellinger and OriolAbril authored Jan 25, 2023
1 parent 5288e05 commit 951c0ef
Show file tree
Hide file tree
Showing 2 changed files with 225 additions and 75 deletions.
294 changes: 221 additions & 73 deletions examples/variational_inference/bayesian_neural_network_advi.ipynb

Large diffs are not rendered by default.

Original file line number Diff line number Diff line change
Expand Up @@ -130,8 +130,8 @@ def construct_nn(ann_input, ann_output):
# "obs_id": np.arange(X_train.shape[0]),
}
with pm.Model(coords=coords) as neural_network:
ann_input = pm.Data("ann_input", X_train, mutable=True)
ann_output = pm.Data("ann_output", Y_train, mutable=True)
ann_input = pm.Data("ann_input", X_train, mutable=True, dims=("obs_id", "train_cols"))
ann_output = pm.Data("ann_output", Y_train, mutable=True, dims="obs_id")
# Weights from input to hidden layer
weights_in_1 = pm.Normal(
Expand All @@ -157,6 +157,7 @@ def construct_nn(ann_input, ann_output):
act_out,
observed=ann_output,
total_size=Y_train.shape[0], # IMPORTANT for minibatches
dims="obs_id",
)
return neural_network
Expand Down Expand Up @@ -340,6 +341,7 @@ You might argue that the above network isn't really deep, but note that we could

- This notebook was originally authored as a [blog post](https://twiecki.github.io/blog/2016/06/01/bayesian-deep-learning/) by Thomas Wiecki in 2016
- Updated by Chris Fonnesbeck for PyMC v4 in 2022
- Updated by Oriol Abril-Pla and Earl Bellinger in 2023

## Watermark

Expand Down

0 comments on commit 951c0ef

Please sign in to comment.