Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Two step source separation #67

Merged
merged 56 commits into from
Apr 9, 2020
Merged
Changes from 1 commit
Commits
Show all changes
56 commits
Select commit Hold shift + click to select a range
e68b59f
Wham Dataset normalize audio by removing mean and diving with standar…
etzinis Mar 27, 2020
bea01e4
Pytorch Lightning System module overwrite commonstep method for integ…
etzinis Mar 27, 2020
2e42ce2
Adding the two-step separation block module for the separator in appr…
etzinis Mar 27, 2020
c8d4fc0
Adding the corresponding encoder/decoder or filterbank for the two st…
etzinis Mar 27, 2020
229cb1f
Training recipe for the two step process, doing separately the optimi…
etzinis Mar 27, 2020
651b8da
Adding recipe for evaluating the two-step process similar to ConvTasnet
etzinis Mar 27, 2020
98aca36
Similar to Conv Tasnet adding bashscript for creating data, training …
etzinis Mar 27, 2020
a49259b
Adding model wrapper for the filterbank and separator as well as clas…
etzinis Mar 27, 2020
7c7c5b1
Local subdirectory copied from ConvTasnet
etzinis Mar 27, 2020
444eb08
Utils bash scripts copied from ConvTasnet recipe
etzinis Mar 27, 2020
bdcf07e
Adding README for the two step process
etzinis Mar 27, 2020
30fbede
[egs] Upload Tasnet WHAMR results
mpariente Mar 28, 2020
2d3e838
[ci] Restore old travis.yml (reverts part of #61)
mpariente Mar 28, 2020
fd22c4d
[src] Add train kwarg in System's common_step for easier subclassing.
mpariente Mar 28, 2020
1645d5e
Wham Dataset normalize audio by removing mean and diving with standar…
etzinis Mar 27, 2020
cd3f8d7
Pytorch Lightning System module overwrite commonstep method for integ…
etzinis Mar 27, 2020
857a413
Adding the two-step separation block module for the separator in appr…
etzinis Mar 27, 2020
0e31c58
Adding the corresponding encoder/decoder or filterbank for the two st…
etzinis Mar 27, 2020
b56ab97
Training recipe for the two step process, doing separately the optimi…
etzinis Mar 27, 2020
4d936c4
Adding recipe for evaluating the two-step process similar to ConvTasnet
etzinis Mar 27, 2020
5b0cfb8
Similar to Conv Tasnet adding bashscript for creating data, training …
etzinis Mar 27, 2020
e99049d
Adding model wrapper for the filterbank and separator as well as clas…
etzinis Mar 27, 2020
1983407
Local subdirectory copied from ConvTasnet
etzinis Mar 27, 2020
cb47c66
Utils bash scripts copied from ConvTasnet recipe
etzinis Mar 27, 2020
aebb0cb
Adding README for the two step process
etzinis Mar 27, 2020
103bc5a
System merging with recent train optional argument change
etzinis Apr 8, 2020
761a137
Wham Dataset normalize audio by removing mean and diving with standar…
etzinis Mar 27, 2020
cb18372
Adding the two-step separation block module for the separator in appr…
etzinis Mar 27, 2020
3f22b0d
Adding the corresponding encoder/decoder or filterbank for the two st…
etzinis Mar 27, 2020
73ddd54
Training recipe for the two step process, doing separately the optimi…
etzinis Mar 27, 2020
ccf73cf
Adding recipe for evaluating the two-step process similar to ConvTasnet
etzinis Mar 27, 2020
59de212
Similar to Conv Tasnet adding bashscript for creating data, training …
etzinis Mar 27, 2020
92a1681
Adding model wrapper for the filterbank and separator as well as clas…
etzinis Mar 27, 2020
52339c6
Local subdirectory copied from ConvTasnet
etzinis Mar 27, 2020
8ffbd84
Utils bash scripts copied from ConvTasnet recipe
etzinis Mar 27, 2020
ed58286
Adding README for the two step process
etzinis Mar 27, 2020
5752f2b
Pytorch Lightning System module overwrite commonstep method for integ…
etzinis Mar 27, 2020
34dc2b0
System training step according to recent change
etzinis Apr 8, 2020
cc663a8
Wham dataset 1e-8 small fix
etzinis Apr 8, 2020
6201064
Wham dataset argument passing in wham dataset for audio normalization
etzinis Apr 8, 2020
18772c7
System Value error instead of not implemetned
etzinis Apr 8, 2020
9d21508
Conflicts from origin
etzinis Apr 8, 2020
052d900
Proper calling of audio normalizer inside wham dataset
etzinis Apr 9, 2020
00efb18
Removing the two step methods from the core system module
etzinis Apr 9, 2020
ba2af48
Calling the new system for the two step separation only
etzinis Apr 9, 2020
4af20b8
Adding subclass for overriding the system file as requested
etzinis Apr 9, 2020
5475cc0
Fixing raising exceptions error messages for File not found and value…
etzinis Apr 9, 2020
9d998cf
train part in better format as requested
etzinis Apr 9, 2020
6f98632
Removing the filterbanks used in two step source separation
etzinis Apr 9, 2020
c5775a8
Adding specific filterbanks used for two step inside the model of the…
etzinis Apr 9, 2020
6a80ccb
Removing input as specified in the pull request because of potential …
etzinis Apr 9, 2020
5c92463
adding srguments for several filterbank important things
etzinis Apr 9, 2020
2495f63
Specifying in conf file whether to reuse a potential pretrained adapt…
etzinis Apr 9, 2020
c482c54
Moving all model modules inside the recipe for not messing with the c…
etzinis Apr 9, 2020
7a06bc7
Two step default values for easier expereimentation
etzinis Apr 9, 2020
c9aa338
Update EPS in WhamDataset
mpariente Apr 9, 2020
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Prev Previous commit
Next Next commit
Training recipe for the two step process, doing separately the optimi…
…zation for the filterbank and the separator.
  • Loading branch information
etzinis committed Mar 27, 2020
commit 229cb1fd03145af6de95aecbed1ee3d1d374efe5
155 changes: 155 additions & 0 deletions egs/wham/TwoStep/train.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,155 @@
import os
import argparse
import json

import torch
from torch.optim.lr_scheduler import ReduceLROnPlateau
from torch.utils.data import DataLoader
import pytorch_lightning as pl
from pytorch_lightning.callbacks import ModelCheckpoint, EarlyStopping

from asteroid import torch_utils
from asteroid.data.wham_dataset import WhamDataset
from asteroid.engine.system import System
from asteroid.losses import PITLossWrapper, pairwise_neg_sisdr

from model import get_encoded_paths
from model import load_best_filterbank_if_available
from model import make_model_and_optimizer

# Keys which are not in the conf.yml file can be added here.
# In the hierarchical dictionary created when parsing, the key `key` can be
# found at dic['main_args'][key]

# By default train.py will use all available GPUs. The `id` option in run.sh
# will limit the number of available GPUs for train.py .
# This can be changed: `python train.py --gpus 0,1` will only train on 2 GPUs.
parser = argparse.ArgumentParser()
parser.add_argument('--gpus', type=str, help='list of GPUs', default='0')
parser.add_argument('--exp_dir', default='exp/model_logs',
help='Full path to save best validation model')


def get_data_loaders(conf, train_part='filterbank'):
train_set = WhamDataset(conf['data']['train_dir'], conf['data']['task'],
sample_rate=conf['data']['sample_rate'],
nondefault_nsrc=conf['data']['nondefault_nsrc'],
normalize_audio=True)
val_set = WhamDataset(conf['data']['valid_dir'], conf['data']['task'],
sample_rate=conf['data']['sample_rate'],
nondefault_nsrc=conf['data']['nondefault_nsrc'],
normalize_audio=True)

if (not train_part == 'filterbank') and (not train_part == 'separator'):
etzinis marked this conversation as resolved.
Show resolved Hide resolved
raise NotImplementedError('Part to train: {} is not '
'available.'.format(train_part))

train_loader = DataLoader(train_set, shuffle=True, drop_last=True,
batch_size=conf[train_part + '_training'][
train_part[0] + '_batch_size'],
num_workers=conf[train_part + '_training'][
train_part[0] + '_num_workers'])
val_loader = DataLoader(val_set, shuffle=True, drop_last=True,
batch_size=conf[train_part + '_training'][
train_part[0] + '_batch_size'],
num_workers=conf[train_part + '_training'][
train_part[0] + '_num_workers'])
# Update number of source values (It depends on the task)
conf['masknet'].update({'n_src': train_set.n_src})

return train_set, val_set, train_loader, val_loader


def train_model_part(conf, train_part='filterbank',
pretrained_filterbank=None):
train_set, val_set, train_loader, val_loader = get_data_loaders(
conf, train_part=train_part)

# Define model and optimizer in a local function (defined in the recipe).
# Two advantages to this : re-instantiating the model and optimizer
# for retraining and evaluating is straight-forward.
model, optimizer = make_model_and_optimizer(
conf, model_part=train_part,
pretrained_filterbank=pretrained_filterbank)
# Define scheduler
scheduler = None
if conf[train_part + '_training'][train_part[0] + '_half_lr']:
scheduler = ReduceLROnPlateau(optimizer=optimizer, factor=0.5,
patience=5)
# Just after instantiating, save the args. Easy loading in the future.
exp_dir, checkpoint_dir = get_encoded_paths(conf, train_part)
os.makedirs(exp_dir, exist_ok=True)
conf_path = os.path.join(exp_dir, 'conf.yml')
with open(conf_path, 'w') as outfile:
yaml.safe_dump(conf, outfile)

# Define Loss function.
loss_func = PITLossWrapper(pairwise_neg_sisdr, mode='pairwise')
system = System(model=model, loss_func=loss_func, optimizer=optimizer,
train_loader=train_loader, val_loader=val_loader,
scheduler=scheduler, config=conf,
two_step_approach=train_part)

# Define callbacks
checkpoint = ModelCheckpoint(checkpoint_dir, monitor='val_loss',
mode='min', save_top_k=1, verbose=1)
early_stopping = False
if conf[train_part + '_training'][train_part[0] + '_early_stop']:
early_stopping = EarlyStopping(monitor='val_loss', patience=10,
verbose=1)
# Don't ask GPU if they are not available.
if not torch.cuda.is_available():
print('No available GPU were found, set gpus to None')
conf['main_args']['gpus'] = None

trainer = pl.Trainer(
max_nb_epochs=conf[train_part + '_training'][train_part[0] + '_epochs'],
checkpoint_callback=checkpoint,
early_stop_callback=early_stopping,
default_save_path=exp_dir,
gpus=conf['main_args']['gpus'],
distributed_backend='dp',
train_percent_check=1.0, # Useful for fast experiment
gradient_clip_val=5.)
trainer.fit(system)

with open(os.path.join(checkpoint_dir, "best_k_models.json"), "w") as file:
json.dump(checkpoint.best_k_models, file, indent=0)


def main(conf):
filterbank = load_best_filterbank_if_available(conf)
_, checkpoint_dir = get_encoded_paths(conf, 'filterbank')
if filterbank is None:
print('There are no available filterbanks under: {}. Going to '
'training.'.format(checkpoint_dir))
train_model_part(conf, train_part='filterbank')
else:
print('Found available filterbank at: {}'.format(checkpoint_dir))
inp = input('Do you want to refine it further? y/n\n')
etzinis marked this conversation as resolved.
Show resolved Hide resolved
if inp.lower() == 'y':
print('Refining filterbank...')
train_model_part(conf, train_part='filterbank')
train_model_part(conf, train_part='separator',
pretrained_filterbank=filterbank)


if __name__ == '__main__':
import yaml
from pprint import pprint as print
from asteroid.utils import prepare_parser_from_dict, parse_args_as_dict

# We start with opening the config file conf.yml as a dictionary from
# which we can create parsers. Each top level key in the dictionary defined
# by the YAML file creates a group in the parser.
with open('local/conf.yml') as f:
def_conf = yaml.safe_load(f)
parser = prepare_parser_from_dict(def_conf, parser=parser)
# Arguments are then parsed into a hierarchical dictionary (instead of
# flat, as returned by argparse) to facilitate calls to the different
# asteroid methods (see in main).
# plain_args is the direct output of parser.parse_args() and contains all
# the attributes in an non-hierarchical structure. It can be useful to also
# have it so we included it here but it is not used.
arg_dic, plain_args = parse_args_as_dict(parser, return_plain_args=True)
main(arg_dic)