Skip to content

Commit

Permalink
Isin (pydata#2031)
Browse files Browse the repository at this point in the history
* gitignore testmon

* initial isin implementation

* gitignore

* dask

* numpy version check not needed

* numpy version check for isin

* move to common

* rename data_set to ds

* Revert "rename data_set to ds"

This reverts commit 75493c2.

* 'expect' test for dataset

* unneeded import

* formatting

* docs

* Raise an informative error message when converting Dataset -> np.ndarray

Makes `np.asarray(dataset)` issue an informative error. Currently,
`np.asarray(xr.Dataset({'x': 0}))` raises `KeyError: 0`, which makes no sense.

* normal tests are better than a weird middle ground

* dask test

* grammar

* try changing skip decorator ordering

* just use has_dask

* another noqa?

* flake for py3.4

* flake
  • Loading branch information
max-sixty authored and shoyer committed Apr 4, 2018
1 parent 8c194b6 commit a5f7d6a
Show file tree
Hide file tree
Showing 7 changed files with 135 additions and 9 deletions.
5 changes: 3 additions & 2 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -37,7 +37,7 @@ nosetests.xml
.cache
.ropeproject/
.tags*
.testmondata
.testmon*
.pytest_cache

# asv environments
Expand All @@ -51,10 +51,11 @@ nosetests.xml
.project
.pydevproject

# PyCharm and Vim
# IDEs
.idea
*.swp
.DS_Store
.vscode/

# xarray specific
doc/_build
Expand Down
2 changes: 2 additions & 0 deletions doc/api.rst
Original file line number Diff line number Diff line change
Expand Up @@ -174,6 +174,7 @@ Computation
:py:attr:`~Dataset.cumsum`
:py:attr:`~Dataset.cumprod`
:py:attr:`~Dataset.rank`
:py:attr:`~Dataset.isin`

**Grouped operations**:
:py:attr:`~core.groupby.DatasetGroupBy.assign`
Expand Down Expand Up @@ -339,6 +340,7 @@ Computation
:py:attr:`~DataArray.cumsum`
:py:attr:`~DataArray.cumprod`
:py:attr:`~DataArray.rank`
:py:attr:`~DataArray.isin`

**Grouped operations**:
:py:attr:`~core.groupby.DataArrayGroupBy.assign_coords`
Expand Down
7 changes: 6 additions & 1 deletion doc/whats-new.rst
Original file line number Diff line number Diff line change
Expand Up @@ -37,7 +37,12 @@ Documentation
Enhancements
~~~~~~~~~~~~

- Some speed improvement to construct :py:class:`~xarray.DataArrayRolling`
- `~xarray.DataArray.isin` and `~xarray.Dataset.isin` methods, which test each value
in the array for whether it is contained in the supplied list, returning a bool array.
Similar to the ``np.isin`` function. Requires NumPy >= 1.13
By `Maximilian Roos <https://github.com/maxim-lian>`

- Some speed improvement to construct :py:class:`~xarray.DataArrayRolling`
object (:issue:`1993`)
By `Keisuke Fujii <https://github.com/fujiisoup>`_.
- Handle variables with different values for ``missing_value`` and
Expand Down
30 changes: 30 additions & 0 deletions xarray/core/common.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
from __future__ import absolute_import, division, print_function

import warnings
from distutils.version import LooseVersion

import numpy as np
import pandas as pd
Expand Down Expand Up @@ -744,6 +745,35 @@ def close(self):
self._file_obj.close()
self._file_obj = None

def isin(self, test_elements):
"""Tests each value in the array for whether it is in the supplied list
Requires NumPy >= 1.13
Parameters
----------
element : array_like
Input array.
test_elements : array_like
The values against which to test each value of `element`.
This argument is flattened if an array or array_like.
See numpy notes for behavior with non-array-like parameters.
-------
isin : same as object, bool
Has the same shape as object
"""
if LooseVersion(np.__version__) < LooseVersion('1.13.0'):
raise ImportError('isin requires numpy version 1.13.0 or later')
from .computation import apply_ufunc

return apply_ufunc(
np.isin,
self,
kwargs=dict(test_elements=test_elements),
dask='parallelized',
output_dtypes=[np.bool_],
)

def __enter__(self):
return self

Expand Down
33 changes: 33 additions & 0 deletions xarray/tests/test_dataarray.py
Original file line number Diff line number Diff line change
Expand Up @@ -3327,6 +3327,14 @@ def da(request):
[0, np.nan, 1, 2, np.nan, 3, 4, 5, np.nan, 6, 7],
dims='time')

