Skip to content

Commit

Permalink
[src&test] Add Sinkhorn PIT loss (asteroid-team#302)
Browse files Browse the repository at this point in the history
- Fix devices of indices in hungarian permutation solving 

Co-authored-by: mpariente <pariente.mnl@gmail.com>
  • Loading branch information
tachi-hi and mpariente authored Nov 6, 2020
1 parent a7f628f commit b0b0e0e
Show file tree
Hide file tree
Showing 6 changed files with 363 additions and 2 deletions.
4 changes: 4 additions & 0 deletions asteroid/losses/__init__.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
from .pit_wrapper import PITLossWrapper
from .sinkpit_wrapper import SinkPITLossWrapper, SinkPITBetaScheduler, sinkpit_default_beta_schedule
from .sdr import singlesrc_neg_sisdr, multisrc_neg_sisdr
from .sdr import singlesrc_neg_sdsdr, multisrc_neg_sdsdr
from .sdr import singlesrc_neg_snr, multisrc_neg_snr
Expand All @@ -18,6 +19,9 @@

__all__ = [
"PITLossWrapper",
"SinkPITLossWrapper",
"SinkPITBetaScheduler",
"sinkpit_default_beta_schedule",
"singlesrc_neg_sisdr",
"multisrc_neg_sisdr",
"singlesrc_neg_sdsdr",
Expand Down
4 changes: 3 additions & 1 deletion asteroid/losses/pit_wrapper.py
Original file line number Diff line number Diff line change
Expand Up @@ -314,7 +314,9 @@ def find_best_perm_hungarian(pair_wise_losses: torch.Tensor):
# Just bring the numbers to cpu(), not the graph
pwl_copy = pwl.detach().cpu()
# Loop over batch + row indices are always ordered for square matrices.
batch_indices = torch.tensor([linear_sum_assignment(pwl)[1] for pwl in pwl_copy])
batch_indices = torch.tensor([linear_sum_assignment(pwl)[1] for pwl in pwl_copy]).to(
pwl.device
)
min_loss = torch.gather(pwl, 2, batch_indices[..., None]).mean([-1, -2])
return min_loss, batch_indices

Expand Down
187 changes: 187 additions & 0 deletions asteroid/losses/sinkpit_wrapper.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,187 @@
import torch
from torch import nn
import pytorch_lightning as pl

from . import PITLossWrapper


class SinkPITLossWrapper(nn.Module):
r"""Permutation invariant loss wrapper.
Args:
loss_func: function with signature (targets, est_targets, **kwargs).
n_iter (int): number of the Sinkhorn iteration (default = 200).
Supposed to be an even number.
hungarian_validation (boolean) : Whether to use the Hungarian algorithm
for the validation. (default = True)
`loss_func` computes pairwise
losses and returns a torch.Tensor of shape
:math:`(batch, n\_src, n\_src)`. Each element
:math:`[batch, i, j]` corresponds to the loss between
:math:`targets[:, i]` and :math:`est\_targets[:, j]`
It evaluates an approximate value of the PIT loss
using Sinkhorn's iterative algorithm.
See :meth:`~PITLossWrapper.best_softperm_sinkhorn`
and http://arxiv.org/abs/2010.11871
Examples
>>> import torch
>>> from asteroid.losses import pairwise_neg_sisdr
>>> sources = torch.randn(10, 3, 16000)
>>> est_sources = torch.randn(10, 3, 16000)
>>> # Compute SinkPIT loss based on pairwise losses
>>> loss_func = SinkPITLossWrapper(pairwise_neg_sisdr)
>>> loss_val = loss_func(est_sources, sources)
>>> # A fixed temperature parameter `beta` (=10) is used
>>> # unless a cooling callback is set. The value can be
>>> # dynamically changed using a cooling callback module as follows.
>>> model = NeuralNetworkModel()
>>> optimizer = optim.Adam(model.parameters(), lr=1e-3)
>>> dataset = YourDataset()
>>> loader = data.DataLoader(dataset, batch_size=16)
>>> system = System(
>>> model,
>>> optimizer,
>>> loss_func=SinkPITLossWrapper(pairwise_neg_sisdr),
>>> train_loader=loader,
>>> val_loader=loader,
>>> )
>>>
>>> trainer = pl.Trainer(
>>> max_epochs=100,
>>> callbacks=[SinkPITBetaScheduler(lambda epoch : 1.02 ** epoch)],
>>> )
>>>
>>> trainer.fit(system)
"""

def __init__(self, loss_func, n_iter=200, hungarian_validation=True):
super().__init__()
self.loss_func = loss_func
self._beta = 10
self.n_iter = n_iter
self.hungarian_validation = hungarian_validation

@property
def beta(self):
return self._beta

@beta.setter
def beta(self, beta):
assert beta > 0
self._beta = beta

def forward(self, est_targets, targets, return_est=False, **kwargs):
"""Evaluate the loss using Sinkhorn's algorithm.
Args:
est_targets: torch.Tensor. Expected shape [batch, nsrc, *].
The batch of target estimates.
targets: torch.Tensor. Expected shape [batch, nsrc, *].
The batch of training targets
return_est: Boolean. Whether to return the reordered targets
estimates (To compute metrics or to save example).
**kwargs: additional keyword argument that will be passed to the
loss function.
Returns:
- Best permutation loss for each batch sample, average over
the batch. torch.Tensor(loss_value)
- The reordered targets estimates if return_est is True.
torch.Tensor of shape [batch, nsrc, *].
"""
n_src = targets.shape[1]
assert n_src < 100, f"Expected source axis along dim 1, found {n_src}"

# Evaluate the loss using Sinkhorn's iterative algorithm
pw_losses = self.loss_func(est_targets, targets, **kwargs)

assert pw_losses.ndim == 3, (
"Something went wrong with the loss " "function, please read the docs."
)
assert pw_losses.shape[0] == targets.shape[0], "PIT loss needs same batch dim as input"

if not return_est:
if self.training or not self.hungarian_validation:
# Train or sinkhorn validation
min_loss, soft_perm = self.best_softperm_sinkhorn(
pw_losses, self._beta, self.n_iter
)
mean_loss = torch.mean(min_loss)
return mean_loss
else:
# Reorder the output by using the Hungarian algorithm below
min_loss, batch_indices = PITLossWrapper.find_best_perm(pw_losses)
mean_loss = torch.mean(min_loss)
return mean_loss
else:
# Test -> reorder the output by using the Hungarian algorithm below
min_loss, batch_indices = PITLossWrapper.find_best_perm(pw_losses)
mean_loss = torch.mean(min_loss)
reordered = PITLossWrapper.reorder_source(est_targets, batch_indices)
return mean_loss, reordered

@staticmethod
def best_softperm_sinkhorn(pair_wise_losses, beta=10, n_iter=200):
"""Compute an approximate PIT loss using Sinkhorn's algorithm.
See http://arxiv.org/abs/2010.11871
Args:
pair_wise_losses (:class:`torch.Tensor`):
Tensor of shape [batch, n_src, n_src]. Pairwise losses.
beta (float) : Inverse temperature parameter. (default = 10)
n_iter (int) : Number of iteration. Even number. (default = 200)
Returns:
tuple:
:class:`torch.Tensor`: The loss corresponding to the best
permutation of size (batch,).
:class:`torch.Tensor`: A soft permutation matrix.
"""
C = pair_wise_losses.transpose(-1, -2)
n_src = C.shape[-1]
# initial values
Z = -beta * C
for it in range(n_iter // 2):
Z = Z - torch.logsumexp(Z, axis=1, keepdim=True)
Z = Z - torch.logsumexp(Z, axis=2, keepdim=True)
min_loss = torch.einsum("bij,bij->b", C + Z / beta, torch.exp(Z))
min_loss = min_loss / n_src
return min_loss, torch.exp(Z)


def sinkpit_default_beta_schedule(epoch):
return min([1.02 ** epoch, 10])


class SinkPITBetaScheduler(pl.callbacks.Callback):
r"""Scheduler of the beta value of SinkPITLossWrapper
This module is used as a Callback function of `pytorch_lightning.Trainer`.
Args:
cooling_schedule (callable) : A callable
that takes a parameter `epoch` (int)
and returns the value of `beta` (float).
The default function is `sinkpit_default_beta_schedule`.
:math: \beta = min(1.02^{epoch}, 10)
Example
>>> from pytorch_lightning import Trainer
>>> from asteroid.losses import SinkPITBetaScheduler
>>> # Default scheduling function
>>> sinkpit_beta_schedule = SinkPITBetaSchedule()
>>> trainer = Trainer(callbacks=[sinkpit_beta_schedule])
>>> # User-defined schedule
>>> sinkpit_beta_schedule = SinkPITBetaScheduler(lambda ep: 1. if ep < 10 else 100.)
>>> trainer = Trainer(callbacks=[sinkpit_beta_schedule])
"""

def __init__(self, cooling_schedule=sinkpit_default_beta_schedule):
self.cooling_schedule = cooling_schedule

def on_epoch_start(self, trainer, pl_module):
assert isinstance(pl_module.loss_func, SinkPITLossWrapper)
assert trainer.current_epoch == pl_module.current_epoch # same
epoch = pl_module.current_epoch
# step = pl_module.global_step
beta = self.cooling_schedule(epoch)
pl_module.loss_func.beta = beta
17 changes: 17 additions & 0 deletions asteroid/utils/test_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,3 +12,20 @@ def __len__(self):

def __getitem__(self, idx):
return torch.randn(1, self.inp_dim), torch.randn(1, self.out_dim)


class DummyWaveformDataset(data.Dataset):
def __init__(self, total=12, n_src=3, len_wave=16000):
self.inp_len_wave = len_wave
self.out_len_wave = len_wave
self.total = total
self.inp_n_sig = 1
self.out_n_sig = n_src

def __len__(self):
return self.total

def __getitem__(self, idx):
mixed = torch.randn(self.inp_n_sig, self.inp_len_wave)
srcs = torch.randn(self.out_n_sig, self.out_len_wave)
return mixed, srcs
3 changes: 2 additions & 1 deletion notebooks/03_PITLossWrapper.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -350,7 +350,8 @@
"Papers on PIT alternatives :\n",
"- [9] C. Fan et al. \"Utterance-level Permutation Invariant Training with Discriminative Learning for Single Channel Speech Separation,\" 2018 11th International Symposium on Chinese Spoken Language Processing (ISCSLP). \n",
"- [10] Yang, Gene-Ping et al. \"Interrupted and cascaded permutation invariant training for speech separation\" 2019. \n",
"- [11] Yousefi, Midia et al. “Probabilistic Permutation Invariant Training for Speech Separation.” Interspeech 2019 "
"- [11] Yousefi, Midia et al. \"Probabilistic Permutation Invariant Training for Speech Separation.\" Interspeech 2019. \n",
"- [12] Tachibana, H. \"Towards Listening to 10 People Simultaneously: An Efficient Permutation Invariant Training of Audio Source Separation Using Sinkhorn's Algorithm.\" arXiv 2020. "
]
}
],
Expand Down
150 changes: 150 additions & 0 deletions tests/losses/sinkpit_wrapper_test.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,150 @@
import pytest
import itertools
import torch
from torch import nn, optim
from torch.utils import data
from torch.testing import assert_allclose

