Skip to content

Commit

Permalink
Fixed train_tacotron.py so that it works with new way of checkpointing
Browse files Browse the repository at this point in the history
  • Loading branch information
TheButlah committed Jul 26, 2019
1 parent c738ee5 commit 924fbeb
Show file tree
Hide file tree
Showing 2 changed files with 39 additions and 43 deletions.
80 changes: 37 additions & 43 deletions train_tacotron.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,15 +18,16 @@
def np_now(x: torch.Tensor): return x.detach().cpu().numpy()


if __name__ == "__main__":

def main():
# Parse Arguments
parser = argparse.ArgumentParser(description='Train Tacotron TTS')
parser.add_argument('--force_train', '-f', action='store_true', help='Forces the model to train past total steps')
parser.add_argument('--force_gta', '-g', action='store_true', help='Force the model to create GTA features')
parser.add_argument('--force_cpu', '-c', action='store_true', help='Forces CPU-only training, even when in CUDA capable environment')
args = parser.parse_args()

paths = Paths(hp.data_path, hp.voc_model_id, hp.tts_model_id)

force_train = args.force_train
force_gta = args.force_gta

Expand All @@ -40,9 +41,8 @@ def np_now(x: torch.Tensor): return x.detach().cpu().numpy()
device = torch.device('cpu')
print('Using device:', device)

print('\nInitialising Tacotron Model...\n')

# Instantiate Tacotron Model
print('\nInitialising Tacotron Model...\n')
model = Tacotron(embed_dims=hp.tts_embed_dims,
num_chars=len(symbols),
encoder_dims=hp.tts_encoder_dims,
Expand All @@ -56,40 +56,26 @@ def np_now(x: torch.Tensor): return x.detach().cpu().numpy()
num_highways=hp.tts_num_highways,
dropout=hp.tts_dropout).to(device=device)

paths = Paths(hp.data_path, hp.voc_model_id, hp.tts_model_id)

model.restore(paths.tts_latest_weights)

# model.reset_step()

# model.set_r(hp.tts_r)

optimizer = optim.Adam(model.parameters())
if paths.tts_latest_optim.exists():
print(f'Loading Optimizer State: "{paths.tts_latest_optim}"\n')
optimizer.load_state_dict(torch.load(paths.tts_latest_optim))
restore_checkpoint(paths, model, optimizer, create_if_missing=True)

current_step = model.get_step()

if not force_gta:

for session in hp.tts_schedule:

r, lr, max_step, batch_size = session

# Skip previous sessions based on model step
if current_step < max_step:

train_set, attn_example = get_tts_datasets(paths.data, batch_size, r)

model.r = r

training_steps = max_step - current_step

