Skip to content

Commit

Permalink
apply #1: fix prior and add glow
Browse files Browse the repository at this point in the history
  • Loading branch information
keonlee9420 committed Jul 22, 2021
1 parent b63864c commit 919d196
Show file tree
Hide file tree
Showing 15 changed files with 349 additions and 422 deletions.
10 changes: 5 additions & 5 deletions config/LJSpeech/train.yaml
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
path:
ckpt_path: "./output/ckpt/LJSpeech"
log_path: "./output/log/LJSpeech"
result_path: "./output/result/LJSpeech"
ckpt_path: "./output/ckpt/LJSpeech_pr"
log_path: "./output/log/LJSpeech_pr"
result_path: "./output/result/LJSpeech_pr"
optimizer:
batch_size: 32
betas: [0.9, 0.999]
Expand All @@ -22,6 +22,6 @@ length:
length_weight: 1.
kl:
kl_weight: 1.
kl_weight_init: 0.00001
kl_weight_increase_epoch: 1
kl_weight_init: 0.00000001
kl_weight_increase_epoch: 1000
kl_weight_end: 0.00001
5 changes: 3 additions & 2 deletions evaluate.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,8 @@ def evaluate(
logger=None,
vocoder=None,
audio_processor=None,
losses_len=4):
losses_len=4,
device="cuda:0"):
preprocess_config, model_config, train_config = configs

