-
Notifications
You must be signed in to change notification settings - Fork 0
/
StyleSpeech.py
339 lines (280 loc) · 13.2 KB
/
StyleSpeech.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
import torch
import torch.nn as nn
import numpy as np
from text.symbols import symbols
import models.Constants as Constants
from models.Modules import Mish, LinearNorm, ConvNorm, Conv1dGLU, \
MultiHeadAttention, StyleAdaptiveLayerNorm, get_sinusoid_encoding_table
from models.VarianceAdaptor import VarianceAdaptor
from models.Loss import StyleSpeechLoss
from utils import get_mask_from_lengths
class StyleSpeech(nn.Module):
''' StyleSpeech '''
def __init__(self, config):
super(StyleSpeech, self).__init__()
self.style_encoder = MelStyleEncoder(config)
self.encoder = Encoder(config)
self.variance_adaptor = VarianceAdaptor(config)
self.decoder = Decoder(config)
def parse_batch(self, batch):
sid = torch.from_numpy(batch["sid"]).long().cuda()
text = torch.from_numpy(batch["text"]).long().cuda()
mel_target = torch.from_numpy(batch["mel_target"]).float().cuda()
D = torch.from_numpy(batch["D"]).long().cuda()
log_D = torch.from_numpy(batch["log_D"]).float().cuda()
f0 = torch.from_numpy(batch["f0"]).float().cuda()
energy = torch.from_numpy(batch["energy"]).float().cuda()
src_len = torch.from_numpy(batch["src_len"]).long().cuda()
mel_len = torch.from_numpy(batch["mel_len"]).long().cuda()
max_src_len = np.max(batch["src_len"]).astype(np.int32)
max_mel_len = np.max(batch["mel_len"]).astype(np.int32)
return sid, text, mel_target, D, log_D, f0, energy, src_len, mel_len, max_src_len, max_mel_len
def forward(self, src_seq, src_len, mel_target, mel_len=None,
d_target=None, p_target=None, e_target=None, max_src_len=None, max_mel_len=None):
src_mask = get_mask_from_lengths(src_len, max_src_len)
mel_mask = get_mask_from_lengths(mel_len, max_mel_len) if mel_len is not None else None
# Extract Style Vector
style_vector = self.style_encoder(mel_target, mel_mask)
# Encoding
encoder_output, src_embedded, _ = self.encoder(src_seq, style_vector, src_mask)
# Variance Adaptor
acoustic_adaptor_output, d_prediction, p_prediction, e_prediction, mel_len, mel_mask = self.variance_adaptor(
encoder_output, src_mask, mel_len, mel_mask,
d_target, p_target, e_target, max_mel_len)
# Deocoding
mel_prediction, _ = self.decoder(acoustic_adaptor_output, style_vector, mel_mask)
return mel_prediction, src_embedded, style_vector, d_prediction, p_prediction, e_prediction, src_mask, mel_mask, mel_len
def inference(self, style_vector, src_seq, src_len=None, max_src_len=None, return_attn=False):
src_mask = get_mask_from_lengths(src_len, max_src_len)
# Encoding
encoder_output, src_embedded, enc_slf_attn = self.encoder(src_seq, style_vector, src_mask)
# Variance Adaptor
acoustic_adaptor_output, d_prediction, p_prediction, e_prediction, \
mel_len, mel_mask = self.variance_adaptor(encoder_output, src_mask)
# Deocoding
mel_output, dec_slf_attn = self.decoder(acoustic_adaptor_output, style_vector, mel_mask)
if return_attn:
return enc_slf_attn, dec_slf_attn
return mel_output, src_embedded, d_prediction, p_prediction, e_prediction, src_mask, mel_mask, mel_len
def get_style_vector(self, mel_target, mel_len=None):
mel_mask = get_mask_from_lengths(mel_len) if mel_len is not None else None
style_vector = self.style_encoder(mel_target, mel_mask)
return style_vector
def get_criterion(self):
return StyleSpeechLoss()
class Encoder(nn.Module):
''' Encoder '''
def __init__(self, config, n_src_vocab=len(symbols)+1):
super(Encoder, self).__init__()
self.max_seq_len = config.max_seq_len
self.n_layers = config.encoder_layer
self.d_model = config.encoder_hidden
self.n_head = config.encoder_head
self.d_k = config.encoder_hidden // config.encoder_head
self.d_v = config.encoder_hidden // config.encoder_head
self.d_inner = config.fft_conv1d_filter_size
self.fft_conv1d_kernel_size = config.fft_conv1d_kernel_size
self.d_out = config.decoder_hidden
self.style_dim = config.style_vector_dim
self.dropout = config.dropout
self.src_word_emb = nn.Embedding(n_src_vocab, self.d_model, padding_idx=Constants.PAD)
self.prenet = Prenet(self.d_model, self.d_model, self.dropout)
n_position = self.max_seq_len + 1
self.position_enc = nn.Parameter(
get_sinusoid_encoding_table(n_position, self.d_model).unsqueeze(0), requires_grad = False)
self.layer_stack = nn.ModuleList([FFTBlock(
self.d_model, self.d_inner, self.n_head, self.d_k, self.d_v,
self.fft_conv1d_kernel_size, self.style_dim, self.dropout) for _ in range(self.n_layers)])
self.fc_out = nn.Linear(self.d_model, self.d_out)
def forward(self, src_seq, style_vector, mask):
batch_size, max_len = src_seq.shape[0], src_seq.shape[1]
# -- Prepare masks
slf_attn_mask = mask.unsqueeze(1).expand(-1, max_len, -1)
# -- Forward
# word embedding
src_embedded = self.src_word_emb(src_seq)
# prenet
src_seq = self.prenet(src_embedded, mask)
# position encoding
if src_seq.shape[1] > self.max_seq_len:
position_embedded = get_sinusoid_encoding_table(src_seq.shape[1], self.d_model)[:src_seq.shape[1], :].unsqueeze(0).expand(batch_size, -1, -1).to(src_seq.device)
else:
position_embedded = self.position_enc[:, :max_len, :].expand(batch_size, -1, -1)
enc_output = src_seq + position_embedded
# fft blocks
slf_attn = []
for enc_layer in self.layer_stack:
enc_output, enc_slf_attn = enc_layer(
enc_output, style_vector,
mask=mask,
slf_attn_mask=slf_attn_mask)
slf_attn.append(enc_slf_attn)
# last fc
enc_output = self.fc_out(enc_output)
return enc_output, src_embedded, slf_attn
class Decoder(nn.Module):
""" Decoder """
def __init__(self, config):
super(Decoder, self).__init__()
self.max_seq_len = config.max_seq_len
self.n_layers = config.decoder_layer
self.d_model = config.decoder_hidden
self.n_head = config.decoder_head
self.d_k = config.decoder_hidden // config.decoder_head
self.d_v = config.decoder_hidden // config.decoder_head
self.d_inner = config.fft_conv1d_filter_size
self.fft_conv1d_kernel_size = config.fft_conv1d_kernel_size
self.d_out = config.n_mel_channels
self.style_dim = config.style_vector_dim
self.dropout = config.dropout
self.prenet = nn.Sequential(
nn.Linear(self.d_model, self.d_model//2),
Mish(),
nn.Dropout(self.dropout),
nn.Linear(self.d_model//2, self.d_model)
)
n_position = self.max_seq_len + 1
self.position_enc = nn.Parameter(
get_sinusoid_encoding_table(n_position, self.d_model).unsqueeze(0), requires_grad = False)
self.layer_stack = nn.ModuleList([FFTBlock(
self.d_model, self.d_inner, self.n_head, self.d_k, self.d_v,
self.fft_conv1d_kernel_size, self.style_dim, self.dropout) for _ in range(self.n_layers)])
self.fc_out = nn.Linear(self.d_model, self.d_out)
def forward(self, enc_seq, style_code, mask):
batch_size, max_len = enc_seq.shape[0], enc_seq.shape[1]
# -- Prepare masks
slf_attn_mask = mask.unsqueeze(1).expand(-1, max_len, -1)
# -- Forward
# prenet
dec_embedded = self.prenet(enc_seq)
# poistion encoding
if enc_seq.shape[1] > self.max_seq_len:
position_embedded = get_sinusoid_encoding_table(enc_seq.shape[1], self.d_model)[:enc_seq.shape[1], :].unsqueeze(0).expand(batch_size, -1, -1).to(enc_seq.device)
else:
position_embedded = self.position_enc[:, :max_len, :].expand(batch_size, -1, -1)
dec_output = dec_embedded + position_embedded
# fft blocks
slf_attn = []
for dec_layer in self.layer_stack:
dec_output, dec_slf_attn = dec_layer(
dec_output, style_code,
mask=mask,
slf_attn_mask=slf_attn_mask)
slf_attn.append(dec_slf_attn)
# last fc
dec_output = self.fc_out(dec_output)
return dec_output, slf_attn
class FFTBlock(nn.Module):
''' FFT Block '''
def __init__(self, d_model,d_inner,
n_head, d_k, d_v, fft_conv1d_kernel_size, style_dim, dropout):
super(FFTBlock, self).__init__()
self.slf_attn = MultiHeadAttention(
n_head, d_model, d_k, d_v, dropout=dropout)
self.saln_0 = StyleAdaptiveLayerNorm(d_model, style_dim)
self.pos_ffn = PositionwiseFeedForward(
d_model, d_inner, fft_conv1d_kernel_size, dropout=dropout)
self.saln_1 = StyleAdaptiveLayerNorm(d_model, style_dim)
def forward(self, input, style_vector, mask=None, slf_attn_mask=None):
# multi-head self attn
slf_attn_output, slf_attn = self.slf_attn(input, mask=slf_attn_mask)
slf_attn_output = self.saln_0(slf_attn_output, style_vector)
if mask is not None:
slf_attn_output = slf_attn_output.masked_fill(mask.unsqueeze(-1), 0)
# position wise FF
output = self.pos_ffn(slf_attn_output)
output = self.saln_1(output, style_vector)
if mask is not None:
output = output.masked_fill(mask.unsqueeze(-1), 0)
return output, slf_attn
class PositionwiseFeedForward(nn.Module):
''' A two-feed-forward-layer module '''
def __init__(self, d_in, d_hid, fft_conv1d_kernel_size, dropout=0.1):
super().__init__()
self.w_1 = ConvNorm(d_in, d_hid, kernel_size=fft_conv1d_kernel_size[0])
self.w_2 = ConvNorm(d_hid, d_in, kernel_size=fft_conv1d_kernel_size[1])
self.mish = Mish()
self.dropout = nn.Dropout(dropout)
def forward(self, input):
residual = input
output = input.transpose(1, 2)
output = self.w_2(self.dropout(self.mish(self.w_1(output))))
output = output.transpose(1, 2)
output = self.dropout(output) + residual
return output
class MelStyleEncoder(nn.Module):
''' MelStyleEncoder '''
def __init__(self, config):
super(MelStyleEncoder, self).__init__()
self.in_dim = config.n_mel_channels
self.hidden_dim = config.style_hidden
self.out_dim = config.style_vector_dim
self.kernel_size = config.style_kernel_size
self.n_head = config.style_head
self.dropout = config.dropout
self.spectral = nn.Sequential(
LinearNorm(self.in_dim, self.hidden_dim),
Mish(),
nn.Dropout(self.dropout),
LinearNorm(self.hidden_dim, self.hidden_dim),
Mish(),
nn.Dropout(self.dropout)
)
self.temporal = nn.Sequential(
Conv1dGLU(self.hidden_dim, self.hidden_dim, self.kernel_size, self.dropout),
Conv1dGLU(self.hidden_dim, self.hidden_dim, self.kernel_size, self.dropout),
)
self.slf_attn = MultiHeadAttention(self.n_head, self.hidden_dim,
self.hidden_dim//self.n_head, self.hidden_dim//self.n_head, self.dropout)
self.fc = LinearNorm(self.hidden_dim, self.out_dim)
def temporal_avg_pool(self, x, mask=None):
if mask is None:
out = torch.mean(x, dim=1)
else:
len_ = (~mask).sum(dim=1).unsqueeze(1)
x = x.masked_fill(mask.unsqueeze(-1), 0)
x = x.sum(dim=1)
out = torch.div(x, len_)
return out
def forward(self, x, mask=None):
max_len = x.shape[1]
slf_attn_mask = mask.unsqueeze(1).expand(-1, max_len, -1) if mask is not None else None
# spectral
x = self.spectral(x)
# temporal
x = x.transpose(1,2)
x = self.temporal(x)
x = x.transpose(1,2)
# self-attention
if mask is not None:
x = x.masked_fill(mask.unsqueeze(-1), 0)
x, _ = self.slf_attn(x, mask=slf_attn_mask)
# fc
x = self.fc(x)
# temoral average pooling
w = self.temporal_avg_pool(x, mask=mask)
return w
class Prenet(nn.Module):
''' Prenet '''
def __init__(self, hidden_dim, out_dim, dropout):
super(Prenet, self).__init__()
self.convs = nn.Sequential(
ConvNorm(hidden_dim, hidden_dim, kernel_size=3),
Mish(),
nn.Dropout(dropout),
ConvNorm(hidden_dim, hidden_dim, kernel_size=3),
Mish(),
nn.Dropout(dropout),
)
self.fc = LinearNorm(hidden_dim, out_dim)
def forward(self, input, mask=None):
residual = input
# convs
output = input.transpose(1,2)
output = self.convs(output)
output = output.transpose(1,2)
# fc & residual
output = self.fc(output) + residual
if mask is not None:
output = output.masked_fill(mask.unsqueeze(-1), 0)
return output