Skip to content

Commit

Permalink
[CI] Catch warnings we expect (asteroid-team#351)
Browse files Browse the repository at this point in the history
  • Loading branch information
mpariente authored Nov 24, 2020
1 parent cbd0bd0 commit c00403d
Show file tree
Hide file tree
Showing 8 changed files with 37 additions and 6 deletions.
2 changes: 1 addition & 1 deletion asteroid/masknn/recurrent.py
Original file line number Diff line number Diff line change
Expand Up @@ -231,7 +231,7 @@ def forward(self, x):
x = output.transpose(1, 2).transpose(2, -1).reshape(B * K, L, N)
x = self.inter_RNN(x)
x = self.inter_linear(x)
x = x.reshape(B, K, L, N).transpose(1, -1).transpose(2, -1)
x = x.reshape(B, K, L, N).transpose(1, -1).transpose(2, -1).contiguous()
x = self.inter_norm(x)
return output + x

Expand Down
5 changes: 5 additions & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -9,3 +9,8 @@ requires = [
line-length = 100
target-version = ["py36"]
exclude = "(.eggs|.git|.hg|.mypy_cache|.nox|.tox|.venv|.svn|_build|buck-out|build|dist)"

[tool.pytest.ini_options]
filterwarnings = [
"ignore:Using or importing the ABCs.*:DeprecationWarning"
]
5 changes: 5 additions & 0 deletions tests/engine/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,5 @@
import pytest

ignored_warnings = ["ignore:Could not log computational graph since"]

pytestmark = pytest.mark.filterwarnings(*ignored_warnings)
1 change: 1 addition & 0 deletions tests/filterbanks/transforms_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -67,6 +67,7 @@ def test_comp_mask(encoder_list):
assert_allclose(masked, tf_rep)


@pytest.mark.filterwarnings("ignore:asteroid.filterbanks.transforms.take_reim")
def test_reim(encoder_list):
for (enc, fb_dim) in encoder_list:
tf_rep = enc(torch.randn(2, 1, 16000)) # [batch, freq, time]
Expand Down
10 changes: 10 additions & 0 deletions tests/jit/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,10 @@
import pytest

ignored_warnings = [
"ignore:torch.tensor results are registered as constants in the trace.",
"ignore:Converting a tensor to a Python boolean might cause the trace to be incorrect.",
"ignore:Converting a tensor to a Python float might cause the trace to be incorrect.",
"ignore:Using or importing the ABCs from",
]

pytestmark = pytest.mark.filterwarnings(*ignored_warnings)
9 changes: 9 additions & 0 deletions tests/losses/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,9 @@
import pytest

ignored_warnings = [
"ignore:Could not log computational graph since",
"ignore:The dataloader, val dataloader",
"ignore:The dataloader, train dataloader",
]

pytestmark = pytest.mark.filterwarnings(*ignored_warnings)
8 changes: 4 additions & 4 deletions tests/models/demask_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -63,7 +63,7 @@
# 'activation': 'relu',
# 'dropout': 0,
# 'fb_kwargs': {},
# 'fb_type': 'stft',
# 'fb_name': 'stft',
# 'hidden_dims': [1024],
# 'input_type': 'mag',
# 'kernel_size': 512,
Expand All @@ -78,7 +78,7 @@


@pytest.mark.parametrize("input_type", ("mag", "cat", "reim"))
@pytest.mark.parametrize("fb_type", ("stft", "free"))
@pytest.mark.parametrize("fb_name", ("stft", "free"))
@pytest.mark.parametrize("output_type", ("mag", "reim"))
@pytest.mark.parametrize(
"data",
Expand All @@ -90,11 +90,11 @@
(torch.rand(2, 1, 50, requires_grad=False) - 0.5) * 2,
),
)
def test_forward(input_type, output_type, fb_type, data):
def test_forward(input_type, output_type, fb_name, data):
demask = DeMask(
input_type=input_type,
output_type=output_type,
fb_type=fb_type,
fb_name=fb_name,
hidden_dims=(16,),
kernel_size=8,
n_filters=8,
Expand Down
3 changes: 2 additions & 1 deletion tests/models/models_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -137,6 +137,7 @@ def test_sudormrf_imp():
)


@pytest.mark.filterwarnings("ignore: DPTransformer input dim")
@pytest.mark.parametrize("fb", ["free", "stft", "analytic_free", "param_sinc"])
def test_dptnet(fb):
_default_test_model(DPTNet(2, ff_hid=10, chunk_size=4, n_repeats=2, fb_name=fb))
Expand Down Expand Up @@ -224,7 +225,7 @@ def test_available_models():

@pytest.mark.parametrize("fb", ["free", "stft", "analytic_free", "param_sinc"])
def test_demask(fb):
model = DeMask(fb_type=fb)
model = DeMask(fb_name=fb)
test_input = torch.randn(1, 801)

model_conf = model.serialize()
Expand Down

0 comments on commit c00403d

Please sign in to comment.