Skip to content

Commit

Permalink
Add default log_lik handling for pystan and cmdstan (arviz-devs#1599)
Browse files Browse the repository at this point in the history
* default log_lik handling

* update changelog

* update pystan test

* fix test

* change log_lik tests for cmdstan

* fix type issue with boolean arrays

* remove utils.full
  • Loading branch information
ahartikainen authored Mar 5, 2021
1 parent 70fcf84 commit f86ddec
Show file tree
Hide file tree
Showing 6 changed files with 30 additions and 14 deletions.
2 changes: 1 addition & 1 deletion CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@
* Added `arviz.labels` module with classes and utilities ([1201](https://github.com/arviz-devs/arviz/pull/1201))
* Added probability estimate within ROPE in `plot_posterior` ([1570](https://github.com/arviz-devs/arviz/pull/1570))
* Added `rope_color` and `ref_val_color` arguments to `plot_posterior` ([1570](https://github.com/arviz-devs/arviz/pull/1570))
* Improved retrieving or pointwise log likelihood in `from_cmdstanpy` ([1579](https://github.com/arviz-devs/arviz/pull/1579) and [1598](https://github.com/arviz-devs/arviz/pull/1598))
* Improved retrieving or pointwise log likelihood in `from_cmdstanpy`, `from_cmdstan` and `from_pystan` ([1579](https://github.com/arviz-devs/arviz/pull/1579) and [1599](https://github.com/arviz-devs/arviz/pull/1599))
* Added interactive legend to bokeh `forestplot` ([1591](https://github.com/arviz-devs/arviz/pull/1591))

### Maintenance and fixes
Expand Down
9 changes: 8 additions & 1 deletion arviz/data/io_cmdstan.py
Original file line number Diff line number Diff line change
Expand Up @@ -90,6 +90,13 @@ def __init__(
self._parse_posterior()
self._parse_prior()

if (
self.log_likelihood is None
and self.posterior_ is not None
and any(name.split(".")[0] == "log_lik" for name in self.posterior_columns)
):
self.log_likelihood = ["log_lik"]

@requires("posterior_")
def _parse_posterior(self):
"""Read csv paths to list of ndarrays."""
Expand Down Expand Up @@ -871,7 +878,7 @@ def _unpack_ndarrays(arrays, columns, dtypes=None):
for key, cols_locs in col_groups.items():
ndim = np.array([loc for _, loc in cols_locs]).max(0) + 1
dtype = dtypes.get(key, np.float64)
sample[key] = utils.full((chains, draws, *ndim), 0, dtype=dtype)
sample[key] = np.zeros((chains, draws, *ndim), dtype=dtype)
for col, loc in cols_locs:
for chain_id, arr in enumerate(arrays):
draw = arr[:, col]
Expand Down
14 changes: 14 additions & 0 deletions arviz/data/io_pystan.py
Original file line number Diff line number Diff line change
Expand Up @@ -54,6 +54,13 @@ def __init__(
self.dims = dims
self.save_warmup = rcParams["data.save_warmup"] if save_warmup is None else save_warmup

if (
self.log_likelihood is None
and self.posterior is not None
and "log_lik" in self.posterior.sim["pars_oi"]
):
self.log_likelihood = ["log_lik"]

import pystan # pylint: disable=import-error

self.pystan = pystan
Expand Down Expand Up @@ -313,6 +320,13 @@ def __init__(
self.coords = coords
self.dims = dims

if (
self.log_likelihood is None
and self.posterior is not None
and "log_lik" in self.posterior.param_names
):
self.log_likelihood = ["log_lik"]

import stan # pylint: disable=import-error

self.stan = stan
Expand Down
5 changes: 2 additions & 3 deletions arviz/tests/external_tests/test_data_cmdstan.py
Original file line number Diff line number Diff line change
Expand Up @@ -166,7 +166,6 @@ def test_inference_data_input_types2(self, paths, observed_data_paths):
constant_data_var=["y"],
predictions_constant_data=observed_data_paths[0],
predictions_constant_data_var=["y"],
log_likelihood="log_lik",
coords={"school": np.arange(8)},
dims={
"theta": ["school"],
Expand Down Expand Up @@ -277,7 +276,7 @@ def test_inference_data_input_types5(self, paths, observed_data_paths):
prior_predictive=None,
observed_data=observed_data_paths[0],
observed_data_var=["y"],
log_likelihood=["log_lik"],
log_likelihood=["y_hat"],
coords={"school": np.arange(8), "log_lik_dim": np.arange(8)},
dims={
"theta": ["school"],
Expand All @@ -290,7 +289,7 @@ def test_inference_data_input_types5(self, paths, observed_data_paths):
test_dict = {
"posterior": ["mu", "tau", "theta_tilde", "theta"],
"prior": ["mu", "tau", "theta_tilde", "theta"],
"log_likelihood": ["log_lik"],
"log_likelihood": ["y_hat"],
"observed_data": ["y"],
"sample_stats_prior": ["lp"],
}
Expand Down
8 changes: 5 additions & 3 deletions arviz/tests/external_tests/test_data_pystan.py
Original file line number Diff line number Diff line change
Expand Up @@ -118,6 +118,7 @@ def get_inference_data4(self, data):
coords=None,
dims=None,
posterior_model=data.model,
log_likelihood=[],
prior_model=data.model,
save_warmup=pystan_version() == 2,
)
Expand Down Expand Up @@ -173,6 +174,7 @@ def test_inference_data(self, data, eight_schools_params):
"predictions_constant_data": ["sigma", "y"],
"sample_stats_prior": ["diverging"],
"sample_stats": ["diverging", "lp"],
"log_likelihood": ["log_lik"],
"prior_predictive": ["y_hat", "log_lik"],
}
fails = check_multiple_attrs(test_dict, inference_data3)
Expand Down Expand Up @@ -261,10 +263,10 @@ def test_index_order(self, data, eight_schools_params):
idata = from_pystan(posterior=fit)
assert idata is not None
for j, fpar in enumerate(fit.sim["fnames_oi"]):
if fpar == "lp__":
continue
par, *shape = fpar.replace("]", "").split("[")
assert hasattr(idata.posterior, par)
if par in {"lp__", "log_lik"}:
continue
assert hasattr(idata.posterior, par), (par, list(idata.posterior.data_vars))
if shape:
shape = [slice(None), slice(None)] + list(map(int, shape))
assert idata.posterior[par][tuple(shape)].values.mean() == float(j)
Expand Down
6 changes: 0 additions & 6 deletions arviz/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -427,12 +427,6 @@ def _cov(data):
raise ValueError("{} dimension arrays are not supported".format(data.ndim))


@conditional_jit(nopython=True)
def full(shape, x, dtype=None):
"""Jitting numpy full."""
return np.full(shape, x, dtype=dtype)


def flatten_inference_data_to_dict(
data,
var_names=None,
Expand Down

0 comments on commit f86ddec

Please sign in to comment.