Skip to content

Commit

Permalink
Merge pull request asteroid-team#133 from mpariente/kinect_bis
Browse files Browse the repository at this point in the history
[src & egs] Add Kinect-WSJ licenses + small fixes
  • Loading branch information
sunits authored May 31, 2020
2 parents b172619 + 855c43c commit 5b32afb
Show file tree
Hide file tree
Showing 4 changed files with 59 additions and 9 deletions.
42 changes: 34 additions & 8 deletions asteroid/data/kinect_wsj.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,5 @@
import torch
from torch.utils import data
import json
import os
import numpy as np
import soundfile as sf
Expand All @@ -12,11 +11,11 @@ def make_dataloaders(train_dir, valid_dir, n_src=2, sample_rate=16000,
**kwargs):
num_workers = num_workers if num_workers else batch_size
train_set = KinectWsjMixDataset(train_dir, n_src=n_src,
sample_rate=sample_rate,
segment=segment)
sample_rate=sample_rate,
segment=segment)
val_set = KinectWsjMixDataset(valid_dir, n_src=n_src,
sample_rate=sample_rate,
segment=segment)
sample_rate=sample_rate,
segment=segment)
train_loader = data.DataLoader(train_set, shuffle=True,
batch_size=batch_size,
num_workers=num_workers,
Expand All @@ -29,8 +28,12 @@ def make_dataloaders(train_dir, valid_dir, n_src=2, sample_rate=16000,


class KinectWsjMixDataset(Wsj0mixDataset):
dataset_name = 'Kinect-WSJ'

def __init__(self, json_dir, n_src=2, sample_rate=16000, segment=4.0):
super().__init__( json_dir, n_src=n_src, sample_rate=sample_rate, segment=segment)
super().__init__(
json_dir, n_src=n_src, sample_rate=sample_rate, segment=segment
)
noises = []
for i in range(len(self.mix)):
path = self.mix[i][0]
Expand Down Expand Up @@ -58,8 +61,10 @@ def __getitem__(self, idx):
else:
stop = rand_start + self.seg_len
# Load mixture
x, _ = sf.read(self.mix[idx][0], start=rand_start, stop=stop, dtype='float32', always_2d=True)
noise, _ = sf.read(self.noises[idx][0], start=rand_start, stop=stop, dtype='float32', always_2d=True)
x, _ = sf.read(self.mix[idx][0], start=rand_start, stop=stop,
dtype='float32', always_2d=True)
noise, _ = sf.read(self.noises[idx][0], start=rand_start, stop=stop,
dtype='float32', always_2d=True)
seg_len = torch.as_tensor([len(x)])
# Load sources
source_arrays = []
Expand All @@ -73,3 +78,24 @@ def __getitem__(self, idx):
source_arrays.append(s)
sources = torch.from_numpy(np.stack(source_arrays))
return torch.from_numpy(x), sources, torch.from_numpy(noise)

def get_infos(self):
""" Get dataset infos (for publishing models).
Returns:
dict, dataset infos with keys `dataset`, `task` and `licences`.
"""
infos = super().get_infos()
infos['licenses'].append(chime5_license)
return infos


chime5_license = dict(
title='The CHiME-5 speech corpus',
title_link='http://spandh.dcs.shef.ac.uk/chime_challenge/CHiME5/index.html',
author='Jon Barker, Shinji Watanabe and Emmanuel Vincent',
author_link='http://spandh.dcs.shef.ac.uk/chime_challenge/chime2018/contact.html',
license='CHiME-5 data licence - non-commercial 1.00',
license_link='https://licensing.sheffield.ac.uk/i/data/chime5.html',
non_commercial=True
)
2 changes: 1 addition & 1 deletion egs/kinect-wsj/DeepClustering/train.py
Original file line number Diff line number Diff line change
Expand Up @@ -81,7 +81,7 @@ def main(conf):
os.path.join(exp_dir, 'checkpoints/final.pth'))


# TODO:Should ideally be inherited from wsj0-mix
# TODO:Should ideally be inherited from wsj0-mix
class ChimeraSystem(System):
def __init__(self, *args, mask_mixture=True, **kwargs):
super().__init__(*args, **kwargs)
Expand Down
22 changes: 22 additions & 0 deletions egs/kinect-wsj/README.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,22 @@
### About the Kinect-WSJ dataset
Kinect-WSJ is a reverberated, noisy version of the WSJ0-2MIX dataset. Microphones are placed on a linear array with spacing between the devices resembling that of Microsoft Kinect ™, the device used to record the CHiME-5 dataset. This was done so that we could use the real ambient noise captured as part of CHiME-5 dataset. The room impulse responses (RIR) were simulated for a sampling rate of 16,000 Hz.

## Path to the dataset
https://github.com/sunits/Reverberated_WSJ_2MIX/

# Requirements to create Kinect-WSJ dataset
* wsj_path : Path to precomputed wsj-2mix dataset. Should contain the folder 2speakers/wav16k/. If you don't have wsj_mix dataset, please create it using the scripts in egs/wsj0_mix
* chime_path : Path to chime-5 dataset. Should contain the folders train, dev and eval
* dihard_path : Path to dihard labels. Should contain ```*.lab``` files for the train and dev set

# References

```
@inproceedings{sivasankaran2020,
booktitle = {2020 28th {{European Signal Processing Conference}} ({{EUSIPCO}})},
title={Analyzing the impact of speaker localization errors on speech separation for automatic speech recognition},
author={Sunit Sivasankaran and Emmanuel Vincent and Dominique Fohr},
year={2021},
month = Jan,
}
```
2 changes: 2 additions & 0 deletions egs/wsj0-mix/DeepClustering/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,8 @@
from asteroid.engine.optimizers import make_optimizer
from asteroid.filterbanks.transforms import take_mag, apply_mag_mask, ebased_vad
from asteroid.masknn.blocks import SingleRNN
from asteroid.utils.torch_utils import pad_x_to_y

EPS = 1e-8


Expand Down

0 comments on commit 5b32afb

Please sign in to comment.