From f62bbc563bd5212724e13ec2b34d2565130d3ff3 Mon Sep 17 00:00:00 2001 From: Ryan Butler Date: Fri, 9 Aug 2019 21:01:54 -0700 Subject: [PATCH] Added tts_stop_threshold hparam --- gen_tacotron.py | 3 ++- hparams.py | 6 ++++-- models/tacotron.py | 5 +++-- quick_start.py | 3 ++- train_tacotron.py | 3 ++- 5 files changed, 13 insertions(+), 7 deletions(-) diff --git a/gen_tacotron.py b/gen_tacotron.py index 38683bb9..04a6eb38 100644 --- a/gen_tacotron.py +++ b/gen_tacotron.py @@ -102,7 +102,8 @@ lstm_dims=hp.tts_lstm_dims, postnet_K=hp.tts_postnet_K, num_highways=hp.tts_num_highways, - dropout=hp.tts_dropout).to(device) + dropout=hp.tts_dropout, + stop_threshold=hp.tts_stop_threshold).to(device) tts_restore_path = weights_path if weights_path else paths.tts_latest_weights tts_model.load(tts_restore_path) diff --git a/hparams.py b/hparams.py index f645e650..61ea943c 100644 --- a/hparams.py +++ b/hparams.py @@ -64,7 +64,6 @@ # Model Hparams -tts_r = 1 # model predicts r frames per output step tts_embed_dims = 256 # embedding dimension for the graphemes/phoneme inputs tts_encoder_dims = 128 tts_decoder_dims = 256 @@ -75,10 +74,13 @@ tts_num_highways = 4 tts_dropout = 0.5 tts_cleaner_names = ['english_cleaners'] +tts_stop_threshold = -3.4 # Value below which audio generation ends. + # For example, for a range of [-4, 4], this + # will terminate the sequence at the first + # frame that has all values < -3.4 # Training - tts_schedule = [(7, 1e-3, 10_000, 32), # progressive training schedule (5, 1e-4, 100_000, 32), # (r, lr, step, batch_size) (2, 1e-4, 180_000, 16), diff --git a/models/tacotron.py b/models/tacotron.py index 3a4df18d..1c5e1d1f 100644 --- a/models/tacotron.py +++ b/models/tacotron.py @@ -281,7 +281,7 @@ def forward(self, encoder_seq, encoder_seq_proj, prenet_in, class Tacotron(nn.Module): def __init__(self, embed_dims, num_chars, encoder_dims, decoder_dims, n_mels, fft_bins, postnet_dims, - encoder_K, lstm_dims, postnet_K, num_highways, dropout): + encoder_K, lstm_dims, postnet_K, num_highways, dropout, stop_threshold): super().__init__() self.n_mels = n_mels self.lstm_dims = lstm_dims @@ -297,6 +297,7 @@ def __init__(self, embed_dims, num_chars, encoder_dims, decoder_dims, n_mels, ff self.num_params() self.register_buffer('step', torch.zeros(1, dtype=torch.long)) + self.register_buffer('stop_threshold', torch.tensor(stop_threshold, dtype=torch.float32)) @property def r(self): @@ -407,7 +408,7 @@ def generate(self, x, steps=2000): mel_outputs.append(mel_frames) attn_scores.append(scores) # Stop the loop if silent frames present - if (mel_frames < -3.8).all() and t > 10: break + if (mel_frames < self.stop_threshold).all() and t > 10: break # Concat the mel outputs into sequence mel_outputs = torch.cat(mel_outputs, dim=2) diff --git a/quick_start.py b/quick_start.py index 98ece28f..f964062a 100644 --- a/quick_start.py +++ b/quick_start.py @@ -82,7 +82,8 @@ lstm_dims=hp.tts_lstm_dims, postnet_K=hp.tts_postnet_K, num_highways=hp.tts_num_highways, - dropout=hp.tts_dropout).to(device) + dropout=hp.tts_dropout, + stop_threshold=hp.tts_stop_threshold).to(device) tts_model.restore('quick_start/tts_weights/latest_weights.pyt') diff --git a/train_tacotron.py b/train_tacotron.py index 18b79e95..7d5c1f23 100644 --- a/train_tacotron.py +++ b/train_tacotron.py @@ -56,7 +56,8 @@ def main(): lstm_dims=hp.tts_lstm_dims, postnet_K=hp.tts_postnet_K, num_highways=hp.tts_num_highways, - dropout=hp.tts_dropout).to(device) + dropout=hp.tts_dropout, + stop_threshold=hp.tts_stop_threshold).to(device) optimizer = optim.Adam(model.parameters()) restore_checkpoint(paths, model, optimizer, create_if_missing=True)