simple_table([(f'Steps with r={r}', str(training_steps//1000) + 'k Steps'),
('Batch Size', batch_size),
('Learning Rate', lr),
('Outputs/Step (r)', model.r)])

train_set, attn_example = get_tts_datasets(paths.data, batch_size, r)
tts_train_loop(model, optimizer, train_set, lr, training_steps, attn_example)

print('Training Complete.')
Expand All @@ -107,7 +93,7 @@ def np_now(x: torch.Tensor): return x.detach().cpu().numpy()
def tts_train_loop(model: Tacotron, optimizer, train_set, lr, train_steps, attn_example):
device = next(model.parameters()).device # use same device as model parameters

for p in optimizer.param_groups: p['lr'] = lr
for g in optimizer.param_groups: g['lr'] = lr

total_iters = len(train_set)
epochs = train_steps // total_iters + 1
Expand All @@ -117,6 +103,7 @@ def tts_train_loop(model: Tacotron, optimizer, train_set, lr, train_steps, attn_
start = time.time()
running_loss = 0

# Perform 1 epoch
for i, (x, m, ids, _) in enumerate(train_set, 1):

optimizer.zero_grad()
Expand All @@ -127,7 +114,7 @@ def tts_train_loop(model: Tacotron, optimizer, train_set, lr, train_steps, attn_
if device.type == 'cuda' and torch.cuda.device_count() > 1:
m1_hat, m2_hat, attention = data_parallel_workaround(model, x, m)
else:
m1_hat, m2_hat, attention = model(x, m)
m1_hat, m2_hat, attention = model(x, m)

m1_loss = F.l1_loss(m1_hat, m)
m2_loss = F.l1_loss(m2_hat, m)
Expand All @@ -152,7 +139,9 @@ def tts_train_loop(model: Tacotron, optimizer, train_set, lr, train_steps, attn_
avg_loss = running_loss / i

if step % hp.tts_checkpoint_every == 0:
model.checkpoint(paths.tts_checkpoints, optimizer)
ckpt_name = f'taco_step{k}K'
save_checkpoint(paths, model, optimizer,
name=ckpt_name, is_silent=False)

if attn_example in ids:
idx = ids.index(attn_example)
Expand All @@ -164,8 +153,7 @@ def tts_train_loop(model: Tacotron, optimizer, train_set, lr, train_steps, attn_

# Must save latest optimizer state to ensure that resuming training
# doesn't produce artifacts
torch.save(optimizer.state_dict(), paths.tts_latest_optim)
model.save(paths.tts_latest_weights)
save_checkpoint(paths, model, optimizer, is_silent=True)
model.log(paths.tts_log, msg)
print(' ')

Expand Down Expand Up @@ -193,7 +181,8 @@ def create_gta_features(model: Tacotron, train_set, save_path: Path):
stream(msg)


def save_checkpoint(paths: Paths, model: Tacotron, optimizer, name=None):
def save_checkpoint(paths: Paths, model: Tacotron, optimizer, *,
name=None, is_silent=False):
"""Saves the training session to disk.
Args:
Expand All @@ -205,29 +194,29 @@ def save_checkpoint(paths: Paths, model: Tacotron, optimizer, name=None):
will always update the files specified in `paths` that give the
location of the latest weights and optimizer state. Saving
a named checkpoint happens in addition to this update.
"""
"""
def helper(path_dict, is_named):
s = 'named' if is_named else 'latest'
num_exist = sum(p.exists() for p in path_dict.values())

if num_exist not in (0,2):
# Checkpoint broken
raise FileNotFoundError(
f'We expected either both or no files in the {s} checkpoint to '
'exist, but instead we got exactly one!')

if num_exist == 0:
print('Creating {s} checkpoint...')
if not is_silent: print(f'Creating {s} checkpoint...')
for p in path_dict.values():
p.parent.mkdir(parents=True)
p.parent.mkdir(parents=True, exist_ok=True)
else:
print('Saving to existing {s} checkpoint...')
print(f'Saving {s} weights: {path_dict["w"]}')
if not is_silent: print(f'Saving to existing {s} checkpoint...')

if not is_silent: print(f'Saving {s} weights: {path_dict["w"]}')
model.save(path_dict['w'])
print(f'Saving {s} optimizer state: {path_dict["o"]}')
torch.save(optimizer.state_dict(), path_dict['o'])
if not is_silent: print(f'Saving {s} optimizer state: {path_dict["o"]}')
torch.save(optimizer.state_dict(), path_dict['o'])

latest_paths = {'w': paths.tts_latest_weights, 'o': paths.tts_latest_optim}
helper(latest_paths, False)

Expand All @@ -239,7 +228,8 @@ def helper(path_dict, is_named):
helper(named_paths, True)


def restore_checkpoint(paths: Paths, model: Tacotron, optimizer, name=None, create_if_missing=False):
def restore_checkpoint(paths: Paths, model: Tacotron, optimizer, *,
name=None, create_if_missing=False):
"""Restores from a training session saved to disk.
Args:
Expand All @@ -252,7 +242,7 @@ def restore_checkpoint(paths: Paths, model: Tacotron, optimizer, name=None, crea
create_if_missing: If `True`, will create the checkpoint if it doesn't
yet exist, as well as update the files specified in `paths` that
give the location of the current latest weights and optimizer state.
If `False` and the checkpoint doesn't exist, will raise a
If `False` and the checkpoint doesn't exist, will raise a
`FileNotFoundError`.
"""
if name:
Expand All @@ -267,16 +257,20 @@ def restore_checkpoint(paths: Paths, model: Tacotron, optimizer, name=None, crea
'o': paths.tts_latest_optim
}
s = 'latest'

num_exist = sum(p.exists() for p in path_dict.values())
if num_exist == 2:
# Checkpoint exists
print(f'Restoring from {s} checkpoint...')
print(f'Loading {s} weights: {path_dict["w"]}')
model.load(path_dict['w'])
print(f'Loading {s} optimizer state: {path_dict["o"]}')
optimizer.load_state_dict(torch.load())
optimizer.load_state_dict(torch.load(path_dict['o']))
elif create_if_missing:
save_checkpoint(paths, model, optimizer, name)
save_checkpoint(paths, model, optimizer, name=name, is_silent=False)
else:
raise FileNotFoundError(f'The {s} checkpoint could not be found!')


if __name__ == "__main__":
main()
2 changes: 2 additions & 0 deletions utils/dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -129,6 +129,8 @@ def get_tts_datasets(path: Path, batch_size, r):
pin_memory=True)

longest = mel_lengths.index(max(mel_lengths))

# Used to evaluate attention during training process
attn_example = dataset_ids[longest]

# print(attn_example)
Expand Down

0 comments on commit 924fbeb

Please sign in to comment.