Skip to content

Commit

Permalink
[egs] Fixed bug in X-UMX (asteroid-team#521)
Browse files Browse the repository at this point in the history
* [Fix] Update "egs/musdb18/X-UMX/requirements.txt"

* [Fix] Bug of X-UMX and README.md
  • Loading branch information
r-sawata authored Jun 25, 2021
1 parent 86b6c7b commit b49069f
Show file tree
Hide file tree
Showing 4 changed files with 34 additions and 10 deletions.
1 change: 1 addition & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -111,6 +111,7 @@ More information in [egs/README.md](./egs).
* [x] [DPTNet](./asteroid/models/dptnet.py) ([Chen et al.](https://arxiv.org/abs/2007.13975))
* [x] [DCCRNet](./asteroid/models/dccrnet.py) ([Hu et al.](https://arxiv.org/abs/2008.00264))
* [x] [DCUNet](./asteroid/models/dcunet.py) ([Choi et al.](https://arxiv.org/abs/1903.03107))
* [x] [CrossNet-Open-Unmix](./asteroid/models/x_umx.py) ([Sawata et al.](https://arxiv.org/abs/2010.04228))
* [ ] Open-Unmix (coming) ([Stöter et al.](https://sigsep.github.io/open-unmix/))
* [ ] Wavesplit (coming) ([Zeghidour et al.](https://arxiv.org/abs/2002.08933))

Expand Down
2 changes: 2 additions & 0 deletions egs/musdb18/X-UMX/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,8 @@

This recipe contains __CrossNet-Open-Unmix (X-UMX)__, an improved version of [Open-Unmix (UMX)](https://github.com/sigsep/open-unmix-nnabla) for music source separation. X-UMX achieves an improved performance without additional learnable parameters compared to the original UMX model. Details of X-UMX can be found in [this paper](https://arxiv.org/abs/2010.04228). X-UMX is one of the two official baseline models for the [Music Demixing (MDX) Challenge 2021](https://www.aicrowd.com/challenges/music-demixing-challenge-ismir-2021).

__Related Projects:__ [umx-pytorch](https://github.com/sigsep/open-unmix-pytorch) | [umx-nnabla](https://github.com/sigsep/open-unmix-nnabla) | x-umx-pytorch | [x-umx-nnabla](https://github.com/sony/ai-research-code/tree/master/x-umx) | [musdb](https://github.com/sigsep/sigsep-mus-db) | [museval](https://github.com/sigsep/sigsep-mus-eval)

### Source separation with pretrained model
Pretrained models on MUSDB18 for X-UMX, which reproduce the results from our paper, are available and can be easily tried out:
```
Expand Down
8 changes: 5 additions & 3 deletions egs/musdb18/X-UMX/local/dataloader.py
Original file line number Diff line number Diff line change
Expand Up @@ -69,9 +69,11 @@ def filtering_out_valid(input_dataset):
Return:
input_dataset (w/o validation tracks)
"""
for i, tmp in enumerate(input_dataset.tracks):
if str(tmp["path"]).split("/")[-1] in validation_tracks:
del input_dataset.tracks[i]
input_dataset.tracks = [
tmp
for tmp in input_dataset.tracks
if not (str(tmp["path"]).split("/")[-1] in validation_tracks)
]

return input_dataset

Expand Down
33 changes: 26 additions & 7 deletions egs/musdb18/X-UMX/train.py
Original file line number Diff line number Diff line change
Expand Up @@ -234,15 +234,25 @@ class MultiDomainLoss(_Loss):
https://arxiv.org/abs/2010.04228 (and ICASSP 2021)
"""

def __init__(self, args):
def __init__(
self,
window_length,
in_chan,
n_hop,
spec_power,
nb_channels,
loss_combine_sources,
loss_use_multidomain,
mix_coef,
):
super().__init__()
self.transform = nn.Sequential(
_STFT(window_length=args.window_length, n_fft=args.in_chan, n_hop=args.nhop),
_Spectrogram(spec_power=args.spec_power, mono=(args.nb_channels == 1)),
_STFT(window_length=window_length, n_fft=in_chan, n_hop=n_hop),
_Spectrogram(spec_power=spec_power, mono=(nb_channels == 1)),
)
self._combi = args.loss_combine_sources
self._multi = args.loss_use_multidomain
self.coef = args.mix_coef
self._combi = loss_combine_sources
self._multi = loss_use_multidomain
self.coef = mix_coef
print("Combination Loss: {}".format(self._combi))
if self._multi:
print(
Expand Down Expand Up @@ -413,7 +423,16 @@ def main(conf, args):
es = EarlyStopping(monitor="val_loss", mode="min", patience=args.patience, verbose=True)

# Define Loss function.
loss_func = MultiDomainLoss(args)
loss_func = MultiDomainLoss(
window_length=args.window_length,
in_chan=args.in_chan,
n_hop=args.nhop,
spec_power=args.spec_power,
nb_channels=args.nb_channels,
loss_combine_sources=args.loss_combine_sources,
loss_use_multidomain=args.loss_use_multidomain,
mix_coef=args.mix_coef,
)
system = XUMXManager(
model=x_unmix,
loss_func=loss_func,
Expand Down

0 comments on commit b49069f

Please sign in to comment.