forked from asteroid-team/asteroid
-
Notifications
You must be signed in to change notification settings - Fork 0
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
[src&tests] Fix complex and add tests (asteroid-team#358)
- Add CLI tests - Add tests for complex_nn - Deprecate `as_torch_complex` - Fix DCCRN
- Loading branch information
Showing
12 changed files
with
236 additions
and
27 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,58 @@ | ||
import torch | ||
import soundfile as sf | ||
import numpy as np | ||
import os | ||
from asteroid.models import ConvTasNet, save_publishable | ||
from asteroid.data.wham_dataset import wham_noise_license, wsj0_license | ||
|
||
|
||
def setup_register_sr(): | ||
model = ConvTasNet( | ||
n_src=2, | ||
n_repeats=2, | ||
n_blocks=3, | ||
bn_chan=16, | ||
hid_chan=4, | ||
skip_chan=8, | ||
n_filters=32, | ||
) | ||
to_save = model.serialize() | ||
to_save["model_args"].pop("sample_rate") | ||
torch.save(to_save, "tmp.th") | ||
|
||
|
||
def setup_infer(): | ||
sf.write("tmp.wav", np.random.randn(16000), 8000) | ||
sf.write("tmp2.wav", np.random.randn(16000), 8000) | ||
|
||
|
||
def setup_upload(): | ||
train_set_infos = dict( | ||
dataset="WHAM", task="sep_noisy", licenses=[wsj0_license, wham_noise_license] | ||
) | ||
final_results = {"si_sdr": 8.67, "si_sdr_imp": 13.16} | ||
model = ConvTasNet( | ||
n_src=2, | ||
n_repeats=2, | ||
n_blocks=3, | ||
bn_chan=16, | ||
hid_chan=4, | ||
skip_chan=8, | ||
n_filters=32, | ||
) | ||
model_dict = model.serialize() | ||
model_dict.update(train_set_infos) | ||
|
||
os.makedirs("publish_dir", exist_ok=True) | ||
save_publishable( | ||
"publish_dir", | ||
model_dict, | ||
metrics=final_results, | ||
train_conf=dict(), | ||
) | ||
|
||
|
||
if __name__ == "__main__": | ||
setup_register_sr() | ||
setup_infer() | ||
setup_upload() |
File renamed without changes.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,18 @@ | ||
# Save a model (tmp.th) and two wavfiles (tmp.wav, tmp2.wav) | ||
python -m pip install -e . --quiet | ||
python tests/cli_setup.py | ||
|
||
# asteroid-register-sr` | ||
coverage run -a `which asteroid-register-sr` tmp.th 8000 | ||
|
||
# asteroid-infer | ||
coverage run -a `which asteroid-infer` tmp.th --files tmp.wav | ||
coverage run -a `which asteroid-infer` tmp.th --files tmp.wav tmp2.wav --force-overwrite | ||
coverage run -a `which asteroid-infer` tmp.th --files tmp.wav --ola-window 1000 --force-overwrite | ||
coverage run -a `which asteroid-infer` tmp.th --files tmp.wav --ola-window 1000 --ola-no-reorder --force-overwrite | ||
|
||
# asteroid-upload | ||
echo "n" | coverage run -a `which asteroid-upload` publish_dir --uploader "Manuel Pariente" --affiliation "Loria" --use_sandbox --token $ACCESS_TOKEN | ||
|
||
# asteroid-version | ||
coverage run -a `which asteroid-versions` |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,106 @@ | ||
import torch | ||
from torch.testing import assert_allclose | ||
import pytest | ||
import math | ||
|
||
from asteroid import complex_nn as cnn | ||
from asteroid.utils.deprecation_utils import VisibleDeprecationWarning | ||
from asteroid_filterbanks import transforms | ||
|
||
|
||
def test_is_torch_complex(): | ||
cnn.is_torch_complex(torch.randn(10, 10, dtype=torch.complex64)) | ||
|
||
|
||
def test_torch_complex_from_magphase(): | ||
shape = (1, 257, 100) | ||
mag = torch.randn(shape).abs() | ||
phase = torch.remainder(torch.randn(shape), math.pi) | ||
out = cnn.torch_complex_from_magphase(mag, phase) | ||
assert_allclose(torch.abs(out), mag) | ||
assert_allclose(out.angle(), phase) | ||
|
||
|
||
def test_torch_complex_from_reim(): | ||
comp = torch.randn(10, 12, dtype=torch.complex64) | ||
assert_allclose(cnn.torch_complex_from_reim(comp.real, comp.imag), comp) | ||
|
||
|
||
def test_as_torch_complex(): | ||
shape = (1, 257, 100) | ||
re = torch.randn(shape) | ||
im = torch.randn(shape) | ||
# From mag and phase | ||
with pytest.warns(VisibleDeprecationWarning): | ||
out = cnn.as_torch_complex((re, im)) | ||
# From torch.complex | ||
with pytest.warns(VisibleDeprecationWarning): | ||
out2 = cnn.as_torch_complex(out) | ||
assert_allclose(out, out2) | ||
# From torchaudio, ambiguous, error | ||
with pytest.raises(RuntimeError): | ||
with pytest.warns(VisibleDeprecationWarning): | ||
cnn.as_torch_complex(torch.view_as_real(out)) | ||
|
||
# From torchaudio, unambiguous | ||
with pytest.warns(VisibleDeprecationWarning): | ||
_ = cnn.as_torch_complex(torch.randn(1, 5, 2)) | ||
# From asteroid | ||
with pytest.warns(VisibleDeprecationWarning): | ||
out4 = cnn.as_torch_complex(transforms.from_torchaudio(torch.view_as_real(out), dim=-2)) | ||
assert_allclose(out4, out) | ||
|
||
|
||
def test_as_torch_complex_raises(): | ||
with pytest.raises(RuntimeError): | ||
with pytest.warns(VisibleDeprecationWarning): | ||
cnn.as_torch_complex(torch.randn(1, 5, 3)) | ||
|
||
|
||
def test_onreim(): | ||
inp = torch.randn(10, 10, dtype=torch.complex64) | ||
# Identity | ||
fn = cnn.on_reim(lambda x: x) | ||
assert_allclose(fn(inp), inp) | ||
# Top right quadrant | ||
fn = cnn.on_reim(lambda x: x.abs()) | ||
assert_allclose(fn(inp), cnn.torch_complex_from_reim(inp.real.abs(), inp.imag.abs())) | ||
|
||
|
||
def test_on_reim_class(): | ||
inp = torch.randn(10, 10, dtype=torch.complex64) | ||
|
||
class Identity(torch.nn.Module): | ||
def __init__(self, a=0, *args, **kwargs): | ||
super().__init__() | ||
self.a = a | ||
|
||
def forward(self, x): | ||
return x + self.a | ||
|
||
fn = cnn.OnReIm(Identity, 0) | ||
assert_allclose(fn(inp), inp) | ||
fn = cnn.OnReIm(Identity, 1) | ||
assert_allclose(fn(inp), cnn.torch_complex_from_reim(inp.real + 1, inp.imag + 1)) | ||
|
||
|
||
def test_complex_mul_wrapper(): | ||
a = torch.randn(10, 10, dtype=torch.complex64) | ||
|
||
fn = cnn.ComplexMultiplicationWrapper(torch.nn.ReLU) | ||
assert_allclose( | ||
fn(a), | ||
cnn.torch_complex_from_reim( | ||
torch.relu(a.real) - torch.relu(a.imag), torch.relu(a.real) + torch.relu(a.imag) | ||
), | ||
) | ||
|
||
|
||
@pytest.mark.parametrize("bound_type", ("BDSS", "sigmoid", "BDT", "tanh", "UBD", None)) | ||
def test_bound_complex_mask(bound_type): | ||
cnn.bound_complex_mask(torch.randn(4, 2, 257, dtype=torch.complex64), bound_type=bound_type) | ||
|
||
|
||
def test_bound_complex_mask_raises(): | ||
with pytest.raises(ValueError): | ||
cnn.bound_complex_mask(torch.randn(4, 2, 257, dtype=torch.complex64), bound_type="foo") |
Oops, something went wrong.