import pytorch_lightning as pl
from pytorch_lightning import Trainer

from asteroid.losses import PITLossWrapper
from asteroid.losses import sdr
from asteroid.losses import singlesrc_mse, pairwise_mse, multisrc_mse
from asteroid.engine.system import System
from asteroid.utils.test_utils import DummyWaveformDataset

# target modules
from asteroid.losses import SinkPITLossWrapper, SinkPITBetaScheduler, sinkpit_default_beta_schedule


def bad_loss_func_ndim0(y_pred, y_true):
return torch.randn(1).mean()


def bad_loss_func_ndim1(y_pred, y_true):
return torch.randn(1)


def good_batch_loss_func(y_pred, y_true):
batch, *_ = y_true.shape
return torch.randn(batch)


def good_pairwise_loss_func(y_pred, y_true):
batch, n_src, *_ = y_true.shape
return torch.randn(batch, n_src, n_src)


@pytest.mark.parametrize("batch_size", [1, 2, 8])
@pytest.mark.parametrize("n_src", [2, 5, 10])
@pytest.mark.parametrize("time", [16000, 1221])
def test_wrapper(batch_size, n_src, time):
targets = torch.randn(batch_size, n_src, time)
est_targets = torch.randn(batch_size, n_src, time)
for bad_loss_func in [bad_loss_func_ndim0, bad_loss_func_ndim1]:
loss = SinkPITLossWrapper(bad_loss_func)
with pytest.raises(AssertionError):
loss(est_targets, targets)

