Skip to content

Commit

Permalink
Generalise checkpointing for both model types
Browse files Browse the repository at this point in the history
  • Loading branch information
fatchord committed Sep 5, 2019
1 parent 942dddb commit 72d28b0
Show file tree
Hide file tree
Showing 3 changed files with 136 additions and 196 deletions.
102 changes: 4 additions & 98 deletions train_tacotron.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@
import time
import numpy as np
import sys
from utils.checkpoints import save_checkpoint, restore_checkpoint


def np_now(x: torch.Tensor): return x.detach().cpu().numpy()
Expand Down Expand Up @@ -61,7 +62,7 @@ def main():
stop_threshold=hp.tts_stop_threshold).to(device)

optimizer = optim.Adam(model.parameters())
restore_checkpoint(paths, model, optimizer, create_if_missing=True)
restore_checkpoint('tts', paths, model, optimizer, create_if_missing=True)

if not force_gta:
for i, session in enumerate(hp.tts_schedule):
Expand Down Expand Up @@ -156,7 +157,7 @@ def tts_train_loop(paths: Paths, model: Tacotron, optimizer, train_set, lr, trai

if step % hp.tts_checkpoint_every == 0:
ckpt_name = f'taco_step{k}K'
save_checkpoint(paths, model, optimizer,
save_checkpoint('tts', paths, model, optimizer,
name=ckpt_name, is_silent=True)

if attn_example in ids:
Expand All @@ -169,7 +170,7 @@ def tts_train_loop(paths: Paths, model: Tacotron, optimizer, train_set, lr, trai

# Must save latest optimizer state to ensure that resuming training
# doesn't produce artifacts
save_checkpoint(paths, model, optimizer, is_silent=True)
save_checkpoint('tts', paths, model, optimizer, is_silent=True)
model.log(paths.tts_log, msg)
print(' ')

Expand Down Expand Up @@ -197,100 +198,5 @@ def create_gta_features(model: Tacotron, train_set, save_path: Path):
stream(msg)


def save_checkpoint(paths: Paths, model: Tacotron, optimizer, *,
name=None, is_silent=False):
"""Saves the training session to disk.
Args:
paths: Provides information about the different paths to use.
model: A `Tacotron` model to save the parameters and buffers from.
optimizer: An optmizer to save the state of (momentum, etc).
name: If provided, will name to a checkpoint with the given name. Note
that regardless of whether this is provided or not, this function
will always update the files specified in `paths` that give the
location of the latest weights and optimizer state. Saving
a named checkpoint happens in addition to this update.
"""
def helper(path_dict, is_named):
s = 'named' if is_named else 'latest'
num_exist = sum(p.exists() for p in path_dict.values())

if num_exist not in (0,2):
# Checkpoint broken
raise FileNotFoundError(
f'We expected either both or no files in the {s} checkpoint to '
'exist, but instead we got exactly one!')

if num_exist == 0:
if not is_silent: print(f'Creating {s} checkpoint...')
for p in path_dict.values():
p.parent.mkdir(parents=True, exist_ok=True)
else:
if not is_silent: print(f'Saving to existing {s} checkpoint...')

if not is_silent: print(f'Saving {s} weights: {path_dict["w"]}')
model.save(path_dict['w'])
if not is_silent: print(f'Saving {s} optimizer state: {path_dict["o"]}')
torch.save(optimizer.state_dict(), path_dict['o'])

latest_paths = {'w': paths.tts_latest_weights, 'o': paths.tts_latest_optim}
helper(latest_paths, False)

if name:
named_paths ={
'w': paths.tts_checkpoints/f'{name}_weights.pyt',
'o': paths.tts_checkpoints/f'{name}_optim.pyt',
}
helper(named_paths, True)


def restore_checkpoint(paths: Paths, model: Tacotron, optimizer, *,
name=None, create_if_missing=False):
"""Restores from a training session saved to disk.
NOTE: The optimizer's state is placed on the same device as it's model
parameters. Therefore, be sure you have done `model.to(device)` before
calling this method.
Args:
paths: Provides information about the different paths to use.
model: A `Tacotron` model to save the parameters and buffers from.
optimizer: An optmizer to save the state of (momentum, etc).
name: If provided, will restore from a checkpoint with the given name.
Otherwise, will restore from the latest weights and optimizer state
as specified in `paths`.
create_if_missing: If `True`, will create the checkpoint if it doesn't
yet exist, as well as update the files specified in `paths` that
give the location of the current latest weights and optimizer state.
If `False` and the checkpoint doesn't exist, will raise a
`FileNotFoundError`.
"""
if name:
path_dict = {
'w': paths.tts_checkpoints/f'{name}_weights.pyt',
'o': paths.tts_checkpoints/f'{name}_optim.pyt',
}
s = 'named'
else:
path_dict = {
'w': paths.tts_latest_weights,
'o': paths.tts_latest_optim
}
s = 'latest'

num_exist = sum(p.exists() for p in path_dict.values())
if num_exist == 2:
# Checkpoint exists
print(f'Restoring from {s} checkpoint...')
print(f'Loading {s} weights: {path_dict["w"]}')
model.load(path_dict['w'])
print(f'Loading {s} optimizer state: {path_dict["o"]}')
optimizer.load_state_dict(torch.load(path_dict['o']))
elif create_if_missing:
save_checkpoint(paths, model, optimizer, name=name, is_silent=False)
else:
raise FileNotFoundError(f'The {s} checkpoint could not be found!')


if __name__ == "__main__":
main()
102 changes: 4 additions & 98 deletions train_wavernn.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
from utils.paths import Paths
import argparse
from utils import data_parallel_workaround
from utils.checkpoints import save_checkpoint, restore_checkpoint


def main():
Expand Down Expand Up @@ -67,7 +68,7 @@ def main():
assert np.cumprod(hp.voc_upsample_factors)[-1] == hp.hop_length

optimizer = optim.Adam(voc_model.parameters())
restore_checkpoint(paths, voc_model, optimizer, create_if_missing=True)
restore_checkpoint('voc', paths, voc_model, optimizer, create_if_missing=True)

train_set, test_set = get_vocoder_datasets(paths.data, batch_size, train_gta)

Expand Down Expand Up @@ -141,113 +142,18 @@ def voc_train_loop(paths: Paths, model: WaveRNN, loss_func, optimizer, train_set
gen_testset(model, test_set, hp.voc_gen_at_checkpoint, hp.voc_gen_batched,
hp.voc_target, hp.voc_overlap, paths.voc_output)
ckpt_name = f'wave_step{k}K'
save_checkpoint(paths, model, optimizer,
save_checkpoint('voc', paths, model, optimizer,
name=ckpt_name, is_silent=True)

msg = f'| Epoch: {e}/{epochs} ({i}/{total_iters}) | Loss: {avg_loss:.4f} | {speed:.1f} steps/s | Step: {k}k | '
stream(msg)

# Must save latest optimizer state to ensure that resuming training
# doesn't produce artifacts
save_checkpoint(paths, model, optimizer, is_silent=True)
save_checkpoint('voc', paths, model, optimizer, is_silent=True)
model.log(paths.voc_log, msg)
print(' ')


def save_checkpoint(paths: Paths, model: WaveRNN, optimizer, *,
name=None, is_silent=False):
"""Saves the training session to disk.
Args:
paths: Provides information about the different paths to use.
model: A `WaveRNN` model to save the parameters and buffers from.
optimizer: An optmizer to save the state of (momentum, etc).
name: If provided, will name to a checkpoint with the given name. Note
that regardless of whether this is provided or not, this function
will always update the files specified in `paths` that give the
location of the latest weights and optimizer state. Saving
a named checkpoint happens in addition to this update.
"""
def helper(path_dict, is_named):
s = 'named' if is_named else 'latest'
num_exist = sum(p.exists() for p in path_dict.values())

if num_exist not in (0,2):
# Checkpoint broken
raise FileNotFoundError(
f'We expected either both or no files in the {s} checkpoint to '
'exist, but instead we got exactly one!')

if num_exist == 0:
if not is_silent: print(f'Creating {s} checkpoint...')
for p in path_dict.values():
p.parent.mkdir(parents=True, exist_ok=True)
else:
if not is_silent: print(f'Saving to existing {s} checkpoint...')

if not is_silent: print(f'Saving {s} weights: {path_dict["w"]}')
model.save(path_dict['w'])
if not is_silent: print(f'Saving {s} optimizer state: {path_dict["o"]}')
torch.save(optimizer.state_dict(), path_dict['o'])

latest_paths = {'w': paths.voc_latest_weights, 'o': paths.voc_latest_optim}
helper(latest_paths, False)

if name:
named_paths ={
'w': paths.voc_checkpoints/f'{name}_weights.pyt',
'o': paths.voc_checkpoints/f'{name}_optim.pyt',
}
helper(named_paths, True)


def restore_checkpoint(paths: Paths, model: WaveRNN, optimizer, *,
name=None, create_if_missing=False):
"""Restores from a training session saved to disk.
NOTE: The optimizer's state is placed on the same device as it's model
parameters. Therefore, be sure you have done `model.to(device)` before
calling this method.
Args:
paths: Provides information about the different paths to use.
model: A `WaveRNN` model to save the parameters and buffers from.
optimizer: An optmizer to save the state of (momentum, etc).
name: If provided, will restore from a checkpoint with the given name.
Otherwise, will restore from the latest weights and optimizer state
as specified in `paths`.
create_if_missing: If `True`, will create the checkpoint if it doesn't
yet exist, as well as update the files specified in `paths` that
give the location of the current latest weights and optimizer state.
If `False` and the checkpoint doesn't exist, will raise a
`FileNotFoundError`.
"""
if name:
path_dict = {
'w': paths.voc_checkpoints/f'{name}_weights.pyt',
'o': paths.voc_checkpoints/f'{name}_optim.pyt',
}
s = 'named'
else:
path_dict = {
'w': paths.voc_latest_weights,
'o': paths.voc_latest_optim
}
s = 'latest'

num_exist = sum(p.exists() for p in path_dict.values())
if num_exist == 2:
# Checkpoint exists
print(f'Restoring from {s} checkpoint...')
print(f'Loading {s} weights: {path_dict["w"]}')
model.load(path_dict['w'])
print(f'Loading {s} optimizer state: {path_dict["o"]}')
optimizer.load_state_dict(torch.load(path_dict['o']))
elif create_if_missing:
save_checkpoint(paths, model, optimizer, name=name, is_silent=False)
else:
raise FileNotFoundError(f'The {s} checkpoint could not be found!')


if __name__ == "__main__":
main()
128 changes: 128 additions & 0 deletions utils/checkpoints.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,128 @@
import torch
from utils.paths import Paths
from models.tacotron import Tacotron


def get_checkpoint_paths(checkpoint_type: str, paths: Paths):
"""
Returns the correct checkpointing paths
depending on whether model is Vocoder or TTS
Args:
checkpoint_type: Either 'voc' or 'tts'
paths: Paths object
"""
if checkpoint_type is 'tts':
weights_path = paths.tts_latest_weights
optim_path = paths.tts_latest_optim
checkpoint_path = paths.tts_checkpoints
elif checkpoint_type is 'voc':
weights_path = paths.voc_latest_weights
optim_path = paths.voc_latest_optim
checkpoint_path = paths.voc_checkpoints
else:
raise NotImplementedError

return weights_path, optim_path, checkpoint_path


def save_checkpoint(checkpoint_type: str, paths: Paths, model, optimizer, *,
name=None, is_silent=False):
"""Saves the training session to disk.
Args:
paths: Provides information about the different paths to use.
model: A `Tacotron` or `WaveRNN` model to save the parameters and buffers from.
optimizer: An optmizer to save the state of (momentum, etc).
name: If provided, will name to a checkpoint with the given name. Note
that regardless of whether this is provided or not, this function
will always update the files specified in `paths` that give the
location of the latest weights and optimizer state. Saving
a named checkpoint happens in addition to this update.
"""
def helper(path_dict, is_named):
s = 'named' if is_named else 'latest'
num_exist = sum(p.exists() for p in path_dict.values())

if num_exist not in (0,2):
# Checkpoint broken
raise FileNotFoundError(
f'We expected either both or no files in the {s} checkpoint to '
'exist, but instead we got exactly one!')

if num_exist == 0:
if not is_silent: print(f'Creating {s} checkpoint...')
for p in path_dict.values():
p.parent.mkdir(parents=True, exist_ok=True)
else:
if not is_silent: print(f'Saving to existing {s} checkpoint...')

if not is_silent: print(f'Saving {s} weights: {path_dict["w"]}')
model.save(path_dict['w'])
if not is_silent: print(f'Saving {s} optimizer state: {path_dict["o"]}')
torch.save(optimizer.state_dict(), path_dict['o'])

weights_path, optim_path, checkpoint_path = \
get_checkpoint_paths(checkpoint_type, paths)

latest_paths = {'w': weights_path, 'o': optim_path}
helper(latest_paths, False)

if name:
named_paths = {
'w': checkpoint_path/f'{name}_weights.pyt',
'o': checkpoint_path/f'{name}_optim.pyt',
}
helper(named_paths, True)


def restore_checkpoint(checkpoint_type: str, paths: Paths, model, optimizer, *,
name=None, create_if_missing=False):
"""Restores from a training session saved to disk.
NOTE: The optimizer's state is placed on the same device as it's model
parameters. Therefore, be sure you have done `model.to(device)` before
calling this method.
Args:
paths: Provides information about the different paths to use.
model: A `Tacotron` or `WaveRNN` model to save the parameters and buffers from.
optimizer: An optmizer to save the state of (momentum, etc).
name: If provided, will restore from a checkpoint with the given name.
Otherwise, will restore from the latest weights and optimizer state
as specified in `paths`.
create_if_missing: If `True`, will create the checkpoint if it doesn't
yet exist, as well as update the files specified in `paths` that
give the location of the current latest weights and optimizer state.
If `False` and the checkpoint doesn't exist, will raise a
`FileNotFoundError`.
"""

weights_path, optim_path, checkpoint_path = \
get_checkpoint_paths(checkpoint_type, paths)

if name:
path_dict = {
'w': checkpoint_path/f'{name}_weights.pyt',
'o': checkpoint_path/f'{name}_optim.pyt',
}
s = 'named'
else:
path_dict = {
'w': weights_path,
'o': optim_path
}
s = 'latest'

num_exist = sum(p.exists() for p in path_dict.values())
if num_exist == 2:
# Checkpoint exists
print(f'Restoring from {s} checkpoint...')
print(f'Loading {s} weights: {path_dict["w"]}')
model.load(path_dict['w'])
print(f'Loading {s} optimizer state: {path_dict["o"]}')
optimizer.load_state_dict(torch.load(path_dict['o']))
elif create_if_missing:
save_checkpoint(checkpoint_type, paths, model, optimizer, name=name, is_silent=False)
else:
raise FileNotFoundError(f'The {s} checkpoint could not be found!')

0 comments on commit 72d28b0

Please sign in to comment.