# Get dataset
Expand All @@ -46,7 +47,7 @@ def evaluate(
batch = to_device(batch, device)
with torch.no_grad():
# Forward
predictions, mel_l2, kl_divergence, length_l2, dec_alignments, reduced_mel_lens = model(
(predictions, mel_l2, kl_divergence, length_l2, dec_alignments, reduced_mel_lens, *_) = model(
*(batch[2:]),
reduce_loss=True,
reduction_factor=reduction_factor
Expand Down
85 changes: 44 additions & 41 deletions model/VAENAR.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,3 @@
import os
import json

import torch
import torch.nn as nn
import torch.nn.functional as F
Expand All @@ -26,16 +23,16 @@ def __init__(self, preprocess_config, model_config):
self.max_reduction_factor = model_config["common"]["max_reduction_factor"]

self.text_encoder = TransformerEncoder(
vocab_size=len(symbols) + 1,
embd_dim=model_config["transformer"]["encoder"]["embd_dim"],
n_symbols=len(symbols) + 1,
embedding_dim=model_config["transformer"]["encoder"]["embd_dim"],
pre_nconv=model_config["transformer"]["encoder"]["n_conv"],
pre_hidden=model_config["transformer"]["encoder"]["pre_hidden"],
pre_conv_kernel=model_config["transformer"]["encoder"]["conv_kernel"],
pre_activation=self._get_activation(model_config["transformer"]["encoder"]["pre_activation"]),
prenet_drop_rate=model_config["transformer"]["encoder"]["pre_drop_rate"],
bn_before_act=model_config["transformer"]["encoder"]["bn_before_act"],
pos_drop_rate=model_config["transformer"]["encoder"]["pos_drop_rate"],
nblk=model_config["transformer"]["encoder"]["n_blk"],
n_blocks=model_config["transformer"]["encoder"]["n_blk"],
attention_dim=model_config["transformer"]["encoder"]["attention_dim"],
attention_heads=model_config["transformer"]["encoder"]["attention_heads"],
attention_temperature=model_config["transformer"]["encoder"]["attention_temperature"],
Expand Down Expand Up @@ -78,8 +75,8 @@ def __init__(self, preprocess_config, model_config):
attention_dim=model_config["transformer"]["prior"]["attention_dim"],
attention_heads=model_config["transformer"]["prior"]["attention_heads"],
temperature=model_config["transformer"]["prior"]["temperature"],
ffn_hidden=model_config["transformer"]["prior"]["ffn_hidden"],
inverse=model_config["transformer"]["prior"]["inverse"])
ffn_hidden=model_config["transformer"]["prior"]["ffn_hidden"]
)

@staticmethod
def _get_activation(activation):
Expand Down Expand Up @@ -115,6 +112,8 @@ def _kl_divergence(p, q, reduce=None):
return torch.mean(kl)
else:
return kl
# kl = F.kl_div(p, F.softmax(q, dim=1))
# return kl

@staticmethod
def _length_l2_loss(predicted_lengths, target_lengths, reduce=False):
Expand All @@ -126,16 +125,16 @@ def _length_l2_loss(predicted_lengths, target_lengths, reduce=False):
return torch.square(log_pre_lengths - log_tgt_lengths)

def forward(
self,
speakers,
inputs,
text_lengths,
max_src_len,
mel_targets=None,
mel_lengths=None,
max_mel_len=None,
reduction_factor=2,
reduce_loss=False,
self,
speakers,
inputs,
text_lengths,
max_src_len,
mel_targets=None,
mel_lengths=None,
max_mel_len=None,
reduction_factor=2,
reduce_loss=False,
):
"""
:param speakers: speaker inputs, [batch, ]
Expand Down Expand Up @@ -168,36 +167,36 @@ def forward(
length_loss = self._length_l2_loss(
predicted_lengths, mel_lengths, reduce=reduce_loss)
logvar, mu, post_alignments = self.posterior(reduced_mels, text_embd,
src_lengths=text_lengths,
target_lengths=reduced_mel_lens)
src_lengths=text_lengths,
target_lengths=reduced_mel_lens)

# prepare batch
# samples, eps: [batch, n_sample, mel_max_time, dim]
samples, eps = self.posterior.reparameterize(mu, logvar, self.n_sample)
# [batch, n_sample]
posterior_logprobs = self.posterior.log_probability(
mu, logvar, eps=eps, seq_lengths=reduced_mel_lens)
posterior_logprobs = self.posterior.log_probability(mu, logvar, eps=eps, seq_lengths=reduced_mel_lens)

# [batch*n_sample, mel_max_len, dim]
batched_samples = samples.view(batch_size * self.n_sample, reduced_mel_max_len, -1)
# [batch*n_sample, text_max_len, dim]
batched_text_embd = torch.tile(
text_embd.unsqueeze(1),
[1, self.n_sample, 1, 1]).view(batch_size * self.n_sample, text_max_len, -1)
text_embd.unsqueeze(1),
[1, self.n_sample, 1, 1]).view(batch_size * self.n_sample, text_max_len, -1)
batched_mel_targets = torch.tile(
mel_targets.unsqueeze(1),
[1, self.n_sample, 1, 1]).view(batch_size * self.n_sample, mel_max_len, -1)
mel_targets.unsqueeze(1),
[1, self.n_sample, 1, 1]).view(batch_size * self.n_sample, mel_max_len, -1)
# [batch*n_sample, ]
batched_mel_lengths = torch.tile(
mel_lengths.unsqueeze(1),
[1, self.n_sample]).view(-1)
mel_lengths.unsqueeze(1),
[1, self.n_sample]).view(-1)
# [batch*n_sample, ]
batched_r_mel_lengths = torch.tile(
reduced_mel_lens.unsqueeze(1),
[1, self.n_sample]).view(-1)
reduced_mel_lens.unsqueeze(1),
[1, self.n_sample]).view(-1)
# [batch*n_sample, ]
batched_text_lengths = torch.tile(
text_lengths.unsqueeze(1),
[1, self.n_sample]).view(-1)
text_lengths.unsqueeze(1),
[1, self.n_sample]).view(-1)

# decoding
decoded_initial, decoded_outs, dec_alignments = self.decoder(
Expand All @@ -216,27 +215,31 @@ def forward(
z_lengths=batched_r_mel_lengths,
condition_lengths=batched_text_lengths)
prior_logprobs = prior_logprobs.view(batch_size, self.n_sample)

kl_divergence = self._kl_divergence(posterior_logprobs, prior_logprobs, reduce_loss)

return decoded_outs, l2_loss, kl_divergence, length_loss, dec_alignments, reduced_mel_lens
return (decoded_outs, l2_loss, kl_divergence, length_loss, dec_alignments, reduced_mel_lens,
posterior_logprobs, prior_logprobs)

def inference(self, inputs, mel_lengths, text_lengths=None, reduction_factor=2):
reduced_mel_lens = (mel_lengths + reduction_factor - 1) // reduction_factor
def inference(self, inputs, text_lengths, reduction_factor=2):
text_pos_step = self.mel_text_len_ratio / float(reduction_factor)
text_embd = self.text_encoder(inputs, text_lengths, pos_step=text_pos_step)
prior_latents, prior_logprobs = self.prior.sample(reduced_mel_lens,
text_embd,
text_lengths)
predicted_mel_lengths = (self.length_predictor(text_embd, text_lengths) + 80).long()
reduced_mel_lens = (predicted_mel_lengths + reduction_factor - 1) // reduction_factor

prior_latents, prior_logprobs = self.prior(reduced_mel_lens, text_embd, text_lengths)

_, predicted_mel, dec_alignments = self.decoder(
inputs=prior_latents, text_embd=text_embd, z_lengths=reduced_mel_lens,
text_lengths=text_lengths, reduction_factor=reduction_factor)
return predicted_mel, dec_alignments
return predicted_mel, predicted_mel_lengths, reduced_mel_lens, dec_alignments, prior_logprobs

def init(self, text_inputs, mel_lengths, text_lengths=None):
reduced_mel_lens = (mel_lengths + self.max_reduction_factor - 1) // self.max_reduction_factor
text_pos_step = self.mel_text_len_ratio / float(self.max_reduction_factor)
text_embd = self.text_encoder(text_inputs, text_lengths, pos_step=text_pos_step)
prior_latents, prior_logprobs = self.prior.init(conditions=text_embd,
reduced_mel_lens = (mel_lengths + self.max_reduction_factor - 1) // self.max_reduction_factor

prior_latents, prior_logprobs = self.prior.init(inputs=text_embd,
targets_lengths=reduced_mel_lens,
condition_lengths=text_lengths)
_, predicted_mel, _ = self.decoder(inputs=prior_latents,
Expand Down
18 changes: 9 additions & 9 deletions model/attention.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@
from .utils import LinearNorm, FFN
from utils.tools import get_mask_from_lengths

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
# device = torch.device("cuda" if torch.cuda.is_available() else "cpu")


class BaseAttention(nn.Module):
Expand All @@ -27,7 +27,7 @@ def forward(self, inputs, memory, memory_lengths, query_lengths):
raise NotImplementedError

@staticmethod
def _get_key_mask(batch_size, memory_max_time, query_max_time, memory_lengths, query_lengths):
def _get_key_mask(batch_size, memory_max_time, query_max_time, memory_lengths, query_lengths, device):
memory_lengths = (memory_lengths if memory_lengths is not None
else torch.ones(batch_size, dtype=torch.int32, device=device) * memory_max_time)
memeory_mask = get_mask_from_lengths(memory_lengths, memory_max_time)
Expand Down Expand Up @@ -81,7 +81,7 @@ def _merge_head(self, inputs):
return reshaped

def _get_key_mask(self, batch_size, memory_max_time, query_max_time,
memory_lengths, query_lengths):
memory_lengths, query_lengths, device):
memory_lengths = (memory_lengths if memory_lengths is not None
else torch.ones(batch_size, dtype=torch.int32, device=device) * memory_max_time)
memory_mask = get_mask_from_lengths(memory_lengths, memory_max_time) # [batch, m_max_time]
Expand All @@ -100,7 +100,7 @@ def _get_key_mask(self, batch_size, memory_max_time, query_max_time,

@staticmethod
def _get_causal_mask(logits):
causal_mask = torch.tril(torch.ones(logits.shape, dtype=torch.bool, device=device))
causal_mask = torch.tril(torch.ones(logits.shape, dtype=torch.bool, device=logits.device))
return causal_mask

def forward(self, inputs, memory, memory_lengths=None, query_lengths=None, causality=None):
Expand All @@ -120,7 +120,7 @@ def forward(self, inputs, memory, memory_lengths=None, query_lengths=None, causa
memory_max_time = memory.shape[1]
query_max_time = inputs.shape[1]
length_mask = self._get_key_mask(
batch_size, memory_max_time, query_max_time, memory_lengths, query_lengths)
batch_size, memory_max_time, query_max_time, memory_lengths, query_lengths, inputs.device)
if causality:
causal_mask = self._get_causal_mask(logits)
length_mask = torch.logical_and(length_mask, causal_mask)
Expand All @@ -134,10 +134,10 @@ def forward(self, inputs, memory, memory_lengths=None, query_lengths=None, causa
return contexts, alignments


class SelfAttentionBLK(nn.Module):
class SelfAttentionBlock(nn.Module):
def __init__(self, input_dim, attention_dim, attention_heads, attention_temperature,
ffn_hidden):
super(SelfAttentionBLK, self).__init__()
super(SelfAttentionBlock, self).__init__()
self.input_dim = input_dim
self.attention_dim = attention_dim
self.attention = MultiHeadScaledProductAttention(attention_dim=attention_dim,
Expand All @@ -161,10 +161,10 @@ def forward(self, inputs, memory, query_lengths, memory_lengths, causality=None)
return ffn_outs, alignments


class CrossAttentionBLK(nn.Module):
class CrossAttentionBlock(nn.Module):
def __init__(self, input_dim, memory_dim, attention_dim, attention_heads, attention_temperature,
ffn_hidden, name=None):
super(CrossAttentionBLK, self).__init__()
super(CrossAttentionBlock, self).__init__()
self.name = name
self.input_dim = input_dim
self.attention_dim = attention_dim
Expand Down
14 changes: 7 additions & 7 deletions model/decoder.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
import torch.nn.functional as F

from .utils import LinearNorm, PostNet
from .attention import CrossAttentionBLK
from .attention import CrossAttentionBlock
from utils.tools import get_mask_from_lengths


Expand Down Expand Up @@ -67,12 +67,12 @@ def __init__(self, nblk, embd_dim, attention_dim, attention_heads,
self.pre_projection = LinearNorm(latent_dim, attention_dim)
self.attentions = nn.ModuleList(
[
CrossAttentionBLK(input_dim=attention_dim,
memory_dim=embd_dim,
attention_dim=attention_dim,
attention_heads=attention_heads,
attention_temperature=temperature,
ffn_hidden=ffn_hidden, name='decoder-attention-{}'.format(i))
CrossAttentionBlock(input_dim=attention_dim,
memory_dim=embd_dim,
attention_dim=attention_dim,
attention_heads=attention_heads,
attention_temperature=temperature,
ffn_hidden=ffn_hidden, name='decoder-attention-{}'.format(i))
for i in range(nblk)
]
)
Expand Down
54 changes: 23 additions & 31 deletions model/encoder.py
Original file line number Diff line number Diff line change
@@ -1,59 +1,51 @@
import torch
import torch.nn as nn
from torch.nn import functional as F
from math import sqrt

from .utils import Conv1D, ConvPreNet, PositionalEncoding
from .attention import SelfAttentionBLK
from .utils import ConvPreNet, PositionalEncoding
from .attention import SelfAttentionBlock

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

# device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

class BaseEncoder(nn.Module):
def __init__(self, vocab_size, embd_dim):
super(BaseEncoder, self).__init__()
self.emb_layer = nn.Embedding(num_embeddings=vocab_size,
embedding_dim=embd_dim,
padding_idx=0)

def forward(self, inputs, input_lengths=None):
"""
:param inputs: text inputs, [batch, max_time]
:param input_lengths: text inputs' lengths, [batch]
:return: (tensor1, tensor2)
tensor1: text encoding, [batch, max_time, hidden_size]
tensor2: global state, i.e., final_time_state, [batch, hidden_size]
"""
raise NotImplementedError


class TransformerEncoder(BaseEncoder):
def __init__(self, vocab_size, embd_dim, pre_nconv, pre_hidden, pre_conv_kernel,
prenet_drop_rate, pre_activation, bn_before_act, pos_drop_rate, nblk,
class TransformerEncoder(nn.Module):
def __init__(self, n_symbols, embedding_dim, pre_nconv, pre_hidden, pre_conv_kernel,
prenet_drop_rate, pre_activation, bn_before_act, pos_drop_rate, n_blocks,
attention_dim, attention_heads, attention_temperature, ffn_hidden):
super(TransformerEncoder, self).__init__(vocab_size, embd_dim)
self.pos_weight = nn.Parameter(torch.tensor(1.0, device=device))
super(TransformerEncoder, self).__init__()
self.embedding = nn.Embedding(num_embeddings=n_symbols,
embedding_dim=embedding_dim,
padding_idx=0)
std = sqrt(2.0 / (n_symbols + embedding_dim))
val = sqrt(3.0) * std # uniform bounds for std
self.embedding.weight.data.uniform_(-val, val)

self.prenet = ConvPreNet(nconv=pre_nconv, hidden=pre_hidden,
conv_kernel=pre_conv_kernel, drop_rate=prenet_drop_rate,
activation=pre_activation, bn_before_act=bn_before_act)

self.register_parameter("pos_weight", nn.Parameter(torch.tensor(1.0)))
self.pe = PositionalEncoding()
self.pe_dropout = nn.Dropout(p=pos_drop_rate)

self.self_attentions = nn.ModuleList(
[
SelfAttentionBLK(
SelfAttentionBlock(
input_dim=pre_hidden, attention_dim=attention_dim,
attention_heads=attention_heads, attention_temperature=attention_temperature,
ffn_hidden=ffn_hidden)
for i in range(nblk)
for i in range(n_blocks)
]
)

def forward(self, inputs, input_lengths=None, pos_step=1.0):
# print('tracing back at text encoding')
embs = self.emb_layer(inputs)
embs = self.embedding(inputs)
prenet_outs = self.prenet(embs)
max_time = prenet_outs.shape[1]
dim = prenet_outs.shape[2]
pos = self.pe.positional_encoding(max_time, dim, device, pos_step)
max_time, prenet_dim = prenet_outs.size(1), prenet_outs.size(2)
pos = self.pe.positional_encoding(max_time, prenet_dim, inputs.device, pos_step)
pos_embs = prenet_outs + self.pos_weight * pos
pos_embs = self.pe_dropout(pos_embs)
att_outs = pos_embs
Expand Down
Loading

0 comments on commit 919d196

Please sign in to comment.