Skip to content

Commit

Permalink
Added note to call model.to(device) after restore_checkpoint()
Browse files Browse the repository at this point in the history
  • Loading branch information
TheButlah committed Aug 7, 2019
1 parent d30754f commit d2fc0a5
Show file tree
Hide file tree
Showing 3 changed files with 16 additions and 8 deletions.
4 changes: 4 additions & 0 deletions train_tacotron.py
Original file line number Diff line number Diff line change
Expand Up @@ -232,6 +232,10 @@ def restore_checkpoint(paths: Paths, model: Tacotron, optimizer, *,
name=None, create_if_missing=False):
"""Restores from a training session saved to disk.
NOTE: The optimizer's state is placed on the same device as it's model
parameters. Therefore, be sure you have done `model.to(device)` before
calling this method.
Args:
paths: Provides information about the different paths to use.
model: A `Tacotron` model to save the parameters and buffers from.
Expand Down
5 changes: 4 additions & 1 deletion train_wavernn.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,6 @@
from utils.paths import Paths
import argparse
from utils import data_parallel_workaround
import os


def main():
Expand Down Expand Up @@ -206,6 +205,10 @@ def restore_checkpoint(paths: Paths, model: WaveRNN, optimizer, *,
name=None, create_if_missing=False):
"""Restores from a training session saved to disk.
NOTE: The optimizer's state is placed on the same device as it's model
parameters. Therefore, be sure you have done `model.to(device)` before
calling this method.
Args:
paths: Provides information about the different paths to use.
model: A `WaveRNN` model to save the parameters and buffers from.
Expand Down
15 changes: 8 additions & 7 deletions utils/paths.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,21 +6,21 @@ class Paths:
"""Manages and configures the paths used by WaveRNN, Tacotron, and the data."""
def __init__(self, data_path, voc_id, tts_id):
self.base = Path(__file__).parent.parent.expanduser().resolve()

# Data Paths
self.data = Path(data_path).expanduser().resolve()
self.quant = self.data/'quant'
self.mel = self.data/'mel'
self.gta = self.data/'gta'

# WaveRNN/Vocoder Paths
self.voc_checkpoints = self.base/'checkpoints'/f'{voc_id}.wavernn'
self.voc_latest_weights = self.voc_checkpoints/'latest_weights.pyt'
self.voc_latest_optim = self.voc_checkpoints/'latest_optim.pyt'
self.voc_output = self.base/'model_outputs'/f'{voc_id}.wavernn'
self.voc_step = self.voc_checkpoints/'step.npy'
self.voc_log = self.voc_checkpoints/'log.txt'

# Tactron/TTS Paths
self.tts_checkpoints = self.base/'checkpoints'/f'{tts_id}.tacotron'
self.tts_latest_weights = self.tts_checkpoints/'latest_weights.pyt'
Expand All @@ -30,6 +30,7 @@ def __init__(self, data_path, voc_id, tts_id):
self.tts_log = self.tts_checkpoints/'log.txt'
self.tts_attention = self.tts_checkpoints/'attention'
self.tts_mel_plot = self.tts_checkpoints/'mel_plots'

self.create_paths()

def create_paths(self):
Expand All @@ -43,21 +44,21 @@ def create_paths(self):
os.makedirs(self.tts_output, exist_ok=True)
os.makedirs(self.tts_attention, exist_ok=True)
os.makedirs(self.tts_mel_plot, exist_ok=True)

def get_tts_named_weights(self, name):
"""Gets the path for the weights in a named tts checkpoint."""
return self.tts_checkpoints/f'{name}_weights.pyt'

def get_tts_named_optim(self, name):
"""Gets the path for the optimizer state in a named tts checkpoint."""
return self.tts_checkpoints/f'{name}_optim.pyt'

def get_voc_named_weights(self, name):
"""Gets the path for the weights in a named voc checkpoint."""
return self.voc_checkpoints/f'{name}_weights.pyt'

def get_voc_named_optim(self, name):
"""Gets the path for the optimizer state in a named voc checkpoint."""
return self.voc_checkpoints/f'{name}_optim.pyt'


0 comments on commit d2fc0a5

Please sign in to comment.