if request.param == 'repeating_ints':
return DataArray(
np.tile(np.arange(12), 5).reshape(5, 4, 3),
coords={'x': list('abc'),
'y': list('defg')},
dims=list('zyx')
)


@pytest.fixture
def da_dask(seed=123):
Expand All @@ -3339,6 +3347,31 @@ def da_dask(seed=123):
return da


@pytest.mark.skipif(LooseVersion(np.__version__) < LooseVersion('1.13.0'),
reason='requires numpy version 1.13.0 or later')
@pytest.mark.parametrize('da', ('repeating_ints', ), indirect=True)
def test_isin(da):

expected = DataArray(
np.asarray([[0, 0, 0], [1, 0, 0]]),
dims=list('yx'),
coords={'x': list('abc'),
'y': list('de')},
).astype('bool')

result = da.isin([3]).sel(y=list('de'), z=0)
assert_equal(result, expected)

expected = DataArray(
np.asarray([[0, 0, 1], [1, 0, 0]]),
dims=list('yx'),
coords={'x': list('abc'),
'y': list('de')},
).astype('bool')
result = da.isin([2, 3]).sel(y=list('de'), z=0)
assert_equal(result, expected)


@pytest.mark.parametrize('da', (1, 2), indirect=True)
def test_rolling_iter(da):

Expand Down
65 changes: 61 additions & 4 deletions xarray/tests/test_dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@

from . import (
InaccessibleArray, TestCase, UnexpectedDataAccess, assert_allclose,
assert_array_equal, assert_equal, assert_identical, raises_regex,
assert_array_equal, assert_equal, assert_identical, has_dask, raises_regex,
requires_bottleneck, requires_dask, requires_scipy, source_ndarray)

try:
Expand Down Expand Up @@ -4037,9 +4037,66 @@ def test_ipython_key_completion(self):
# Py.test tests


@pytest.fixture()
def data_set(seed=None):
return create_test_data(seed)
@pytest.fixture(params=[None])
def data_set(request):
return create_test_data(request.param)


@pytest.mark.skipif(LooseVersion(np.__version__) < LooseVersion('1.13.0'),
reason='requires numpy version 1.13.0 or later')
@pytest.mark.parametrize('test_elements', (
[1, 2],
np.array([1, 2]),
DataArray([1, 2]),
pytest.mark.xfail(Dataset({'x': [1, 2]})),
))
def test_isin(test_elements):
expected = Dataset(
data_vars={
'var1': (('dim1',), [0, 1]),
'var2': (('dim1',), [1, 1]),
'var3': (('dim1',), [0, 1]),
}
).astype('bool')

result = Dataset(
data_vars={
'var1': (('dim1',), [0, 1]),
'var2': (('dim1',), [1, 2]),
'var3': (('dim1',), [0, 1]),
}
).isin(test_elements)

assert_equal(result, expected)


@pytest.mark.skipif(LooseVersion(np.__version__) < LooseVersion('1.13.0') or # noqa
not has_dask, # noqa
reason='requires dask and numpy version 1.13.0 or later')
@pytest.mark.parametrize('test_elements', (
[1, 2],
np.array([1, 2]),
DataArray([1, 2]),
pytest.mark.xfail(Dataset({'x': [1, 2]})),
))
def test_isin_dask(test_elements):
expected = Dataset(
data_vars={
'var1': (('dim1',), [0, 1]),
'var2': (('dim1',), [1, 1]),
'var3': (('dim1',), [0, 1]),
}
).astype('bool')

result = Dataset(
data_vars={
'var1': (('dim1',), [0, 1]),
'var2': (('dim1',), [1, 2]),
'var3': (('dim1',), [0, 1]),
}
).chunk(1).isin(test_elements).compute()

assert_equal(result, expected)


def test_dir_expected_attrs(data_set):
Expand Down
2 changes: 0 additions & 2 deletions xarray/tests/test_variable.py
Original file line number Diff line number Diff line change
Expand Up @@ -1422,8 +1422,6 @@ def test_reduce(self):
with raises_regex(ValueError, 'cannot supply both'):
v.mean(dim='x', axis=0)

@pytest.mark.skipif(LooseVersion(np.__version__) < LooseVersion('1.10.0'),
reason='requires numpy version 1.10.0 or later')
def test_quantile(self):
v = Variable(['x', 'y'], self.d)
for q in [0.25, [0.50], [0.25, 0.75]]:
Expand Down

0 comments on commit a5f7d6a

Please sign in to comment.