Skip to content

Commit

Permalink
FIX Fixes pandas extension arrays with objects in check_array (#25814)
Browse files Browse the repository at this point in the history
  • Loading branch information
thomasjpfan authored Mar 22, 2023
1 parent 65dfab0 commit 18af550
Show file tree
Hide file tree
Showing 3 changed files with 38 additions and 0 deletions.
4 changes: 4 additions & 0 deletions doc/whats_new/v1.3.rst
Original file line number Diff line number Diff line change
Expand Up @@ -410,6 +410,10 @@ Changelog
- |FIX| Fixes :func:`utils.validation.check_array` to properly convert pandas
extension arrays. :pr:`25813` by `Thomas Fan`_.

- |Fix| :func:`utils.validation.check_array` now suports pandas DataFrames with
extension arrays and object dtypes by return an ndarray with object dtype.
:pr:`25814` by `Thomas Fan`_.

:mod:`sklearn.semi_supervised`
..............................

Expand Down
31 changes: 31 additions & 0 deletions sklearn/utils/tests/test_validation.py
Original file line number Diff line number Diff line change
Expand Up @@ -1795,3 +1795,34 @@ def test_check_array_array_api_has_non_finite(array_namespace):
with config_context(array_api_dispatch=True):
with pytest.raises(ValueError, match="infinity or a value too large"):
check_array(X_inf)


@pytest.mark.parametrize(
"extension_dtype, regular_dtype",
[
("boolean", "bool"),
("Int64", "int64"),
("Float64", "float64"),
("category", "object"),
],
)
@pytest.mark.parametrize("include_object", [True, False])
def test_check_array_multiple_extensions(
extension_dtype, regular_dtype, include_object
):
"""Check pandas extension arrays give the same result as non-extension arrays."""
pd = pytest.importorskip("pandas")
X_regular = pd.DataFrame(
{
"a": pd.Series([1, 0, 1, 0], dtype=regular_dtype),
"c": pd.Series([9, 8, 7, 6], dtype="int64"),
}
)
if include_object:
X_regular["b"] = pd.Series(["a", "b", "c", "d"], dtype="object")

X_extension = X_regular.assign(a=X_regular["a"].astype(extension_dtype))

X_regular_checked = check_array(X_regular, dtype=None)
X_extension_checked = check_array(X_extension, dtype=None)
assert_array_equal(X_regular_checked, X_extension_checked)
3 changes: 3 additions & 0 deletions sklearn/utils/validation.py
Original file line number Diff line number Diff line change
Expand Up @@ -785,6 +785,9 @@ def check_array(
)
if all(isinstance(dtype_iter, np.dtype) for dtype_iter in dtypes_orig):
dtype_orig = np.result_type(*dtypes_orig)
elif pandas_requires_conversion and any(d == object for d in dtypes_orig):
# Force object if any of the dtypes is an object
dtype_orig = object

elif (_is_extension_array_dtype(array) or hasattr(array, "iloc")) and hasattr(
array, "dtype"
Expand Down

0 comments on commit 18af550

Please sign in to comment.