Skip to content

Commit

Permalink
Fix plot_lm and _z_scale for scipy 1.10 (arviz-devs#2186)
Browse files Browse the repository at this point in the history
* stop using np.tile on plotter list

* take nan_policy into account for scipy>=1.10

* fix version comparisons to take prereleases into account

* update changelog

* workarounds to be able to merge
  • Loading branch information
OriolAbril authored Dec 22, 2022
1 parent 3c82bdd commit fadbc20
Show file tree
Hide file tree
Showing 4 changed files with 25 additions and 6 deletions.
2 changes: 2 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,8 @@
- plot_bpv smooth discrete data only when computing u_values ([2179](https://github.com/arviz-devs/arviz/pull/2179))
- Fix bug when beanmachine objects lack some fields ([2154](https://github.com/arviz-devs/arviz/pull/2154))
- Fix gap for `plot_trace` with option `kind="rank_bars"` ([2180](https://github.com/arviz-devs/arviz/pull/2180))
- Fix `plot_lm` unsupported usage of `np.tile` ([2186](https://github.com/arviz-devs/arviz/pull/2186))
- Update `_z_scale` to work with SciPy 1.10 ([2186](https://github.com/arviz-devs/arviz/pull/2186))

### Deprecation

Expand Down
13 changes: 9 additions & 4 deletions arviz/plots/lmplot.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
"""Plot regression figure."""
import warnings
from numbers import Integral
from itertools import repeat

import xarray as xr
import numpy as np
Expand All @@ -11,6 +12,10 @@
from .plot_utils import default_grid, filter_plotters_list, get_plotting_function


def _repeat_flatten_list(lst, n):
return [item for sublist in repeat(lst, n) for item in sublist]


def plot_lm(
y,
idata=None,
Expand Down Expand Up @@ -268,8 +273,8 @@ def plot_lm(
len_y = len(y)
len_x = len(x)
length_plotters = len_x * len_y
y = np.tile(y, (len_x, 1))
x = np.tile(x, (len_y, 1))
y = _repeat_flatten_list(y, len_x)
x = _repeat_flatten_list(x, len_y)

# Filter out the required values to generate plotters
if y_hat is not None:
Expand All @@ -289,7 +294,7 @@ def plot_lm(
)
]

y_hat = np.tile(y_hat, (len_x, 1))
y_hat = _repeat_flatten_list(y_hat, len_x)

# Filter out the required values to generate plotters
if y_model is not None:
Expand All @@ -307,7 +312,7 @@ def plot_lm(
),
)
]
y_model = np.tile(y_model, (len_x, 1))
y_model = _repeat_flatten_list(y_model, len_x)

rows, cols = default_grid(length_plotters)

Expand Down
10 changes: 9 additions & 1 deletion arviz/stats/diagnostics.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,9 @@
from collections.abc import Sequence

import numpy as np
import packaging
import pandas as pd
import scipy
from scipy import stats

from ..data import convert_to_dataset
Expand Down Expand Up @@ -538,7 +540,13 @@ def _z_scale(ary):
np.ndarray
"""
ary = np.asarray(ary)
rank = stats.rankdata(ary, method="average")
if packaging.version.parse(scipy.__version__) < packaging.version.parse("1.10.0.dev0"):
rank = stats.rankdata(ary, method="average")
else:
# the .ravel part is only needed to overcom a bug in scipy 1.10.0.rc1
rank = stats.rankdata( # pylint: disable=unexpected-keyword-arg
ary.ravel(), method="average", nan_policy="omit"
)
rank = _backtransform_ranks(rank)
z = stats.norm.ppf(rank)
z = z.reshape(ary.shape)
Expand Down
6 changes: 5 additions & 1 deletion arviz/tests/base_tests/test_diagnostics.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,8 +3,10 @@
import os

import numpy as np
import packaging
import pandas as pd
import pytest
import scipy
from numpy.testing import assert_almost_equal

from ...data import from_cmdstan, load_arviz_data
Expand Down Expand Up @@ -488,9 +490,11 @@ def test_nan_behaviour(self, func):
data[0, 0] = np.nan # pylint: disable=unsupported-assignment-operation
if func == "_mcse_quantile":
assert np.isnan(_mcse_quantile(data, 0.5)).all(None)
else:
elif packaging.version.parse(scipy.__version__) < packaging.version.parse("1.10.0.dev0"):
assert not np.isnan(_z_scale(data)).all(None)
assert not np.isnan(_z_scale(data)).any(None)
else:
assert np.isnan(_z_scale(data)).sum() == 1

@pytest.mark.parametrize("chains", (None, 1, 2, 3))
@pytest.mark.parametrize("draws", (2, 3, 100, 101))
Expand Down

0 comments on commit fadbc20

Please sign in to comment.