Skip to content

Commit

Permalink
[src & egs] Replace WSJ0-mix Dataset (asteroid-team#97)
Browse files Browse the repository at this point in the history
  • Loading branch information
mpariente authored May 11, 2020
1 parent 40a0139 commit e09c712
Show file tree
Hide file tree
Showing 5 changed files with 91 additions and 303 deletions.
1 change: 1 addition & 0 deletions asteroid/data/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,3 +2,4 @@
from .whamr_dataset import WhamRDataset
from .dns_dataset import DNSDataset
from .librimix_dataset import LibriMix
from .wsj0_mix import Wsj0mixDataset
288 changes: 88 additions & 200 deletions asteroid/data/wsj0_mix.py
Original file line number Diff line number Diff line change
@@ -1,213 +1,101 @@
import torch
from torch.utils import data
from torch.utils.data.sampler import Sampler
import json
import os
import numpy as np
import soundfile as sf
from glob import glob

from asteroid.data.wav import SingleWav
from asteroid.filterbanks.transforms import take_mag

EPS = torch.finfo(torch.float).eps

def collate_fn(batch):
""" Trim all elements of the batch to ensure they have the same
length
"""
batch = sorted(batch, key=lambda sample: sample[0].shape[0],
reverse=True)
smallest_sample = batch[-1]
minibatch_size = len(batch)
src_cnt = smallest_sample[1].shape[0]
sample_len = smallest_sample[0].shape[0]

mixture = torch.zeros(minibatch_size, sample_len)
sources = torch.zeros(minibatch_size, src_cnt, sample_len)

for sample_idx in range(minibatch_size):
sample = batch[sample_idx]
mixture[sample_idx] = sample[0][:sample_len]
sources[sample_idx] = sample[1][...,:sample_len]
return mixture, sources



class BucketingSampler(Sampler):
def __init__(self, data_source, batch_size=1, percentage=1):
"""
Samples batches assuming they are in order of size to batch similarly sized samples together.
Taken from deepspeech github codebase in data/data_loader.py
percentage: Amount of data to take
"""
Sampler.__init__(self, data_source)
self.data_source = data_source
print('At BucketingSampler: Available ', data_source.len)
to_take = int(percentage * data_source.len)
print('At BucketingSampler: Samples to take', to_take)
assert to_take > batch_size, 'Number of samples should atleast be greater than batch_size'
ids = list(range(0, to_take))
self.bins = [ids[i:i + batch_size] for i in range(0, len(ids), batch_size)]
self.shuffle(0)

def __iter__(self):
self.shuffle(0)
for ids in self.bins:
np.random.shuffle(ids)
yield ids

def __len__(self):
return len(self.bins)

def shuffle(self, epoch):
print('shuffling bins')
np.random.shuffle(self.bins)

class WSJmixDataset(data.Dataset):
"""
A interface to process the
Args:
wav_len_list: str. A file containing <wav_id> <sample_len>
wav_base_path: str. Base dir path to obtain the wav files. \
Should find mix, s1, s2 etc in this folder
callback_func: func, A function to process raw wav file
elements: List of elements you want to acess. Ex: mix, s1, s2 and so on
sample_rate: int. Sampling rate of the data
segment: Float. Length of the segments used for training, in seconds
By default returns the full signal. If segment is set to a
float value, signals less that segment lengths are removed.
"""
def __init__(self, wav_len_list, wav_base_path, callback_func=None,
elements=['mix', 's1', 's2'], sample_rate=8000, segment=None):
segment_samples = segment * sample_rate if segment is not None else -1
self.segment = float(segment) if segment is not None else -1
assert os.path.exists(wav_len_list), wav_len_list+' does not exists'
data.Dataset.__init__(self)
id_list = []
id_wav_map = {}
with open(wav_len_list) as fid:
for line in fid:
wav_id, wav_len = line.strip().split()
wav_len = int(wav_len)
id_list.append(wav_id)
if segment_samples != -1 and wav_len < segment_samples:
#print("Drop {} utts. {} (shorter than {} samples)".format(
# wav_id, wav_len/sample_rate, segment))
continue
if wav_id not in id_wav_map:
id_wav_map[wav_id] = {}
for _ele_ in elements:
id_wav_map[wav_id][_ele_] = SingleWav(\
os.path.join(wav_base_path, _ele_, \
wav_id))
id_wav_map[wav_id]['sample'] = wav_len
self.id_list = list(id_wav_map.keys())
self.id_wav_map = id_wav_map
self.len = len(id_wav_map)
# Create an identity function if callback is None
self.callback_func = callback_func if callback_func is not None \
else self.identity
print("{:f}% file dropped".format(100*(1-self.len/len(id_list))))

def identity(self, *kargs):
return kargs
def make_dataloaders(train_dir, valid_dir, n_src=2, sample_rate=8000,
segment=4.0, batch_size=4, num_workers=None,
**kwargs):
num_workers = num_workers if num_workers else batch_size
train_set = Wsj0mixDataset(train_dir, n_src=n_src,
sample_rate=sample_rate,
segment=segment)
val_set = Wsj0mixDataset(valid_dir, n_src=n_src,
sample_rate=sample_rate,
segment=segment)
train_loader = data.DataLoader(train_set, shuffle=True,
batch_size=batch_size,
num_workers=num_workers,
drop_last=True)
val_loader = data.DataLoader(val_set, shuffle=True,
batch_size=batch_size,
num_workers=num_workers,
drop_last=True)
return train_loader, val_loader


class Wsj0mixDataset(data.Dataset):
def __init__(self, json_dir, n_src=2, sample_rate=8000, segment=4.0):
super().__init__()
# Task setting
self.json_dir = json_dir
self.sample_rate = sample_rate
if segment is None:
self.seg_len = None
else:
self.seg_len = int(segment * sample_rate)
self.n_src = n_src
self.like_test = self.seg_len is None
# Load json files
mix_json = os.path.join(json_dir, 'mix.json')
sources_json = [os.path.join(json_dir, source + '.json') for
source in [f"s{n+1}" for n in range(n_src)]]
with open(mix_json, 'r') as f:
mix_infos = json.load(f)
sources_infos = []
for src_json in sources_json:
with open(src_json, 'r') as f:
sources_infos.append(json.load(f))
# Filter out short utterances only when segment is specified
orig_len = len(mix_infos)
drop_utt, drop_len = 0, 0
if not self.like_test:
for i in range(len(mix_infos) - 1, -1, -1): # Go backward
if mix_infos[i][1] < self.seg_len:
drop_utt += 1
drop_len += mix_infos[i][1]
del mix_infos[i]
for src_inf in sources_infos:
del src_inf[i]

print("Drop {} utts({:.2f} h) from {} (shorter than {} samples)".format(
drop_utt, drop_len/sample_rate/36000, orig_len, self.seg_len))
self.mix = mix_infos
self.sources = sources_infos

def __len__(self):
return self.len

def shuffle_list(self):
"""
Shuffle the id list
"""
np.random.shuffle(self.id_list)


class WSJ2mixDataset(WSJmixDataset):
"""
Interface to get 2 mix dataset
Args:
wav_len_list: str. A file containing <wav_id> <sample_len>
wav_base_path: str. Base dir path to obtain the wav files. \
Should find mix, s1, s2 etc in this folder
callback_func: func, A function to process raw wav file
sample_rate: int. Sampling rate of the data
segment: Float. Length of the segments used for training, in seconds
By default returns the full signal. If segment is set to a
float value, signals less that segment lengths are removed.
"""
def __init__(self, wav_len_list, wav_base_path, callback_func=None, \
sample_rate=8000, segment=None):
self.sources = ['s1', 's2']
WSJmixDataset.__init__(self, wav_len_list, wav_base_path,\
elements=['mix'] + self.sources, sample_rate=sample_rate, \
segment=segment)
return len(self.mix)

def __getitem__(self, idx):
item_id = self.id_list[idx]
try:
mixture = self.id_wav_map[item_id]["mix"].data.T[0]
except:
print(self.id_wav_map[item_id]["mix"].file_name)
exit(0)
""" Gets a mixture/sources pair.
Returns:
mixture, vstack([source_arrays])
"""
# Random start
if self.mix[idx][1] == self.seg_len or self.like_test:
rand_start = 0
else:
rand_start = np.random.randint(0, self.mix[idx][1] - self.seg_len)
if self.like_test:
stop = None
else:
stop = rand_start + self.seg_len
# Load mixture
x, _ = sf.read(self.mix[idx][0], start=rand_start,
stop=stop, dtype='float32')
seg_len = torch.as_tensor([len(x)])
# Load sources
source_arrays = []
for _src_ in self.sources:
source_arrays.append(self.id_wav_map[item_id][_src_].data.T[0])
sources = torch.from_numpy(np.vstack(source_arrays)).type(torch.float32)
mixture = torch.from_numpy(mixture).type(torch.float32)
return self.callback_func(mixture, sources)


class WSJ3mixDataset(WSJ2mixDataset):
"""
Interface to get 3 mix dataset
Args:
wav_len_list: str. A file containing <wav_id> <sample_len>
wav_base_path: str. Base dir path to obtain the wav files. \
Should find mix, s1, s2 etc in this folder
callback_func: func, A function to process raw wav file
sample_rate: int. Sampling rate of the data
segment: Float. Length of the segments used for training, in seconds
By default returns the full signal. If segment is set to a
float value, signals less that segment lengths are removed.
"""
def __init__(self, wav_len_list, wav_base_path, callback_func=None,\
sample_rate=8000, segment=None):
sources = ['s1', 's2', 's3']
WSJ2mixDataset.__init__(self, wav_len_list, wav_base_path,\
elements=['mix'] + sources, sample_rate=sample_rate, \
segment=segment)
self.sources = sources


def create_wav_id_sample_count_list(base_path, dest):
""" Create a list file with the following entry per line
wav_id sample_count
Args:
base_path: str. Path to either mix, s1 or s2 directory
dest: str. Path to save the list file
"""
all_wav_files = glob(os.path.join(base_path, '*.wav'))
wid = open(dest, 'w')
id_sample_array = []
for _file in all_wav_files:
sample_cnt = sf.info(_file).frames
wav_id = os.path.basename(_file)
id_sample_array.append((wav_id, sample_cnt))
id_sample_array = sorted(id_sample_array, key=lambda x: x[1])
for wav_id, sample_cnt in id_sample_array:
wid.write('{}\t{}\n'.format(wav_id, sample_cnt))
wid.close()


def transform(mixture, sources):
mix_mag = take_mag(mixture) + EPS
src_mags = []
for _src_ in sources:
_src_mag_ = take_mag(_src_)
src_mags.append(_src_mag_)
spec_sum = torch.stack(src_mags, 0).sum(0) + EPS
src_masks = [_src_mag/spec_sum for _src_mag in src_mags]
return mix_mag, torch.stack(src_masks, 1)


for src in self.sources:
if src[idx] is None:
# Target is filled with zeros if n_src > default_nsrc
s = np.zeros((seg_len, ))
else:
s, _ = sf.read(src[idx][0], start=rand_start,
stop=stop, dtype='float32')
source_arrays.append(s)
sources = torch.from_numpy(np.vstack(source_arrays))
return torch.from_numpy(x), sources
2 changes: 1 addition & 1 deletion egs/wsj0-mix/DeepClustering/eval.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@
from asteroid.metrics import get_metrics

from model import load_best_model
from wsj0_mix_dataset import Wsj0mixDataset
from asteroid.data import Wsj0mixDataset


parser = argparse.ArgumentParser()
Expand Down
2 changes: 1 addition & 1 deletion egs/wsj0-mix/DeepClustering/train.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@
from asteroid.losses import deep_clustering_loss
from asteroid.filterbanks.transforms import take_mag, ebased_vad

from wsj0_mix_dataset import make_dataloaders
from asteroid.data.wsj0_mix import make_dataloaders
from model import make_model_and_optimizer

EPS = 1e-8
Expand Down
Loading

0 comments on commit e09c712

Please sign in to comment.