forked from pymc-devs/pymc-examples
-
Notifications
You must be signed in to change notification settings - Fork 0
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
BART: Fully non-parametric curve fit example (pymc-devs#519)
* init * initial complete version * plot improvements * improve last sentence
- Loading branch information
1 parent
9028ba3
commit 9cbb346
Showing
4 changed files
with
879 additions
and
0 deletions.
There are no files selected for viewing
Large diffs are not rendered by default.
Oops, something went wrong.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,166 @@ | ||
--- | ||
jupytext: | ||
text_representation: | ||
extension: .md | ||
format_name: myst | ||
format_version: 0.13 | ||
kernelspec: | ||
display_name: pymc-examples-env | ||
language: python | ||
name: python3 | ||
--- | ||
|
||
(bart_heteroscedasticity)= | ||
# Modeling Heteroscedasticity with BART | ||
|
||
:::{post} January, 2023 | ||
:tags: bart regression | ||
:category: beginner, reference | ||
:author: [Juan Orduz](https://juanitorduz.github.io/) | ||
::: | ||
|
||
+++ | ||
|
||
In this notebook we show how to use BART to model heteroscedasticity as described in Section 4.1 of [`pymc-bart`](https://github.com/pymc-devs/pymc-bart)'s paper {cite:p}`quiroga2022bart`. We use the `marketing` data set provided by the R package `datarium` {cite:p}`kassambara2019datarium`. The idea is to model a marketing channel contribution to sales as a function of budget. | ||
|
||
```{code-cell} ipython3 | ||
:tags: [] | ||
import os | ||
import arviz as az | ||
import matplotlib.pyplot as plt | ||
import numpy as np | ||
import pandas as pd | ||
import pymc as pm | ||
import pymc_bart as pmb | ||
``` | ||
|
||
```{code-cell} ipython3 | ||
:tags: [] | ||
%config InlineBackend.figure_format = "retina" | ||
az.style.use("arviz-darkgrid") | ||
plt.rcParams["figure.figsize"] = [10, 6] | ||
rng = np.random.default_rng(42) | ||
``` | ||
|
||
## Read Data | ||
|
||
```{code-cell} ipython3 | ||
try: | ||
df = pd.read_csv(os.path.join("..", "data", "marketing.csv"), sep=";", decimal=",") | ||
except FileNotFoundError: | ||
df = pd.read_csv(pm.get_data("marketing.csv"), sep=";", decimal=",") | ||
n_obs = df.shape[0] | ||
df.head() | ||
``` | ||
|
||
## EDA | ||
|
||
We start by looking into the data. We are going to focus on *Youtube*. | ||
|
||
```{code-cell} ipython3 | ||
fig, ax = plt.subplots() | ||
ax.plot(df["youtube"], df["sales"], "o", c="C0") | ||
ax.set(title="Sales as a function of Youtube budget", xlabel="budget", ylabel="sales"); | ||
``` | ||
|
||
We clearly see that both the mean and variance are increasing as a function of budget. One possibility is to manually select an explicit parametrization of these functions, e.g. square root or logarithm. However, in this example we want to learn these functions from the data using a BART model. | ||
|
||
+++ | ||
|
||
## Model Specification | ||
|
||
We proceed to prepare the data for modeling. We are going to use the `budget` as the predictor and `sales` as the response. | ||
|
||
```{code-cell} ipython3 | ||
X = df["youtube"].to_numpy().reshape(-1, 1) | ||
Y = df["sales"].to_numpy() | ||
``` | ||
|
||
Next, we specify the model. Note that we just need one BART distribution which can be vectorized to model both the mean and variance. We use a Gamma distribution as likelihood as we expect the sales to be positive. | ||
|
||
```{code-cell} ipython3 | ||
with pm.Model() as model_marketing_full: | ||
w = pmb.BART(name="w", X=X, Y=Y, m=200, shape=(2, n_obs)) | ||
y = pm.Gamma(name="y", mu=w[0], sigma=pm.math.abs(w[1]), observed=Y) | ||
pm.model_to_graphviz(model=model_marketing_full) | ||
``` | ||
|
||
We now fit the model. | ||
|
||
```{code-cell} ipython3 | ||
with model_marketing_full: | ||
idata_marketing_full = pm.sample(random_seed=rng) | ||
posterior_predictive_marketing_full = pm.sample_posterior_predictive( | ||
trace=idata_marketing_full, random_seed=rng | ||
) | ||
``` | ||
|
||
## Results | ||
|
||
We can now visualize the posterior predictive distribution of the mean and the likelihood. | ||
|
||
```{code-cell} ipython3 | ||
posterior_mean = idata_marketing_full.posterior["w"].mean(dim=("chain", "draw"))[0] | ||
w_hdi = az.hdi(ary=idata_marketing_full, group="posterior", var_names=["w"]) | ||
pps = az.extract( | ||
posterior_predictive_marketing_full, group="posterior_predictive", var_names=["y"] | ||
).T | ||
``` | ||
|
||
```{code-cell} ipython3 | ||
idx = np.argsort(X[:, 0]) | ||
fig, ax = plt.subplots() | ||
az.plot_hdi(x=X[:, 0], y=pps, ax=ax, fill_kwargs={"alpha": 0.3, "label": r"Likelihood $94\%$ HDI"}) | ||
az.plot_hdi( | ||
x=X[:, 0], | ||
hdi_data=w_hdi["w"].sel(w_dim_0=0), | ||
ax=ax, | ||
fill_kwargs={"alpha": 0.6, "label": r"Mean $94\%$ HDI"}, | ||
) | ||
ax.plot(X[:, 0][idx], posterior_mean[idx], c="black", lw=3, label="Posterior Mean") | ||
ax.plot(df["youtube"], df["sales"], "o", c="C0", label="Raw Data") | ||
ax.legend(loc="upper left") | ||
ax.set( | ||
title="Sales as a function of Youtube budget - Posterior Predictive", | ||
xlabel="budget", | ||
ylabel="sales", | ||
); | ||
``` | ||
|
||
The fit looks good! In fact, we see that the mean and variance increase as a function of the budget. | ||
|
||
+++ | ||
|
||
## Authors | ||
- Authored by [Juan Orduz](https://juanitorduz.github.io/) in February 2023 | ||
|
||
+++ | ||
|
||
## References | ||
:::{bibliography} | ||
:filter: docname in docnames | ||
::: | ||
|
||
+++ | ||
|
||
## Watermark | ||
|
||
```{code-cell} ipython3 | ||
:tags: [] | ||
%load_ext watermark | ||
%watermark -n -u -v -iv -w -p pytensor | ||
``` | ||
|
||
:::{include} ../page_footer.md | ||
::: |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,201 @@ | ||
youtube;facebook;newspaper;sales | ||
276,12;45,36;83,04;26,52 | ||
53,40;47,16;54,12;12,48 | ||
20,64;55,08;83,16;11,16 | ||
181,80;49,56;70,20;22,20 | ||
216,96;12,96;70,08;15,48 | ||
10,44;58,68;90,00;8,64 | ||
69,00;39,36;28,20;14,16 | ||
144,24;23,52;13,92;15,84 | ||
10,32;2,52;1,20;5,76 | ||
239,76;3,12;25,44;12,72 | ||
79,32;6,96;29,04;10,32 | ||
257,64;28,80;4,80;20,88 | ||
28,56;42,12;79,08;11,04 | ||
117,00;9,12;8,64;11,64 | ||
244,92;39,48;55,20;22,80 | ||
234,48;57,24;63,48;26,88 | ||
81,36;43,92;136,80;15,00 | ||
337,68;47,52;66,96;29,28 | ||
83,04;24,60;21,96;13,56 | ||
176,76;28,68;22,92;17,52 | ||
262,08;33,24;64,08;21,60 | ||
284,88;6,12;28,20;15,00 | ||
15,84;19,08;59,52;6,72 | ||
273,96;20,28;31,44;18,60 | ||
74,76;15,12;21,96;11,64 | ||
315,48;4,20;23,40;14,40 | ||
171,48;35,16;15,12;18,00 | ||
288,12;20,04;27,48;19,08 | ||
298,56;32,52;27,48;22,68 | ||
84,72;19,20;48,96;12,60 | ||
351,48;33,96;51,84;25,68 | ||
135,48;20,88;46,32;14,28 | ||
116,64;1,80;36,00;11,52 | ||
318,72;24,00;0,36;20,88 | ||
114,84;1,68;8,88;11,40 | ||
348,84;4,92;10,20;15,36 | ||
320,28;52,56;6,00;30,48 | ||
89,64;59,28;54,84;17,64 | ||
51,72;32,04;42,12;12,12 | ||
273,60;45,24;38,40;25,80 | ||
243,00;26,76;37,92;19,92 | ||
212,40;40,08;46,44;20,52 | ||
352,32;33,24;2,16;24,84 | ||
248,28;10,08;31,68;15,48 | ||
30,12;30,84;51,96;10,20 | ||
210,12;27,00;37,80;17,88 | ||
107,64;11,88;42,84;12,72 | ||
287,88;49,80;22,20;27,84 | ||
272,64;18,96;59,88;17,76 | ||
80,28;14,04;44,16;11,64 | ||
239,76;3,72;41,52;13,68 | ||
120,48;11,52;4,32;12,84 | ||
259,68;50,04;47,52;27,12 | ||
219,12;55,44;70,44;25,44 | ||
315,24;34,56;19,08;24,24 | ||
238,68;59,28;72,00;28,44 | ||
8,76;33,72;49,68;6,60 | ||
163,44;23,04;19,92;15,84 | ||
252,96;59,52;45,24;28,56 | ||
252,84;35,40;11,16;22,08 | ||
64,20;2,40;25,68;9,72 | ||
313,56;51,24;65,64;29,04 | ||
287,16;18,60;32,76;18,84 | ||
123,24;35,52;10,08;16,80 | ||
157,32;51,36;34,68;21,60 | ||
82,80;11,16;1,08;11,16 | ||
37,80;29,52;2,64;11,40 | ||
167,16;17,40;12,24;16,08 | ||
284,88;33,00;13,20;22,68 | ||
260,16;52,68;32,64;26,76 | ||
238,92;36,72;46,44;21,96 | ||
131,76;17,16;38,04;14,88 | ||
32,16;39,60;23,16;10,56 | ||
155,28;6,84;37,56;13,20 | ||
256,08;29,52;15,72;20,40 | ||
20,28;52,44;107,28;10,44 | ||
33,00;1,92;24,84;8,28 | ||
144,60;34,20;17,04;17,04 | ||
6,48;35,88;11,28;6,36 | ||
139,20;9,24;27,72;13,20 | ||
91,68;32,04;26,76;14,16 | ||
287,76;4,92;44,28;14,76 | ||
90,36;24,36;39,00;13,56 | ||
82,08;53,40;42,72;16,32 | ||
256,20;51,60;40,56;26,04 | ||
231,84;22,08;78,84;18,24 | ||
91,56;33,00;19,20;14,40 | ||
132,84;48,72;75,84;19,20 | ||
105,96;30,60;88,08;15,48 | ||
131,76;57,36;61,68;20,04 | ||
161,16;5,88;11,16;13,44 | ||
34,32;1,80;39,60;8,76 | ||
261,24;40,20;70,80;23,28 | ||
301,08;43,80;86,76;26,64 | ||
128,88;16,80;13,08;13,80 | ||
195,96;37,92;63,48;20,28 | ||
237,12;4,20;7,08;14,04 | ||
221,88;25,20;26,40;18,60 | ||
347,64;50,76;61,44;30,48 | ||
162,24;50,04;55,08;20,64 | ||
266,88;5,16;59,76;14,04 | ||
355,68;43,56;121,08;28,56 | ||
336,24;12,12;25,68;17,76 | ||
225,48;20,64;21,48;17,64 | ||
285,84;41,16;6,36;24,84 | ||
165,48;55,68;70,80;23,04 | ||
30,00;13,20;35,64;8,64 | ||
108,48;0,36;27,84;10,44 | ||
15,72;0,48;30,72;6,36 | ||
306,48;32,28;6,60;23,76 | ||
270,96;9,84;67,80;16,08 | ||
290,04;45,60;27,84;26,16 | ||
210,84;18,48;2,88;16,92 | ||
251,52;24,72;12,84;19,08 | ||
93,84;56,16;41,40;17,52 | ||
90,12;42,00;63,24;15,12 | ||
167,04;17,16;30,72;14,64 | ||
91,68;0,96;17,76;11,28 | ||
150,84;44,28;95,04;19,08 | ||
23,28;19,20;26,76;7,92 | ||
169,56;32,16;55,44;18,60 | ||
22,56;26,04;60,48;8,40 | ||
268,80;2,88;18,72;13,92 | ||
147,72;41,52;14,88;18,24 | ||
275,40;38,76;89,04;23,64 | ||
104,64;14,16;31,08;12,72 | ||
9,36;46,68;60,72;7,92 | ||
96,24;0,00;11,04;10,56 | ||
264,36;58,80;3,84;29,64 | ||
71,52;14,40;51,72;11,64 | ||
0,84;47,52;10,44;1,92 | ||
318,24;3,48;51,60;15,24 | ||
10,08;32,64;2,52;6,84 | ||
263,76;40,20;54,12;23,52 | ||
44,28;46,32;78,72;12,96 | ||
57,96;56,40;10,20;13,92 | ||
30,72;46,80;11,16;11,40 | ||
328,44;34,68;71,64;24,96 | ||
51,60;31,08;24,60;11,52 | ||
221,88;52,68;2,04;24,84 | ||
88,08;20,40;15,48;13,08 | ||
232,44;42,48;90,72;23,04 | ||
264,60;39,84;45,48;24,12 | ||
125,52;6,84;41,28;12,48 | ||
115,44;17,76;46,68;13,68 | ||
168,36;2,28;10,80;12,36 | ||
288,12;8,76;10,44;15,84 | ||
291,84;58,80;53,16;30,48 | ||
45,60;48,36;14,28;13,08 | ||
53,64;30,96;24,72;12,12 | ||
336,84;16,68;44,40;19,32 | ||
145,20;10,08;58,44;13,92 | ||
237,12;27,96;17,04;19,92 | ||
205,56;47,64;45,24;22,80 | ||
225,36;25,32;11,40;18,72 | ||
4,92;13,92;6,84;3,84 | ||
112,68;52,20;60,60;18,36 | ||
179,76;1,56;29,16;12,12 | ||
14,04;44,28;54,24;8,76 | ||
158,04;22,08;41,52;15,48 | ||
207,00;21,72;36,84;17,28 | ||
102,84;42,96;59,16;15,96 | ||
226,08;21,72;30,72;17,88 | ||
196,20;44,16;8,88;21,60 | ||
140,64;17,64;6,48;14,28 | ||
281,40;4,08;101,76;14,28 | ||
21,48;45,12;25,92;9,60 | ||
248,16;6,24;23,28;14,64 | ||
258,48;28,32;69,12;20,52 | ||
341,16;12,72;7,68;18,00 | ||
60,00;13,92;22,08;10,08 | ||
197,40;25,08;56,88;17,40 | ||
23,52;24,12;20,40;9,12 | ||
202,08;8,52;15,36;14,04 | ||
266,88;4,08;15,72;13,80 | ||
332,28;58,68;50,16;32,40 | ||
298,08;36,24;24,36;24,24 | ||
204,24;9,36;42,24;14,04 | ||
332,04;2,76;28,44;14,16 | ||
198,72;12,00;21,12;15,12 | ||
187,92;3,12;9,96;12,60 | ||
262,20;6,48;32,88;14,64 | ||
67,44;6,84;35,64;10,44 | ||
345,12;51,60;86,16;31,44 | ||
304,56;25,56;36,00;21,12 | ||
246,00;54,12;23,52;27,12 | ||
167,40;2,52;31,92;12,36 | ||
229,32;34,44;21,84;20,76 | ||
343,20;16,68;4,44;19,08 | ||
22,44;14,52;28,08;8,04 | ||
47,40;49,32;6,96;12,96 | ||
90,60;12,96;7,20;11,88 | ||
20,64;4,92;37,92;7,08 | ||
200,16;50,40;4,32;23,52 | ||
179,64;42,72;7,20;20,76 | ||
45,84;4,44;16,56;9,12 | ||
113,04;5,88;9,72;11,64 | ||
212,40;11,16;7,68;15,36 | ||
340,32;50,40;79,44;30,60 | ||
278,52;10,32;10,44;16,08 |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters