Skip to content

Commit

Permalink
[src] Improve Beamforming naming and add TODOs (asteroid-team#483)
Browse files Browse the repository at this point in the history
  • Loading branch information
mpariente authored Apr 19, 2021
1 parent 228d8bb commit b508c1d
Show file tree
Hide file tree
Showing 3 changed files with 39 additions and 21 deletions.
8 changes: 8 additions & 0 deletions asteroid/dsp/__init__.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,13 @@
from .consistency import mixture_consistency
from .overlap_add import LambdaOverlapAdd, DualPathProcessing
from .beamforming import (
SCM,
Beamformer,
RTFMVDRBeamformer,
SoudenMVDRBeamformer,
SDWMWFBeamformer,
GEVBeamformer,
)

__all__ = [
"mixture_consistency",
Expand Down
40 changes: 25 additions & 15 deletions asteroid/dsp/beamforming.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,5 @@
import torch
from torch import nn
import warnings
from functools import wraps


class SCM(nn.Module):
Expand All @@ -10,7 +8,7 @@ def forward(self, x: torch.Tensor, mask: torch.Tensor = None, normalize: bool =
return compute_scm(x, mask=mask, normalize=normalize)


class BeamFormer(nn.Module):
class Beamformer(nn.Module):
@staticmethod
def apply_beamforming_vector(bf_vector: torch.Tensor, mix: torch.Tensor):
"""Apply the beamforming vector to the mixture. Output (batch, freqs, frames).
Expand All @@ -22,7 +20,7 @@ def apply_beamforming_vector(bf_vector: torch.Tensor, mix: torch.Tensor):
return torch.einsum("...mf,...mft->...ft", bf_vector.conj(), mix)


class MvdrBeamformer(BeamFormer):
class RTFMVDRBeamformer(Beamformer):
def forward(
self,
mix: torch.Tensor,
Expand All @@ -42,39 +40,45 @@ def forward(
Returns:
Filtered mixture. torch.ComplexTensor (batch, freqs, frames)
"""
# Get acoustic transfer function (1st PCA of Σss)
# TODO: Implement several RTF estimation strategies, and choose one here, or expose all.
# Get relative transfer function (1st PCA of Σss)
e_val, e_vec = torch.symeig(target_scm.permute(0, 3, 1, 2), eigenvectors=True)
atf_vect = e_vec[..., -1] # bfm
return self.from_atf_vect(mix=mix, atf_vec=atf_vect.transpose(-1, -2), noise_scm=noise_scm)
rtf_vect = e_vec[..., -1] # bfm
return self.from_rtf_vect(mix=mix, rtf_vec=rtf_vect.transpose(-1, -2), noise_scm=noise_scm)

def from_atf_vect(
def from_rtf_vect(
self,
mix: torch.Tensor,
atf_vec: torch.Tensor,
rtf_vec: torch.Tensor,
noise_scm: torch.Tensor,
):
"""Compute and apply MVDR beamformer from the ATF vector and noise SCM matrix.
Args:
mix (torch.ComplexTensor): shape (batch, mics, freqs, frames)
atf_vec (torch.ComplexTensor): (batch, mics, freqs)
rtf_vec (torch.ComplexTensor): (batch, mics, freqs)
noise_scm (torch.ComplexTensor): (batch, mics, mics, freqs)
Returns:
Filtered mixture. torch.ComplexTensor (batch, freqs, frames)
"""
noise_scm_t = noise_scm.permute(0, 3, 1, 2) # -> bfmm
atf_vec_t = atf_vec.transpose(-1, -2).unsqueeze(-1) # -> bfm1
rtf_vec_t = rtf_vec.transpose(-1, -2).unsqueeze(-1) # -> bfm1

numerator = stable_solve(atf_vec_t, noise_scm_t) # -> bfm1
numerator = stable_solve(rtf_vec_t, noise_scm_t) # -> bfm1

denominator = torch.matmul(atf_vec_t.conj().transpose(-1, -2), numerator) # -> bf11
denominator = torch.matmul(rtf_vec_t.conj().transpose(-1, -2), numerator) # -> bf11
bf_vect = (numerator / denominator).squeeze(-1).transpose(-1, -2) # -> bfm1 -> bmf
output = self.apply_beamforming_vector(bf_vect, mix=mix) # -> bft
return output


class SdwMwfBeamformer(BeamFormer):
class SoudenMVDRBeamformer(Beamformer):
# TODO(popcornell): fill in the code.
pass


class SDWMWFBeamformer(Beamformer):
def __init__(self, mu=1.0):
super().__init__()
self.mu = mu
Expand Down Expand Up @@ -105,7 +109,7 @@ def forward(
return output


class GEVBeamformer(BeamFormer):
class GEVBeamformer(Beamformer):
def forward(self, mix: torch.Tensor, target_scm: torch.Tensor, noise_scm: torch.Tensor):
"""Compute and apply the GEV beamformer.
Expand Down Expand Up @@ -262,3 +266,9 @@ def _common_dtype(*args):
if len(set(all_dtypes)) > 1:
raise RuntimeError(f"Expected inputs from the same dtype. Received {all_dtypes}.")
return all_dtypes[0]


# Legacy
BeamFormer = Beamformer
SdwMwfBeamformer = SDWMWFBeamformer
MvdrBeamformer = RTFMVDRBeamformer
12 changes: 6 additions & 6 deletions tests/dsp/beamforming_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,10 +3,10 @@
from asteroid_filterbanks import make_enc_dec, transforms as tr

from asteroid.dsp.beamforming import (
BeamFormer,
Beamformer,
SCM,
MvdrBeamformer,
SdwMwfBeamformer,
RTFMVDRBeamformer,
SDWMWFBeamformer,
GEVBeamformer,
stable_cholesky,
)
Expand All @@ -20,7 +20,7 @@


@pytest.mark.skipif(not torch_has_complex_support, "No complex support ")
def _default_beamformer_test(beamformer: BeamFormer, n_mics=4, *args, **kwargs):
def _default_beamformer_test(beamformer: Beamformer, n_mics=4, *args, **kwargs):
scm = SCM()

speech = torch.randn(1, n_mics, 16000 * 6)
Expand All @@ -46,14 +46,14 @@ def test_gev(n_mics):
@pytest.mark.skipif(not torch_has_complex_support, reason="No complex support ")
@pytest.mark.parametrize("n_mics", [2, 3, 4])
def test_mvdr(n_mics):
_default_beamformer_test(MvdrBeamformer(), n_mics=n_mics)
_default_beamformer_test(RTFMVDRBeamformer(), n_mics=n_mics)


@pytest.mark.skipif(not torch_has_complex_support, reason="No complex support ")
@pytest.mark.parametrize("n_mics", [2, 3, 4])
@pytest.mark.parametrize("mu", [1.0, 2.0, 0])
def test_mwf(n_mics, mu):
_default_beamformer_test(SdwMwfBeamformer(mu=mu), n_mics=n_mics)
_default_beamformer_test(SDWMWFBeamformer(mu=mu), n_mics=n_mics)


def test_stable_cholesky():
Expand Down

0 comments on commit b508c1d

Please sign in to comment.