Skip to content

Commit

Permalink
[src & tests] Add support for STOI loss (asteroid-team#79)
Browse files Browse the repository at this point in the history
  • Loading branch information
mpariente authored Apr 18, 2020
1 parent 4d8e382 commit 9e6701a
Showing 5 changed files with 34 additions and 2 deletions.
1 change: 1 addition & 0 deletions asteroid/losses/__init__.py
Original file line number Diff line number Diff line change
@@ -5,6 +5,7 @@
from .mse import singlesrc_mse, multisrc_mse
from .cluster import deep_clustering_loss
from .pmsqe import SingleSrcPMSQE
from .stoi import NegSTOILoss as SingleSrcNegSTOI

# Legacy
from .sdr import pairwise_neg_sisdr, nosrc_neg_sisdr, nonpit_neg_sisdr
13 changes: 13 additions & 0 deletions asteroid/losses/stoi.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,13 @@
from torch_stoi import NegSTOILoss

asteroid_examples = """
Examples:
>>> import torch
>>> from asteroid.losses import PITLossWrapper
>>> targets = torch.randn(10, 2, 32000)
>>> est_targets = torch.randn(10, 2, 32000)
>>> loss_func = PITLossWrapper(NegSTOILoss(), pit_from='pw_pt')
>>> loss = loss_func(est_targets, targets)
"""

NegSTOILoss.__doc__ += asteroid_examples
3 changes: 2 additions & 1 deletion requirements.txt
Original file line number Diff line number Diff line change
@@ -5,4 +5,5 @@ SoundFile>=0.10.2
torch>=1.3.0
pytorch-lightning==0.6.0
pb_bss_eval>=0.0.1
asranger>=0.0.5
asranger>=0.0.5
torch_stoi>=0.0.1
1 change: 1 addition & 0 deletions setup.py
Original file line number Diff line number Diff line change
@@ -24,6 +24,7 @@
'pytorch-lightning==0.6.0',
'pb_bss_eval',
'asranger',
'torch_stoi',
],
extras_require={
'visualize': ['seaborn>=0.9.0'],
18 changes: 17 additions & 1 deletion tests/losses/loss_functions_test.py
Original file line number Diff line number Diff line change
@@ -7,8 +7,10 @@
from asteroid.losses import sdr
from asteroid.losses import singlesrc_mse, pairwise_mse, multisrc_mse
from asteroid.losses import deep_clustering_loss, SingleSrcPMSQE
from asteroid.losses import SingleSrcNegSTOI
from asteroid.losses.multi_scale_spectral import SingleSrcMultiScaleSpectral


@pytest.mark.parametrize("n_src", [2, 3, 4])
@pytest.mark.parametrize("function_triplet", [
[sdr.pairwise_neg_sisdr, sdr.singlesrc_neg_sisdr, sdr.multisrc_neg_sisdr],
@@ -116,4 +118,18 @@ def test_pmsqe_pit(n_src, sample_rate):
loss_func = PITLossWrapper(SingleSrcPMSQE(sample_rate=sample_rate),
pit_from='pw_pt')
# Assert forward ok.
loss_value = loss_func(ref_spec, est_spec)
loss_value = loss_func(est_spec, ref_spec)


@pytest.mark.parametrize("n_src", [2, 3])
@pytest.mark.parametrize("sample_rate", [8000, 16000])
@pytest.mark.parametrize("use_vad", [True, False])
@pytest.mark.parametrize("extended", [True, False])
def test_negstoi_pit(n_src, sample_rate, use_vad, extended):
ref, est = torch.randn(2, n_src, 16000), torch.randn(2, n_src, 16000)
singlesrc_negstoi = SingleSrcNegSTOI(sample_rate=sample_rate,
use_vad=use_vad,
extended=extended)
loss_func = PITLossWrapper(singlesrc_negstoi, pit_from='pw_pt')
# Assert forward ok.
loss_value = loss_func(est, ref)

0 comments on commit 9e6701a

Please sign in to comment.