From f62bbc563bd5212724e13ec2b34d2565130d3ff3 Mon Sep 17 00:00:00 2001
From: Ryan Butler
Date: Fri, 9 Aug 2019 21:01:54 -0700
Subject: [PATCH] Added tts_stop_threshold hparam
---
gen_tacotron.py | 3 ++-
hparams.py | 6 ++++--
models/tacotron.py | 5 +++--
quick_start.py | 3 ++-
train_tacotron.py | 3 ++-
5 files changed, 13 insertions(+), 7 deletions(-)
diff --git a/gen_tacotron.py b/gen_tacotron.py
index 38683bb9..04a6eb38 100644
--- a/gen_tacotron.py
+++ b/gen_tacotron.py
@@ -102,7 +102,8 @@
lstm_dims=hp.tts_lstm_dims,
postnet_K=hp.tts_postnet_K,
num_highways=hp.tts_num_highways,
- dropout=hp.tts_dropout).to(device)
+ dropout=hp.tts_dropout,
+ stop_threshold=hp.tts_stop_threshold).to(device)
tts_restore_path = weights_path if weights_path else paths.tts_latest_weights
tts_model.load(tts_restore_path)
diff --git a/hparams.py b/hparams.py
index f645e650..61ea943c 100644
--- a/hparams.py
+++ b/hparams.py
@@ -64,7 +64,6 @@
# Model Hparams
-tts_r = 1 # model predicts r frames per output step
tts_embed_dims = 256 # embedding dimension for the graphemes/phoneme inputs
tts_encoder_dims = 128
tts_decoder_dims = 256
@@ -75,10 +74,13 @@
tts_num_highways = 4
tts_dropout = 0.5
tts_cleaner_names = ['english_cleaners']
+tts_stop_threshold = -3.4 # Value below which audio generation ends.
+ # For example, for a range of [-4, 4], this
+ # will terminate the sequence at the first
+ # frame that has all values < -3.4
# Training
-
tts_schedule = [(7, 1e-3, 10_000, 32), # progressive training schedule
(5, 1e-4, 100_000, 32), # (r, lr, step, batch_size)
(2, 1e-4, 180_000, 16),
diff --git a/models/tacotron.py b/models/tacotron.py
index 3a4df18d..1c5e1d1f 100644
--- a/models/tacotron.py
+++ b/models/tacotron.py
@@ -281,7 +281,7 @@ def forward(self, encoder_seq, encoder_seq_proj, prenet_in,
class Tacotron(nn.Module):
def __init__(self, embed_dims, num_chars, encoder_dims, decoder_dims, n_mels, fft_bins, postnet_dims,
- encoder_K, lstm_dims, postnet_K, num_highways, dropout):
+ encoder_K, lstm_dims, postnet_K, num_highways, dropout, stop_threshold):
super().__init__()
self.n_mels = n_mels
self.lstm_dims = lstm_dims
@@ -297,6 +297,7 @@ def __init__(self, embed_dims, num_chars, encoder_dims, decoder_dims, n_mels, ff
self.num_params()
self.register_buffer('step', torch.zeros(1, dtype=torch.long))
+ self.register_buffer('stop_threshold', torch.tensor(stop_threshold, dtype=torch.float32))
@property
def r(self):
@@ -407,7 +408,7 @@ def generate(self, x, steps=2000):
mel_outputs.append(mel_frames)
attn_scores.append(scores)
# Stop the loop if silent frames present
- if (mel_frames < -3.8).all() and t > 10: break
+ if (mel_frames < self.stop_threshold).all() and t > 10: break
# Concat the mel outputs into sequence
mel_outputs = torch.cat(mel_outputs, dim=2)
diff --git a/quick_start.py b/quick_start.py
index 98ece28f..f964062a 100644
--- a/quick_start.py
+++ b/quick_start.py
@@ -82,7 +82,8 @@
lstm_dims=hp.tts_lstm_dims,
postnet_K=hp.tts_postnet_K,
num_highways=hp.tts_num_highways,
- dropout=hp.tts_dropout).to(device)
+ dropout=hp.tts_dropout,
+ stop_threshold=hp.tts_stop_threshold).to(device)
tts_model.restore('quick_start/tts_weights/latest_weights.pyt')
diff --git a/train_tacotron.py b/train_tacotron.py
index 18b79e95..7d5c1f23 100644
--- a/train_tacotron.py
+++ b/train_tacotron.py
@@ -56,7 +56,8 @@ def main():
lstm_dims=hp.tts_lstm_dims,
postnet_K=hp.tts_postnet_K,
num_highways=hp.tts_num_highways,
- dropout=hp.tts_dropout).to(device)
+ dropout=hp.tts_dropout,
+ stop_threshold=hp.tts_stop_threshold).to(device)
optimizer = optim.Adam(model.parameters())
restore_checkpoint(paths, model, optimizer, create_if_missing=True)