diff --git a/train_tacotron.py b/train_tacotron.py index 168bd1b5..462ec3c3 100644 --- a/train_tacotron.py +++ b/train_tacotron.py @@ -1,9 +1,9 @@ import torch from torch import optim import torch.nn.functional as F +from utils import import_from_file from utils.display import * from utils.dataset import get_tts_datasets -import hparams as hp from utils.text.symbols import symbols from utils.paths import Paths from models.tacotron import Tacotron @@ -24,8 +24,10 @@ def main(): parser.add_argument('--force_train', '-f', action='store_true', help='Forces the model to train past total steps') parser.add_argument('--force_gta', '-g', action='store_true', help='Force the model to create GTA features') parser.add_argument('--force_cpu', '-c', action='store_true', help='Forces CPU-only training, even when in CUDA capable environment') + parser.add_argument('--hp_file', '-p', metavar='FILE', default='hparams.py', help='The file to use for the hyperparameters') args = parser.parse_args() + hp = import_from_file('hparams', args.hp_file) paths = Paths(hp.data_path, hp.voc_model_id, hp.tts_model_id) force_train = args.force_train @@ -273,4 +275,4 @@ def restore_checkpoint(paths: Paths, model: Tacotron, optimizer, *, if __name__ == "__main__": - main() \ No newline at end of file + main() diff --git a/train_wavernn.py b/train_wavernn.py index 27bebe90..8867c009 100644 --- a/train_wavernn.py +++ b/train_wavernn.py @@ -3,6 +3,7 @@ import torch from torch import optim import torch.nn.functional as F +from utils import import_from_file from utils.display import stream, simple_table from utils.dataset import get_vocoder_datasets from utils.distribution import discretized_mix_logistic_loss @@ -81,7 +82,7 @@ def voc_train_loop(model: WaveRNN, loss_func, optimizer, train_set, test_set, lr print(' ') -if __name__ == "__main__": +def main(): # Parse Arguments parser = argparse.ArgumentParser(description='Train WaveRNN Vocoder') @@ -90,10 +91,15 @@ def voc_train_loop(model: WaveRNN, loss_func, optimizer, train_set, test_set, lr parser.add_argument('--force_train', '-f', action='store_true', help='Forces the model to train past total steps') parser.add_argument('--gta', '-g', action='store_true', help='train wavernn on GTA features') parser.add_argument('--force_cpu', '-c', action='store_true', help='Forces CPU-only training, even when in CUDA capable environment') - parser.set_defaults(lr=hp.voc_lr) - parser.set_defaults(batch_size=hp.voc_batch_size) + parser.add_argument('--hp_file', '-p', metavar='FILE', default='hparams.py', help='The file to use for the hyperparameters') args = parser.parse_args() + hp = import_from_file('hparams', args.hp_file) + if args.lr is None: + args.lr = hp.voc_lr + if args.batch_size is None: + args.batch_size = hp.voc_batch_size + batch_size = args.batch_size force_train = args.force_train train_gta = args.gta @@ -151,3 +157,6 @@ def voc_train_loop(model: WaveRNN, loss_func, optimizer, train_set, test_set, lr print('Training Complete.') print('To continue training increase voc_total_steps in hparams.py or use --force_train') + +if __name__ == "__main__": + main() diff --git a/utils/__init__.py b/utils/__init__.py index 8751513c..45258a84 100644 --- a/utils/__init__.py +++ b/utils/__init__.py @@ -5,6 +5,8 @@ import sys import torch +from importlib.util import spec_from_file_location, module_from_spec + # Credit: Ryuichi Yamamoto (https://github.com/r9y9/wavenet_vocoder/blob/1717f145c8f8c0f3f85ccdf346b5209fa2e1c920/train.py#L599) # Modified by: Ryan Butler (https://github.com/TheButlah) # workaround for https://github.com/pytorch/pytorch/issues/15716 @@ -30,3 +32,10 @@ def data_parallel_workaround(model, *input): _replicas_ref = replicas return y_hat + +def import_from_file(name, path): + """Programmatically imports a module""" + spec = spec_from_file_location(name, path) + m = module_from_spec(spec) + spec.loader.exec_module(m) + return m