Skip to content

Commit

Permalink
[src] Better names in asteroid.filterbanks.transforms (asteroid-team#342
Browse files Browse the repository at this point in the history
)

* Better deprecation decorator

* Deprecate take_reim take_mag, take_cat and from_mag_and_phase
  • Loading branch information
mpariente authored Nov 20, 2020
1 parent 6e517bd commit 3d3be45
Showing 14 changed files with 116 additions and 78 deletions.
12 changes: 6 additions & 6 deletions asteroid/filterbanks/griffin_lim.py
Original file line number Diff line number Diff line change
@@ -34,7 +34,7 @@ def griffin_lim(mag_specgram, stft_enc, angles=None, istft_dec=None, n_iter=6, m
>>> wav = torch.randn(2, 1, 8000)
>>> spec = stft(wav)
>>> masked_spec = spec * torch.sigmoid(torch.randn_like(spec))
>>> mag = transforms.take_mag(masked_spec, -2)
>>> mag = transforms.mag(masked_spec, -2)
>>> est_wav = griffin_lim(mag, stft, n_iter=32)
References
@@ -60,15 +60,15 @@ def griffin_lim(mag_specgram, stft_enc, angles=None, istft_dec=None, n_iter=6, m
for _ in range(n_iter):
prev_built = rebuilt
# Go to the time domain
complex_specgram = transforms.from_mag_and_phase(mag_specgram, angles)
complex_specgram = transforms.from_magphase(mag_specgram, angles)
waveform = istft_dec(complex_specgram)
# And back to TF domain
rebuilt = stft_enc(waveform)
# Update phase estimates (with momentum)
diff = rebuilt - momentum / (1 + momentum) * prev_built
angles = transforms.angle(diff)

final_complex_spec = transforms.from_mag_and_phase(mag_specgram, angles)
final_complex_spec = transforms.from_magphase(mag_specgram, angles)
return istft_dec(final_complex_spec)


@@ -116,7 +116,7 @@ def misi(
>>> wav = torch.randn(2, 3, 8000)
>>> specs = stft(wav)
>>> masked_specs = specs * torch.sigmoid(torch.randn_like(specs))
>>> mag = transforms.take_mag(masked_specs, -2)
>>> mag = transforms.mag(masked_specs, -2)
>>> est_wav = misi(wav.sum(1), mag, stft, n_iter=32)
References
@@ -149,7 +149,7 @@ def misi(
for _ in range(n_iter):
prev_built = rebuilt
# Go to the time domain
complex_specgram = transforms.from_mag_and_phase(mag_specgrams, angles)
complex_specgram = transforms.from_magphase(mag_specgrams, angles)
wavs = istft_dec(complex_specgram)
# Make wavs sum up to the mixture
consistent_wavs = mixture_consistency(
@@ -162,5 +162,5 @@ def misi(
diff = rebuilt - momentum / (1 + momentum) * prev_built
angles = transforms.angle(diff)
# Final source estimates
final_complex_spec = transforms.from_mag_and_phase(mag_specgrams, angles)
final_complex_spec = transforms.from_magphase(mag_specgrams, angles)
return istft_dec(final_complex_spec)
2 changes: 1 addition & 1 deletion asteroid/filterbanks/melgram_fb.py
Original file line number Diff line number Diff line change
@@ -96,6 +96,6 @@ def __init__(
self.register_buffer("fb_mat", torch.from_numpy(fb_mat).unsqueeze(0))

def forward(self, spec: torch.Tensor):
mag_spec = transforms.take_mag(spec, dim=-2)
mag_spec = transforms.mag(spec, dim=-2)
mel_spec = torch.matmul(self.fb_mat, mag_spec)
return mel_spec
72 changes: 51 additions & 21 deletions asteroid/filterbanks/transforms.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,9 @@
from typing import Tuple

import math
import torch
import numpy as np
from typing import Tuple

from ..utils.torch_utils import script_if_tracing
from ..utils.deprecation_utils import mark_deprecated


def mul_c(inp, other, dim: int = -2):
@@ -47,11 +47,18 @@ def mul_c(inp, other, dim: int = -2):
return torch.cat([real1 * real2 - imag1 * imag2, real1 * imag2 + imag1 * real2], dim=dim)


def take_reim(x, dim: int = -2):
return x
def reim(x, dim: int = -2) -> Tuple[torch.Tensor, torch.Tensor]:
"""Returns a tuple (re, im).
Args:
x (:class:`torch.Tensor`): Complex valued tensor.
dim (int): frequency (or equivalent) dimension along which real and
imaginary values are concatenated.
"""
return torch.chunk(x, 2, dim=dim)

def take_mag(x, dim: int = -2, EPS: float = 1e-8):

def mag(x, dim: int = -2, EPS: float = 1e-8):
"""Takes the magnitude of a complex tensor.
The operands is assumed to have the real parts of each entry followed by
@@ -86,8 +93,15 @@ def take_mag(x, dim: int = -2, EPS: float = 1e-8):
return power.pow(0.5)


def take_cat(x, dim: int = -2):
return torch.cat([take_mag(x, dim=dim), x], dim=dim)
def magreim(x, dim: int = -2):
"""Returns a concatenation of (mag, re, im).
Args:
x (:class:`torch.Tensor`): Complex valued tensor.
dim (int): frequency (or equivalent) dimension along which real and
imaginary values are concatenated.
"""
return torch.cat([mag(x, dim=dim), x], dim=dim)


def apply_real_mask(tf_rep, mask, dim: int = -2):
@@ -253,7 +267,7 @@ def is_torchaudio_complex(x):
"""Check if tensor is Torchaudio-style complex-like (last dimension is 2).
Args:
tensor (torch.Tensor): tensor to be checked.
x (torch.Tensor): tensor to be checked.
Returns:
True if last dimension is 2, else False.
@@ -329,11 +343,11 @@ def angle(tensor, dim: int = -2):
return torch.atan2(imag, real)


def from_mag_and_phase(mag, phase, dim: int = -2):
def from_magphase(mag_spec, phase, dim: int = -2):
"""Return a complex-like torch tensor from magnitude and phase components.
Args:
mag (torch.tensor): magnitude of the tensor.
mag_spec (torch.tensor): magnitude of the tensor.
phase (torch.tensor): angle of the tensor
dim(int, optional): the frequency (or equivalent) dimension along which
real and imaginary values are concatenated.
@@ -342,18 +356,14 @@ def from_mag_and_phase(mag, phase, dim: int = -2):
:class:`torch.Tensor`:
The corresponding complex-like torch tensor.
"""
return torch.cat([mag * torch.cos(phase), mag * torch.sin(phase)], dim=dim)


# Alias
from_polar = from_mag_and_phase
return torch.cat([mag_spec * torch.cos(phase), mag_spec * torch.sin(phase)], dim=dim)


def magphase(spec: torch.Tensor, dim: int = -2) -> Tuple[torch.Tensor, torch.Tensor]:
"""Splits Asteroid complex-like tensor into magnitude and phase."""
mag = take_mag(spec, dim=dim)
mag_val = mag(spec, dim=dim)
phase = angle(spec, dim=dim)
return mag, phase
return mag_val, phase


@script_if_tracing
@@ -460,9 +470,9 @@ def centerfreq_correction(
if stride is None:
stride = kernel_size // 2
# Phase will be (batch, n_freq // 2 + 1, frames)
mag, phase = magphase(spec, dim=dim)
mag_spec, phase = magphase(spec, dim=dim)
new_phase = phase_centerfreq_correction(phase, kernel_size=kernel_size, stride=stride)
new_spec = from_polar(mag, new_phase, dim=dim)
new_spec = from_magphase(mag_spec, new_phase, dim=dim)
return new_spec


@@ -484,5 +494,25 @@ def phase_centerfreq_correction(
"""
*_, freq, frames = phase.shape
tmp = torch.arange(freq).unsqueeze(-1) * torch.arange(frames)[None]
correction = -2 * tmp * stride * np.pi / kernel_size
correction = -2 * tmp * stride * math.pi / kernel_size
return phase + correction


@mark_deprecated(None, None)
def take_reim(x, dim: int = -2):
return x


@mark_deprecated("Please use `asteroid.filterbanks.transforms.mag` instead.", None)
def take_mag(*args, **kwargs):
return mag(*args, **kwargs)


@mark_deprecated("Please use `asteroid.filterbanks.transforms.magreim` instead.", None)
def take_cat(*args, **kwargs):
return magreim(*args, **kwargs)


@mark_deprecated("Please use `asteroid.filterbanks.transforms.from_magphase` instead.", None)
def from_mag_and_phase(*args, **kwargs):
return from_magphase(*args, **kwargs)
6 changes: 3 additions & 3 deletions asteroid/losses/multi_scale_spectral.py
Original file line number Diff line number Diff line change
@@ -2,7 +2,7 @@
import torch.nn as nn
from torch.nn.modules.loss import _Loss
from ..filterbanks import STFTFB, Encoder
from ..filterbanks.transforms import take_mag
from ..filterbanks.transforms import mag


class SingleSrcMultiScaleSpectral(_Loss):
@@ -80,8 +80,8 @@ def forward(self, est_target, target):

def compute_spectral_loss(self, encoder, est_target, target, EPS=1e-8):
batch_size = est_target.shape[0]
spect_est_target = take_mag(encoder(est_target)).view(batch_size, -1)
spect_target = take_mag(encoder(target)).view(batch_size, -1)
spect_est_target = mag(encoder(est_target)).view(batch_size, -1)
spect_target = mag(encoder(target)).view(batch_size, -1)
linear_loss = self.norm1(spect_est_target - spect_target)
log_loss = self.norm1(torch.log(spect_est_target + EPS) - torch.log(spect_target + EPS))
return linear_loss + self.alpha * log_loss
8 changes: 4 additions & 4 deletions asteroid/losses/pmsqe.py
Original file line number Diff line number Diff line change
@@ -45,15 +45,15 @@ class (see Tensorflow implementation).
>>> stft = Encoder(STFTFB(kernel_size=512, n_filters=512, stride=256))
>>> # Usage by itself
>>> ref, est = torch.randn(2, 1, 16000), torch.randn(2, 1, 16000)
>>> ref_spec = transforms.take_mag(stft(ref))
>>> est_spec = transforms.take_mag(stft(est))
>>> ref_spec = transforms.mag(stft(ref))
>>> est_spec = transforms.mag(stft(est))
>>> loss_func = SingleSrcPMSQE()
>>> loss_value = loss_func(est_spec, ref_spec)
>>> # Usage with PITLossWrapper
>>> loss_func = PITLossWrapper(SingleSrcPMSQE(), pit_from='pw_pt')
>>> ref, est = torch.randn(2, 3, 16000), torch.randn(2, 3, 16000)
>>> ref_spec = transforms.take_mag(stft(ref))
>>> est_spec = transforms.take_mag(stft(est))
>>> ref_spec = transforms.mag(stft(ref))
>>> est_spec = transforms.mag(stft(est))
>>> loss_value = loss_func(ref_spec, est_spec)
"""

6 changes: 3 additions & 3 deletions asteroid/models/demask.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
from torch import nn
from .base_models import BaseEncoderMaskerDecoder
from ..filterbanks import make_enc_dec
from ..filterbanks.transforms import take_mag, take_cat
from ..filterbanks.transforms import mag, magreim
from ..masknn import norms, activations
from ..utils.torch_utils import pad_x_to_y
from ..utils.deprecation_utils import VisibleDeprecationWarning
@@ -121,9 +121,9 @@ def forward_masker(self, tf_rep):
"""
masker_input = tf_rep
if self.input_type == "mag":
masker_input = take_mag(masker_input)
masker_input = mag(masker_input)
elif self.input_type == "cat":
masker_input = take_cat(masker_input)
masker_input = magreim(masker_input)
est_masks = self.masker(masker_input)
if self.output_type == "mag":
est_masks = est_masks.repeat(1, 2, 1)
34 changes: 21 additions & 13 deletions asteroid/utils/deprecation_utils.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
import warnings
from functools import wraps


class VisibleDeprecationWarning(UserWarning):
@@ -25,18 +26,25 @@ def warn_deprecated(self):
)


def deprecate_func(func, old_name):
"""Function to return DeprecationWarning when a deprecated function
is called. Example to come."""
def mark_deprecated(message, version=None):
"""Decorator to add deprecation message.
def func_with_warning(*args, **kwargs):
""" Deprecated function, please read your warnings. """
warnings.warn(
"{} is deprecated since v0.1.0, it will be removed in "
"v0.2.0. Please use {} instead."
"".format(old_name, func.__name__),
VisibleDeprecationWarning,
)
return func(*args, **kwargs)
Args:
message: Migration steps to be given to users.
"""

return func_with_warning
def decorator(func):
@wraps(func)
def wrapped(*args, **kwargs):
from_what = "a future release" if version is None else f"asteroid v{version}"
warn_message = (
f"{func.__module__}.{func.__name__} has been deprecated "
f"and will be removed from {from_what}. "
f"{message}"
)
warnings.warn(warn_message, VisibleDeprecationWarning, stacklevel=2)
return func(*args, **kwargs)

return wrapped

return decorator
10 changes: 5 additions & 5 deletions egs/dns_challenge/baseline/model.py
Original file line number Diff line number Diff line change
@@ -6,7 +6,7 @@

from asteroid.engine.system import System
from asteroid.filterbanks import make_enc_dec
from asteroid.filterbanks.transforms import take_cat, take_mag
from asteroid.filterbanks.transforms import magreim, mag
from asteroid.filterbanks.transforms import apply_real_mask
from asteroid.filterbanks.transforms import apply_mag_mask
from asteroid.masknn import blocks
@@ -75,9 +75,9 @@ def forward(self, x):
tf_rep = self.encoder(x)
# Estimate TF mask from STFT features : cat([re, im, mag])
if self.is_complex:
to_masker = take_cat(tf_rep)
to_masker = magreim(tf_rep)
else:
to_masker = take_mag(tf_rep)
to_masker = mag(tf_rep)
# LSTM masker expects a feature dimension last (not like 1D conv)
est_masks = self.masker(to_masker.transpose(1, 2)).transpose(1, 2)
# Apply TF mask
@@ -155,10 +155,10 @@ def distance(estimate, target, is_complex=True):
if is_complex:
# Take the difference in the complex plane and compute the squared norm
# of the remaining vector.
return take_mag(estimate - target).pow(2).mean()
return mag(estimate - target).pow(2).mean()
else:
# Compute the mean difference between magnitudes.
return (take_mag(estimate) - take_mag(target)).pow(2).mean()
return (mag(estimate) - mag(target)).pow(2).mean()


def load_best_model(train_conf, exp_dir):
8 changes: 4 additions & 4 deletions egs/kinect-wsj/DeepClustering/train.py
Original file line number Diff line number Diff line change
@@ -10,7 +10,7 @@
from asteroid.engine.system import System
from asteroid.losses import PITLossWrapper, pairwise_mse
from asteroid.losses import deep_clustering_loss
from asteroid.filterbanks.transforms import take_mag, ebased_vad
from asteroid.filterbanks.transforms import mag, ebased_vad

from asteroid.data.kinect_wsj import make_dataloaders
from model import make_model_and_optimizer
@@ -91,7 +91,7 @@ def __init__(self, *args, mask_mixture=True, **kwargs):
def common_step(self, batch, batch_nb, train=False):
inputs, targets, masks = self.unpack_data(batch)
embeddings, est_masks = self(inputs)
spec = take_mag(self.model.encoder(inputs.unsqueeze(1)))
spec = mag(self.model.encoder(inputs.unsqueeze(1)))
if self.mask_mixture:
est_masks = est_masks * spec.unsqueeze(1)
masks = masks * spec.unsqueeze(1)
@@ -136,8 +136,8 @@ def unpack_data(self, batch, EPS=1e-8):
noise = noise[..., 0]
noise = noise.unsqueeze(1)
# Compute magnitude spectrograms and IRM
src_mag_spec = take_mag(self.model.encoder(sources))
noise_mag_spec = take_mag(self.model.encoder(noise))
src_mag_spec = mag(self.model.encoder(sources))
noise_mag_spec = mag(self.model.encoder(noise))
noise_mag_spec = noise_mag_spec.unsqueeze(1)
real_mask = src_mag_spec / (noise_mag_spec + src_mag_spec.sum(1, keepdim=True) + EPS)
# Get the src idx having the maximum energy
Loading

0 comments on commit 3d3be45

Please sign in to comment.