Skip to content

Commit

Permalink
Merge pull request #35 from sony/feature/20210816-d3net-mss-training-…
Browse files Browse the repository at this point in the history
…code

Add D3Net-MSS training code
  • Loading branch information
TakuyaNarihira authored Oct 11, 2021
2 parents edea6a6 + d576115 commit ab5e896
Show file tree
Hide file tree
Showing 36 changed files with 2,440 additions and 1,253 deletions.
2 changes: 1 addition & 1 deletion d3net/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
This is the official NNabla implementation of D3Net, densely connected multidilated convolutional networks for dense prediction tasks ([arXiv](https://arxiv.org/abs/2011.11844),)
which is accepted at CVPR 2021.

D3Net is demonstrated on the Music Source Separation and Semantic Segmentation tasks. Please follow the links below to explore inference code and pre-trained models:
D3Net is demonstrated on the Music Source Separation and Semantic Segmentation tasks. Please follow the links below to explore inference code and training code:
* [Music Source Separation](./music-source-separation)
* [Semantic Segmentation](./semantic-segmentation)

Expand Down
106 changes: 106 additions & 0 deletions d3net/d3net_basic_blocks.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,106 @@
# Copyright 2021 Sony Group Corporation.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import nnabla as nn
import nnabla.functions as F
import nnabla.parametric_functions as PF
import nnabla.initializer as I

'''
D3net Basic Blocks definition.
'''


class BasicLayers(object):
'''
Define 2D-Convolution Layer abd BatchNormalization Layer
'''

def __init__(self, comm=None, test=False, recompute=False, init_method=None):
self.comm = comm
self.test = test
self.recompute = recompute
self.init_method = init_method

def conv2d(self, conv_input, out_channels, kernel_size, stride, bias=True, name='', dilation=1, pad=0):
'''
Define 2D-Convolution Layer
'''
if self.init_method == 'xavier':
sigma = I.calc_normal_std_glorot(
conv_input.shape[1], out_channels, kernel=(kernel_size, kernel_size))
w_init = I.NormalInitializer(sigma)
elif self.init_method == 'normal':
w_init = I.NormalInitializer(sigma=0.01)
else:
w_init = None
conv_out = PF.convolution(conv_input, out_channels, kernel=(kernel_size, kernel_size), stride=(
stride, stride), with_bias=bias, dilation=(dilation, dilation), pad=(pad, pad), name=name, w_init=w_init)
conv_out.apply(recompute=self.recompute)
return conv_out

def batch_norm(self, inp, name):
'''
Define BatchNormalization Layer
'''
if self.comm is not None:
return PF.sync_batch_normalization(inp, comm=self.comm, group='world', batch_stat=not self.test, name=name)
return PF.batch_normalization(inp, batch_stat=not self.test, name=name)


class D3NetBase(BasicLayers):
'''
A base class of D3Net.
'''

def __init__(self, comm=None, test=False, recompute=False, init_method=None):
super(D3NetBase, self).__init__(comm=comm, test=test,
recompute=recompute, init_method=init_method)

def bn_conv_block(self, inp, growth_rate, name, kernel_size=3, dilation=1, pad=1, stride=1):
'''
Define Simple Batch-Normalization and Convolution Block
'''
with nn.parameter_scope(name):
# Conv 3x3
out = self.batch_norm(inp, name='norm1')
out = F.relu(out, inplace=True)
out = self.conv2d(out, growth_rate, kernel_size=kernel_size,
stride=stride, name='conv1', dilation=dilation, pad=pad)
return out

def dilated_dense_block(self, inp, growth_rate, num_layers, name, kernel_size=3, out_block=1):
'''
Dilated Dense Block
'''
out = inp
if num_layers > 1:
lst = []
for i in range(num_layers):
# Split Variable(h) and append them in lst.
lst.append(inp[:, i*growth_rate:(i+1)*growth_rate])

def update(inp_, n):
for j in range(num_layers-n-1):
lst[j+1+n] += inp_[:, j*growth_rate:(j+1)*growth_rate]

for i in range(num_layers-1):
d = int(2**(i+1))
with nn.parameter_scope('layers/layer%s' % (i+1)):
update(self.bn_conv_block(
lst[i], growth_rate*(num_layers-i-1), name, dilation=d, kernel_size=kernel_size, pad=d), i)

# concatenate the splitted and updated Variables from the lst
out = F.concatenate(*lst, axis=1)
return out[:, -out_block*growth_rate:]
5 changes: 3 additions & 2 deletions d3net/music-source-separation/D3Net-MSS.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@
"metadata": {},
"source": [
"\n",
"Next, clone the code from sony/ai-research-code repository, and then download the [pre-trained weights](https://nnabla.org/pretrained-models/ai-research-code/d3net/mss/d3net-mss.h5) to test on music files(only wav format)."
"Next, clone the code from sony/ai-research-code repository, and then download and extract the [D3Net-MSS model weights](https://nnabla.org/pretrained-models/ai-research-code/d3net/mss/d3net-mss.zip) to test on music files(Any audio format files supported by FFMPEG)."
]
},
{
Expand All @@ -35,7 +35,8 @@
"source": [
"!git clone https://github.com/sony/ai-research-code.git\n",
"%cd ai-research-code/d3net/music-source-separation\n",
"!wget https://nnabla.org/pretrained-models/ai-research-code/d3net/mss/d3net-mss.h5"
"!wget https://nnabla.org/pretrained-models/ai-research-code/d3net/mss/d3net-mss.zip",
"!unzip d3net-mss.zip -d d3net-mss"
]
},
{
Expand Down
121 changes: 115 additions & 6 deletions d3net/music-source-separation/README.md
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
# D3Net (Music Source Separation)

This is inference code for D3Net based music source separation.
This is the official NNabla implementation of D3Net based music source separation.

## Quick Music Source Separation Demo by D3Net

Expand All @@ -22,11 +22,14 @@ If you want to do test with openvino, you should also install openvino.

## Inference: Music source separation with pretrained model

Download the pre-trained D3Net model for Music Source Separation [here](https://nnabla.org/pretrained-models/ai-research-code/d3net/mss/d3net-mss.h5).
Download and extract the pre-trained weights [here](https://nnabla.org/pretrained-models/ai-research-code/d3net/mss/d3net-mss.zip).
```bash
unzip d3net-mss.zip -d d3net-mss
```

Run the below inference command for a sample audio file `test.wav` in current directory:
```python
python ./separate.py -i ./test.wav -o output/ -m d3net-mss.h5 -c cudnn
```bash
python ./separate.py -i ./test.wav -o output/ -m ./d3net-mss -c cudnn
```
Arguments:
-i : Input files. (Any audio format files supported by FFMPEG.)
Expand All @@ -41,7 +44,7 @@ Download and extract the openvino weights [here](https://nnabla.org/pretrained-m
```bash
unzip d3net-openvino.zip -d openvino_models
```
```python
```bash
python ./separate_with_openvino.py -i ./test.wav -o output/ -m ./openvino_models -n 4
```
Arguments:
Expand All @@ -50,4 +53,110 @@ Arguments:
-m : Openvino models directory.
-n : Specifies the number of threads that openvino should use for inference.

## Training: Train the music source separation model from scratch (**coming soon**)
### Evaluation using `museval`

To perform evaluation in comparison to other SISEC systems, you would need to install the `museval` package using

```
pip install museval
```

and then run the below command for the evaluation:

```bash
python eval.py -m ./d3net-mss --root [Path of MUSDB18 dataset] --out-dir [Path to save musdb estimates and museval results]
```

#### Scores (Median of frames, Median of tracks)

|target|SDR | SDR |
|------|-----|-----|
|`model`|Paper|NNabla|
|vocals|7.24 |7.14 |
|drums |7.01 |6.85 |
|bass |5.25 |5.32 |
|other |4.53 |4.82 |
|**Avg** |**6.01** |**6.03** |

Published pre-trained models yield SDR values as mentioned in the above table. The difference in the metrics between pre-trained models and the paper are due to randomness in the data preprocessing and weight initialization.

## Training: Train the music source separation model from scratch

D3Net model for Music Source Separation can be trained using the default parameters of the `train.py` function.

[MUSDB18](https://sigsep.github.io/datasets/musdb.html#musdb18-compressed-stems) and [MUSDB18-HQ](https://sigsep.github.io/datasets/musdb.html#musdb18-hq-uncompressed-wav) are the largest freely available datasets for professionally produced music tracks (~10h duration) of different styles. They come with isolated `drums`, `bass`, `vocals` and `others` stems. _MUSDB18_ contains two subsets: "train", composed of 100 songs, and "test", composed of 50 songs.

To directly train a vocal model with _d3net-mss_, we would first need to download one of the datasets and place it in _unzipped_ in a directory of your choice (called `root`).

| Argument | Description | Default |
|----------|-------------|---------|
| `--root <str>` | path to root of dataset on disk. | `None` |

Also note that, if `--root` is not specified, we automatically download a 7 second preview version of the MUSDB18 dataset. While this is comfortable for testing purposes, we wouldn't recommend to actually train your model on this.

#### Using WAV files

All files from the MUSDB18 dataset are encoded in the Native Instruments stems format (.mp4). If you want to use WAV files (e.g. for faster audio decoding), `musdb` also supports parsing and processing pre-decoded PCM/wav files. Downloaded STEMS dataset (.mp4) can be decoded into WAV version either by [docker based solution or running scripts manually as shown here](https://github.com/sigsep/sigsep-mus-io).

__When you use the decoded MUSDB18 dataset (WAV version), use the `--is-wav` argument while running train.py.__

### Single GPU training

#### For encoded MUSDB18 STEMS version
```bash
python train.py --root [Path of MUSDB18 dataset] --target [target track to be trained] --output [Path to save weights]
```

#### For decoded MUSDB18 WAV version
```bash
python train.py --root [Path of MUSDB18 dataset] --target [target track to be trained] --output [Path to save weights] --is-wav
```

### Distributed Training
For distributed training [install NNabla package compatible with Multi-GPU execution](https://nnabla.readthedocs.io/en/latest/python/pip_installation_cuda.html#pip-installation-distributed). Use the below code to start the distributed training.
```
export CUDA_VISIBLE_DEVICES=0,1,2,3 {device ids that you want to use}
```

#### For encoded MUSDB18 STEMS version
```bash
mpirun -n {no. of devices} python train.py --root [Path of MUSDB18 dataset] --target [target track to be trained] --output [Path to save weights]
```

#### For decoded MUSDB18 WAV version
```bash
mpirun -n {no. of devices} python train.py --root [Path of MUSDB18 dataset] --target [target track to be trained] --output [Path to save weights] --is-wav
```

Please note that above sample training scripts will work on high quality 'STEM' or low quality 'MP4 files'. In case you would like faster data loading, kindly look at [more details here](https://github.com/sigsep/sigsep-mus-db#using-wav-files-optional) to generate decoded 'WAV' files. In that case, please use `--is-wav` flag for training.

Training `MUSDB18` using _d3net-mss_ comes with several design decisions that we made as part of our defaults to improve efficiency and performance:

* __chunking__: we do not feed full audio tracks into _d3net-mss_ but instead chunk the audio into 6s excerpts (`--seq-dur 6.0`).
* __balanced track sampling__: to not create a bias for longer audio tracks we randomly yield one track from MUSDB18 and select a random chunk subsequently. In one epoch we select (on average) 64 samples from each track.
* __source augmentation__: we apply random gains between `0.75` and `1.25` to all sources before mixing. Furthermore, we randomly swap the channels the input mixture.
* __random track mixing__: for a given target we select a _random track_ with replacement. To yield a mixture we draw the interfering sources from different tracks (again with replacement) to increase generalization of the model.

Some of the parameters for the MUSDB sampling can be controlled using the following arguments:

| Argument | Description | Default |
|---------------------|-----------------------------------------------|--------------|
| `--is-wav` | loads the decoded WAVs instead of STEMS for faster data loading. See [more details here](https://github.com/sigsep/sigsep-mus-db#using-wav-files-optional). | `True` |
| `--samples-per-track <int>` | sets the number of samples that are randomly drawn from each track | `64` |
| `--source-augmentations <list[str]>` | applies augmentations to each audio source before mixing | `gain channelswap` |

## Training and Model Parameters

An extensive list of additional training parameters allows researchers to quickly try out different parameterizations such as a different FFT size. The table below, we list the additional training parameters and their default values :

| Argument | Description | Default |
|----------------------------|---------------------------------------------------------------------------------|-----------------|
| `--target <str>` | name of target source (will be passed to the dataset) | `vocals` |
| `--output <str>` | path where to save the trained output model as well as checkpoints. | `./d3net-mss` |
| `--epochs <int>` | Number of epochs to train | `50` |
| `--batch-size <int>` | Batch size has influence on memory usage and performance of the LSTM layer | `6` |
| `--seq-dur <int>` | Sequence duration in seconds of chunks taken from the dataset. A value of `<=0.0` results in full/variable length | `6.0` |
| `--lr <float>` | learning rate | `0.001` |
| `--bandwidth <int>` | maximum bandwidth in Hertz processed by the LSTM. Input and Output is always full bandwidth! | `16000` |
| `--context <str>` | Extension modules. ex) 'cpu', 'cudnn'. | 'cudnn' |
| `--seed <int>` | Initial seed to set the random initialization | `None` |
82 changes: 82 additions & 0 deletions d3net/music-source-separation/args.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,82 @@
# Copyright (c) 2021 Sony Corporation. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import argparse


def get_train_args():
'''
Get command line arguments.
Arguments set the default values of command line arguments.
'''
parser = argparse.ArgumentParser(
description=f"Music Source Separation Trainer using D3Net")

# which target do we want to train?
parser.add_argument('--target', type=str, default='vocals',
help='target source (will be passed to the dataset)')

# Dataset paramaters
parser.add_argument('--root', type=str, help='root path of dataset')
parser.add_argument('--output', type=str, default="d3net-mss",
help='provide output path base folder name')

# Training Parameters
parser.add_argument('--epochs', type=int, default=50)
parser.add_argument('--batch-size', type=int, default=6)
parser.add_argument('--lr', type=float, default=0.001,
help='learning rate, defaults to 1e-3')
parser.add_argument('--seed', type=int, default=None, metavar='S',
help='random seed (default: None)')
parser.add_argument('--weight-decay', type=float, default=0.0001,
help='weight decay')
# Model Parameters
parser.add_argument('--seq-dur', type=float, default=6.0,
help='Sequence duration in seconds per trainig batch'
'value of <=0.0 will use full/variable length')
parser.add_argument('--nfft', type=int, default=4096,
help='STFT fft size and window size')
parser.add_argument('--nhop', type=int, default=1024,
help='STFT hop size')

# Misc Parameters
parser.add_argument('--device-id', '-d', type=str, default='0',
help='Device ID the training run on. This is only valid if you specify `-c cudnn`.')
parser.add_argument('--context', '-c', type=str,
default='cudnn', help="Extension modules. ex) 'cpu', 'cudnn'.")

args, _ = parser.parse_known_args()

return parser, args


def get_inference_args():
'''
Get command line arguments.
Arguments set the default values of command line arguments.
'''
parser = argparse.ArgumentParser(
description=f"Music Source Separation Inference using D3Net")

parser.add_argument('--inputs', '-i', nargs='+', type=str,
help='List of input audio files supported by FFMPEG.', required=True)
parser.add_argument('--model-dir', '-m', type=str,
default='./d3net-mss/', help='path to the directory of pretrained models.')
parser.add_argument('--targets', nargs='+', default=['vocals', 'drums', 'bass', 'other'],
type=str, help='provide targets to be processed. If none, all available targets will be computed')
parser.add_argument('--out-dir', '-o', type=str,
default='./output/', help='output directory')
parser.add_argument('--context', '-c', type=str,
default='cudnn', help="Extension modules('cpu', 'cudnn')")
return parser.parse_args()
Loading

0 comments on commit ab5e896

Please sign in to comment.