Skip to content

Commit

Permalink
Made force_train in tacotron actually do something
Browse files Browse the repository at this point in the history
  • Loading branch information
TheButlah committed Aug 10, 2019
1 parent f62bbc5 commit 1b4aaa7
Showing 1 changed file with 28 additions and 13 deletions.
41 changes: 28 additions & 13 deletions train_tacotron.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
from pathlib import Path
import time
import numpy as np
import sys


def np_now(x: torch.Tensor): return x.detach().cpu().numpy()
Expand Down Expand Up @@ -65,21 +66,35 @@ def main():
current_step = model.get_step()

if not force_gta:
for session in hp.tts_schedule:
for i, session in enumerate(hp.tts_schedule):
r, lr, max_step, batch_size = session

# Skip previous sessions based on model step
if current_step < max_step:
model.r = r
training_steps = max_step - current_step

simple_table([(f'Steps with r={r}', str(training_steps//1000) + 'k Steps'),
('Batch Size', batch_size),
('Learning Rate', lr),
('Outputs/Step (r)', model.r)])

train_set, attn_example = get_tts_datasets(paths.data, batch_size, r)
tts_train_loop(paths, model, optimizer, train_set, lr, training_steps, attn_example)
training_steps = max_step - current_step

# Do we need to change to the next session?
if current_step >= max_step:
# Are there no further sessions than the current one?
if i == len(hp.tts_schedule)-1:
# There are no more sessions. Check if we force training.
if force_train:
# Don't finish the loop - train forever
training_steps = 999_999_999
else:
# We have completed training. Breaking is same as continue
break
else:
# There is a following session, go to it
continue

model.r = r

simple_table([(f'Steps with r={r}', str(training_steps//1000) + 'k Steps'),
('Batch Size', batch_size),
('Learning Rate', lr),
('Outputs/Step (r)', model.r)])

train_set, attn_example = get_tts_datasets(paths.data, batch_size, r)
tts_train_loop(paths, model, optimizer, train_set, lr, training_steps, attn_example)

print('Training Complete.')
print('To continue training increase tts_total_steps in hparams.py or use --force_train\n')
Expand Down

0 comments on commit 1b4aaa7

Please sign in to comment.