Skip to content

Commit

Permalink
Fix phasor_calibrate to handle multi harmonic calibration (phasorpy#69
Browse files Browse the repository at this point in the history
)

* Added phasorpy_dev virtual environment to gitignore

* Update main branch

* Update of `phasor_transform` and `phasor_center`

* Fix bug with `phasor_transform` when `skip_axis = None`

* Add tests for changes to 'phasor_calibrate', 'phasor_transform' and 'phasor_center' to handle multiple harmonics. Modified 'phasorpy_introduction' tutorial to add multiple harmonic calibration.

* Delete test file

* Moved fix for multiple harmonics to `phasor_calibrate` from `phasor_transform`. Updated tests and removed example of multiple harmonic calibration from `phasorpy_introduction`.

* Fix minor error in calling `real` variable in `phasor_calibrate` instead of `re`.

* Add `_determine_axis` helper function to be used in `phasor_center` and `phasor_calibrate`.

* Applied blackdoc formatting

* Update `_parse_skip_axis` and add test.

* Minor fix doctest in `_parse_skip_axis`.
  • Loading branch information
bruno-pannunzio authored and schutyb committed Aug 7, 2024
1 parent dd2d2f1 commit cf41a64
Show file tree
Hide file tree
Showing 2 changed files with 190 additions and 42 deletions.
88 changes: 72 additions & 16 deletions src/phasorpy/phasor.py
Original file line number Diff line number Diff line change
Expand Up @@ -479,7 +479,7 @@ def phasor_calibrate(
preexponential: bool = False,
unit_conversion: float = 1e-3,
method: Literal['mean', 'median'] = 'mean',
skip_axes: int | Sequence[int] | None = None,
skip_axis: int | Sequence[int] | None = None,
) -> tuple[NDArray[Any], NDArray[Any]]:
"""
Return calibrated/referenced phasor coordinates.
Expand Down Expand Up @@ -525,7 +525,7 @@ def phasor_calibrate(
- ``'mean'``: Arithmetic mean of phasor coordinates.
- ``'median'``: Spatial median of phasor coordinates.
skip_axes : int or sequence of int, optional
skip_axis : int or sequence of int, optional
Axes to be excluded during center calculation. If None, all
axes are considered.
Expand Down Expand Up @@ -562,7 +562,7 @@ def phasor_calibrate(
*phasor_center(
reference_real,
reference_imag,
skip_axes,
skip_axis,
method,
),
*phasor_from_lifetime(
Expand Down Expand Up @@ -600,7 +600,7 @@ def phasor_calibrate(
f'!= reference_imag.shape{ref_im.shape}'
)
measured_re, measured_im = phasor_center(
reference_real, reference_imag, skip_axes=skip_axes, method=method
reference_real, reference_imag, skip_axis=skip_axis, method=method
)
known_re, known_im = phasor_from_lifetime(
frequency,
Expand All @@ -612,6 +612,17 @@ def phasor_calibrate(
phi_zero, mod_zero = polar_from_reference_phasor(
measured_re, measured_im, known_re, known_im
)
if numpy.ndim(phi_zero) > 0:
_, axis = _parse_skip_axis(skip_axis, re.ndim)
if axis is not None:
phi_zero = numpy.expand_dims(
phi_zero,
axis=axis,
)
mod_zero = numpy.expand_dims(
mod_zero,
axis=axis,
)
return phasor_transform(re, im, phi_zero, mod_zero)


Expand Down Expand Up @@ -1928,8 +1939,9 @@ def phasor_center(
imag: ArrayLike,
/,
*,
skip_axes: int | Sequence[int] | None = None,
skip_axis: int | Sequence[int] | None = None,
method: Literal['mean', 'median'] = 'mean',
**kwargs: Any,
) -> tuple[NDArray[Any], NDArray[Any]]:
"""Return center of phasor coordinates.
Expand All @@ -1939,7 +1951,7 @@ def phasor_center(
Real component of phasor coordinates.
imag : array_like
Imaginary component of phasor coordinates.
skip_axes : int or sequence of int, optional
skip_axis : int or sequence of int, optional
Axes to be excluded during center calculation. If None, all
axes are considered.
method : str, optional
Expand All @@ -1948,6 +1960,10 @@ def phasor_center(
- ``'mean'``: Arithmetic mean of phasor coordinates.
- ``'median'``: Spatial median of phasor coordinates.
**kwargs
Optional arguments passed to :py:func:`numpy.mean` or
:py:func:`numpy.median`.
Returns
-------
real_center : ndarray
Expand Down Expand Up @@ -1985,22 +2001,14 @@ def phasor_center(
if real.shape != imag.shape:
raise ValueError(f'{real.shape=} != {imag.shape=}')

if skip_axes is None:
axis = None
else:
if not isinstance(skip_axes, Sequence):
skip_axes = (skip_axes,)
if any(i >= real.ndim for i in skip_axes):
raise IndexError(f'{skip_axes=} out of range {real.ndim=}')
skip_axes = tuple(i % real.ndim for i in skip_axes)
axis = tuple(i for i in range(real.ndim) if i not in skip_axes)
_, axis = _parse_skip_axis(skip_axis, real.ndim)

return {
'mean': _mean,
'median': _median,
}[
method
](real, imag, axis=axis)
](real, imag, axis=axis, **kwargs)


def _mean(
Expand Down Expand Up @@ -2061,3 +2069,51 @@ def _median(
"""
return numpy.median(real, **kwargs), numpy.median(imag, **kwargs)


def _parse_skip_axis(
skip_axis: int | Sequence[int] | None,
/,
ndim: int,
) -> tuple[tuple[int, ...], tuple[int, ...]]:
"""Return axes to skip and not to skip.
This helper function is used to validate and parse `skip_axis`
parameters.
Parameters
----------
skip_axis : Sequence of int, or None
Axes to skip. If None, no axes are skipped.
ndim : int
Dimensionality of array in which to skip axes.
Returns
-------
skip_axis
Ordered, positive values of `skip_axis`.
other_axis
Axes indices not included in `skip_axis`.
Raises
------
IndexError
If any `skip_axis` value is out of bounds of `ndim`.
Examples
--------
>>> _parse_skip_axis((1, -2), 5)
((1, 3), (0, 2, 4))
"""
if ndim < 0:
raise ValueError(f'invalid {ndim=}')
if skip_axis is None:
return (), tuple(range(ndim))
if not isinstance(skip_axis, Sequence):
skip_axis = (skip_axis,)
if any(i >= ndim or i < -ndim for i in skip_axis):
raise IndexError(f"skip_axis={skip_axis} out of range for {ndim=}")
skip_axis = tuple(sorted(int(i % ndim) for i in skip_axis))
other_axis = tuple(i for i in range(ndim) if i not in skip_axis)
return skip_axis, other_axis
144 changes: 118 additions & 26 deletions tests/test_phasor.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@
mkl_fft = None

from phasorpy.phasor import (
_parse_skip_axis,
fraction_from_amplitude,
fraction_to_amplitude,
frequency_from_lifetime,
Expand Down Expand Up @@ -492,8 +493,22 @@ def test_polar_from_reference_functions():
[
(2, 2, None, None, 2, 2),
(2, 2, 0, 1, 2, 2),
(2, 2, 0.5, 2.0, 1.592628093144679, 5.428032401978303),
(2, 2, -0.5, -2.0, -5.428032401978303, -1.592628093144679),
(
2,
2,
0.5,
2.0,
1.592628093144679,
5.428032401978303,
),
(
2,
2,
-0.5,
-2.0,
-5.428032401978303,
-1.592628093144679,
),
(
SYNTH_DATA_LIST,
SYNTH_DATA_LIST,
Expand Down Expand Up @@ -565,10 +580,7 @@ def test_phasor_transform(
imag_copy = copy.deepcopy(imag)
if phase_zero is not None and modulation_zero is not None:
calibrated_real, calibrated_imag = phasor_transform(
real_copy,
imag_copy,
phase_zero,
modulation_zero,
real_copy, imag_copy, phase_zero, modulation_zero
)
else:
calibrated_real, calibrated_imag = phasor_transform(
Expand Down Expand Up @@ -602,56 +614,74 @@ def test_phasor_transform_more():


@pytest.mark.parametrize(
"""real, imag,
skip_axes, method,
"""real, imag, kwargs,
expected_real_center, expected_imag_center""",
[
(1.0, 4.0, None, 'mean', 1.0, 4.0),
(1.0, -4.0, None, 'median', 1.0, -4.0),
(1.0, 4.0, {'skip_axis': None, 'method': 'mean'}, 1.0, 4.0),
(1.0, -4.0, {'skip_axis': None, 'method': 'median'}, 1.0, -4.0),
(
SYNTH_DATA_LIST,
SYNTH_DATA_LIST,
None,
'mean',
{'skip_axis': None, 'method': 'mean'},
2.3333333333333335,
2.3333333333333335,
),
(SYNTH_DATA_LIST, SYNTH_DATA_LIST, None, 'median', 2.0, 2.0),
(SYNTH_DATA_ARRAY, SYNTH_DATA_ARRAY, None, 'mean', 13.25, 13.25),
(SYNTH_DATA_ARRAY, SYNTH_DATA_ARRAY, None, 'median', 1.0, 1.0),
# with skip_axes
(
SYNTH_DATA_LIST,
SYNTH_DATA_LIST,
{'skip_axis': None, 'method': 'median'},
2.0,
2.0,
),
(
SYNTH_DATA_ARRAY,
SYNTH_DATA_ARRAY,
0,
'mean',
{'skip_axis': None, 'method': 'mean'},
13.25,
13.25,
),
(
SYNTH_DATA_ARRAY,
SYNTH_DATA_ARRAY,
{'skip_axis': None, 'method': 'median'},
1.0,
1.0,
),
# with skip_axis
(
SYNTH_DATA_ARRAY,
SYNTH_DATA_ARRAY,
{'skip_axis': 0, 'method': 'mean'},
numpy.asarray([25.5, 1.0]),
numpy.asarray([25.5, 1.0]),
),
(
SYNTH_DATA_ARRAY,
SYNTH_DATA_ARRAY,
(-2,),
'median',
{'skip_axis': (-2,), 'method': 'median'},
numpy.asarray([25.5, 1.0]),
numpy.asarray([25.5, 1.0]),
),
(
SYNTH_DATA_ARRAY,
SYNTH_DATA_ARRAY,
{'keepdims': True},
[[13.25]],
[[13.25]],
), # with kwargs for numpy functions
],
)
def test_phasor_center(
real,
imag,
skip_axes,
method,
kwargs,
expected_real_center,
expected_imag_center,
):
"""Test `phasor_center` function with various inputs and methods."""
real_copy = copy.deepcopy(real)
imag_copy = copy.deepcopy(imag)
real_center, imag_center = phasor_center(
real_copy, imag_copy, skip_axes=skip_axes, method=method
)
real_center, imag_center = phasor_center(real_copy, imag_copy, **kwargs)
assert_array_equal(real, real_copy)
assert_array_equal(imag, imag_copy)
assert_almost_equal(real_center, expected_real_center)
Expand All @@ -665,7 +695,7 @@ def test_phasor_center_exceptions():
with pytest.raises(ValueError):
phasor_center([0], [0, 0])
with pytest.raises(IndexError):
phasor_center([0, 0], [0, 0], skip_axes=1)
phasor_center([0, 0], [0, 0], skip_axis=1)


@pytest.mark.parametrize(
Expand Down Expand Up @@ -978,6 +1008,49 @@ def test_phasor_from_lifetime_modify():
),
),
), # multiple lifetime with median method
(
(
numpy.stack(
(
SYNTH_DATA_ARRAY,
SYNTH_DATA_ARRAY / 2,
SYNTH_DATA_ARRAY / 3,
),
axis=0,
),
numpy.stack(
(
SYNTH_DATA_ARRAY,
SYNTH_DATA_ARRAY / 2,
SYNTH_DATA_ARRAY / 3,
),
axis=0,
),
[0.5, 0.25, 0.1],
[0.3, 0.2, 0.1],
),
{
'frequency': [80, 160, 240],
'lifetime': 4,
'skip_axis': 0,
},
(
numpy.array(
[
[[11.60340173, 0.23206803], [0.23206803, 0.23206803]],
[[3.53612871, 0.07072257], [0.07072257, 0.07072257]],
[[4.45831758, 0.08916635], [0.08916635, 0.08916635]],
]
),
numpy.array(
[
[[52.74178815, 1.05483576], [1.05483576, 1.05483576]],
[[26.41473921, 0.52829478], [0.52829478, 0.52829478]],
[[26.89193809, 0.53783876], [0.53783876, 0.53783876]],
]
),
),
), # multiple harmonics with skip_axis
],
)
def test_phasor_calibrate(args, kwargs, expected):
Expand Down Expand Up @@ -1425,3 +1498,22 @@ def test_phasor_at_harmonic():
phasor_at_harmonic(0.5, 0, 1)
with pytest.raises(ValueError):
phasor_at_harmonic(0.5, 1, 0)


def test_parse_skip_axis():
"""Test _parse_skip_axis function."""
assert _parse_skip_axis(None, 0) == ((), ())
assert _parse_skip_axis(None, 1) == ((), (0,))
assert _parse_skip_axis((), 1) == ((), (0,))
assert _parse_skip_axis(0, 1) == ((0,), ())
assert _parse_skip_axis(0, 2) == ((0,), (1,))
assert _parse_skip_axis(-1, 2) == ((1,), (0,))
assert _parse_skip_axis((1, -2), 5) == ((1, 3), (0, 2, 4))
with pytest.raises(ValueError):
_parse_skip_axis(0, -1)
with pytest.raises(IndexError):
_parse_skip_axis(0, 0)
with pytest.raises(IndexError):
_parse_skip_axis(1, 1)
with pytest.raises(IndexError):
_parse_skip_axis(-2, 1)

0 comments on commit cf41a64

Please sign in to comment.