Skip to content

Commit

Permalink
Enabled specifying hparams location in train_tacotron/wavernn
Browse files Browse the repository at this point in the history
  • Loading branch information
TheButlah authored and Ryan Butler committed Jul 30, 2019
1 parent 54bb541 commit 45707ef
Show file tree
Hide file tree
Showing 3 changed files with 25 additions and 5 deletions.
6 changes: 4 additions & 2 deletions train_tacotron.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,9 @@
import torch
from torch import optim
import torch.nn.functional as F
from utils import import_from_file
from utils.display import *
from utils.dataset import get_tts_datasets
import hparams as hp
from utils.text.symbols import symbols
from utils.paths import Paths
from models.tacotron import Tacotron
Expand All @@ -24,8 +24,10 @@ def main():
parser.add_argument('--force_train', '-f', action='store_true', help='Forces the model to train past total steps')
parser.add_argument('--force_gta', '-g', action='store_true', help='Force the model to create GTA features')
parser.add_argument('--force_cpu', '-c', action='store_true', help='Forces CPU-only training, even when in CUDA capable environment')
parser.add_argument('--hp_file', '-p', metavar='FILE', default='hparams.py', help='The file to use for the hyperparameters')
args = parser.parse_args()

hp = import_from_file('hparams', args.hp_file)
paths = Paths(hp.data_path, hp.voc_model_id, hp.tts_model_id)

force_train = args.force_train
Expand Down Expand Up @@ -273,4 +275,4 @@ def restore_checkpoint(paths: Paths, model: Tacotron, optimizer, *,


if __name__ == "__main__":
main()
main()
15 changes: 12 additions & 3 deletions train_wavernn.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
import torch
from torch import optim
import torch.nn.functional as F
from utils import import_from_file
from utils.display import stream, simple_table
from utils.dataset import get_vocoder_datasets
from utils.distribution import discretized_mix_logistic_loss
Expand Down Expand Up @@ -81,7 +82,7 @@ def voc_train_loop(model: WaveRNN, loss_func, optimizer, train_set, test_set, lr
print(' ')


if __name__ == "__main__":
def main():

# Parse Arguments
parser = argparse.ArgumentParser(description='Train WaveRNN Vocoder')
Expand All @@ -90,10 +91,15 @@ def voc_train_loop(model: WaveRNN, loss_func, optimizer, train_set, test_set, lr
parser.add_argument('--force_train', '-f', action='store_true', help='Forces the model to train past total steps')
parser.add_argument('--gta', '-g', action='store_true', help='train wavernn on GTA features')
parser.add_argument('--force_cpu', '-c', action='store_true', help='Forces CPU-only training, even when in CUDA capable environment')
parser.set_defaults(lr=hp.voc_lr)
parser.set_defaults(batch_size=hp.voc_batch_size)
parser.add_argument('--hp_file', '-p', metavar='FILE', default='hparams.py', help='The file to use for the hyperparameters')
args = parser.parse_args()

hp = import_from_file('hparams', args.hp_file)
if args.lr is None:
args.lr = hp.voc_lr
if args.batch_size is None:
args.batch_size = hp.voc_batch_size

batch_size = args.batch_size
force_train = args.force_train
train_gta = args.gta
Expand Down Expand Up @@ -151,3 +157,6 @@ def voc_train_loop(model: WaveRNN, loss_func, optimizer, train_set, test_set, lr

print('Training Complete.')
print('To continue training increase voc_total_steps in hparams.py or use --force_train')

if __name__ == "__main__":
main()
9 changes: 9 additions & 0 deletions utils/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,8 @@
import sys
import torch

from importlib.util import spec_from_file_location, module_from_spec

# Credit: Ryuichi Yamamoto (https://github.com/r9y9/wavenet_vocoder/blob/1717f145c8f8c0f3f85ccdf346b5209fa2e1c920/train.py#L599)
# Modified by: Ryan Butler (https://github.com/TheButlah)
# workaround for https://github.com/pytorch/pytorch/issues/15716
Expand All @@ -30,3 +32,10 @@ def data_parallel_workaround(model, *input):
_replicas_ref = replicas
return y_hat


def import_from_file(name, path):
"""Programmatically imports a module"""
spec = spec_from_file_location(name, path)
m = module_from_spec(spec)
spec.loader.exec_module(m)
return m

0 comments on commit 45707ef

Please sign in to comment.