Skip to content

Commit

Permalink
rename arguments and variables in LinguisticEncoder for readability
Browse files Browse the repository at this point in the history
  • Loading branch information
keonlee9420 committed Oct 11, 2021
1 parent 2ad88da commit bf23771
Show file tree
Hide file tree
Showing 4 changed files with 45 additions and 47 deletions.
5 changes: 2 additions & 3 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -68,7 +68,7 @@ Please note that the controllability is originated from [FastSpeech2](https://ar

## Datasets

The supported datasets is
The supported datasets are

- [LJSpeech](https://keithito.com/LJ-Speech-Dataset/): a **single-speaker** English dataset consists of 13100 short audio clips of a female speaker reading passages from 7 non-fiction books, approximately 24 hours in total.
<!-- - [VCTK](https://datashare.ed.ac.uk/handle/10283/3443): The CSTR VCTK Corpus includes speech data uttered by 110 English speakers (**multi-speaker TTS**) with various accents. Each speaker reads out about 400 sentences, which were selected from a newspaper, the rainbow passage and an elicitation paragraph used for the speech accent archive.
Expand Down Expand Up @@ -120,8 +120,7 @@ to serve TensorBoard on your localhost.
# Notes

- For vocoder, **HiFi-GAN** and **MelGAN** are supported.
- Add convolution layer and residual layer in **VariationalGenerator** to match the shape of conditioner and output.
- No ReLU activation and LayerNorm in **VariationalGenerator** for convergence of word-to-phoneme alignment of **LinguisticEncoder**.
- No ReLU activation and LayerNorm in **VariationalGenerator** to avoid mashed output.
- Will be extended to a **multi-speaker TTS**.
<!-- - Two options for embedding for the **multi-speaker TTS** setting: training speaker embedder from scratch or using a pre-trained [philipperemy's DeepSpeaker](https://github.com/philipperemy/deep-speaker) model (as [STYLER](https://github.com/keonlee9420/STYLER) did). You can toggle it by setting the config (between `'none'` and `'DeepSpeaker'`).
- DeepSpeaker on VCTK dataset shows clear identification among speakers. The following figure shows the T-SNE plot of extracted speaker embedding.
Expand Down
22 changes: 11 additions & 11 deletions model/blocks.py
Original file line number Diff line number Diff line change
Expand Up @@ -555,7 +555,7 @@ def __init__(self, n_head, d_model, d_k, d_v, dropout=0.0):

self.dropout = nn.Dropout(dropout)

def forward(self, q, k, v, mask_1=None, mask_2=None, mapping_mask=None, indivisual_attn=False):
def forward(self, q, k, v, key_mask=None, query_mask=None, mapping_mask=None, indivisual_attn=False):

d_k, d_v, n_head = self.d_k, self.d_v, self.n_head

Expand All @@ -575,14 +575,14 @@ def forward(self, q, k, v, mask_1=None, mask_2=None, mapping_mask=None, indivisu
v = v.permute(2, 0, 1, 3).contiguous().view(-1,
len_v, d_v) # (n*b) x lv x dv

if mask_1 is not None:
mask_1 = mask_1.repeat(n_head, 1, 1) # (n*b) x .. x ..
if mask_2 is not None:
mask_2 = mask_2.repeat(n_head, 1, 1) # (n*b) x .. x ..
if key_mask is not None:
key_mask = key_mask.repeat(n_head, 1, 1) # (n*b) x .. x ..
if query_mask is not None:
query_mask = query_mask.repeat(n_head, 1, 1) # (n*b) x .. x ..
if mapping_mask is not None:
mapping_mask = mapping_mask.repeat(n_head, 1, 1) # (n*b) x .. x ..
output, attn = self.attention(
q, k, v, mask_1=mask_1, mask_2=mask_2, mapping_mask=mapping_mask)
q, k, v, key_mask=key_mask, query_mask=query_mask, mapping_mask=mapping_mask)

output = output.view(n_head, sz_b, len_q, d_v)
output = (
Expand All @@ -605,17 +605,17 @@ def __init__(self, temperature):
self.temperature = temperature
self.softmax = nn.Softmax(dim=2)

def forward(self, q, k, v, mask_1=None, mask_2=None, mapping_mask=None):
def forward(self, q, k, v, key_mask=None, query_mask=None, mapping_mask=None):

attn = torch.bmm(q, k.transpose(1, 2))
attn = attn / self.temperature

if mask_1 is not None:
attn = attn.masked_fill(mask_1==0., -np.inf)
if key_mask is not None:
attn = attn.masked_fill(key_mask==0., -np.inf)
attn = self.softmax(attn)

if mask_2 is not None:
attn = attn * mask_2
if query_mask is not None:
attn = attn * query_mask
if mapping_mask is not None:
attn = attn * mapping_mask
output = torch.bmm(attn, v)
Expand Down
57 changes: 28 additions & 29 deletions model/linguistic_encoder.py
Original file line number Diff line number Diff line change
Expand Up @@ -105,7 +105,7 @@ def __init__(self, config):

def get_mapping_mask(self, q, kv, dur_w, wb, src_w_len):
"""
A word-to-phoneme mapping mask to the attention weight to force each query (Q)
For applying a word-to-phoneme mapping mask to the attention weight to force each query (Q)
to only attend to the phonemes belongs to the word corresponding to this query.
"""
batch_size, q_len, kv_len, device = q.shape[0], q.shape[1], kv.shape[1], kv.device
Expand Down Expand Up @@ -142,7 +142,7 @@ def add_position_enc(self, src_seq, position_enc=None, coef=None):

def get_rel_coef(self, dur, dur_len, mask):
"""
A well-designed positional encoding to the inputs of word-to-phoneme attention module.
For adding a well-designed positional encoding to the inputs of word-to-phoneme attention module.
"""
idx, L, device = [], [], dur.device
for d, dl in zip(dur, dur_len):
Expand All @@ -158,49 +158,48 @@ def get_rel_coef(self, dur, dur_len, mask):

def forward(
self,
src_seq,
src_len,
wb,
p_mask,
src_p_seq,
src_p_len,
word_boundary,
src_p_mask,
src_w_len,
w_mask,
src_w_mask,
mel_mask=None,
max_len=None,
duration_target=None,
d_control=1.0,
return_attns=False,
duration_control=1.0,
):
# Phoneme Encoding
src_seq = self.src_emb(src_seq)
enc_out_p = self.phoneme_encoder(src_seq.transpose(
1, 2), p_mask.unsqueeze(1)).transpose(1, 2)
src_p_seq = self.src_emb(src_p_seq)
enc_p_out = self.phoneme_encoder(src_p_seq.transpose(
1, 2), src_p_mask.unsqueeze(1)).transpose(1, 2)

# Word-level Pooing
src_seq_w = word_level_pooling(
enc_out_p, src_len, wb, src_w_len, reduce_mean=True)
src_w_seq = word_level_pooling(
enc_p_out, src_p_len, word_boundary, src_w_len, reduce="mean")

# Word Encoding
enc_out_w = self.word_encoder(src_seq_w.transpose(
1, 2), w_mask.unsqueeze(1)).transpose(1, 2)
enc_w_out = self.word_encoder(src_w_seq.transpose(
1, 2), src_w_mask.unsqueeze(1)).transpose(1, 2)

# Phoneme-level Duration Prediction
log_duration_p_prediction = self.duration_predictor(enc_out_p, p_mask)
log_duration_p_prediction = self.duration_predictor(enc_p_out, src_p_mask)

# Word-level Pooling
log_duration_w_prediction = word_level_pooling(
log_duration_p_prediction.unsqueeze(-1), src_len, wb, src_w_len, reduce_sum=True).squeeze(-1)
log_duration_p_prediction.unsqueeze(-1), src_p_len, word_boundary, src_w_len, reduce="sum").squeeze(-1)

x = enc_out_w
x = enc_w_out
if duration_target is not None:
# Word-level Pooing
duration_w_rounded = word_level_pooling(
duration_target.unsqueeze(-1), src_len, wb, src_w_len, reduce_sum=True).squeeze(-1)
duration_target.unsqueeze(-1), src_p_len, word_boundary, src_w_len, reduce="sum").squeeze(-1)
# Word-level Length Regulate
x, mel_len = self.length_regulator(x, duration_w_rounded, max_len)
else:
# Word-level Duration
duration_w_rounded = torch.clamp(
(torch.round(torch.exp(log_duration_w_prediction) - 1) * d_control),
(torch.round(torch.exp(log_duration_w_prediction) - 1) * duration_control),
min=0,
).long()
# Word-level Length Regulate
Expand All @@ -209,23 +208,23 @@ def forward(

# Word-to-Phoneme Attention
# [batch, mel_len, seq_len]
src_mask_ = p_mask.unsqueeze(1).expand(-1, mel_mask.shape[1], -1)
src_mask_ = src_p_mask.unsqueeze(1).expand(-1, mel_mask.shape[1], -1)
# [batch, mel_len, seq_len]
mel_mask_ = mel_mask.unsqueeze(-1).expand(-1, -1, p_mask.shape[1])
mel_mask_ = mel_mask.unsqueeze(-1).expand(-1, -1, src_p_mask.shape[1])
mapping_mask = self.get_mapping_mask(
x, enc_out_p, duration_w_rounded, wb, src_w_len) # [batch, mel_len, seq_len]
x, enc_p_out, duration_w_rounded, word_boundary, src_w_len) # [batch, mel_len, seq_len]

q = self.add_position_enc(x, position_enc=self.q_position_enc, coef=self.get_rel_coef(
duration_w_rounded, src_w_len, mel_mask))
k = self.add_position_enc(
enc_out_p, position_enc=self.kv_position_enc, coef=self.get_rel_coef(wb, src_len, p_mask))
enc_p_out, position_enc=self.kv_position_enc, coef=self.get_rel_coef(word_boundary, src_p_len, src_p_mask))
v = self.add_position_enc(
enc_out_p, position_enc=self.kv_position_enc, coef=self.get_rel_coef(wb, src_len, p_mask))
enc_p_out, position_enc=self.kv_position_enc, coef=self.get_rel_coef(word_boundary, src_p_len, src_p_mask))
# q = self.add_position_enc(x)
# k = self.add_position_enc(enc_out_p)
# v = self.add_position_enc(enc_out_p)
# k = self.add_position_enc(enc_p_out)
# v = self.add_position_enc(enc_p_out)
x, alignment = self.w2p_attn(
q, k, v, mask_1=src_mask_, mask_2=mel_mask_, mapping_mask=mapping_mask, indivisual_attn=True
q, k, v, key_mask=src_mask_, query_mask=mel_mask_, mapping_mask=mapping_mask, indivisual_attn=True
)

return (
Expand Down
8 changes: 4 additions & 4 deletions utils/tools.py
Original file line number Diff line number Diff line change
Expand Up @@ -395,7 +395,7 @@ def reparameterize(mu, logvar):
return eps * std + mu


def word_level_pooling(src_seq, src_len, wb, src_w_len, reduce_sum=False, reduce_mean=False):
def word_level_pooling(src_seq, src_len, wb, src_w_len, reduce="sum"):
"""
src_seq -- [batch_size, max_time, dim]
src_len -- [batch_size,]
Expand All @@ -406,11 +406,11 @@ def word_level_pooling(src_seq, src_len, wb, src_w_len, reduce_sum=False, reduce
for s, sl, w, wl in zip(src_seq, src_len, wb, src_w_len):
m, split_size = s[:sl, :], list(w[:wl].int())
m = nn.utils.rnn.pad_sequence(torch.split(m, split_size, dim=0))
if reduce_mean and not reduce_sum:
if reduce == "sum":
m = torch.sum(m, dim=0) # [src_w_len, hidden]
elif reduce == "mean":
m = torch.div(torch.sum(m, dim=0), torch.tensor(
split_size, device=device).unsqueeze(-1)) # [src_w_len, hidden]
elif reduce_sum and not reduce_mean:
m = torch.sum(m, dim=0) # [src_w_len, hidden]
else:
raise ValueError()
batch.append(m)
Expand Down

0 comments on commit bf23771

Please sign in to comment.