Skip to content

Commit

Permalink
fix origin handling for minimum_filter and maximum_filter with axes s…
Browse files Browse the repository at this point in the history
…ubset
  • Loading branch information
grlee77 committed May 6, 2024
1 parent 7f0a5cb commit 1ce327b
Show file tree
Hide file tree
Showing 2 changed files with 36 additions and 4 deletions.
14 changes: 10 additions & 4 deletions scipy/ndimage/_filters.py
Original file line number Diff line number Diff line change
Expand Up @@ -1240,7 +1240,7 @@ def _min_or_max_filter(input, size, footprint, structure, output, mode,
footprint = np.asarray(footprint, dtype=bool)
input = np.asarray(input)
if np.iscomplexobj(input):
raise TypeError('Complex type not supported')
raise TypeError("Complex type not supported")
output = _ni_support._get_output(output, input)
temp_needed = np.may_share_memory(input, output)
if temp_needed:
Expand All @@ -1266,25 +1266,31 @@ def _min_or_max_filter(input, size, footprint, structure, output, mode,
else:
output[...] = input[...]
else:
origins = _ni_support._normalize_sequence(origin, input.ndim)
origins = _ni_support._normalize_sequence(origin, num_axes)
if num_axes < input.ndim:
if footprint.ndim != num_axes:
raise RuntimeError("footprint array has incorrect shape")
footprint = np.expand_dims(
footprint,
tuple(ax for ax in range(input.ndim) if ax not in axes)
)
# set origin = 0 for any axes not being filtered
origins_temp = [0,] * input.ndim
for o, ax in zip(origins, axes):
origins_temp[ax] = o
origins = origins_temp

fshape = [ii for ii in footprint.shape if ii > 0]
if len(fshape) != input.ndim:
raise RuntimeError('footprint array has incorrect shape.')
for origin, lenf in zip(origins, fshape):
if (lenf // 2 + origin < 0) or (lenf // 2 + origin >= lenf):
raise ValueError('invalid origin')
raise ValueError("invalid origin")
if not footprint.flags.contiguous:
footprint = footprint.copy()
if structure is not None:
if len(structure.shape) != input.ndim:
raise RuntimeError('structure array has incorrect shape')
raise RuntimeError("structure array has incorrect shape")
if num_axes != structure.ndim:
structure = np.expand_dims(
structure,
Expand Down
26 changes: 26 additions & 0 deletions scipy/ndimage/tests/test_filters.py
Original file line number Diff line number Diff line change
Expand Up @@ -792,6 +792,32 @@ def test_filter_axes_kwargs(self, filter_func, size0, size, kwargs, axes):
expected = filter_func(array, *args, size_3d, **kwargs)
assert_allclose(output, expected)

@pytest.mark.parametrize("filter_func, kwargs",
[(ndimage.minimum_filter, {}),
(ndimage.maximum_filter, {}),
(ndimage.median_filter, {}),
(ndimage.rank_filter, {"rank": 1}),
(ndimage.percentile_filter, {"percentile": 30})])
def test_filter_weights_subset_axes_origins(self, filter_func, kwargs):
axes = (-2, -1)
origins = (0, 1)
array = np.arange(6 * 8 * 12, dtype=np.float64).reshape(6, 8, 12)
axes = np.array(axes)

# weights with ndim matching len(axes)
footprint = np.ones((3, 5), dtype=bool)
footprint[0, 1] = 0 # make non-separable

output = filter_func(
array, footprint=footprint, axes=axes, origin=origins, **kwargs)

output0 = filter_func(
array, footprint=footprint, axes=axes, origin=0, **kwargs)

# output has origin shift on last axis relative to output0, so
# expect shifted arrays to be equal.
np.testing.assert_array_equal(output[:, :, 1:], output0[:, :, :-1])

@pytest.mark.parametrize(
'filter_func, args',
[(ndimage.gaussian_filter, (1.0,)), # args = (sigma,)
Expand Down

0 comments on commit 1ce327b

Please sign in to comment.