Skip to content

Commit

Permalink
Made vocoder and synth weights explicitly separated for inference
Browse files Browse the repository at this point in the history
  • Loading branch information
TheButlah committed Aug 14, 2019
1 parent d2f0bb9 commit 352c51a
Show file tree
Hide file tree
Showing 2 changed files with 11 additions and 11 deletions.
13 changes: 8 additions & 5 deletions gen_tacotron.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@
# Parse Arguments
parser = argparse.ArgumentParser(description='TTS Generator')
parser.add_argument('--input_text', '-i', type=str, help='[string] Type in something here and TTS will generate it!')
parser.add_argument('--weights_path', '-w', type=str, help='[string/path] Load in different Tacotron Weights')
parser.add_argument('--tts_weights', type=str, help='[string/path] Load in different Tacotron weights')
parser.add_argument('--save_attention', '-a', dest='save_attn', action='store_true', help='Save Attention Plots')
parser.add_argument('--force_cpu', '-c', action='store_true', help='Forces CPU-only training, even when in CUDA capable environment')
parser.add_argument('--hp_file', metavar='FILE', default='hparams.py', help='The file to use for the hyperparameters')
Expand All @@ -31,6 +31,7 @@
wr_parser.add_argument('--unbatched', '-u', dest='batched', action='store_false', help='Slow Unbatched Generation')
wr_parser.add_argument('--overlap', '-o', type=int, help='[int] number of crossover samples')
wr_parser.add_argument('--target', '-t', type=int, help='[int] number of samples in each batch index')
wr_parser.add_argument('--voc_weights', type=str, help='[string/path] Load in different WaveRNN weights')
wr_parser.set_defaults(batched=None)

gl_parser = subparsers.add_parser('griffinlim', aliases=['gl'])
Expand Down Expand Up @@ -60,7 +61,7 @@
overlap = args.overlap

input_text = args.input_text
weights_path = args.weights_path
tts_weights = args.tts_weights
save_attn = args.save_attn

paths = Paths(hp.data_path, hp.voc_model_id, hp.tts_model_id)
Expand All @@ -86,7 +87,9 @@
hop_length=hp.hop_length,
sample_rate=hp.sample_rate,
mode=hp.voc_mode).to(device)
voc_model.load(paths.voc_latest_weights)

voc_load_path = args.voc_weights if args.voc_weights else paths.voc_latest_weights
voc_model.load(voc_load_path)

print('\nInitialising Tacotron Model...\n')

Expand All @@ -105,8 +108,8 @@
dropout=hp.tts_dropout,
stop_threshold=hp.tts_stop_threshold).to(device)

tts_restore_path = weights_path if weights_path else paths.tts_latest_weights
tts_model.load(tts_restore_path)
tts_load_path = tts_weights if tts_weights else paths.tts_latest_weights
tts_model.load(tts_load_path)

if input_text:
inputs = [text_to_sequence(input_text.strip(), hp.tts_cleaner_names)]
Expand Down
9 changes: 3 additions & 6 deletions gen_wavernn.py
Original file line number Diff line number Diff line change
Expand Up @@ -61,14 +61,11 @@ def gen_from_file(model: WaveRNN, load_path, save_path: Path, batched, target, o
parser.add_argument('--target', '-t', type=int, help='[int] number of samples in each batch index')
parser.add_argument('--overlap', '-o', type=int, help='[int] number of crossover samples')
parser.add_argument('--file', '-f', type=str, help='[string/path] for testing a wav outside dataset')
parser.add_argument('--weights', '-w', type=str, help='[string/path] checkpoint file to load weights from')
parser.add_argument('--voc_weights', '-w', type=str, help='[string/path] Load in different WaveRNN weights')
parser.add_argument('--gta', '-g', dest='gta', action='store_true', help='Generate from GTA testset')
parser.add_argument('--force_cpu', '-c', action='store_true', help='Forces CPU-only training, even when in CUDA capable environment')
parser.add_argument('--hp_file', metavar='FILE', default='hparams.py', help='The file to use for the hyperparameters')

parser.set_defaults(file=None)
parser.set_defaults(weights=None)
parser.set_defaults(gta=False)
parser.set_defaults(batched=None)

args = parser.parse_args()
Expand Down Expand Up @@ -114,9 +111,9 @@ def gen_from_file(model: WaveRNN, load_path, save_path: Path, batched, target, o

paths = Paths(hp.data_path, hp.voc_model_id, hp.tts_model_id)

restore_path = args.weights if args.weights else paths.voc_latest_weights
voc_weights = args.voc_weights if args.voc_weights else paths.voc_latest_weights

model.load(restore_path)
model.load(voc_weights)

simple_table([('Generation Mode', 'Batched' if batched else 'Unbatched'),
('Target Samples', target if batched else 'N/A'),
Expand Down

0 comments on commit 352c51a

Please sign in to comment.