Skip to content

Commit

Permalink
Fixed train_wavernn and inference scripts to use new ckpt code
Browse files Browse the repository at this point in the history
  • Loading branch information
TheButlah committed Aug 7, 2019
1 parent 697ec9c commit e185a96
Show file tree
Hide file tree
Showing 5 changed files with 206 additions and 116 deletions.
20 changes: 15 additions & 5 deletions gen_tacotron.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,14 +20,24 @@
parser.add_argument('--weights_path', '-w', type=str, help='[string/path] Load in different Tacotron Weights')
parser.add_argument('--save_attention', '-a', dest='save_attn', action='store_true', help='Save Attention Plots')
parser.add_argument('--force_cpu', '-c', action='store_true', help='Forces CPU-only training, even when in CUDA capable environment')
parser.set_defaults(batched=hp.voc_gen_batched)
parser.set_defaults(target=hp.voc_target)
parser.set_defaults(overlap=hp.voc_overlap)
parser.add_argument('--hp_file', metavar='FILE', default='hparams.py', help='The file to use for the hyperparameters')

parser.set_defaults(batched=None)
parser.set_defaults(input_text=None)
parser.set_defaults(weights_path=None)
parser.set_defaults(save_attention=False)

args = parser.parse_args()

hp.configure(args.hp_file) # Load hparams from file
# set defaults for any arguments that depend on hparams
if args.target is None:
args.target = hp.voc_target
if args.overlap is None:
args.overlap = hp.voc_overlap
if args.batched is None:
args.batched = hp.voc_gen_batched

batched = args.batched
target = args.target
overlap = args.overlap
Expand Down Expand Up @@ -59,7 +69,7 @@
sample_rate=hp.sample_rate,
mode=hp.voc_mode).to(device)

voc_model.restore(paths.voc_latest_weights)
voc_model.load(paths.voc_latest_weights)

print('\nInitialising Tacotron Model...\n')

Expand All @@ -78,7 +88,7 @@
dropout=hp.tts_dropout).to(device)

tts_restore_path = weights_path if weights_path else paths.tts_latest_weights
tts_model.restore(tts_restore_path)
tts_model.load(tts_restore_path)

