Skip to content

Commit

Permalink
[src] Beamforming: Souden MVDR and optimal channel selection (asteroi…
Browse files Browse the repository at this point in the history
…d-team#484)

Co-authored-by: popcornell <cornellsamuele@gmail.com>
  • Loading branch information
mpariente and popcornell authored Apr 26, 2021
1 parent b508c1d commit efb95a4
Show file tree
Hide file tree
Showing 2 changed files with 182 additions and 20 deletions.
155 changes: 145 additions & 10 deletions asteroid/dsp/beamforming.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,7 @@
from typing import Union
import torch
from torch import nn
from torch.nn import functional as F


class SCM(nn.Module):
Expand All @@ -19,6 +21,53 @@ def apply_beamforming_vector(bf_vector: torch.Tensor, mix: torch.Tensor):
"""
return torch.einsum("...mf,...mft->...ft", bf_vector.conj(), mix)

@staticmethod
def get_reference_mic_vects(
ref_mic,
bf_mat: torch.Tensor,
target_scm: torch.Tensor = None,
noise_scm: torch.Tensor = None,
):
"""Return the reference channel indices over the batch.
Args:
ref_mic (Optional[Union[int, torch.Tensor]]): The reference channel.
If torch.Tensor (ndim>1), return it, it is the reference mic vector,
If torch.LongTensor of size `batch`, select independent reference mic of the batch.
If int, select the corresponding reference mic,
If None, the optimal reference mics are computed with :func:`get_optimal_reference_mic`,
If None, and either SCM is None, `ref_mic` is set to `0`,
bf_mat: beamforming matrix of shape (batch, freq, mics, mics).
target_scm (torch.ComplexTensor): (batch, freqs, mics, mics).
noise_scm (torch.ComplexTensor): (batch, freqs, mics, mics).
Returns:
torch.LongTensor of size ``batch`` to select with the reference channel indices.
"""
# If ref_mic already has the expected shape.
if isinstance(ref_mic, torch.Tensor) and ref_mic.ndim > 1:
return ref_mic

if (target_scm is None or noise_scm is None) and ref_mic is None:
ref_mic = 0
if ref_mic is None:
batch_mic_idx = get_optimal_reference_mic(
bf_mat=bf_mat, target_scm=target_scm, noise_scm=noise_scm
)
elif isinstance(ref_mic, int):
batch_mic_idx = torch.LongTensor([ref_mic] * bf_mat.shape[0]).to(bf_mat.device)
elif isinstance(ref_mic, torch.Tensor): # Must be 1D
batch_mic_idx = ref_mic
else:
raise ValueError(
f"Unsupported reference microphone format. Support None, int and 1D "
f"torch.LongTensor and torch.Tensor, received {type(ref_mic)}."
)
# Output (batch, 1, n_mics, 1)
# import ipdb; ipdb.set_trace()
ref_mic_vects = F.one_hot(batch_mic_idx, num_classes=bf_mat.shape[-1])[:, None, :, None]
return ref_mic_vects.to(bf_mat.dtype).to(bf_mat.device)


class RTFMVDRBeamformer(Beamformer):
def forward(
Expand All @@ -30,7 +79,8 @@ def forward(
r"""Compute and apply MVDR beamformer from the speech and noise SCM matrices.
:math:`\mathbf{w} = \displaystyle \frac{\Sigma_{nn}^{-1} \mathbf{a}}{
\mathbf{a}^H \Sigma_{nn}^{-1} \mathbf{a}}` where :math:`\mathbf{a}` is the ATF estimated from the target SCM.
\mathbf{a}^H \Sigma_{nn}^{-1} \mathbf{a}}` where :math:`\mathbf{a}` is the
ATF estimated from the target SCM.
Args:
mix (torch.ComplexTensor): shape (batch, mics, freqs, frames)
Expand Down Expand Up @@ -74,8 +124,50 @@ def from_rtf_vect(


class SoudenMVDRBeamformer(Beamformer):
# TODO(popcornell): fill in the code.
pass
def forward(
self,
mix: torch.Tensor,
target_scm: torch.Tensor,
noise_scm: torch.Tensor,
ref_mic: Union[torch.Tensor, torch.LongTensor, int] = 0,
eps=1e-8,
):
r"""Compute and apply MVDR beamformer from the speech and noise SCM matrices.
This class uses Souden's formulation [1].
:math:`\mathbf{w} = \displaystyle \frac{\Sigma_{nn}^{-1} \Sigma_{ss}}{
Tr\left( \Sigma_{nn}^{-1} \Sigma_{ss} \right) }\mathbf{u}` where :math:`\mathbf{a}`
is the steering vector.
Args:
mix (torch.ComplexTensor): shape (batch, mics, freqs, frames)
target_scm (torch.ComplexTensor): (batch, mics, mics, freqs)
noise_scm (torch.ComplexTensor): (batch, mics, mics, freqs)
ref_mic (int): reference microphone.
eps: numerical stabilizer.
Returns:
Filtered mixture. torch.ComplexTensor (batch, freqs, frames)
References
[1] Souden, M., Benesty, J., & Affes, S. (2009). On optimal frequency-domain multichannel
linear filtering for noise reduction. IEEE Transactions on audio, speech, and language processing, 18(2), 260-276.
"""
noise_scm = noise_scm.permute(0, 3, 1, 2) # -> bfmm
target_scm = target_scm.permute(0, 3, 1, 2) # -> bfmm

numerator = stable_solve(target_scm, noise_scm)
bf_mat = numerator / (batch_trace(numerator)[..., None, None] + eps) # bfmm

# allow for a-posteriori SNR selection
batch_mic_vects = self.get_reference_mic_vects(
ref_mic, bf_mat, target_scm=target_scm, noise_scm=noise_scm
)
bf_vect = torch.matmul(bf_mat, batch_mic_vects) # -> bfmm -> bfm1
bf_vect = bf_vect.squeeze(-1).transpose(-1, -2) # bfm1 -> bmf
output = self.apply_beamforming_vector(bf_vect, mix=mix) # -> bft
return output


class SDWMWFBeamformer(Beamformer):
Expand All @@ -84,9 +176,13 @@ def __init__(self, mu=1.0):
self.mu = mu

def forward(
self, mix: torch.Tensor, target_scm: torch.Tensor, noise_scm: torch.Tensor, ref_mic: int = 0
self,
mix: torch.Tensor,
target_scm: torch.Tensor,
noise_scm: torch.Tensor,
ref_mic: Union[torch.Tensor, torch.LongTensor, int] = None,
):
"""Compute and apply SDW-MWF beamformer.
r"""Compute and apply SDW-MWF beamformer.
:math:`\mathbf{w} = \displaystyle (\Sigma_{ss} + \mu \Sigma_{nn})^{-1} \Sigma_{ss}`.
Expand All @@ -102,19 +198,27 @@ def forward(
noise_scm_t = noise_scm.permute(0, 3, 1, 2) # -> bfmm
target_scm_t = target_scm.permute(0, 3, 1, 2) # -> bfmm

# import ipdb; ipdb.set_trace()

denominator = target_scm_t + self.mu * noise_scm_t
bf_vect = stable_solve(target_scm_t, denominator)
bf_vect = bf_vect[..., ref_mic].transpose(-1, -2) # -> bfm1 -> bmf
bf_mat = stable_solve(target_scm_t, denominator)
# Reference mic selection and application
batch_mic_vects = self.get_reference_mic_vects(
ref_mic, bf_mat, target_scm=target_scm_t, noise_scm=noise_scm_t
) # b1m1
bf_vect = torch.matmul(bf_mat, batch_mic_vects) # -> bfmm -> bfm1
bf_vect = bf_vect.squeeze(-1).transpose(-1, -2) # bfm1 -> bmf
output = self.apply_beamforming_vector(bf_vect, mix=mix) # -> bft
return output


class GEVBeamformer(Beamformer):
def forward(self, mix: torch.Tensor, target_scm: torch.Tensor, noise_scm: torch.Tensor):
"""Compute and apply the GEV beamformer.
r"""Compute and apply the GEV beamformer.
:math:`\mathbf{w} = \displaystyle MaxEig\{ \Sigma_{nn}^{-1}\Sigma_{ss} \}`, where
MaxEig extracts the eigenvector corresponding to the maximum eigenvalue (using the GEV decomposition).
MaxEig extracts the eigenvector corresponding to the maximum eigenvalue
(using the GEV decomposition).
Args:
mix: shape (batch, mics, freqs, frames)
Expand Down Expand Up @@ -166,6 +270,37 @@ def compute_scm(x: torch.Tensor, mask: torch.Tensor = None, normalize: bool = Tr
return scm


def get_optimal_reference_mic(
bf_mat: torch.Tensor,
target_scm: torch.Tensor,
noise_scm: torch.Tensor,
eps: float = 1e-6,
):
"""Compute the optimal reference mic given the a posteriori SNR, see [1].
Args:
bf_mat: (batch, freq, mics, mics)
target_scm (torch.ComplexTensor): (batch, freqs, mics, mics)
noise_scm (torch.ComplexTensor): (batch, freqs, mics, mics)
eps: value to clip the denominator.
Returns:
torch.
References
Erdogan et al. 2016: "Improved MVDR beamforming using single-channel maskprediction networks"
https://www.merl.com/publications/docs/TR2016-072.pdf
"""
den = torch.clamp(
torch.einsum("...flm,...fln,...fnm->...m", bf_mat.conj(), noise_scm, bf_mat).real, min=eps
)
snr_post = (
torch.einsum("...flm,...fln,...fnm->...m", bf_mat.conj(), target_scm, bf_mat).real / den
)
assert torch.all(torch.isfinite(snr_post)), snr_post
return torch.argmax(snr_post, dim=-1)


def condition_scm(x, eps=1e-6, dim1=-2, dim2=-1):
"""Condition input SCM with (x + eps tr(x) I) / (1 + eps) along `dim1` and `dim2`.
Expand All @@ -175,7 +310,7 @@ def condition_scm(x, eps=1e-6, dim1=-2, dim2=-1):
if dim1 != -2 or dim2 != -1:
raise NotImplementedError
scale = eps * batch_trace(x, dim1=dim1, dim2=dim2)[..., None, None] / x.shape[dim1]
scaled_eye = torch.eye(x.shape[dim1])[None, None] * scale
scaled_eye = torch.eye(x.shape[dim1], device=x.device)[None, None] * scale
return (x + scaled_eye) / (1 + eps)


Expand Down
47 changes: 37 additions & 10 deletions tests/dsp/beamforming_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
Beamformer,
SCM,
RTFMVDRBeamformer,
SoudenMVDRBeamformer,
SDWMWFBeamformer,
GEVBeamformer,
stable_cholesky,
Expand All @@ -20,11 +21,11 @@


@pytest.mark.skipif(not torch_has_complex_support, "No complex support ")
def _default_beamformer_test(beamformer: Beamformer, n_mics=4, *args, **kwargs):
def _default_beamformer_test(beamformer: Beamformer, batch_size=2, n_mics=4, **forward_kwargs):
scm = SCM()

speech = torch.randn(1, n_mics, 16000 * 6)
noise = torch.randn(1, n_mics, 16000 * 6)
speech = torch.randn(batch_size, n_mics, 16000 * 6)
noise = torch.randn(batch_size, n_mics, 16000 * 6)
mix = speech + noise
# GeV Beamforming
mix_stft = stft(mix)
Expand All @@ -33,27 +34,53 @@ def _default_beamformer_test(beamformer: Beamformer, n_mics=4, *args, **kwargs):
sigma_ss = scm(speech_stft)
sigma_nn = scm(noise_stft)

Ys_gev = beamformer.forward(mix=mix_stft, target_scm=sigma_ss, noise_scm=sigma_nn)
Ys_gev = beamformer.forward(
mix=mix_stft, target_scm=sigma_ss, noise_scm=sigma_nn, **forward_kwargs
)
ys_gev = istft(Ys_gev)


@pytest.mark.skipif(not torch_has_complex_support, reason="No complex support ")
@pytest.mark.parametrize("n_mics", [2, 3, 4])
def test_gev(n_mics):
_default_beamformer_test(GEVBeamformer(), n_mics=n_mics)
@pytest.mark.parametrize("batch_size", [1, 2])
def test_gev(n_mics, batch_size):
_default_beamformer_test(GEVBeamformer(), n_mics=n_mics, batch_size=batch_size)


@pytest.mark.skipif(not torch_has_complex_support, reason="No complex support ")
@pytest.mark.parametrize("n_mics", [2, 3, 4])
def test_mvdr(n_mics):
_default_beamformer_test(RTFMVDRBeamformer(), n_mics=n_mics)
@pytest.mark.parametrize("batch_size", [1, 2])
def test_mvdr(n_mics, batch_size):
_default_beamformer_test(RTFMVDRBeamformer(), n_mics=n_mics, batch_size=batch_size)
_default_beamformer_test(SoudenMVDRBeamformer(), n_mics=n_mics, batch_size=batch_size)


@pytest.mark.skipif(not torch_has_complex_support, reason="No complex support ")
@pytest.mark.parametrize("n_mics", [2, 3, 4])
@pytest.mark.parametrize("mu", [1.0, 2.0, 0])
def test_mwf(n_mics, mu):
_default_beamformer_test(SDWMWFBeamformer(mu=mu), n_mics=n_mics)
@pytest.mark.parametrize("batch_size", [1, 2])
def test_mwf(n_mics, mu, batch_size):
_default_beamformer_test(SDWMWFBeamformer(mu=mu), n_mics=n_mics, batch_size=batch_size)


@pytest.mark.skipif(not torch_has_complex_support, reason="No complex support ")
@pytest.mark.parametrize("n_mics", [2, 3, 4])
@pytest.mark.parametrize("batch_size", [1, 2])
def test_mwf_indices(n_mics, batch_size):
_default_beamformer_test(SDWMWFBeamformer(), n_mics=n_mics, batch_size=batch_size, ref_mic=0)
_default_beamformer_test(SDWMWFBeamformer(), n_mics=n_mics, batch_size=batch_size, ref_mic=None)
_default_beamformer_test(
SDWMWFBeamformer(),
n_mics=n_mics,
batch_size=batch_size,
ref_mic=torch.randint(0, n_mics, size=(batch_size,)),
)
_default_beamformer_test(
SDWMWFBeamformer(),
n_mics=n_mics,
batch_size=batch_size,
ref_mic=torch.randn(batch_size, 1, n_mics, 1, dtype=torch.complex64),
)


def test_stable_cholesky():
Expand Down

0 comments on commit efb95a4

Please sign in to comment.