forked from asteroid-team/asteroid
-
Notifications
You must be signed in to change notification settings - Fork 0
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
[src&test] Add Sinkhorn PIT loss (asteroid-team#302)
- Fix devices of indices in hungarian permutation solving Co-authored-by: mpariente <pariente.mnl@gmail.com>
- Loading branch information
Showing
6 changed files
with
363 additions
and
2 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,187 @@ | ||
import torch | ||
from torch import nn | ||
import pytorch_lightning as pl | ||
|
||
from . import PITLossWrapper | ||
|
||
|
||
class SinkPITLossWrapper(nn.Module): | ||
r"""Permutation invariant loss wrapper. | ||
Args: | ||
loss_func: function with signature (targets, est_targets, **kwargs). | ||
n_iter (int): number of the Sinkhorn iteration (default = 200). | ||
Supposed to be an even number. | ||
hungarian_validation (boolean) : Whether to use the Hungarian algorithm | ||
for the validation. (default = True) | ||
`loss_func` computes pairwise | ||
losses and returns a torch.Tensor of shape | ||
:math:`(batch, n\_src, n\_src)`. Each element | ||
:math:`[batch, i, j]` corresponds to the loss between | ||
:math:`targets[:, i]` and :math:`est\_targets[:, j]` | ||
It evaluates an approximate value of the PIT loss | ||
using Sinkhorn's iterative algorithm. | ||
See :meth:`~PITLossWrapper.best_softperm_sinkhorn` | ||
and http://arxiv.org/abs/2010.11871 | ||
Examples | ||
>>> import torch | ||
>>> from asteroid.losses import pairwise_neg_sisdr | ||
>>> sources = torch.randn(10, 3, 16000) | ||
>>> est_sources = torch.randn(10, 3, 16000) | ||
>>> # Compute SinkPIT loss based on pairwise losses | ||
>>> loss_func = SinkPITLossWrapper(pairwise_neg_sisdr) | ||
>>> loss_val = loss_func(est_sources, sources) | ||
>>> # A fixed temperature parameter `beta` (=10) is used | ||
>>> # unless a cooling callback is set. The value can be | ||
>>> # dynamically changed using a cooling callback module as follows. | ||
>>> model = NeuralNetworkModel() | ||
>>> optimizer = optim.Adam(model.parameters(), lr=1e-3) | ||
>>> dataset = YourDataset() | ||
>>> loader = data.DataLoader(dataset, batch_size=16) | ||
>>> system = System( | ||
>>> model, | ||
>>> optimizer, | ||
>>> loss_func=SinkPITLossWrapper(pairwise_neg_sisdr), | ||
>>> train_loader=loader, | ||
>>> val_loader=loader, | ||
>>> ) | ||
>>> | ||
>>> trainer = pl.Trainer( | ||
>>> max_epochs=100, | ||
>>> callbacks=[SinkPITBetaScheduler(lambda epoch : 1.02 ** epoch)], | ||
>>> ) | ||
>>> | ||
>>> trainer.fit(system) | ||
""" | ||
|
||
def __init__(self, loss_func, n_iter=200, hungarian_validation=True): | ||
super().__init__() | ||
self.loss_func = loss_func | ||
self._beta = 10 | ||
self.n_iter = n_iter | ||
self.hungarian_validation = hungarian_validation | ||
|
||
@property | ||
def beta(self): | ||
return self._beta | ||
|
||
@beta.setter | ||
def beta(self, beta): | ||
assert beta > 0 | ||
self._beta = beta | ||
|
||
def forward(self, est_targets, targets, return_est=False, **kwargs): | ||
"""Evaluate the loss using Sinkhorn's algorithm. | ||
Args: | ||
est_targets: torch.Tensor. Expected shape [batch, nsrc, *]. | ||
The batch of target estimates. | ||
targets: torch.Tensor. Expected shape [batch, nsrc, *]. | ||
The batch of training targets | ||
return_est: Boolean. Whether to return the reordered targets | ||
estimates (To compute metrics or to save example). | ||
**kwargs: additional keyword argument that will be passed to the | ||
loss function. | ||
Returns: | ||
- Best permutation loss for each batch sample, average over | ||
the batch. torch.Tensor(loss_value) | ||
- The reordered targets estimates if return_est is True. | ||
torch.Tensor of shape [batch, nsrc, *]. | ||
""" | ||
n_src = targets.shape[1] | ||
assert n_src < 100, f"Expected source axis along dim 1, found {n_src}" | ||
|
||
# Evaluate the loss using Sinkhorn's iterative algorithm | ||
pw_losses = self.loss_func(est_targets, targets, **kwargs) | ||
|
||
assert pw_losses.ndim == 3, ( | ||
"Something went wrong with the loss " "function, please read the docs." | ||
) | ||
assert pw_losses.shape[0] == targets.shape[0], "PIT loss needs same batch dim as input" | ||
|
||
if not return_est: | ||
if self.training or not self.hungarian_validation: | ||
# Train or sinkhorn validation | ||
min_loss, soft_perm = self.best_softperm_sinkhorn( | ||
pw_losses, self._beta, self.n_iter | ||
) | ||
mean_loss = torch.mean(min_loss) | ||
return mean_loss | ||
else: | ||
# Reorder the output by using the Hungarian algorithm below | ||
min_loss, batch_indices = PITLossWrapper.find_best_perm(pw_losses) | ||
mean_loss = torch.mean(min_loss) | ||
return mean_loss | ||
else: | ||
# Test -> reorder the output by using the Hungarian algorithm below | ||
min_loss, batch_indices = PITLossWrapper.find_best_perm(pw_losses) | ||
mean_loss = torch.mean(min_loss) | ||
reordered = PITLossWrapper.reorder_source(est_targets, batch_indices) | ||
return mean_loss, reordered | ||
|
||
@staticmethod | ||
def best_softperm_sinkhorn(pair_wise_losses, beta=10, n_iter=200): | ||
"""Compute an approximate PIT loss using Sinkhorn's algorithm. | ||
See http://arxiv.org/abs/2010.11871 | ||
Args: | ||
pair_wise_losses (:class:`torch.Tensor`): | ||
Tensor of shape [batch, n_src, n_src]. Pairwise losses. | ||
beta (float) : Inverse temperature parameter. (default = 10) | ||
n_iter (int) : Number of iteration. Even number. (default = 200) | ||
Returns: | ||
tuple: | ||
:class:`torch.Tensor`: The loss corresponding to the best | ||
permutation of size (batch,). | ||
:class:`torch.Tensor`: A soft permutation matrix. | ||
""" | ||
C = pair_wise_losses.transpose(-1, -2) | ||
n_src = C.shape[-1] | ||
# initial values | ||
Z = -beta * C | ||
for it in range(n_iter // 2): | ||
Z = Z - torch.logsumexp(Z, axis=1, keepdim=True) | ||
Z = Z - torch.logsumexp(Z, axis=2, keepdim=True) | ||
min_loss = torch.einsum("bij,bij->b", C + Z / beta, torch.exp(Z)) | ||
min_loss = min_loss / n_src | ||
return min_loss, torch.exp(Z) | ||
|
||
|
||
def sinkpit_default_beta_schedule(epoch): | ||
return min([1.02 ** epoch, 10]) | ||
|
||
|
||
class SinkPITBetaScheduler(pl.callbacks.Callback): | ||
r"""Scheduler of the beta value of SinkPITLossWrapper | ||
This module is used as a Callback function of `pytorch_lightning.Trainer`. | ||
Args: | ||
cooling_schedule (callable) : A callable | ||
that takes a parameter `epoch` (int) | ||
and returns the value of `beta` (float). | ||
The default function is `sinkpit_default_beta_schedule`. | ||
:math: \beta = min(1.02^{epoch}, 10) | ||
Example | ||
>>> from pytorch_lightning import Trainer | ||
>>> from asteroid.losses import SinkPITBetaScheduler | ||
>>> # Default scheduling function | ||
>>> sinkpit_beta_schedule = SinkPITBetaSchedule() | ||
>>> trainer = Trainer(callbacks=[sinkpit_beta_schedule]) | ||
>>> # User-defined schedule | ||
>>> sinkpit_beta_schedule = SinkPITBetaScheduler(lambda ep: 1. if ep < 10 else 100.) | ||
>>> trainer = Trainer(callbacks=[sinkpit_beta_schedule]) | ||
""" | ||
|
||
def __init__(self, cooling_schedule=sinkpit_default_beta_schedule): | ||
self.cooling_schedule = cooling_schedule | ||
|
||
def on_epoch_start(self, trainer, pl_module): | ||
assert isinstance(pl_module.loss_func, SinkPITLossWrapper) | ||
assert trainer.current_epoch == pl_module.current_epoch # same | ||
epoch = pl_module.current_epoch | ||
# step = pl_module.global_step | ||
beta = self.cooling_schedule(epoch) | ||
pl_module.loss_func.beta = beta |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,150 @@ | ||
import pytest | ||
import itertools | ||
import torch | ||
from torch import nn, optim | ||
from torch.utils import data | ||
from torch.testing import assert_allclose | ||
|
||
import pytorch_lightning as pl | ||
from pytorch_lightning import Trainer | ||
|
||
from asteroid.losses import PITLossWrapper | ||
from asteroid.losses import sdr | ||
from asteroid.losses import singlesrc_mse, pairwise_mse, multisrc_mse | ||
from asteroid.engine.system import System | ||
from asteroid.utils.test_utils import DummyWaveformDataset | ||
|
||
# target modules | ||
from asteroid.losses import SinkPITLossWrapper, SinkPITBetaScheduler, sinkpit_default_beta_schedule | ||
|
||
|
||
def bad_loss_func_ndim0(y_pred, y_true): | ||
return torch.randn(1).mean() | ||
|
||
|
||
def bad_loss_func_ndim1(y_pred, y_true): | ||
return torch.randn(1) | ||
|
||
|
||
def good_batch_loss_func(y_pred, y_true): | ||
batch, *_ = y_true.shape | ||
return torch.randn(batch) | ||
|
||
|
||
def good_pairwise_loss_func(y_pred, y_true): | ||
batch, n_src, *_ = y_true.shape | ||
return torch.randn(batch, n_src, n_src) | ||
|
||
|
||
@pytest.mark.parametrize("batch_size", [1, 2, 8]) | ||
@pytest.mark.parametrize("n_src", [2, 5, 10]) | ||
@pytest.mark.parametrize("time", [16000, 1221]) | ||
def test_wrapper(batch_size, n_src, time): | ||
targets = torch.randn(batch_size, n_src, time) | ||
est_targets = torch.randn(batch_size, n_src, time) | ||
for bad_loss_func in [bad_loss_func_ndim0, bad_loss_func_ndim1]: | ||
loss = SinkPITLossWrapper(bad_loss_func) | ||
with pytest.raises(AssertionError): | ||
loss(est_targets, targets) | ||
|
||
loss = SinkPITLossWrapper(good_pairwise_loss_func) | ||
loss(est_targets, targets) | ||
loss_value, reordered_est = loss(est_targets, targets, return_est=True) | ||
assert reordered_est.shape == est_targets.shape | ||
|
||
|
||
@pytest.mark.parametrize("batch_size", [1, 2]) | ||
@pytest.mark.parametrize("n_src", [2, 3, 4]) | ||
@pytest.mark.parametrize("beta,n_iter", [(100.0, 2000)]) | ||
@pytest.mark.parametrize( | ||
"function_triplet", | ||
[ | ||
[sdr.pairwise_neg_sisdr, sdr.singlesrc_neg_sisdr, sdr.multisrc_neg_sisdr], | ||
[sdr.pairwise_neg_sdsdr, sdr.singlesrc_neg_sdsdr, sdr.multisrc_neg_sdsdr], | ||
[sdr.pairwise_neg_snr, sdr.singlesrc_neg_snr, sdr.multisrc_neg_snr], | ||
[pairwise_mse, singlesrc_mse, multisrc_mse], | ||
], | ||
) | ||
def test_proximity_sinkhorn_hungarian(batch_size, n_src, beta, n_iter, function_triplet): | ||
time = 16000 | ||
noise_level = 0.1 | ||
pairwise, nosrc, nonpit = function_triplet | ||
|
||
# random data | ||
targets = torch.randn(batch_size, n_src, time) * 10 # ground truth | ||
noise = torch.randn(batch_size, n_src, time) * noise_level | ||
est_targets = ( | ||
targets[:, torch.randperm(n_src), :] + noise | ||
) # reorder channels, and add small noise | ||
|
||
# initialize wrappers | ||
loss_sinkhorn = SinkPITLossWrapper(pairwise, n_iter=n_iter) | ||
loss_hungarian = PITLossWrapper(pairwise, pit_from="pw_mtx") | ||
|
||
# compute loss by sinkhorn | ||
loss_sinkhorn.beta = beta | ||
mean_loss_sinkhorn = loss_sinkhorn(est_targets, targets, return_est=False) | ||
|
||
# compute loss by hungarian | ||
mean_loss_hungarian = loss_hungarian(est_targets, targets, return_est=False) | ||
|
||
# compare | ||
assert_allclose(mean_loss_sinkhorn, mean_loss_hungarian) | ||
|
||
|
||
class _TestCallback(pl.callbacks.Callback): | ||
def __init__(self, function, total, batch_size): | ||
self.f = function | ||
self.epoch = 0 | ||
self.n_batch = total // batch_size | ||
|
||
def on_batch_end(self, trainer, pl_module): | ||
step = trainer.global_step | ||
assert self.epoch * self.n_batch <= step | ||
assert step <= (self.epoch + 1) * self.n_batch | ||
|
||
def on_epoch_end(self, trainer, pl_module): | ||
epoch = trainer.current_epoch | ||
assert epoch == self.epoch | ||
assert pl_module.loss_func.beta == self.f(epoch) | ||
self.epoch += 1 | ||
|
||
|
||
@pytest.mark.parametrize("batch_size", [2]) | ||
@pytest.mark.parametrize("n_src", [2, 10]) | ||
@pytest.mark.parametrize("len_wave", [100]) | ||
@pytest.mark.parametrize( | ||
"beta_schedule", | ||
[ | ||
sinkpit_default_beta_schedule, # default | ||
lambda epoch: 123.0 if epoch < 3 else 456.0, # test if lambda function works | ||
], | ||
) | ||
def test_sinkpit_beta_scheduler(batch_size, n_src, len_wave, beta_schedule): | ||
model = nn.Sequential(nn.Conv1d(1, n_src, 1), nn.ReLU()) | ||
optimizer = optim.Adam(model.parameters(), lr=1e-3) | ||
dataset = DummyWaveformDataset(total=2 * batch_size, n_src=n_src, len_wave=len_wave) | ||
loader = data.DataLoader( | ||
dataset, batch_size=batch_size, num_workers=0 | ||
) # num_workers=0 means doing everything in the main process without calling subprocesses | ||
|
||
system = System( | ||
model, | ||
optimizer, | ||
loss_func=SinkPITLossWrapper(sdr.pairwise_neg_sisdr, n_iter=5), | ||
train_loader=loader, | ||
val_loader=loader, | ||
) | ||
|
||
trainer = pl.Trainer( | ||
max_epochs=10, | ||
fast_dev_run=False, | ||
callbacks=[ | ||
SinkPITBetaScheduler(beta_schedule), | ||
_TestCallback( | ||
beta_schedule, len(dataset), batch_size | ||
), # test if beta are the same at epoch_start and epoch_end. | ||
], | ||
) | ||
|
||
trainer.fit(system) |