Skip to content

Commit

Permalink
[src&tests] Fix complex and add tests (asteroid-team#358)
Browse files Browse the repository at this point in the history
- Add CLI tests
- Add tests for complex_nn
- Deprecate `as_torch_complex`
- Fix DCCRN
  • Loading branch information
mpariente authored Nov 27, 2020
1 parent 18c430e commit ca55d98
Show file tree
Hide file tree
Showing 12 changed files with 236 additions and 27 deletions.
11 changes: 10 additions & 1 deletion .github/workflows/test_asteroid_package.yml
Original file line number Diff line number Diff line change
Expand Up @@ -51,12 +51,21 @@ jobs:
- name: Source code tests
run: |
coverage run -a -m py.test tests --ignore tests/models/publish_test.py
chmod +x ./tests/cli_test.sh
./tests/cli_test.sh
# This tends to fail despite the code works, when Zenodo is slow.
- name: Model-sharing tests
run: |
coverage run -a -m py.test tests/models/publish_test.py
env: # Access token as an env variable
env:
ACCESS_TOKEN: ${{ secrets.ACCESS_TOKEN }}

- name: CLI tests
run: |
chmod +x ./tests/cli_test.sh
./tests/cli_test.sh
env:
ACCESS_TOKEN: ${{ secrets.ACCESS_TOKEN }}

- name: Coverage report
Expand Down
28 changes: 18 additions & 10 deletions asteroid/complex_nn.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,19 +8,19 @@
- Asteroid style representation, identical to the Torchaudio representation, but
with the last dimension concatenated: tensor([r1, r2, ..., rn, i1, i2, ..., in]).
The concatenated (2 * n) dimension may be at an arbitrary position, i.e. the tensor
is of shape [..., 2 * n, ...]. See `asteroid.filterbanks.transforms` for details.
is of shape [..., 2 * n, ...]. See `asteroid_filterbanks.transforms` for details.
"""
from typing import Union, List, Tuple
import functools
import torch
import warnings
from asteroid_filterbanks import transforms
from .utils.torch_utils import script_if_tracing
from .utils.deprecation_utils import mark_deprecated

with warnings.catch_warnings():
warnings.simplefilter("ignore")
import torchaudio
from torch import nn
from asteroid_filterbanks import transforms
from .utils.torch_utils import script_if_tracing


# Alias to denote PyTorch native complex tensor (complex64/complex128).
Expand All @@ -38,6 +38,14 @@ def torch_complex_from_magphase(mag, phase):
)


def torch_complex_from_reim(re, im):
return torch.view_as_complex(torch.stack([re, im], dim=-1))


@mark_deprecated(
"Use `torch.view_as_complex`, `torch_complex_from_magphase`, `torch_complex_from_reim` or "
"`asteroid_filterbanks.transforms.from_torch_complex` instead."
)
@script_if_tracing
def as_torch_complex(x, asteroid_dim: int = -2):
"""Convert complex `x` to complex. Input may be one of:
Expand All @@ -54,14 +62,14 @@ def as_torch_complex(x, asteroid_dim: int = -2):
ValueError: If type of `x` is not understood.
"""
if isinstance(x, (list, tuple)) and len(x) == 2:
return torch_complex_from_magphase(*x)
return torch_complex_from_reim(*x)
elif is_torch_complex(x):
return x
else:
is_torchaudio_complex = transforms.is_torchaudio_complex(x)
is_asteroid_complex = transforms.is_asteroid_complex(x, asteroid_dim)
if is_torchaudio_complex and is_asteroid_complex:
raise ValueError(
raise RuntimeError(
f"Tensor of shape {x.shape} is both a valid Torchaudio-style and "
"Asteroid-style complex. PyTorch complex conversion is ambiguous."
)
Expand All @@ -70,7 +78,7 @@ def as_torch_complex(x, asteroid_dim: int = -2):
elif is_asteroid_complex:
return torch.view_as_complex(transforms.to_torchaudio(x, asteroid_dim))
else:
raise ValueError(
raise RuntimeError(
f"Do not know how to convert tensor of shape {x.shape}, dtype={x.dtype} to complex"
)

Expand All @@ -86,7 +94,7 @@ def on_reim(f):

@functools.wraps(f)
def cf(x):
return torch_complex_from_magphase(f(x.real), f(x.imag))
return torch_complex_from_reim(f(x.real), f(x.imag))

# functools.wraps keeps the original name of `f`, which might be confusing,
# since we are creating a new function that behaves differently.
Expand All @@ -110,7 +118,7 @@ def __init__(self, module_cls, *args, **kwargs):
self.im_module = module_cls(*args, **kwargs)

def forward(self, x):
return torch_complex_from_magphase(self.re_module(x.real), self.im_module(x.imag))
return torch_complex_from_reim(self.re_module(x.real), self.im_module(x.imag))


class ComplexMultiplicationWrapper(nn.Module):
Expand All @@ -133,7 +141,7 @@ def __init__(self, module_cls, *args, **kwargs):
self.im_module = module_cls(*args, **kwargs)

def forward(self, x: ComplexTensor) -> ComplexTensor:
return torch_complex_from_magphase(
return torch_complex_from_reim(
self.re_module(x.real) - self.im_module(x.imag),
self.re_module(x.imag) + self.im_module(x.real),
)
Expand Down
2 changes: 1 addition & 1 deletion asteroid/losses/cluster.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@ def deep_clustering_loss(embedding, tgt_index, binary_mask=None):
tgt_index (torch.Tensor): Dominating source index in each TF bin.
Expected shape: [batch, frequency, frame]
binary_mask (torch.Tensor): VAD in TF plane. Bool or Float.
See asteroid.filterbanks.transforms.ebased_vad.
See asteroid.dsp.vad.ebased_vad.
Returns:
`torch.Tensor`. Deep clustering loss for every batch sample.
Expand Down
2 changes: 1 addition & 1 deletion asteroid/losses/pmsqe.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,7 +40,7 @@ class (see Tensorflow implementation).
Examples
>>> import torch
>>> from asteroid.filterbanks import STFTFB, Encoder, transforms
>>> from asteroid_filterbanks import STFTFB, Encoder, transforms
>>> from asteroid.losses import PITLossWrapper, SingleSrcPMSQE
>>> stft = Encoder(STFTFB(kernel_size=512, n_filters=512, stride=256))
>>> # Usage by itself
Expand Down
9 changes: 3 additions & 6 deletions asteroid/models/dccrnet.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,5 @@
import torch
from .. import complex_nn
from ..filterbanks.transforms import from_torchaudio
from asteroid_filterbanks.transforms import from_torch_complex, to_torch_complex
from ..masknn.recurrent import DCCRMaskNet
from .dcunet import BaseDCUNet

Expand Down Expand Up @@ -33,11 +32,9 @@ def __init__(self, *args, stft_kernel_size=512, **masknet_kwargs):
def forward_encoder(self, wav):
tf_rep = self.encoder(wav)
# Remove Nyquist frequency bin
return complex_nn.as_torch_complex(tf_rep)[..., :-1, :]
return to_torch_complex(tf_rep)[..., :-1, :]

def apply_masks(self, tf_rep, est_masks):
masked_tf_rep = est_masks * tf_rep.unsqueeze(1)
# Pad Nyquist frequency bin
return from_torchaudio(
torch.view_as_real(torch.nn.functional.pad(masked_tf_rep, (0, 0, 0, 1)))
)
return from_torch_complex(torch.nn.functional.pad(masked_tf_rep, [0, 0, 0, 1]))
9 changes: 3 additions & 6 deletions asteroid/models/dcunet.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,5 @@
import torch

from .. import complex_nn
from asteroid_filterbanks import make_enc_dec
from asteroid_filterbanks.transforms import from_torchaudio
from asteroid_filterbanks.transforms import from_torch_complex, to_torch_complex
from ..masknn.convolutional import DCUMaskNet
from .base_models import BaseEncoderMaskerDecoder

Expand Down Expand Up @@ -45,11 +42,11 @@ def __init__(

def forward_encoder(self, wav):
tf_rep = self.encoder(wav)
return complex_nn.as_torch_complex(tf_rep)
return to_torch_complex(tf_rep)

def apply_masks(self, tf_rep, est_masks):
masked_tf_rep = est_masks * tf_rep.unsqueeze(1)
return from_torchaudio(torch.view_as_real(masked_tf_rep))
return from_torch_complex(masked_tf_rep)

def get_model_args(self):
"""Arguments needed to re-instantiate the model."""
Expand Down
2 changes: 0 additions & 2 deletions asteroid/separate.py
Original file line number Diff line number Diff line change
Expand Up @@ -123,8 +123,6 @@ def file_separate(
import soundfile as sf

if not hasattr(model, "sample_rate"):
if isinstance(model, LambdaOverlapAdd):
model = model.nnet
raise TypeError(
f"This function requires your model ({type(model).__name__}) to have a "
"'sample_rate' attribute. See `BaseModel.sample_rate` for details."
Expand Down
58 changes: 58 additions & 0 deletions tests/cli_setup.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,58 @@
import torch
import soundfile as sf
import numpy as np
import os
from asteroid.models import ConvTasNet, save_publishable
from asteroid.data.wham_dataset import wham_noise_license, wsj0_license


def setup_register_sr():
model = ConvTasNet(
n_src=2,
n_repeats=2,
n_blocks=3,
bn_chan=16,
hid_chan=4,
skip_chan=8,
n_filters=32,
)
to_save = model.serialize()
to_save["model_args"].pop("sample_rate")
torch.save(to_save, "tmp.th")


def setup_infer():
sf.write("tmp.wav", np.random.randn(16000), 8000)
sf.write("tmp2.wav", np.random.randn(16000), 8000)


def setup_upload():
train_set_infos = dict(
dataset="WHAM", task="sep_noisy", licenses=[wsj0_license, wham_noise_license]
)
final_results = {"si_sdr": 8.67, "si_sdr_imp": 13.16}
model = ConvTasNet(
n_src=2,
n_repeats=2,
n_blocks=3,
bn_chan=16,
hid_chan=4,
skip_chan=8,
n_filters=32,
)
model_dict = model.serialize()
model_dict.update(train_set_infos)

os.makedirs("publish_dir", exist_ok=True)
save_publishable(
"publish_dir",
model_dict,
metrics=final_results,
train_conf=dict(),
)


if __name__ == "__main__":
setup_register_sr()
setup_infer()
setup_upload()
File renamed without changes.
18 changes: 18 additions & 0 deletions tests/cli_test.sh
Original file line number Diff line number Diff line change
@@ -0,0 +1,18 @@
# Save a model (tmp.th) and two wavfiles (tmp.wav, tmp2.wav)
python -m pip install -e . --quiet
python tests/cli_setup.py

# asteroid-register-sr`
coverage run -a `which asteroid-register-sr` tmp.th 8000

# asteroid-infer
coverage run -a `which asteroid-infer` tmp.th --files tmp.wav
coverage run -a `which asteroid-infer` tmp.th --files tmp.wav tmp2.wav --force-overwrite
coverage run -a `which asteroid-infer` tmp.th --files tmp.wav --ola-window 1000 --force-overwrite
coverage run -a `which asteroid-infer` tmp.th --files tmp.wav --ola-window 1000 --ola-no-reorder --force-overwrite

# asteroid-upload
echo "n" | coverage run -a `which asteroid-upload` publish_dir --uploader "Manuel Pariente" --affiliation "Loria" --use_sandbox --token $ACCESS_TOKEN

# asteroid-version
coverage run -a `which asteroid-versions`
106 changes: 106 additions & 0 deletions tests/complex_nn_test.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,106 @@
import torch
from torch.testing import assert_allclose
import pytest
import math

from asteroid import complex_nn as cnn
from asteroid.utils.deprecation_utils import VisibleDeprecationWarning
from asteroid_filterbanks import transforms


def test_is_torch_complex():
cnn.is_torch_complex(torch.randn(10, 10, dtype=torch.complex64))


def test_torch_complex_from_magphase():
shape = (1, 257, 100)
mag = torch.randn(shape).abs()
phase = torch.remainder(torch.randn(shape), math.pi)
out = cnn.torch_complex_from_magphase(mag, phase)
assert_allclose(torch.abs(out), mag)
assert_allclose(out.angle(), phase)


def test_torch_complex_from_reim():
comp = torch.randn(10, 12, dtype=torch.complex64)
assert_allclose(cnn.torch_complex_from_reim(comp.real, comp.imag), comp)


def test_as_torch_complex():
shape = (1, 257, 100)
re = torch.randn(shape)
im = torch.randn(shape)
# From mag and phase
with pytest.warns(VisibleDeprecationWarning):
out = cnn.as_torch_complex((re, im))
# From torch.complex
with pytest.warns(VisibleDeprecationWarning):
out2 = cnn.as_torch_complex(out)
assert_allclose(out, out2)
# From torchaudio, ambiguous, error
with pytest.raises(RuntimeError):
with pytest.warns(VisibleDeprecationWarning):
cnn.as_torch_complex(torch.view_as_real(out))

# From torchaudio, unambiguous
with pytest.warns(VisibleDeprecationWarning):
_ = cnn.as_torch_complex(torch.randn(1, 5, 2))
# From asteroid
with pytest.warns(VisibleDeprecationWarning):
out4 = cnn.as_torch_complex(transforms.from_torchaudio(torch.view_as_real(out), dim=-2))
assert_allclose(out4, out)


def test_as_torch_complex_raises():
with pytest.raises(RuntimeError):
with pytest.warns(VisibleDeprecationWarning):
cnn.as_torch_complex(torch.randn(1, 5, 3))


def test_onreim():
inp = torch.randn(10, 10, dtype=torch.complex64)
# Identity
fn = cnn.on_reim(lambda x: x)
assert_allclose(fn(inp), inp)
# Top right quadrant
fn = cnn.on_reim(lambda x: x.abs())
assert_allclose(fn(inp), cnn.torch_complex_from_reim(inp.real.abs(), inp.imag.abs()))


def test_on_reim_class():
inp = torch.randn(10, 10, dtype=torch.complex64)

class Identity(torch.nn.Module):
def __init__(self, a=0, *args, **kwargs):
super().__init__()
self.a = a

def forward(self, x):
return x + self.a

fn = cnn.OnReIm(Identity, 0)
assert_allclose(fn(inp), inp)
fn = cnn.OnReIm(Identity, 1)
assert_allclose(fn(inp), cnn.torch_complex_from_reim(inp.real + 1, inp.imag + 1))


def test_complex_mul_wrapper():
a = torch.randn(10, 10, dtype=torch.complex64)

fn = cnn.ComplexMultiplicationWrapper(torch.nn.ReLU)
assert_allclose(
fn(a),
cnn.torch_complex_from_reim(
torch.relu(a.real) - torch.relu(a.imag), torch.relu(a.real) + torch.relu(a.imag)
),
)


@pytest.mark.parametrize("bound_type", ("BDSS", "sigmoid", "BDT", "tanh", "UBD", None))
def test_bound_complex_mask(bound_type):
cnn.bound_complex_mask(torch.randn(4, 2, 257, dtype=torch.complex64), bound_type=bound_type)


def test_bound_complex_mask_raises():
with pytest.raises(ValueError):
cnn.bound_complex_mask(torch.randn(4, 2, 257, dtype=torch.complex64), bound_type="foo")
Loading

0 comments on commit ca55d98

Please sign in to comment.