loss = SinkPITLossWrapper(good_pairwise_loss_func)
loss(est_targets, targets)
loss_value, reordered_est = loss(est_targets, targets, return_est=True)
assert reordered_est.shape == est_targets.shape


@pytest.mark.parametrize("batch_size", [1, 2])
@pytest.mark.parametrize("n_src", [2, 3, 4])
@pytest.mark.parametrize("beta,n_iter", [(100.0, 2000)])
@pytest.mark.parametrize(
"function_triplet",
[
[sdr.pairwise_neg_sisdr, sdr.singlesrc_neg_sisdr, sdr.multisrc_neg_sisdr],
[sdr.pairwise_neg_sdsdr, sdr.singlesrc_neg_sdsdr, sdr.multisrc_neg_sdsdr],
[sdr.pairwise_neg_snr, sdr.singlesrc_neg_snr, sdr.multisrc_neg_snr],
[pairwise_mse, singlesrc_mse, multisrc_mse],
],
)
def test_proximity_sinkhorn_hungarian(batch_size, n_src, beta, n_iter, function_triplet):
time = 16000
noise_level = 0.1
pairwise, nosrc, nonpit = function_triplet

# random data
targets = torch.randn(batch_size, n_src, time) * 10 # ground truth
noise = torch.randn(batch_size, n_src, time) * noise_level
est_targets = (
targets[:, torch.randperm(n_src), :] + noise
) # reorder channels, and add small noise

