Skip to content

Commit

Permalink
[src] Stabilize GEV beamformer (asteroid-team#479)
Browse files Browse the repository at this point in the history
  • Loading branch information
mpariente authored Apr 17, 2021
1 parent 654f79e commit 2b79902
Show file tree
Hide file tree
Showing 2 changed files with 58 additions and 3 deletions.
52 changes: 49 additions & 3 deletions asteroid/dsp/beamforming.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
import torch
from torch import nn
import warnings


class SCM(nn.Module):
Expand Down Expand Up @@ -119,6 +120,12 @@ def forward(self, mix: torch.Tensor, target_scm: torch.Tensor, noise_scm: torch.
Returns:
Filtered mixture. torch.ComplexTensor (batch, freqs, frames)
"""
bf_vect = self.compute_beamforming_vector(target_scm, noise_scm)
output = self.apply_beamforming_vector(bf_vect, mix=mix) # -> bft
return output

@staticmethod
def compute_beamforming_vector(target_scm: torch.Tensor, noise_scm: torch.Tensor):
noise_scm_t = noise_scm.permute(0, 3, 1, 2)
noise_scm_t = condition_scm(noise_scm_t, 1e-6)
e_val, e_vec = generalized_eigenvalue_decomposition(
Expand All @@ -128,8 +135,7 @@ def forward(self, mix: torch.Tensor, target_scm: torch.Tensor, noise_scm: torch.
# Normalize
bf_vect /= torch.norm(bf_vect, dim=-1, keepdim=True)
bf_vect = bf_vect.squeeze(-1).transpose(-1, -2) # -> bft
output = self.apply_beamforming_vector(bf_vect, mix=mix) # -> bft
return output
return bf_vect


def compute_scm(x: torch.Tensor, mask: torch.Tensor = None, normalize: bool = True):
Expand All @@ -149,6 +155,7 @@ def compute_scm(x: torch.Tensor, mask: torch.Tensor = None, normalize: bool = Tr
if mask.ndim == 3:
mask = mask[:, None]

# torch.matmul((mask * x).transpose(1, 2), x.conj().permute(0, 2, 3, 1))
scm = torch.einsum("bmft,bnft->bmnf", mask * x, x.conj())
if normalize:
scm /= mask.sum(-1, keepdim=True).transpose(-1, -2)
Expand Down Expand Up @@ -186,7 +193,7 @@ def generalized_eigenvalue_decomposition(a, b):
"""Solves the generalized eigenvalue decomposition through Cholesky decomposition.
Returns eigen values and eigen vectors (ascending order).
"""
cholesky = torch.cholesky(b)
cholesky = stable_cholesky(b, max_tries=2)
inv_cholesky = torch.inverse(cholesky)
# Compute C matrix L⁻1 A L^-T
cmat = inv_cholesky @ a @ inv_cholesky.conj().transpose(-1, -2)
Expand All @@ -195,3 +202,42 @@ def generalized_eigenvalue_decomposition(a, b):
# Collecting the eigenvectors
e_vec = torch.matmul(inv_cholesky.conj().transpose(-1, -2), e_vec)
return e_val, e_vec


def stable_cholesky(input, upper=False, out=None, jitter=1e-6, max_tries=2, verbose=False):
"""Compute the Cholesky decomposition of A.
If A is only p.s.d, add a small jitter to the diagonal.
Args:
input (Tensor): The tensor to compute the Cholesky decomposition of
upper (bool, optional): See torch.cholesky
out (Tensor, optional): See torch.cholesky
jitter (float): The jitter to add to the diagonal of A in case A is only p.s.d.
max_tries (int, optional): Number of attempts (with increasing jitter) before raising an error.
verbose (bool): Whether to raise a warning if the jitter had to be added.
Adapted from GPytorch https://github.com/cornellius-gp/gpytorch/blob/master/gpytorch/utils/cholesky.py#L12
"""
try:
return torch.cholesky(input, upper=upper, out=out)
except RuntimeError as e:
clone = input.clone()
jitter_prev = 0
for i in range(max_tries):
jitter_new = jitter * (10 ** i)
clone.diagonal(dim1=-2, dim2=-1).add_(jitter_new - jitter_prev)
jitter_prev = jitter_new
try:
out = torch.cholesky(clone, upper=upper, out=out)
if verbose is True:
warnings.warn(
f"Had to add a jitter of {jitter_new:.1e} to compute the cholesky decomposition.",
RuntimeWarning,
)
return out
except RuntimeError:
continue
raise RuntimeError(
f"Matrix not positive definite after repeatedly adding jitter up to {jitter_new:.1e}. "
f"Original error on first attempt: {e}"
)
9 changes: 9 additions & 0 deletions tests/dsp/beamforming_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
MvdrBeamformer,
SdwMwfBeamformer,
GEVBeamformer,
stable_cholesky,
)


Expand Down Expand Up @@ -53,3 +54,11 @@ def test_mvdr(n_mics):
@pytest.mark.parametrize("mu", [1.0, 2.0, 0])
def test_mwf(n_mics, mu):
_default_beamformer_test(SdwMwfBeamformer(mu=mu), n_mics=n_mics)


def test_stable_cholesky():
stable_cholesky(torch.zeros(2, 2))
with pytest.warns(RuntimeWarning):
stable_cholesky(torch.zeros(2, 2), verbose=True)
with pytest.raises(RuntimeError):
stable_cholesky(torch.zeros(2, 2), jitter=0.0)

0 comments on commit 2b79902

Please sign in to comment.