Skip to content

Commit

Permalink
[src] Rename _separate to forward_wav (asteroid-team#337)
Browse files Browse the repository at this point in the history
Co-authored-by: Jonas Haag <jonas@lophus.org>

Co-authored-by: Jonas Haag <jonas@lophus.org>
  • Loading branch information
mpariente and jonashaag authored Nov 25, 2020
1 parent e882b4a commit c7464de
Show file tree
Hide file tree
Showing 4 changed files with 166 additions and 17 deletions.
32 changes: 24 additions & 8 deletions asteroid/models/base_models.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
from ..masknn import activations
from ..utils.torch_utils import pad_x_to_y, script_if_tracing, jitable_shape
from ..utils.hub_utils import cached_download
from ..utils.deprecation_utils import is_overridden, mark_deprecated, VisibleDeprecationWarning


@script_if_tracing
Expand All @@ -27,8 +28,9 @@ class BaseModel(torch.nn.Module):
the `get_model_args` method.
Models inheriting from `BaseModel` can be used by :mod:`asteroid.separate`
and by the `asteroid-infer` CLI. For models whose `forward` doesn't return
waveform tensors, overwrite `_separate` to return waveform tensors.
and by the `asteroid-infer` CLI. For models whose `forward` doesn't go from
waveform to waveform tensors, overwrite `forward_wav` to return
waveform tensors.
"""

def __init__(self, sample_rate: float = 8000.0):
Expand Down Expand Up @@ -68,16 +70,30 @@ def file_separate(self, *args, **kwargs):
"""Convenience for ``asteroid.separate.file_separate(self, ...)``."""
return separate.file_separate(self, *args, **kwargs)

def _separate(self, wav, *args, **kwargs):
"""Hidden separation method
def forward_wav(self, wav, *args, **kwargs):
"""Separation method for waveforms.
In case the network's `forward` doesn't have waveforms as input/output,
overwrite this method to separate from waveform to waveform.
Should return a single torch.Tensor, the separated waveforms.
Args:
wav (Union[torch.Tensor, numpy.ndarray, str]): waveform array/tensor.
wav (torch.Tensor): waveform array/tensor.
Shape: 1D, 2D or 3D tensor, time last.
Returns:
The output of self(wav, *args, **kwargs).
"""
if is_overridden("_separate", self, parent=BaseModel):
# If `_separate` is overridden, the mark_deprecated won't be triggered.
warnings.warn(
"`BaseModel._separate` has been deprecated and will be remove from a "
"future release. Use `forward_wav` instead",
VisibleDeprecationWarning,
)
return self._separate(wav, *args, **kwargs)
return self(wav, *args, **kwargs)

@mark_deprecated("Use `forward_wav` instead.")
def _separate(self, wav, *args, **kwargs):
"""Deprecated."""
return self(wav, *args, **kwargs)

@classmethod
Expand Down
20 changes: 11 additions & 9 deletions asteroid/separate.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@ class Protocol:
class Separatable(Protocol):
"""Things that are separatable."""

def _separate(self, wav, **kwargs):
def forward_wav(self, wav, **kwargs):
"""
Args:
wav (torch.Tensor): waveform tensor.
Expand All @@ -45,23 +45,25 @@ def separate(
Also supports filenames.
Args:
model (Separatable, for example asteroid.models.BaseModel): Model to use
model (Separatable, for example asteroid.models.BaseModel): Model to use.
wav (Union[torch.Tensor, numpy.ndarray, str]): waveform array/tensor.
Shape: 1D, 2D or 3D tensor, time last.
output_dir (str): path to save all the wav files. If None,
estimated sources will be saved next to the original ones.
force_overwrite (bool): whether to overwrite existing files (when separating from file)..
resample (bool): Whether to resample input files with wrong sample rate (when separating from file).
**kwargs: keyword arguments to be passed to `_separate`.
force_overwrite (bool): whether to overwrite existing files
(when separating from file).
resample (bool): Whether to resample input files with wrong sample rate
(when separating from file).
**kwargs: keyword arguments to be passed to `forward_wav`.
Returns:
Union[torch.Tensor, numpy.ndarray, None], the estimated sources.
(batch, n_src, time) or (n_src, time) w/o batch dim.
.. note::
By default, `separate` calls `model._separate` which calls `forward`.
For models whose `forward` doesn't return waveform tensors,
overwrite their `_separate` method to return waveform tensors.
`separate` calls `model.forward_wav` which calls `forward` by default.
For models whose `forward` doesn't have waveform tensors as input/ouput,
overwrite their `forward_wav` method to separate from waveform to waveform.
"""
if isinstance(wav, str):
file_separate(
Expand Down Expand Up @@ -90,7 +92,7 @@ def torch_separate(model: Separatable, wav: torch.Tensor, **kwargs) -> torch.Ten
model_device = get_device(model, default="cpu")
wav = wav.to(model_device)
# Forward
separate_func = getattr(model, "_separate", model)
separate_func = getattr(model, "forward_wav", model)
out_wavs = separate_func(wav, **kwargs)

# FIXME: for now this is the best we can do.
Expand Down
49 changes: 49 additions & 0 deletions asteroid/utils/deprecation_utils.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
import warnings
import inspect
from functools import wraps


Expand Down Expand Up @@ -48,3 +49,51 @@ def wrapped(*args, **kwargs):
return wrapped

return decorator


def is_overridden(method_name, obj, parent=None) -> bool:
"""Check if `method_name` from parent is overridden in `obj`.
Args:
method_name (str): Name of the method.
obj: Instance or class that potentially overrode the method.
parent: parent class with which to compare. If None, traverse the MRO
for the first parent that has the method.
Raises RuntimeError if `parent` is not a parent class and if `parent`
doesn't have the method. Or, if `parent` was None, that none of the
potential parents had the method.
"""

def get_mro(cls):
try:
return inspect.getmro(cls)
except AttributeError:
return inspect.getmro(cls.__class__)

def first_parent_with_method(fn, mro_list):
for cls in mro_list[::-1]:
if hasattr(cls, fn):
return cls
return None

if not hasattr(obj, method_name):
return False

try:
instance_attr = getattr(obj, method_name)
except AttributeError:
return False
return False

mro = get_mro(obj)[1:] # All parent classes in order, self excluded
parent = parent if parent is not None else first_parent_with_method(method_name, mro)

if parent not in mro:
raise RuntimeError(f"`{obj}` has no parent that defined method {method_name}`.")

if not hasattr(parent, method_name):
raise RuntimeError(f"Parent `{parent}` does have method `{method_name}`")

super_attr = getattr(parent, method_name)
return instance_attr.__code__ is not super_attr.__code__
82 changes: 82 additions & 0 deletions tests/utils/deprecation_utils_test.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,82 @@
import pytest
import warnings
from asteroid.utils import deprecation_utils as dp


def test_warning():
with pytest.warns(dp.VisibleDeprecationWarning):
warnings.warn("Expected warning.", dp.VisibleDeprecationWarning)


def test_deprecated():
class Foo:
def new_func(self):
pass

@dp.mark_deprecated("Please use `new_func`", "0.5.0")
def old_func(self):
pass

@dp.mark_deprecated("Please use `new_func`")
def no_version_old_func(self):
pass

@dp.mark_deprecated(message="")
def no_message_old_func(self):
pass

foo = Foo()
foo.new_func()

with pytest.warns(dp.VisibleDeprecationWarning) as record:
foo.old_func()
# check that only one warning was raised
assert len(record) == 1
# check that the message matches
assert "0.5.0" in record[0].message.args[0]

with pytest.warns(dp.VisibleDeprecationWarning):
foo.no_version_old_func()
foo.no_message_old_func()


def test_is_overidden():
class Foo:
def some_func(self):
return None

class Bar(Foo):
def some_func(self):
something_changed = None
return None

class Ho(Bar):
pass

# On class
assert dp.is_overridden("some_func", Bar, parent=Foo)
assert dp.is_overridden("some_func", Bar)
# On instance
bar = Bar()
assert dp.is_overridden("some_func", bar, parent=Foo)
assert dp.is_overridden("some_func", bar)

class Hey(Foo):
def some_other_func(self):
return None

# On class
assert not dp.is_overridden("some_func", Hey, parent=Foo)
# On instance
hey = Hey()
assert not dp.is_overridden("some_func", hey, parent=Foo)
assert not dp.is_overridden("some_func", hey, parent=Foo)

with pytest.raises(RuntimeError):
dp.is_overridden("some_func", hey, parent=Bar)

with pytest.raises(RuntimeError):
dp.is_overridden("some_other_func", hey, parent=Foo)

with pytest.raises(RuntimeError):
dp.is_overridden("some_other_func", hey)

0 comments on commit c7464de

Please sign in to comment.