Skip to content

Commit

Permalink
[src] Add MixITWrapper loss (asteroid-team#320)
Browse files Browse the repository at this point in the history
 Mixture Invariant Training for unsupervised source separation 

Co-authored-by: mpariente <pariente.mnl@gmail.com>
  • Loading branch information
giorgiacantisani and mpariente authored Nov 16, 2020
1 parent 769fcb6 commit aedb22f
Show file tree
Hide file tree
Showing 5 changed files with 295 additions and 4 deletions.
2 changes: 2 additions & 0 deletions asteroid/losses/__init__.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
from .pit_wrapper import PITLossWrapper
from .mixit_wrapper import MixITLossWrapper
from .sinkpit_wrapper import SinkPITLossWrapper, SinkPITBetaScheduler
from .sdr import singlesrc_neg_sisdr, multisrc_neg_sisdr
from .sdr import singlesrc_neg_sdsdr, multisrc_neg_sdsdr
Expand All @@ -19,6 +20,7 @@

__all__ = [
"PITLossWrapper",
"MixITLossWrapper",
"SinkPITLossWrapper",
"SinkPITBetaScheduler",
"singlesrc_neg_sisdr",
Expand Down
225 changes: 225 additions & 0 deletions asteroid/losses/mixit_wrapper.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,225 @@
import warnings
from itertools import combinations
import torch
from torch import nn


class MixITLossWrapper(nn.Module):
r"""Mixture invariant loss wrapper.
Args:
loss_func: function with signature (est_targets, targets, **kwargs).
generalized (bool): Determines how MixIT is applied. If False (default),
apply MixIT for any number of mixtures as soon as they contain
the same number of sources (:meth:`~MixITLossWrapper.best_part_mixit`.)
If True, apply MixIT for two mixtures, but those mixtures do not
necessarly have to contain the same number of sources.
See :meth:`~MixITLossWrapper.best_part_mixit_gen`.
For each of these modes, the best partition and reordering will be
automatically computed.
Examples:
>>> import torch
>>> from asteroid.losses import multisrc_mse
>>> mixtures = torch.randn(10, 2, 16000)
>>> est_sources = torch.randn(10, 4, 16000)
>>> # Compute MixIT loss based on pairwise losses
>>> loss_func = MixITLossWrapper(multisrc_mse)
>>> loss_val = loss_func(est_sources, mixtures)
References
[1] Scott Wisdom et al. "Unsupervised sound separation using
mixtures of mixtures." arXiv:2006.12701 (2020)
"""

def __init__(self, loss_func, generalized=True):
super().__init__()
self.loss_func = loss_func
self.generalized = generalized

def forward(self, est_targets, targets, return_est=False, **kwargs):
"""Find the best partition and return the loss.
Args:
est_targets: torch.Tensor. Expected shape [batch, nsrc, *].
The batch of target estimates.
targets: torch.Tensor. Expected shape [batch, nmix, *].
The batch of training targets
return_est: Boolean. Whether to return the estimated mixtures
estimates (To compute metrics or to save example).
**kwargs: additional keyword argument that will be passed to the
loss function.
Returns:
- Best partition loss for each batch sample, average over
the batch. torch.Tensor(loss_value)
- The estimated mixtures (estimated sources summed according
to the partition) if return_est is True.
torch.Tensor of shape [batch, nmix, *].
"""
# Check input dimensions
assert est_targets.shape[0] == targets.shape[0]
assert est_targets.shape[2] == targets.shape[2]

if not self.generalized:
min_loss, min_loss_idx, parts = self.best_part_mixit(
self.loss_func, est_targets, targets, **kwargs
)
else:
min_loss, min_loss_idx, parts = self.best_part_mixit_generalized(
self.loss_func, est_targets, targets, **kwargs
)
# Take the mean over the batch
mean_loss = torch.mean(min_loss)
if not return_est:
return mean_loss
# Order and sum on the best partition to get the estimated mixtures
reordered = self.reorder_source(est_targets, targets, min_loss_idx, parts)
return mean_loss, reordered

@staticmethod
def best_part_mixit(loss_func, est_targets, targets, **kwargs):
"""Find best partition of the estimated sources that gives the minimum
loss for the MixIT training paradigm in [1]. Valid for any number of
mixtures as soon as they contain the same number of sources.
Args:
loss_func: function with signature (est_targets, targets, **kwargs)
The loss function to get batch losses from.
est_targets: torch.Tensor. Expected shape [batch, nsrc, *].
The batch of target estimates.
targets: torch.Tensor. Expected shape [batch, nmix, *].
The batch of training targets (mixtures).
**kwargs: additional keyword argument that will be passed to the
loss function.
Returns:
tuple:
:class:`torch.Tensor`: The loss corresponding to the best
permutation of size (batch,).
:class:`torch.LongTensor`: The indices of the best partition.
:class:`list`: list of the possible partitions of the sources.
"""
nmix = targets.shape[1]
nsrc = est_targets.shape[1]
if nsrc % nmix != 0:
raise ValueError("The mixtures are assumed to contain the same number of sources")
nsrcmix = nsrc // nmix

# Generate all unique partitions of size k from a list lst of
# length n, where l = n // k is the number of parts. The total
# number of such partitions is: NPK(n,k) = n! / ((k!)^l * l!)
# Algorithm recursively distributes items over parts
def parts_mixit(lst, k, l):
if l == 0:
yield []
else:
for c in combinations(lst, k):
rest = [x for x in lst if x not in c]
for r in parts_mixit(rest, k, l - 1):
yield [list(c), *r]

# Generate all the possible partitions
parts = list(parts_mixit(range(nsrc), nsrcmix, nmix))
# Compute the loss corresponding to each partition
loss_set = MixITLossWrapper.loss_set_from_parts(
loss_func, est_targets=est_targets, targets=targets, parts=parts, **kwargs
)
# Indexes and values of min losses for each batch element
min_loss, min_loss_indexes = torch.min(loss_set, dim=1, keepdim=True)
return min_loss, min_loss_indexes, parts

@staticmethod
def best_part_mixit_generalized(loss_func, est_targets, targets, **kwargs):
"""Find best partition of the estimated sources that gives the minimum
loss for the MixIT training paradigm in [1]. Valid only for two mixtures,
but those mixtures do not necessarly have to contain the same number of
sources e.g the case where one mixture is silent is allowed..
Args:
loss_func: function with signature (est_targets, targets, **kwargs)
The loss function to get batch losses from.
est_targets: torch.Tensor. Expected shape [batch, nsrc, *].
The batch of target estimates.
targets: torch.Tensor. Expected shape [batch, nmix, *].
The batch of training targets (mixtures).
**kwargs: additional keyword argument that will be passed to the
loss function.
Returns:
tuple:
:class:`torch.Tensor`: The loss corresponding to the best
permutation of size (batch,).
:class:`torch.LongTensor`: The indexes of the best permutations.
:class:`list`: list of the possible partitions of the sources.
"""
nmix = targets.shape[1] # number of mixtures
nsrc = est_targets.shape[1] # number of estimated sources
if nmix != 2:
raise ValueError("Works only with two mixtures")

# Generate all unique partitions of any size from a list lst of
# length n. Algorithm recursively distributes items over parts
def parts_mixit_gen(lst):
partitions = []
for k in range(len(lst) + 1):
for c in combinations(lst, k):
rest = [x for x in lst if x not in c]
partitions.append([list(c), rest])
return partitions

# Generate all the possible partitions
parts = parts_mixit_gen(range(nsrc))
# Compute the loss corresponding to each partition
loss_set = MixITLossWrapper.loss_set_from_parts(
loss_func, est_targets=est_targets, targets=targets, parts=parts, **kwargs
)
# Indexes and values of min losses for each batch element
min_loss, min_loss_indexes = torch.min(loss_set, dim=1, keepdim=True)
return min_loss, min_loss_indexes, parts

@staticmethod
def loss_set_from_parts(loss_func, est_targets, targets, parts, **kwargs):
"""Common loop between both best_part_mixit"""
loss_set = []
for partition in parts:
# sum the sources according to the given partition
est_mixes = torch.stack([est_targets[:, idx, :].sum(1) for idx in partition], dim=1)
# get loss for the given partition
loss_set.append(loss_func(est_mixes, targets, **kwargs)[:, None])
loss_set = torch.cat(loss_set, dim=1)
return loss_set

@staticmethod
def reorder_source(est_targets, targets, min_loss_idx, parts):
"""Reorder sources according to the best partition.
Args:
est_targets: torch.Tensor. Expected shape [batch, nsrc, *].
The batch of target estimates.
targets: torch.Tensor. Expected shape [batch, nmix, *].
The batch of training targets.
min_loss_idx: torch.LongTensor. The indexes of the best permutations.
parts: list of the possible partitions of the sources.
Returns:
:class:`torch.Tensor`:
Reordered sources of shape [batch, nmix, time].
"""
# For each batch there is a different min_loss_idx
ordered = torch.zeros_like(targets)
for b, idx in enumerate(min_loss_idx):
right_partition = parts[idx]
# Sum the estimated sources to get the estimated mixtures
ordered[b, :, :] = torch.stack(
[est_targets[b, idx, :][None, :, :].sum(1) for idx in right_partition], dim=1
)

return ordered
6 changes: 3 additions & 3 deletions asteroid/losses/pit_wrapper.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@ class PITLossWrapper(nn.Module):
r"""Permutation invariant loss wrapper.
Args:
loss_func: function with signature (targets, est_targets, **kwargs).
loss_func: function with signature (est_targets, targets, **kwargs).
pit_from (str): Determines how PIT is applied.
* ``'pw_mtx'`` (pairwise matrix): `loss_func` computes pairwise
Expand Down Expand Up @@ -135,7 +135,7 @@ def get_pw_losses(loss_func, est_targets, targets, **kwargs):
for a given loss function.
Args:
loss_func: function with signature (targets, est_targets, **kwargs)
loss_func: function with signature (est_targets, targets, **kwargs)
The loss function to get pair-wise losses from.
est_targets: torch.Tensor. Expected shape [batch, nsrc, *].
The batch of target estimates.
Expand Down Expand Up @@ -164,7 +164,7 @@ def best_perm_from_perm_avg_loss(loss_func, est_targets, targets, **kwargs):
"""Find best permutation from loss function with source axis.
Args:
loss_func: function with signature (targets, est_targets, **kwargs)
loss_func: function with signature (est_targets, targets, **kwargs)
The loss function batch losses from.
est_targets: torch.Tensor. Expected shape [batch, nsrc, *].
The batch of target estimates.
Expand Down
12 changes: 11 additions & 1 deletion docs/source/package_reference/losses.rst
Original file line number Diff line number Diff line change
Expand Up @@ -11,9 +11,19 @@ Losses & Metrics
Permutation invariant training (PIT) made easy
----------------------------------------------

Asteroid supports regular Permutation Invariant Training (PIT), it's extension
using Sinkhorn algorithm (SinkPIT) as well as Mixture Invariant Training (MixIT).


.. automodule:: asteroid.losses.pit_wrapper
:members:

.. automodule:: asteroid.losses.sinkpit_wrapper
:members:

.. automodule:: asteroid.losses.mixit_wrapper
:members:

Available loss functions
------------------------

Expand Down Expand Up @@ -62,4 +72,4 @@ into both PIT and nonPIT training.
Computing metrics
-----------------

.. autofunction:: asteroid.metrics.get_metrics
.. autofunction:: asteroid.metrics.get_metrics
54 changes: 54 additions & 0 deletions tests/losses/mixit_wrapper_test.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,54 @@
import pytest
import torch

from asteroid.losses import MixITLossWrapper


def good_batch_loss_func(y_pred, y_true):
batch, *_ = y_true.shape
return torch.randn(batch)


@pytest.mark.parametrize("batch_size", [1, 2, 8])
@pytest.mark.parametrize("n_src", [2, 3, 4])
@pytest.mark.parametrize("time", [16000])
def test_mixitwrapper_as_pit_wrapper(batch_size, n_src, time):
targets = torch.randn(batch_size, n_src, time)
est_targets = torch.randn(batch_size, n_src, time)

# mix_it base case: targets == mixtures / With and without return estimates
loss = MixITLossWrapper(good_batch_loss_func, generalized=False)
loss(est_targets, targets)
loss_value, reordered_est = loss(est_targets, targets, return_est=True)
assert reordered_est.shape == est_targets.shape


@pytest.mark.parametrize("batch_size", [1, 2, 4])
@pytest.mark.parametrize("factor", [1, 2, 3])
@pytest.mark.parametrize("n_mix", [2, 3])
@pytest.mark.parametrize("time", [16000])
def test_mixit_wrapper(batch_size, factor, n_mix, time):
mixtures = torch.randn(batch_size, n_mix, time)
n_src = n_mix * factor
est_targets = torch.randn(batch_size, n_src, time)

# mix_it / With and without return estimates
loss = MixITLossWrapper(good_batch_loss_func, generalized=False)
loss(est_targets, mixtures)
loss_value, reordered_mix = loss(est_targets, mixtures, return_est=True)
assert reordered_mix.shape == mixtures.shape


@pytest.mark.parametrize("batch_size", [1, 2, 8])
@pytest.mark.parametrize("n_src", [2, 3, 4, 5])
@pytest.mark.parametrize("n_mix", [2])
@pytest.mark.parametrize("time", [16000])
def test_mixit_gen_wrapper(batch_size, n_src, n_mix, time):
mixtures = torch.randn(batch_size, n_mix, time)
est_targets = torch.randn(batch_size, n_src, time)

# mix_it_gen / With and without return estimates. Works only with two mixtures
loss = MixITLossWrapper(good_batch_loss_func)
loss(est_targets, mixtures)
loss_value, reordered_est = loss(est_targets, mixtures, return_est=True)
assert reordered_est.shape == mixtures.shape

0 comments on commit aedb22f

Please sign in to comment.