# initialize wrappers
loss_sinkhorn = SinkPITLossWrapper(pairwise, n_iter=n_iter)
loss_hungarian = PITLossWrapper(pairwise, pit_from="pw_mtx")

# compute loss by sinkhorn
loss_sinkhorn.beta = beta
mean_loss_sinkhorn = loss_sinkhorn(est_targets, targets, return_est=False)

# compute loss by hungarian
mean_loss_hungarian = loss_hungarian(est_targets, targets, return_est=False)

# compare
assert_allclose(mean_loss_sinkhorn, mean_loss_hungarian)


class _TestCallback(pl.callbacks.Callback):
def __init__(self, function, total, batch_size):
self.f = function
self.epoch = 0
self.n_batch = total // batch_size

def on_batch_end(self, trainer, pl_module):
step = trainer.global_step
assert self.epoch * self.n_batch <= step
assert step <= (self.epoch + 1) * self.n_batch

def on_epoch_end(self, trainer, pl_module):
epoch = trainer.current_epoch
assert epoch == self.epoch
assert pl_module.loss_func.beta == self.f(epoch)
self.epoch += 1


@pytest.mark.parametrize("batch_size", [2])
@pytest.mark.parametrize("n_src", [2, 10])
@pytest.mark.parametrize("len_wave", [100])
@pytest.mark.parametrize(
"beta_schedule",
[
sinkpit_default_beta_schedule, # default
lambda epoch: 123.0 if epoch < 3 else 456.0, # test if lambda function works
],
)
def test_sinkpit_beta_scheduler(batch_size, n_src, len_wave, beta_schedule):
model = nn.Sequential(nn.Conv1d(1, n_src, 1), nn.ReLU())
optimizer = optim.Adam(model.parameters(), lr=1e-3)
dataset = DummyWaveformDataset(total=2 * batch_size, n_src=n_src, len_wave=len_wave)
loader = data.DataLoader(
dataset, batch_size=batch_size, num_workers=0
) # num_workers=0 means doing everything in the main process without calling subprocesses

system = System(
model,
optimizer,
loss_func=SinkPITLossWrapper(sdr.pairwise_neg_sisdr, n_iter=5),
train_loader=loader,
val_loader=loader,
)

trainer = pl.Trainer(
max_epochs=10,
fast_dev_run=False,
callbacks=[
SinkPITBetaScheduler(beta_schedule),
_TestCallback(
beta_schedule, len(dataset), batch_size
), # test if beta are the same at epoch_start and epoch_end.
],
)

trainer.fit(system)

0 comments on commit b0b0e0e

Please sign in to comment.