if input_text:
inputs = [text_to_sequence(input_text.strip(), hp.tts_cleaner_names)]
Expand Down
19 changes: 14 additions & 5 deletions gen_wavernn.py
Original file line number Diff line number Diff line change
Expand Up @@ -63,17 +63,26 @@ def gen_from_file(model: WaveRNN, load_path, save_path, batched, target, overlap
parser.add_argument('--weights', '-w', type=str, help='[string/path] checkpoint file to load weights from')
parser.add_argument('--gta', '-g', dest='gta', action='store_true', help='Generate from GTA testset')
parser.add_argument('--force_cpu', '-c', action='store_true', help='Forces CPU-only training, even when in CUDA capable environment')
parser.add_argument('--hp_file', metavar='FILE', default='hparams.py', help='The file to use for the hyperparameters')

parser.set_defaults(batched=hp.voc_gen_batched)
parser.set_defaults(samples=hp.voc_gen_at_checkpoint)
parser.set_defaults(target=hp.voc_target)
parser.set_defaults(overlap=hp.voc_overlap)
parser.set_defaults(file=None)
parser.set_defaults(weights=None)
parser.set_defaults(gta=False)
parser.set_defaults(batched=None)

args = parser.parse_args()

hp.configure(args.hp_file) # Load hparams from file
# set defaults for any arguments that depend on hparams
if args.target is None:
args.target = hp.voc_target
if args.overlap is None:
args.overlap = hp.voc_overlap
if args.batched is None:
args.batched = hp.voc_gen_batched
if args.samples is None:
args.samples = hp.voc_gen_at_checkpoint

batched = args.batched
samples = args.samples
target = args.target
Expand Down Expand Up @@ -106,7 +115,7 @@ def gen_from_file(model: WaveRNN, load_path, save_path, batched, target, overlap

restore_path = args.weights if args.weights else paths.voc_latest_weights

model.restore(restore_path)
model.load(restore_path)

simple_table([('Generation Mode', 'Batched' if batched else 'Unbatched'),
('Target Samples', target if batched else 'N/A'),
Expand Down
28 changes: 6 additions & 22 deletions models/fatchord_version.py
Original file line number Diff line number Diff line change
Expand Up @@ -105,19 +105,19 @@ def __init__(self, rnn_dims, fc_dims, bits, pad, upsample_factors,

# List of rnns to call `flatten_parameters()` on
self._to_flatten = []

self.rnn_dims = rnn_dims
self.aux_dims = res_out_dims // 4
self.hop_length = hop_length
self.sample_rate = sample_rate

self.upsample = UpsampleNetwork(feat_dims, upsample_factors, compute_dims, res_blocks, res_out_dims, pad)
self.I = nn.Linear(feat_dims + self.aux_dims + 1, rnn_dims)

self.rnn1 = nn.GRU(rnn_dims, rnn_dims, batch_first=True)
self.rnn2 = nn.GRU(rnn_dims + self.aux_dims, rnn_dims, batch_first=True)
self._to_flatten += [self.rnn1, self.rnn2]

self.fc1 = nn.Linear(rnn_dims + self.aux_dims, fc_dims)
self.fc2 = nn.Linear(fc_dims + self.aux_dims, fc_dims)
self.fc3 = nn.Linear(fc_dims, self.n_classes)
Expand All @@ -131,11 +131,11 @@ def __init__(self, rnn_dims, fc_dims, bits, pad, upsample_factors,
def forward(self, x, mels):
device = next(self.parameters()).device # use same device as parameters

# Although we `_flatten_parameters()` on init, when using DataParallel
# Although we `_flatten_parameters()` on init, when using DataParallel
# the model gets replicated, making it no longer guaranteed that the
# weights are contiguous in GPU memory. Hence, we must call it again
self._flatten_parameters()

self.step += 1
bsize = x.size(0)
h1 = torch.zeros(1, bsize, self.rnn_dims, device=device)
Expand Down Expand Up @@ -168,7 +168,7 @@ def forward(self, x, mels):

def generate(self, mels, save_path: Union[str, Path], batched, target, overlap, mu_law):
self.eval()

device = next(self.parameters()).device # use same device as parameters

mu_law = mu_law if self.mode == 'RAW' else False
Expand Down Expand Up @@ -406,26 +406,10 @@ def xfade_and_unfold(self, y, target, overlap):
def get_step(self):
return self.step.data.item()

def checkpoint(self, path: Union[str, Path], optimizer):
# Optimizer can be given as an argument because checkpoint function is
# only useful in context of already existing training process.
if isinstance(path, str): path = Path(path)
k_steps = self.get_step() // 1000
self.save(path/f'checkpoint_{k_steps}k_steps.pyt')
torch.save(optimizer.state_dict(), path/f'checkpoint_{k_steps}k_steps_optim.pyt')

def log(self, path, msg):
with open(path, 'a') as f:
print(msg, file=f)

def restore(self, path: Union[str, Path]):
if not os.path.exists(path):
print('\nNew WaveRNN Training Session...\n')
self.save(path)
else:
print(f'\nLoading Weights: "{path}"\n')
self.load(path)

def load(self, path: Union[str, Path]):
# Use device of model params as location for loaded state
device = next(self.parameters()).device
Expand Down
18 changes: 8 additions & 10 deletions train_tacotron.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,7 @@ def main():
parser.add_argument('--hp_file', metavar='FILE', default='hparams.py', help='The file to use for the hyperparameters')
args = parser.parse_args()

hp.configure(args.hp_file) # Load hparams from file.
hp.configure(args.hp_file) # Load hparams from file
paths = Paths(hp.data_path, hp.voc_model_id, hp.tts_model_id)

force_train = args.force_train
Expand Down Expand Up @@ -56,7 +56,7 @@ def main():
lstm_dims=hp.tts_lstm_dims,
postnet_K=hp.tts_postnet_K,
num_highways=hp.tts_num_highways,
dropout=hp.tts_dropout).to(device=device)
dropout=hp.tts_dropout).to(device)

optimizer = optim.Adam(model.parameters())
restore_checkpoint(paths, model, optimizer, create_if_missing=True)
Expand Down Expand Up @@ -92,7 +92,7 @@ def main():
print('\n\nYou can now train WaveRNN on GTA features - use python train_wavernn.py --gta\n')


def tts_train_loop(paths, model: Tacotron, optimizer, train_set, lr, train_steps, attn_example):
def tts_train_loop(paths: Paths, model: Tacotron, optimizer, train_set, lr, train_steps, attn_example):
device = next(model.parameters()).device # use same device as model parameters

for g in optimizer.param_groups: g['lr'] = lr
Expand All @@ -108,8 +108,6 @@ def tts_train_loop(paths, model: Tacotron, optimizer, train_set, lr, train_steps
# Perform 1 epoch
for i, (x, m, ids, _) in enumerate(train_set, 1):

optimizer.zero_grad()

x, m = x.to(device), m.to(device)

# Parallelize model onto GPUS using workaround due to python bug
Expand All @@ -123,8 +121,7 @@ def tts_train_loop(paths, model: Tacotron, optimizer, train_set, lr, train_steps

loss = m1_loss + m2_loss

running_loss += loss.item()

optimizer.zero_grad()
loss.backward()
if hp.tts_clip_grad_norm is not None:
grad_norm = torch.nn.utils.clip_grad_norm_(model.parameters(), hp.tts_clip_grad_norm)
Expand All @@ -133,12 +130,13 @@ def tts_train_loop(paths, model: Tacotron, optimizer, train_set, lr, train_steps

optimizer.step()

step = model.get_step()
k = step // 1000
running_loss += loss.item()
avg_loss = running_loss / i

speed = i / (time.time() - start)

avg_loss = running_loss / i
step = model.get_step()
k = step // 1000

if step % hp.tts_checkpoint_every == 0:
ckpt_name = f'taco_step{k}K'
Expand Down
Loading

0 comments on commit e185a96

Please sign in to comment.