Skip to content

Commit

Permalink
[src] Style fixes
Browse files Browse the repository at this point in the history
  • Loading branch information
mpariente committed May 31, 2020
1 parent b172619 commit 7296c2b
Show file tree
Hide file tree
Showing 2 changed files with 12 additions and 9 deletions.
19 changes: 11 additions & 8 deletions asteroid/data/kinect_wsj.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,5 @@
import torch
from torch.utils import data
import json
import os
import numpy as np
import soundfile as sf
Expand All @@ -12,11 +11,11 @@ def make_dataloaders(train_dir, valid_dir, n_src=2, sample_rate=16000,
**kwargs):
num_workers = num_workers if num_workers else batch_size
train_set = KinectWsjMixDataset(train_dir, n_src=n_src,
sample_rate=sample_rate,
segment=segment)
sample_rate=sample_rate,
segment=segment)
val_set = KinectWsjMixDataset(valid_dir, n_src=n_src,
sample_rate=sample_rate,
segment=segment)
sample_rate=sample_rate,
segment=segment)
train_loader = data.DataLoader(train_set, shuffle=True,
batch_size=batch_size,
num_workers=num_workers,
Expand All @@ -30,7 +29,9 @@ def make_dataloaders(train_dir, valid_dir, n_src=2, sample_rate=16000,

class KinectWsjMixDataset(Wsj0mixDataset):
def __init__(self, json_dir, n_src=2, sample_rate=16000, segment=4.0):
super().__init__( json_dir, n_src=n_src, sample_rate=sample_rate, segment=segment)
super().__init__(
json_dir, n_src=n_src, sample_rate=sample_rate, segment=segment
)
noises = []
for i in range(len(self.mix)):
path = self.mix[i][0]
Expand Down Expand Up @@ -58,8 +59,10 @@ def __getitem__(self, idx):
else:
stop = rand_start + self.seg_len
# Load mixture
x, _ = sf.read(self.mix[idx][0], start=rand_start, stop=stop, dtype='float32', always_2d=True)
noise, _ = sf.read(self.noises[idx][0], start=rand_start, stop=stop, dtype='float32', always_2d=True)
x, _ = sf.read(self.mix[idx][0], start=rand_start, stop=stop,
dtype='float32', always_2d=True)
noise, _ = sf.read(self.noises[idx][0], start=rand_start, stop=stop,
dtype='float32', always_2d=True)
seg_len = torch.as_tensor([len(x)])
# Load sources
source_arrays = []
Expand Down
2 changes: 1 addition & 1 deletion egs/kinect-wsj/DeepClustering/train.py
Original file line number Diff line number Diff line change
Expand Up @@ -81,7 +81,7 @@ def main(conf):
os.path.join(exp_dir, 'checkpoints/final.pth'))


# TODO:Should ideally be inherited from wsj0-mix
# TODO:Should ideally be inherited from wsj0-mix
class ChimeraSystem(System):
def __init__(self, *args, mask_mixture=True, **kwargs):
super().__init__(*args, **kwargs)
Expand Down

0 comments on commit 7296c2b

Please sign in to comment.