Skip to content

Commit

Permalink
Added tts_stop_threshold hparam
Browse files Browse the repository at this point in the history
  • Loading branch information
TheButlah committed Aug 10, 2019
1 parent 8652895 commit f62bbc5
Show file tree
Hide file tree
Showing 5 changed files with 13 additions and 7 deletions.
3 changes: 2 additions & 1 deletion gen_tacotron.py
Original file line number Diff line number Diff line change
Expand Up @@ -102,7 +102,8 @@
lstm_dims=hp.tts_lstm_dims,
postnet_K=hp.tts_postnet_K,
num_highways=hp.tts_num_highways,
dropout=hp.tts_dropout).to(device)
dropout=hp.tts_dropout,
stop_threshold=hp.tts_stop_threshold).to(device)

tts_restore_path = weights_path if weights_path else paths.tts_latest_weights
tts_model.load(tts_restore_path)
Expand Down
6 changes: 4 additions & 2 deletions hparams.py
Original file line number Diff line number Diff line change
Expand Up @@ -64,7 +64,6 @@


# Model Hparams
tts_r = 1 # model predicts r frames per output step
tts_embed_dims = 256 # embedding dimension for the graphemes/phoneme inputs
tts_encoder_dims = 128
tts_decoder_dims = 256
Expand All @@ -75,10 +74,13 @@
tts_num_highways = 4
tts_dropout = 0.5
tts_cleaner_names = ['english_cleaners']
tts_stop_threshold = -3.4 # Value below which audio generation ends.
# For example, for a range of [-4, 4], this
# will terminate the sequence at the first
# frame that has all values < -3.4

# Training


tts_schedule = [(7, 1e-3, 10_000, 32), # progressive training schedule
(5, 1e-4, 100_000, 32), # (r, lr, step, batch_size)
(2, 1e-4, 180_000, 16),
Expand Down
5 changes: 3 additions & 2 deletions models/tacotron.py
Original file line number Diff line number Diff line change
Expand Up @@ -281,7 +281,7 @@ def forward(self, encoder_seq, encoder_seq_proj, prenet_in,

class Tacotron(nn.Module):
def __init__(self, embed_dims, num_chars, encoder_dims, decoder_dims, n_mels, fft_bins, postnet_dims,
encoder_K, lstm_dims, postnet_K, num_highways, dropout):
encoder_K, lstm_dims, postnet_K, num_highways, dropout, stop_threshold):
super().__init__()
self.n_mels = n_mels
self.lstm_dims = lstm_dims
Expand All @@ -297,6 +297,7 @@ def __init__(self, embed_dims, num_chars, encoder_dims, decoder_dims, n_mels, ff
self.num_params()

self.register_buffer('step', torch.zeros(1, dtype=torch.long))
self.register_buffer('stop_threshold', torch.tensor(stop_threshold, dtype=torch.float32))

@property
def r(self):
Expand Down Expand Up @@ -407,7 +408,7 @@ def generate(self, x, steps=2000):
mel_outputs.append(mel_frames)
attn_scores.append(scores)
# Stop the loop if silent frames present
if (mel_frames < -3.8).all() and t > 10: break
if (mel_frames < self.stop_threshold).all() and t > 10: break

# Concat the mel outputs into sequence
mel_outputs = torch.cat(mel_outputs, dim=2)
Expand Down
3 changes: 2 additions & 1 deletion quick_start.py
Original file line number Diff line number Diff line change
Expand Up @@ -82,7 +82,8 @@
lstm_dims=hp.tts_lstm_dims,
postnet_K=hp.tts_postnet_K,
num_highways=hp.tts_num_highways,
dropout=hp.tts_dropout).to(device)
dropout=hp.tts_dropout,
stop_threshold=hp.tts_stop_threshold).to(device)


tts_model.restore('quick_start/tts_weights/latest_weights.pyt')
Expand Down
3 changes: 2 additions & 1 deletion train_tacotron.py
Original file line number Diff line number Diff line change
Expand Up @@ -56,7 +56,8 @@ def main():
lstm_dims=hp.tts_lstm_dims,
postnet_K=hp.tts_postnet_K,
num_highways=hp.tts_num_highways,
dropout=hp.tts_dropout).to(device)
dropout=hp.tts_dropout,
stop_threshold=hp.tts_stop_threshold).to(device)

optimizer = optim.Adam(model.parameters())
restore_checkpoint(paths, model, optimizer, create_if_missing=True)
Expand Down

0 comments on commit f62bbc5

Please sign in to comment.