Skip to content

Commit

Permalink
update to reflect changes in pymc-bart 0.4.0 (pymc-devs#531)
Browse files Browse the repository at this point in the history
  • Loading branch information
aloctavodia authored Apr 1, 2023
1 parent bc06eef commit 2ab6aab
Show file tree
Hide file tree
Showing 6 changed files with 284 additions and 565 deletions.
527 changes: 169 additions & 358 deletions examples/case_studies/BART_introduction.ipynb

Large diffs are not rendered by default.

76 changes: 53 additions & 23 deletions examples/case_studies/BART_introduction.myst.md
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,6 @@ import numpy as np
import pandas as pd
import pymc as pm
import pymc_bart as pmb
import seaborn as sns
from sklearn.model_selection import train_test_split
Expand Down Expand Up @@ -83,43 +82,50 @@ x_centers = x_edges[:-1] + (x_edges[1] - x_edges[0]) / 2
# xdata needs to be 2D for BART
x_data = x_centers[:, None]
# express data as the rate number of disaster per year
y_data = hist / 4
y_data = hist
```

In PyMC a BART variable can be defined very similar to other random variables. One important difference is that we have to pass ours Xs and Ys to the BART variable. Here we are also making explicit that we are going to use a sum over 20 trees (`m=20`). Low number of trees like 20 could be good enough for simple models like this and could also work very good as a quick approximation for more complex models in particular during the iterative or explorative phase of modeling. In those cases once we have more certainty about the model we really like we can improve the approximation by increasing `m`, in the literature is common to find reports of good results with numbers like 50, 100 or 200.
In PyMC a BART variable can be defined very similar to other random variables. One important difference is that we have to pass ours Xs and Ys to the BART variable, this information is used when sampling trees, the prior over the sum of trees is so huge that without any information from our data this will be an impossible task.

Here we are also making explicit that we are going to use a sum over 20 trees (`m=20`). Low number of trees like 20 could be good enough for simple models like this and could also work very good as a quick approximation for more complex models in particular during the early stage of modeling, when we may want to try a few things as quickly as possible in order to better grasp which model may be a good idea for our problem. In those cases once we have more certainty about the model(s) we really like we can improve the approximation by increasing `m`, in the literature is common to find reports of good results with numbers like 50, 100 or 200.

```{code-cell} ipython3
with pm.Model() as model_coal:
μ_ = pmb.BART("μ_", X=x_data, Y=y_data, m=20)
μ = pm.Deterministic("μ", pm.math.abs(μ_))
μ_ = pmb.BART("μ_", X=x_data, Y=np.log(y_data), m=20)
μ = pm.Deterministic("μ", pm.math.exp(μ_))
y_pred = pm.Poisson("y_pred", mu=μ, observed=y_data)
idata_coal = pm.sample(random_seed=RANDOM_SEED)
```

The white line in the following plot shows the median rate of accidents. The darker orange band represent the HDI 50% and the lighter one the 94%. We can see a rapid decrease of coal accidents between 1880 and 1900. Feel free to compare these results with those in the original {ref}`pymc:pymc_overview` example.
Before checking the result, we need to discuss one more detail, the BART variable always samples over the real line, meaning that in principle we can get values that go from $-\infty$ to $\infty$. Thus, we may need to transform their values as we would do for standard Generalized Linear Models, for example in the `model_coal` we computed `pm.math.exp(μ_)` because the Poisson distribution is expecting values that go from 0 to $\infty$. This is business as usual, the novelty is that we may need to apply the inverse transformation to the values of `Y`, as we did in the previous model where we took $\log(Y)$. The main reason to do this is that the values of `Y` are used to get a reasonable initial value for the sum of trees and also the variance of the leaf nodes. Thus, applying the inverse transformation is a simple way to improve the efficiency and accuracy of the result. Should we do this for every possible likelihood? Well, no. If we are using BART for the location parameter of distributions like Normal, StudentT, or AssymetricLaplace, we don't need to do anything as the support of these parameters is also the real line. A nontrivial exception is the Bernoulli likelihood (or Binomial with n=1), in that case, we need to apply the logistic function to the BART variable, but there is no need to apply its inverse to transform `Y`, PyMC-BART already takes care of that particular case.

OK, now let's see the result of `model_coal`.

```{code-cell} ipython3
_, ax = plt.subplots(figsize=(10, 6))
rates = idata_coal.posterior["μ"]
rate_mean = idata_coal.posterior["μ"].mean(dim=["draw", "chain"])
rates = idata_coal.posterior["μ"] / 4
rate_mean = rates.mean(dim=["draw", "chain"])
ax.plot(x_centers, rate_mean, "w", lw=3)
ax.plot(x_centers, y_data / 4, "k.")
az.plot_hdi(x_centers, rates, smooth=False)
az.plot_hdi(x_centers, rates, hdi_prob=0.5, smooth=False, plot_kwargs={"alpha": 0})
ax.plot(coal, np.zeros_like(coal) - 0.5, "k|")
ax.set_xlabel("years")
ax.set_ylabel("rate");
```

In the previous plot the white line is the median over 4000 posterior draws, and each one of those posterior draws is a sum over `m=20` trees.
The white line in the following plot shows the median rate of accidents. The darker orange band represent the HDI 50% and the lighter one the 94%. We can see a rapid decrease of coal accidents between 1880 and 1900. Feel free to compare these results with those in the original {ref}`pymc:pymc_overview` example.

In the previous plot the white line is the mean over 4000 posterior draws, and each one of those posterior draws is a sum over `m=20` trees.


The following figure shows two samples from the posterior of $\mu$. We can see that these functions are not smooth. This is fine and is a direct consequence of using regression trees. Trees can be seen as a way to represent stepwise functions, and a sum of stepwise functions is just another stepwise function. Thus, when using BART we just need to know that we are assuming that a stepwise function is a good enough approximation for our problem. In practice this is often the case because we sum over many trees, usually values like 50, 100 or 200. Additionally, we often average over the posterior distribution. All this makes the "steps smoother", even when we never really have an smooth function as for example with Gaussian processes (splines). A nice theoretical result, tells us that in the limit of $m \to \infty$ the BART prior converges to a [nowheredifferentiable](https://en.wikipedia.org/wiki/Weierstrass_function) Gaussian process.

The following figure shows two samples of $\mu$ from the posterior.

```{code-cell} ipython3
plt.step(x_data, idata_coal.posterior["μ"].sel(chain=0, draw=[3, 10]).T);
plt.step(x_data, rates.sel(chain=0, draw=[3, 10]).T);
```

The next figure shows 3 trees. As we can see these are very simple function and definitely not very good approximators by themselves. Inspecting individuals trees is generally not necessary when working with BART, we are showing them just so we can gain further intuition on the inner workings of BART.
Expand Down Expand Up @@ -150,23 +156,46 @@ Y = bikes["count"]

```{code-cell} ipython3
with pm.Model() as model_bikes:
α = pm.Exponential("α", 1 / 10)
μ = pmb.BART("μ", X, Y)
y = pm.NegativeBinomial("y", mu=pm.math.abs(μ), alpha=α, observed=Y)
idata_bikes = pm.sample(random_seed=RANDOM_SEED)
α = pm.Exponential("α", 1)
μ = pmb.BART("μ", X, np.log(Y), m=50)
y = pm.NegativeBinomial("y", mu=pm.math.exp(μ), alpha=α, observed=Y)
idata_bikes = pm.sample(compute_convergence_checks=False, random_seed=RANDOM_SEED)
```

### Convergence diagnostics

To check sampling convergence of BART models we recommend a 2 step approach.

* For the non-BART variables (like $\alpha$ in `model_bikes`) we follow the standard recommendations, like checking R-hat (<= 1.01), and ESS (< 100x number of chains) numerical diagnostics as well as using trace plots or even better rankplots
* For the BART variables we recommend using the `pmb.plot_convergence` function.

We can see such checks next:

```{code-cell} ipython3
az.plot_trace(idata_bikes, var_names=["α"], kind="rank_bars");
```

```{code-cell} ipython3
pmb.plot_convergence(idata_bikes, var_name="μ");
```

In the BART literature, the diagnostics of the BART variables is sometimes considered less important than the diagnostics of the non-BART variables, the main argument is that the individual estimates of the latent variables are of no direct interest, and instead we should only care about how well we are estimating the whole function/regression.

We instead consider checking the convergence of BART variables an important part of the Bayesian workflow. The main reason to use `pmb.plot_convergence` is that usually the BART variable will be a large vector (we estimate a distribution per observation) and thus we will need to check a large number of diagnostics. Additionally, the R-hat threshold of 1.01 is not a hard threshold, this value was chosen assuming one or a few R-hats are examined (and chains are long enough to accurately estimate their autocorrelation), and if we observed a large number of R-hat a few of them are expected to be larger than the 1.01 threshold (or whatever threshold we pick) even if there is nothing wrong with our inference. For that reason, a fair analysis should include a multiple comparison adjustment, and that's what `pmb.plot_convergence` does automatically for you. So, how to read its output? We have two panels one for ESS and one for the R-hat. The blue line is the empirical cumulative distribution for those values, for the ESS we want the entire curve above the dashed line, and for R-hat we want the curve to be entirely below the dashed line. In the previous figure, we can see that we barely make it for the ESS and for the R-hat we have very few values above the threshold. Are our results useless? Most likely not. But to be sure we may want to take a few more draws.

+++

### Partial dependence plots

+++

To help us interpret the results of our model we are going to use partial dependence plot. This is a type of plot that shows the marginal effect that one covariate has on the predicted variable. That is, what is the effect that a covariate $X_i$ has of $Y$ while we average over all the other covariates ($X_j, \forall j \not = i$). This type of plot are not exclusive of BART. But they are often used in the BART literature. PyMC-BART provides an utility function to make this plot from the inference data.
To help us interpret the results of our model we are going to use partial dependence plots. This is a type of plot that shows the marginal effect that one covariate has on the predicted variable. That is, what is the effect that a covariate $X_i$ has of $Y$ while we average over all the other covariates ($X_j, \forall j \not = i$). This type of plot are not exclusive of BART. But they are often used in the BART literature. PyMC-BART provides an utility function to make this plot from the inference data.

```{code-cell} ipython3
pmb.plot_dependence(μ, X=X, Y=Y, grid=(2, 2), var_discrete=[3]);
pmb.plot_dependence(μ, X=X, Y=Y, grid=(2, 2), func=np.exp);
```

From this plot we can see the main effect of each covariate on the predicted value. This is very useful we can recover complex relationship beyond monotonic increasing or decreasing effects. For example for the `hour` covariate we can see two peaks around 8 and and 17 hs and a minimum at midnight.
From this plot we can see the main effect of each covariate on the predicted value. This is very useful as we can recover complex relationship beyond monotonic increasing or decreasing effects. For example for the `hour` covariate we can see two peaks around 8 and and 17 hs and a minimum at midnight.

When interpreting partial dependence plots we should be careful about the assumptions in this plot. First we are assuming variables are independent. For example when computing the effect of `hour` we have to marginalize the effect of `temperature` and this means that to compute the partial dependence value at `hour=0` we are including all observed values of temperature, and this may include temperatures that are actually not observed at midnight, given that lower temperatures are more likely than higher ones. We are seeing only averages, so if for a covariate half the values are positively associated with predicted variable and the other half negatively associated. The partial dependence plot will be flat as their contributions will cancel each other out. This is a problem that can be solved by using individual conditional expectation plots `pmb.plot_dependence(..., kind="ice")`. Notice that all this assumptions are assumptions of the partial dependence plot, not of our model! In fact BART can easily accommodate interaction of variables Although the prior in BART regularizes high order interactions). For more on interpreting Machine Learning model you could check the "Interpretable Machine Learning" book {cite:p}`molnar2019`.

Expand All @@ -180,7 +209,7 @@ As we saw in the previous section a partial dependence plot can visualize give u

The following plot shows the relative importance in a scale from 0 to 1 (less to more importance) and the sum of the individual importance is 1. See that, at least in this case, the relative importance qualitative agrees with the partial dependence plot.

Additionally, PyMC-BART provides a novel method to assess the variable importance. You can see an example in the bottom panel. On the x-axis we have the number of covariables and on the y-axis the square of the Pearson correlation coefficient between the predictions made for the full-model (all variables included) and the restricted-models, those with only a subset of the variables. The components are included following the relative variable importance order, as show in the top panel. Thus, in this example 1 component means `hour`, two components means `hour` and `temperature`, 3 components `hour`, `temperature`and `humidity`. Finally, four components means `hour`, `temperature`, `humidity`, `workingday`, i.e., the full model. Hence, from the next figure we can see that even a model with a single component, `hour`, is very close to the full model. Even more, the model with two components `hour`, and `temperature` is on average indistinguishable from the full model. The error bars represent the 94 \% HDI from the posterior predictive distribution. It is important to notice that to compute these correlations we do not resample the models, instead the predictions of the restricted-models are approximated by *prunning* variables from the full-model.
Additionally, PyMC-BART provides a novel method to assess the variable importance. You can see an example in the bottom panel. On the x-axis we have the number of covariables and on the y-axis the square of the Pearson correlation coefficient between the predictions made for the full-model (all variables included) and the restricted-models, those with only a subset of the variables. The components are included following the relative variable importance order, as show in the top panel. Thus, in this example "number of covariables" is 1 `hour`, 2 `hour` and `temperature`, 3 `hour`, `temperature`and `humidity`. Finally, 4 means `hour`, `temperature`, `humidity`, `workingday`, i.e., the full model. Hence, from the next figure we can see that even a model with a single component, `hour`, is very close to the full model. Even more, the model with two components `hour`, and `temperature` is on average indistinguishable from the full model. The error bars represent the 94 \% HDI from the posterior predictive distribution. It is important to notice that to compute these correlations we do not resample the models, instead the predictions of the restricted-models are approximated by *prunning* variables from the full-model.

```{code-cell} ipython3
pmb.plot_variable_importance(idata_bikes, μ, X, samples=100);
Expand All @@ -206,9 +235,9 @@ Now, we fit the same model as above but this time using a *shared variable* for
with pm.Model() as model_oos_regression:
X = pm.MutableData("X", X_train)
Y = Y_train
α = pm.Exponential("α", 1 / 10)
μ = pmb.BART("μ", X, Y)
y = pm.NegativeBinomial("y", mu=pm.math.abs(μ), alpha=α, observed=Y, shape=μ.shape)
α = pm.Exponential("α", 1)
μ = pmb.BART("μ", X, np.log(Y))
y = pm.NegativeBinomial("y", mu=pm.math.exp(μ), alpha=α, observed=Y, shape=μ.shape)
idata_oos_regression = pm.sample(random_seed=RANDOM_SEED)
posterior_predictive_oos_regression_train = pm.sample_posterior_predictive(
trace=idata_oos_regression, random_seed=RANDOM_SEED
Expand Down Expand Up @@ -273,8 +302,8 @@ with pm.Model() as model_oos_ts:
X = pm.MutableData("X", X_train)
Y = Y_train
α = pm.Exponential("α", 1 / 10)
μ = pmb.BART("μ", X, Y)
y = pm.NegativeBinomial("y", mu=pm.math.abs(μ), alpha=α, observed=Y, shape=μ.shape)
μ = pmb.BART("μ", X, np.log(Y))
y = pm.NegativeBinomial("y", mu=pm.math.exp(μ), alpha=α, observed=Y, shape=μ.shape)
idata_oos_ts = pm.sample(random_seed=RANDOM_SEED)
posterior_predictive_oos_ts_train = pm.sample_posterior_predictive(
trace=idata_oos_ts, random_seed=RANDOM_SEED
Expand Down Expand Up @@ -370,6 +399,7 @@ This plot helps us understand the season behind the bad performance on the test
* Updated by Osvaldo Martin in Sep, 2022
* Updated by Osvaldo Martin in Nov, 2022
* Juan Orduz added out-of-sample section in Jan, 2023
* Updated by Osvaldo Martin in Mar, 2023

+++

Expand Down
173 changes: 24 additions & 149 deletions examples/case_studies/BART_quantile_regression.ipynb

Large diffs are not rendered by default.

1 change: 1 addition & 0 deletions examples/case_studies/BART_quantile_regression.myst.md
Original file line number Diff line number Diff line change
Expand Up @@ -146,6 +146,7 @@ We can see that when we use a Normal likelihood, and from that fit we compute th

## Authors
* Authored by Osvaldo Martin in Jan, 2023
* Rerun by Osvaldo Martin in March 2023

+++

Expand Down
Loading

0 comments on commit 2ab6aab

Please sign in to comment.