From 919d1962204baf8e752c96a337e00597e74bab80 Mon Sep 17 00:00:00 2001 From: keonlee9420 Date: Thu, 22 Jul 2021 12:37:28 +0900 Subject: [PATCH] apply #001: fix prior and add glow --- config/LJSpeech/train.yaml | 10 +- evaluate.py | 5 +- model/VAENAR.py | 85 ++++++++--------- model/attention.py | 18 ++-- model/decoder.py | 14 +-- model/encoder.py | 54 +++++------ model/flow.py | 184 +++++++++++-------------------------- model/glow.py | 92 +++++++++++++++++++ model/posterior.py | 56 +++++------ model/prior.py | 153 ++++++++---------------------- model/transform.py | 30 +++--- model/utils.py | 36 ++++---- synthesize.py | 26 ++---- train.py | 5 +- utils/tools.py | 3 +- 15 files changed, 349 insertions(+), 422 deletions(-) create mode 100644 model/glow.py diff --git a/config/LJSpeech/train.yaml b/config/LJSpeech/train.yaml index 525a3fc..e0503da 100644 --- a/config/LJSpeech/train.yaml +++ b/config/LJSpeech/train.yaml @@ -1,7 +1,7 @@ path: - ckpt_path: "./output/ckpt/LJSpeech" - log_path: "./output/log/LJSpeech" - result_path: "./output/result/LJSpeech" + ckpt_path: "./output/ckpt/LJSpeech_pr" + log_path: "./output/log/LJSpeech_pr" + result_path: "./output/result/LJSpeech_pr" optimizer: batch_size: 32 betas: [0.9, 0.999] @@ -22,6 +22,6 @@ length: length_weight: 1. kl: kl_weight: 1. - kl_weight_init: 0.00001 - kl_weight_increase_epoch: 1 + kl_weight_init: 0.00000001 + kl_weight_increase_epoch: 1000 kl_weight_end: 0.00001 diff --git a/evaluate.py b/evaluate.py index 249449d..17579bd 100644 --- a/evaluate.py +++ b/evaluate.py @@ -24,7 +24,8 @@ def evaluate( logger=None, vocoder=None, audio_processor=None, - losses_len=4): + losses_len=4, + device="cuda:0"): preprocess_config, model_config, train_config = configs # Get dataset @@ -46,7 +47,7 @@ def evaluate( batch = to_device(batch, device) with torch.no_grad(): # Forward - predictions, mel_l2, kl_divergence, length_l2, dec_alignments, reduced_mel_lens = model( + (predictions, mel_l2, kl_divergence, length_l2, dec_alignments, reduced_mel_lens, *_) = model( *(batch[2:]), reduce_loss=True, reduction_factor=reduction_factor diff --git a/model/VAENAR.py b/model/VAENAR.py index a2f23cc..0600a35 100644 --- a/model/VAENAR.py +++ b/model/VAENAR.py @@ -1,6 +1,3 @@ -import os -import json - import torch import torch.nn as nn import torch.nn.functional as F @@ -26,8 +23,8 @@ def __init__(self, preprocess_config, model_config): self.max_reduction_factor = model_config["common"]["max_reduction_factor"] self.text_encoder = TransformerEncoder( - vocab_size=len(symbols) + 1, - embd_dim=model_config["transformer"]["encoder"]["embd_dim"], + n_symbols=len(symbols) + 1, + embedding_dim=model_config["transformer"]["encoder"]["embd_dim"], pre_nconv=model_config["transformer"]["encoder"]["n_conv"], pre_hidden=model_config["transformer"]["encoder"]["pre_hidden"], pre_conv_kernel=model_config["transformer"]["encoder"]["conv_kernel"], @@ -35,7 +32,7 @@ def __init__(self, preprocess_config, model_config): prenet_drop_rate=model_config["transformer"]["encoder"]["pre_drop_rate"], bn_before_act=model_config["transformer"]["encoder"]["bn_before_act"], pos_drop_rate=model_config["transformer"]["encoder"]["pos_drop_rate"], - nblk=model_config["transformer"]["encoder"]["n_blk"], + n_blocks=model_config["transformer"]["encoder"]["n_blk"], attention_dim=model_config["transformer"]["encoder"]["attention_dim"], attention_heads=model_config["transformer"]["encoder"]["attention_heads"], attention_temperature=model_config["transformer"]["encoder"]["attention_temperature"], @@ -78,8 +75,8 @@ def __init__(self, preprocess_config, model_config): attention_dim=model_config["transformer"]["prior"]["attention_dim"], attention_heads=model_config["transformer"]["prior"]["attention_heads"], temperature=model_config["transformer"]["prior"]["temperature"], - ffn_hidden=model_config["transformer"]["prior"]["ffn_hidden"], - inverse=model_config["transformer"]["prior"]["inverse"]) + ffn_hidden=model_config["transformer"]["prior"]["ffn_hidden"] + ) @staticmethod def _get_activation(activation): @@ -115,6 +112,8 @@ def _kl_divergence(p, q, reduce=None): return torch.mean(kl) else: return kl + # kl = F.kl_div(p, F.softmax(q, dim=1)) + # return kl @staticmethod def _length_l2_loss(predicted_lengths, target_lengths, reduce=False): @@ -126,16 +125,16 @@ def _length_l2_loss(predicted_lengths, target_lengths, reduce=False): return torch.square(log_pre_lengths - log_tgt_lengths) def forward( - self, - speakers, - inputs, - text_lengths, - max_src_len, - mel_targets=None, - mel_lengths=None, - max_mel_len=None, - reduction_factor=2, - reduce_loss=False, + self, + speakers, + inputs, + text_lengths, + max_src_len, + mel_targets=None, + mel_lengths=None, + max_mel_len=None, + reduction_factor=2, + reduce_loss=False, ): """ :param speakers: speaker inputs, [batch, ] @@ -168,36 +167,36 @@ def forward( length_loss = self._length_l2_loss( predicted_lengths, mel_lengths, reduce=reduce_loss) logvar, mu, post_alignments = self.posterior(reduced_mels, text_embd, - src_lengths=text_lengths, - target_lengths=reduced_mel_lens) + src_lengths=text_lengths, + target_lengths=reduced_mel_lens) # prepare batch # samples, eps: [batch, n_sample, mel_max_time, dim] samples, eps = self.posterior.reparameterize(mu, logvar, self.n_sample) # [batch, n_sample] - posterior_logprobs = self.posterior.log_probability( - mu, logvar, eps=eps, seq_lengths=reduced_mel_lens) + posterior_logprobs = self.posterior.log_probability(mu, logvar, eps=eps, seq_lengths=reduced_mel_lens) + # [batch*n_sample, mel_max_len, dim] batched_samples = samples.view(batch_size * self.n_sample, reduced_mel_max_len, -1) # [batch*n_sample, text_max_len, dim] batched_text_embd = torch.tile( - text_embd.unsqueeze(1), - [1, self.n_sample, 1, 1]).view(batch_size * self.n_sample, text_max_len, -1) + text_embd.unsqueeze(1), + [1, self.n_sample, 1, 1]).view(batch_size * self.n_sample, text_max_len, -1) batched_mel_targets = torch.tile( - mel_targets.unsqueeze(1), - [1, self.n_sample, 1, 1]).view(batch_size * self.n_sample, mel_max_len, -1) + mel_targets.unsqueeze(1), + [1, self.n_sample, 1, 1]).view(batch_size * self.n_sample, mel_max_len, -1) # [batch*n_sample, ] batched_mel_lengths = torch.tile( - mel_lengths.unsqueeze(1), - [1, self.n_sample]).view(-1) + mel_lengths.unsqueeze(1), + [1, self.n_sample]).view(-1) # [batch*n_sample, ] batched_r_mel_lengths = torch.tile( - reduced_mel_lens.unsqueeze(1), - [1, self.n_sample]).view(-1) + reduced_mel_lens.unsqueeze(1), + [1, self.n_sample]).view(-1) # [batch*n_sample, ] batched_text_lengths = torch.tile( - text_lengths.unsqueeze(1), - [1, self.n_sample]).view(-1) + text_lengths.unsqueeze(1), + [1, self.n_sample]).view(-1) # decoding decoded_initial, decoded_outs, dec_alignments = self.decoder( @@ -216,27 +215,31 @@ def forward( z_lengths=batched_r_mel_lengths, condition_lengths=batched_text_lengths) prior_logprobs = prior_logprobs.view(batch_size, self.n_sample) + kl_divergence = self._kl_divergence(posterior_logprobs, prior_logprobs, reduce_loss) - return decoded_outs, l2_loss, kl_divergence, length_loss, dec_alignments, reduced_mel_lens + return (decoded_outs, l2_loss, kl_divergence, length_loss, dec_alignments, reduced_mel_lens, + posterior_logprobs, prior_logprobs) - def inference(self, inputs, mel_lengths, text_lengths=None, reduction_factor=2): - reduced_mel_lens = (mel_lengths + reduction_factor - 1) // reduction_factor + def inference(self, inputs, text_lengths, reduction_factor=2): text_pos_step = self.mel_text_len_ratio / float(reduction_factor) text_embd = self.text_encoder(inputs, text_lengths, pos_step=text_pos_step) - prior_latents, prior_logprobs = self.prior.sample(reduced_mel_lens, - text_embd, - text_lengths) + predicted_mel_lengths = (self.length_predictor(text_embd, text_lengths) + 80).long() + reduced_mel_lens = (predicted_mel_lengths + reduction_factor - 1) // reduction_factor + + prior_latents, prior_logprobs = self.prior(reduced_mel_lens, text_embd, text_lengths) + _, predicted_mel, dec_alignments = self.decoder( inputs=prior_latents, text_embd=text_embd, z_lengths=reduced_mel_lens, text_lengths=text_lengths, reduction_factor=reduction_factor) - return predicted_mel, dec_alignments + return predicted_mel, predicted_mel_lengths, reduced_mel_lens, dec_alignments, prior_logprobs def init(self, text_inputs, mel_lengths, text_lengths=None): - reduced_mel_lens = (mel_lengths + self.max_reduction_factor - 1) // self.max_reduction_factor text_pos_step = self.mel_text_len_ratio / float(self.max_reduction_factor) text_embd = self.text_encoder(text_inputs, text_lengths, pos_step=text_pos_step) - prior_latents, prior_logprobs = self.prior.init(conditions=text_embd, + reduced_mel_lens = (mel_lengths + self.max_reduction_factor - 1) // self.max_reduction_factor + + prior_latents, prior_logprobs = self.prior.init(inputs=text_embd, targets_lengths=reduced_mel_lens, condition_lengths=text_lengths) _, predicted_mel, _ = self.decoder(inputs=prior_latents, diff --git a/model/attention.py b/model/attention.py index 68ad7a2..efa6cd1 100644 --- a/model/attention.py +++ b/model/attention.py @@ -6,7 +6,7 @@ from .utils import LinearNorm, FFN from utils.tools import get_mask_from_lengths -device = torch.device("cuda" if torch.cuda.is_available() else "cpu") +# device = torch.device("cuda" if torch.cuda.is_available() else "cpu") class BaseAttention(nn.Module): @@ -27,7 +27,7 @@ def forward(self, inputs, memory, memory_lengths, query_lengths): raise NotImplementedError @staticmethod - def _get_key_mask(batch_size, memory_max_time, query_max_time, memory_lengths, query_lengths): + def _get_key_mask(batch_size, memory_max_time, query_max_time, memory_lengths, query_lengths, device): memory_lengths = (memory_lengths if memory_lengths is not None else torch.ones(batch_size, dtype=torch.int32, device=device) * memory_max_time) memeory_mask = get_mask_from_lengths(memory_lengths, memory_max_time) @@ -81,7 +81,7 @@ def _merge_head(self, inputs): return reshaped def _get_key_mask(self, batch_size, memory_max_time, query_max_time, - memory_lengths, query_lengths): + memory_lengths, query_lengths, device): memory_lengths = (memory_lengths if memory_lengths is not None else torch.ones(batch_size, dtype=torch.int32, device=device) * memory_max_time) memory_mask = get_mask_from_lengths(memory_lengths, memory_max_time) # [batch, m_max_time] @@ -100,7 +100,7 @@ def _get_key_mask(self, batch_size, memory_max_time, query_max_time, @staticmethod def _get_causal_mask(logits): - causal_mask = torch.tril(torch.ones(logits.shape, dtype=torch.bool, device=device)) + causal_mask = torch.tril(torch.ones(logits.shape, dtype=torch.bool, device=logits.device)) return causal_mask def forward(self, inputs, memory, memory_lengths=None, query_lengths=None, causality=None): @@ -120,7 +120,7 @@ def forward(self, inputs, memory, memory_lengths=None, query_lengths=None, causa memory_max_time = memory.shape[1] query_max_time = inputs.shape[1] length_mask = self._get_key_mask( - batch_size, memory_max_time, query_max_time, memory_lengths, query_lengths) + batch_size, memory_max_time, query_max_time, memory_lengths, query_lengths, inputs.device) if causality: causal_mask = self._get_causal_mask(logits) length_mask = torch.logical_and(length_mask, causal_mask) @@ -134,10 +134,10 @@ def forward(self, inputs, memory, memory_lengths=None, query_lengths=None, causa return contexts, alignments -class SelfAttentionBLK(nn.Module): +class SelfAttentionBlock(nn.Module): def __init__(self, input_dim, attention_dim, attention_heads, attention_temperature, ffn_hidden): - super(SelfAttentionBLK, self).__init__() + super(SelfAttentionBlock, self).__init__() self.input_dim = input_dim self.attention_dim = attention_dim self.attention = MultiHeadScaledProductAttention(attention_dim=attention_dim, @@ -161,10 +161,10 @@ def forward(self, inputs, memory, query_lengths, memory_lengths, causality=None) return ffn_outs, alignments -class CrossAttentionBLK(nn.Module): +class CrossAttentionBlock(nn.Module): def __init__(self, input_dim, memory_dim, attention_dim, attention_heads, attention_temperature, ffn_hidden, name=None): - super(CrossAttentionBLK, self).__init__() + super(CrossAttentionBlock, self).__init__() self.name = name self.input_dim = input_dim self.attention_dim = attention_dim diff --git a/model/decoder.py b/model/decoder.py index 65fc5e0..d399d81 100644 --- a/model/decoder.py +++ b/model/decoder.py @@ -3,7 +3,7 @@ import torch.nn.functional as F from .utils import LinearNorm, PostNet -from .attention import CrossAttentionBLK +from .attention import CrossAttentionBlock from utils.tools import get_mask_from_lengths @@ -67,12 +67,12 @@ def __init__(self, nblk, embd_dim, attention_dim, attention_heads, self.pre_projection = LinearNorm(latent_dim, attention_dim) self.attentions = nn.ModuleList( [ - CrossAttentionBLK(input_dim=attention_dim, - memory_dim=embd_dim, - attention_dim=attention_dim, - attention_heads=attention_heads, - attention_temperature=temperature, - ffn_hidden=ffn_hidden, name='decoder-attention-{}'.format(i)) + CrossAttentionBlock(input_dim=attention_dim, + memory_dim=embd_dim, + attention_dim=attention_dim, + attention_heads=attention_heads, + attention_temperature=temperature, + ffn_hidden=ffn_hidden, name='decoder-attention-{}'.format(i)) for i in range(nblk) ] ) diff --git a/model/encoder.py b/model/encoder.py index d886d15..e8f4477 100644 --- a/model/encoder.py +++ b/model/encoder.py @@ -1,59 +1,51 @@ import torch import torch.nn as nn from torch.nn import functional as F +from math import sqrt -from .utils import Conv1D, ConvPreNet, PositionalEncoding -from .attention import SelfAttentionBLK +from .utils import ConvPreNet, PositionalEncoding +from .attention import SelfAttentionBlock -device = torch.device("cuda" if torch.cuda.is_available() else "cpu") +# device = torch.device("cuda" if torch.cuda.is_available() else "cpu") -class BaseEncoder(nn.Module): - def __init__(self, vocab_size, embd_dim): - super(BaseEncoder, self).__init__() - self.emb_layer = nn.Embedding(num_embeddings=vocab_size, - embedding_dim=embd_dim, - padding_idx=0) - def forward(self, inputs, input_lengths=None): - """ - :param inputs: text inputs, [batch, max_time] - :param input_lengths: text inputs' lengths, [batch] - :return: (tensor1, tensor2) - tensor1: text encoding, [batch, max_time, hidden_size] - tensor2: global state, i.e., final_time_state, [batch, hidden_size] - """ - raise NotImplementedError - - -class TransformerEncoder(BaseEncoder): - def __init__(self, vocab_size, embd_dim, pre_nconv, pre_hidden, pre_conv_kernel, - prenet_drop_rate, pre_activation, bn_before_act, pos_drop_rate, nblk, +class TransformerEncoder(nn.Module): + def __init__(self, n_symbols, embedding_dim, pre_nconv, pre_hidden, pre_conv_kernel, + prenet_drop_rate, pre_activation, bn_before_act, pos_drop_rate, n_blocks, attention_dim, attention_heads, attention_temperature, ffn_hidden): - super(TransformerEncoder, self).__init__(vocab_size, embd_dim) - self.pos_weight = nn.Parameter(torch.tensor(1.0, device=device)) + super(TransformerEncoder, self).__init__() + self.embedding = nn.Embedding(num_embeddings=n_symbols, + embedding_dim=embedding_dim, + padding_idx=0) + std = sqrt(2.0 / (n_symbols + embedding_dim)) + val = sqrt(3.0) * std # uniform bounds for std + self.embedding.weight.data.uniform_(-val, val) + self.prenet = ConvPreNet(nconv=pre_nconv, hidden=pre_hidden, conv_kernel=pre_conv_kernel, drop_rate=prenet_drop_rate, activation=pre_activation, bn_before_act=bn_before_act) + + self.register_parameter("pos_weight", nn.Parameter(torch.tensor(1.0))) self.pe = PositionalEncoding() self.pe_dropout = nn.Dropout(p=pos_drop_rate) + self.self_attentions = nn.ModuleList( [ - SelfAttentionBLK( + SelfAttentionBlock( input_dim=pre_hidden, attention_dim=attention_dim, attention_heads=attention_heads, attention_temperature=attention_temperature, ffn_hidden=ffn_hidden) - for i in range(nblk) + for i in range(n_blocks) ] ) def forward(self, inputs, input_lengths=None, pos_step=1.0): # print('tracing back at text encoding') - embs = self.emb_layer(inputs) + embs = self.embedding(inputs) prenet_outs = self.prenet(embs) - max_time = prenet_outs.shape[1] - dim = prenet_outs.shape[2] - pos = self.pe.positional_encoding(max_time, dim, device, pos_step) + max_time, prenet_dim = prenet_outs.size(1), prenet_outs.size(2) + pos = self.pe.positional_encoding(max_time, prenet_dim, inputs.device, pos_step) pos_embs = prenet_outs + self.pos_weight * pos pos_embs = self.pe_dropout(pos_embs) att_outs = pos_embs diff --git a/model/flow.py b/model/flow.py index e46bf09..06b4b2c 100644 --- a/model/flow.py +++ b/model/flow.py @@ -7,52 +7,34 @@ from .transform import TransformerTransform from utils.tools import get_mask_from_lengths -device = torch.device("cuda" if torch.cuda.is_available() else "cpu") +# device = torch.device("cuda" if torch.cuda.is_available() else "cpu") -class BaseFlow(nn.Module): - def __init__(self, inverse): - super(BaseFlow, self).__init__() - self.inverse = inverse - def _forward(self, *inputs, **kwargs) -> Tuple[torch.Tensor, torch.Tensor]: - """ +class Flow(nn.Module): + def __init__(self): + super(Flow, self).__init__() + def forward(self, *inputs, **kwargs) -> Tuple[torch.Tensor, torch.Tensor]: + """ Args: *inputs: input [batch, *input_size] - Returns: out: Tensor [batch, *input_size], logdet: Tensor [batch] out, the output of the flow logdet, the log determinant of :math:`\partial output / \partial input` """ raise NotImplementedError - def _backward(self, *inputs, **kwargs) -> Tuple[torch.Tensor, torch.Tensor]: + def inverse(self, *inputs, **kwargs) -> Tuple[torch.Tensor, torch.Tensor]: """ - Args: *input: input [batch, *input_size] - Returns: out: Tensor [batch, *input_size], logdet: Tensor [batch] out, the output of the flow logdet, the log determinant of :math:`\partial output / \partial input` """ raise NotImplementedError - def forward(self, *inputs, **kwargs) -> Tuple[torch.Tensor, torch.Tensor]: - """ - - Args: - *inputs: input [batch, *input_size] - - Returns: out: Tensor [batch, *input_size], logdet: Tensor [batch] - out, the output of the flow - logdet, the log determinant of :math:`\partial output / \partial input` - """ - if self.inverse: - return self._backward(*inputs, **kwargs) - return self._forward(*inputs, **kwargs) - def init(self, *inputs, **kwargs) -> Tuple[torch.Tensor, torch.Tensor]: """ Initiate the weights according to the initial input data @@ -62,144 +44,84 @@ def init(self, *inputs, **kwargs) -> Tuple[torch.Tensor, torch.Tensor]: """ raise NotImplementedError - def fwd_pass(self, inputs, *h, init=False, init_scale=1.0, - **kwargs) -> Tuple[torch.Tensor, torch.Tensor]: - """ - - Args: - inputs: Tensor - The random variable before flow - h: list of object - other conditional inputs - init: bool - perform initialization or not (default: False) - init_scale: float - initial scale (default: 1.0) - - Returns: y: Tensor, logdet: Tensor - y, the random variable after flow - logdet, the log determinant of :math:`\partial y / \partial x` - Then the density :math:`\log(p(y)) = \log(p(x)) - logdet` - - """ - if self.inverse: - if init: - raise RuntimeError( - 'inverse flow shold be initialized with backward pass') - else: - return self._backward(inputs, *h, **kwargs) - else: - if init: - return self.init(inputs, *h, init_scale=init_scale, **kwargs) - else: - return self._forward(inputs, *h, **kwargs) - - def bwd_pass(self, inputs, *h, init=False, init_scale=1.0, - **kwargs) -> Tuple[torch.Tensor, torch.Tensor]: - """ - :param inputs: the random variable after the flow - :param h: other conditional inputs - :param init: bool, whether perform initialization or not - :param init_scale: float (default: 1.0) - :param kwargs: - :return: x: the random variable before the flow, - log_det: the log determinant of :math:`\partial x / \partial y` - Then the density :math:`\log(p(y)) = \log(p(x)) + logdet` - """ - if self.inverse: - if init: - return self.init(inputs, *h, init_scale=init_scale, **kwargs) - else: - return self._forward(inputs, *h, **kwargs) - else: - if init: - raise RuntimeError( - 'forward flow should be initialzed with forward pass') - else: - return self._backward(inputs, *h, **kwargs) - -class InvertibleLinearFlow(BaseFlow): - def __init__(self, channels, inverse): - super(InvertibleLinearFlow, self).__init__(inverse) +class InvertibleLinearFlow(Flow): + def __init__(self, channels): + super(InvertibleLinearFlow, self).__init__() self.channels = channels w_init = np.linalg.qr(np.random.randn(channels, channels))[0].astype(np.float32) - self.weight = nn.Parameter(torch.from_numpy(w_init).type(torch.float32).to(device)) + self.register_parameter("weight", nn.Parameter(torch.from_numpy(w_init))) - def _forward(self, inputs, inputs_lengths=None - ) -> Tuple[torch.Tensor, torch.Tensor]: + def forward(self, inputs: torch.Tensor, inputs_lengths=None) -> Tuple[torch.Tensor, torch.Tensor]: input_shape = inputs.shape outputs = torch.matmul(inputs, self.weight) - logdet = torch.linalg.slogdet( - self.weight.type(torch.float64))[1].type(torch.float32) + logdet = torch.linalg.slogdet(self.weight.double())[1].float() if inputs_lengths is None: - logdet = torch.ones(input_shape[0], device=device) * float(input_shape[1]) * logdet + logdet = torch.ones(input_shape[0], device=inputs.device) * float(input_shape[1]) * logdet else: - logdet = inputs_lengths.type(torch.float32) * logdet + logdet = inputs_lengths.float() * logdet return outputs, logdet - def _backward(self, inputs, inputs_lengths=None - ) -> Tuple[torch.Tensor, torch.Tensor]: + def inverse(self, inputs: torch.Tensor, inputs_lengths=None) -> Tuple[torch.Tensor, torch.Tensor]: input_shape = inputs.shape outputs = torch.matmul(inputs, torch.linalg.inv(self.weight)) - logdet = torch.linalg.slogdet( - torch.linalg.inv( - self.weight.type(torch.float64)))[1].type(torch.float32) + logdet = torch.linalg.slogdet(torch.linalg.inv(self.weight.double()))[1].float() if inputs_lengths is None: - logdet = torch.ones(input_shape[0], device=device) * float(input_shape[1]) * logdet + logdet = torch.ones(input_shape[0], device=inputs.device) * float(input_shape[1]) * logdet else: - logdet = inputs_lengths.type(torch.float32) * logdet + logdet = inputs_lengths.float() * logdet return outputs, logdet - def init(self, inputs, inputs_lengths=None): - return self._forward(inputs, inputs_lengths) + def init(self, inputs: torch.Tensor, inputs_lengths=None): + return self.forward(inputs, inputs_lengths) -class ActNormFlow(BaseFlow): - def __init__(self, channels, inverse): - super(ActNormFlow, self).__init__(inverse) +class ActNormFlow(Flow): + def __init__(self, channels): + super(ActNormFlow, self).__init__() self.channels = channels - self.log_scale = nn.Parameter(torch.normal(0.0, 0.05, [self.channels, ]).to(device)) - self.bias = nn.Parameter(torch.zeros(self.channels, device=device)) + self.register_parameter("log_scale", nn.Parameter(torch.normal(0.0, 0.05, [self.channels, ]))) + self.register_parameter("bias", nn.Parameter(torch.zeros(self.channels))) - def _forward(self, inputs, input_lengths=None) -> Tuple[torch.Tensor, torch.Tensor]: + def forward(self, inputs: torch.Tensor, input_lengths=None) -> Tuple[torch.Tensor, torch.Tensor]: input_shape = inputs.shape outputs = inputs * torch.exp(self.log_scale) + self.bias logdet = torch.sum(self.log_scale) if input_lengths is None: - logdet = torch.ones(input_shape[0], device=device) * float(input_shape[1]) * logdet + logdet = torch.ones(input_shape[0], device=inputs.device) * float(input_shape[1]) * logdet else: - logdet = input_lengths.type(torch.float32) * logdet + logdet = input_lengths.float() * logdet return outputs, logdet - def _backward(self, inputs, input_lengths=None, epsilon=1e-8 - ) -> Tuple[torch.Tensor, torch.Tensor]: + def inverse(self, inputs: torch.Tensor, input_lengths=None, epsilon=1e-8) -> Tuple[torch.Tensor, torch.Tensor]: input_shape = inputs.shape outputs = (inputs - self.bias) / (torch.exp(self.log_scale) + epsilon) logdet = -torch.sum(self.log_scale) if input_lengths is None: - logdet = torch.ones(input_shape[0], device=device) * float(input_shape[1]) * logdet + logdet = torch.ones(input_shape[0], device=inputs.device) * float(input_shape[1]) * logdet else: - logdet = input_lengths.type(torch.float32) * logdet + logdet = input_lengths.float() * logdet return outputs, logdet - def init(self, inputs, input_lengths=None, init_scale=1.0, epsilon=1e-8): + def init(self, inputs: torch.Tensor, input_lengths=None, init_scale=1.0, epsilon=1e-8): + # initialize from batch statistics _mean = torch.mean(inputs.view(-1, self.channels), dim=0) _std = torch.std(inputs.view(-1, self.channels), dim=0) self.log_scale.copy_(torch.log(init_scale / (_std + epsilon))) self.bias.copy_(-_mean / (_std + epsilon)) - return self._forward(inputs, input_lengths) + return self.forward(inputs, input_lengths) -class TransformerCoupling(BaseFlow): - def __init__(self, channels, inverse, nblk, embd_dim, attention_dim, attention_heads, +class TransformerCoupling(Flow): + def __init__(self, channels, nblk, embd_dim, attention_dim, attention_heads, temperature, ffn_hidden, order='upper'): + super(TransformerCoupling, self).__init__() # assert channels % 2 == 0 out_dim = channels // 2 self.channels = channels - super(TransformerCoupling, self).__init__(inverse) self.net = TransformerTransform( - nblk=nblk, channels=channels, embd_dim=embd_dim, attention_dim=attention_dim, attention_heads=attention_heads, + nblk=nblk, channels=channels, embd_dim=embd_dim, attention_dim=attention_dim, + attention_heads=attention_heads, temperature=temperature, ffn_hidden=ffn_hidden, out_dim=out_dim) self.upper = (order == 'upper') @@ -215,41 +137,39 @@ def _affine(inputs, scale, shift): def _inverse_affine(inputs, scale, shift, epsilon=1e-12): return (inputs - shift) / (scale + epsilon) - def _forward(self, inputs, condition_inputs, - inputs_lengths=None, condition_lengths=None - ) -> Tuple[torch.Tensor, torch.Tensor]: + def forward(self, inputs: torch.Tensor, condition_inputs: torch.Tensor, inputs_lengths=None, + condition_lengths=None) -> Tuple[torch.Tensor, torch.Tensor]: # assert inputs.shape[-1] == self.channels lower_pt, upper_pt = self._split(inputs) z, zp = (lower_pt, upper_pt) if self.upper else (upper_pt, lower_pt) - log_scale, shift = self.net(z, condition_inputs, condition_lengths, - inputs_lengths) + log_scale, shift = self.net(z, condition_inputs, condition_lengths, inputs_lengths) scale = torch.sigmoid(log_scale + 2.0) + zp = self._affine(zp, scale, shift) inputs_max_time = inputs.shape[1] mask = (get_mask_from_lengths(inputs_lengths, inputs_max_time).unsqueeze(-1) if inputs_lengths is not None else torch.ones_like(log_scale)) + logdet = torch.sum(torch.log(scale) * mask, dim=[1, 2]) # [batch, ] outputs = torch.cat([z, zp], dim=-1) if self.upper else torch.cat([zp, z], dim=-1) return outputs, logdet - def _backward(self, inputs, condition_inputs, - inputs_lengths=None, condition_lengths=None - ) -> Tuple[torch.Tensor, torch.Tensor]: + def inverse(self, inputs: torch.Tensor, condition_inputs: torch.Tensor, inputs_lengths=None, + condition_lengths=None) -> Tuple[torch.Tensor, torch.Tensor]: # assert inputs.shape[-1] == self.channels lower_pt, upper_pt = self._split(inputs) z, zp = (lower_pt, upper_pt) if self.upper else (upper_pt, lower_pt) - log_scale, shift = self.net(z, condition_inputs, condition_lengths, - inputs_lengths) + log_scale, shift = self.net(z, condition_inputs, condition_lengths, inputs_lengths) scale = torch.sigmoid(log_scale + 2.0) + zp = self._inverse_affine(zp, scale, shift) inputs_max_time = inputs.shape[1] mask = (get_mask_from_lengths(inputs_lengths, inputs_max_time).unsqueeze(-1) if inputs_lengths is not None else torch.ones_like(log_scale)) + log_det = -torch.sum(torch.log(scale) * mask, dim=[1, 2]) # [batch,] outputs = torch.cat([z, zp], dim=-1) if self.upper else torch.cat([zp, z], dim=-1) return outputs, log_det - def init(self, inputs, condition_inputs, inputs_lengths=None, - condition_lengths=None): - return self._forward( - inputs, condition_inputs, inputs_lengths, condition_lengths) + def init(self, inputs: torch.Tensor, condition_inputs: torch.Tensor, inputs_lengths=None, condition_lengths=None): + return self.forward(inputs, condition_inputs, inputs_lengths, condition_lengths) diff --git a/model/glow.py b/model/glow.py new file mode 100644 index 0000000..6c6adad --- /dev/null +++ b/model/glow.py @@ -0,0 +1,92 @@ +import torch +import torch.nn as nn + +from .flow import InvertibleLinearFlow, ActNormFlow, TransformerCoupling, Flow + + +class GlowBlock(Flow): + def __init__(self, channels, n_transformer_blk, embd_dim, attention_dim, + attention_heads, temperature, ffn_hidden, order): + super(GlowBlock, self).__init__() + self.actnorm = ActNormFlow(channels) + self.linear = InvertibleLinearFlow(channels) + self.affine_coupling = TransformerCoupling(channels=channels, + nblk=n_transformer_blk, + embd_dim=embd_dim, + attention_dim=attention_dim, + attention_heads=attention_heads, + temperature=temperature, + ffn_hidden=ffn_hidden, + order=order) + + def forward(self, z, inputs: torch.Tensor, targets_lengths, condition_lengths): + total_logdet = torch.zeros([z.size(0), ], device=z.device) + + z, logdet = self.actnorm(z, targets_lengths) + total_logdet += logdet + z, logdet = self.linear(z, targets_lengths) + total_logdet += logdet + z, logdet = self.affine_coupling(inputs=z, condition_inputs=inputs, + inputs_lengths=targets_lengths, + condition_lengths=condition_lengths) + total_logdet += logdet + return z, total_logdet + + def inverse(self, z, inputs: torch.Tensor, targets_lengths, condition_lengths): + total_logdet = torch.zeros([z.size(0), ], device=z.device) + + # reverse order + z, logdet = self.affine_coupling.inverse(inputs=z, condition_inputs=inputs, + inputs_lengths=targets_lengths, + condition_lengths=condition_lengths) + total_logdet += logdet + z, logdet = self.linear.inverse(z, targets_lengths) + total_logdet += logdet + z, logdet = self.actnorm.inverse(z, targets_lengths) + total_logdet += logdet + + return z, total_logdet + + def init(self, z, inputs: torch.Tensor, targets_lengths, condition_lengths): + total_logdet = torch.zeros([z.size(0), ], device=z.device) + + z, logdet = self.actnorm.init(z, targets_lengths) + total_logdet += logdet + z, logdet = self.linear.init(z, targets_lengths) + total_logdet += logdet + z, logdet = self.affine_coupling.init(inputs=z, condition_inputs=inputs, + inputs_lengths=targets_lengths, + condition_lengths=condition_lengths) + total_logdet += logdet + return z, total_logdet + + +class Glow(Flow): + def __init__(self, n_blocks, channels, n_transformer_blk, embd_dim, attention_dim, + attention_heads, temperature, ffn_hidden): + super(Glow, self).__init__() + orders = ['upper', 'lower'] + self.flows = nn.ModuleList([GlowBlock(channels, n_transformer_blk, embd_dim, attention_dim, + attention_heads, temperature, ffn_hidden, orders[i % 2]) + for i in range(n_blocks)]) + + def forward(self, z, inputs: torch.Tensor, targets_lengths, condition_lengths): + total_logdet = torch.zeros([z.size(0), ], device=z.device) + for flow in self.flows: + z, logdet = flow(z, inputs, targets_lengths, condition_lengths) + total_logdet += logdet + return z, total_logdet + + def inverse(self, z, inputs: torch.Tensor, targets_lengths, condition_lengths): + total_logdet = torch.zeros([z.size(0), ], device=z.device) + for i in range(len(self.flows) - 1, -1, -1): + z, logdet = self.flows[i].inverse(z, inputs, targets_lengths, condition_lengths) + total_logdet += logdet + return z, total_logdet + + def init(self, z, inputs: torch.Tensor, targets_lengths, condition_lengths): + total_logdet = torch.zeros([z.size(0), ], device=z.device) + for flow in self.flows: + z, logdet = flow.init(z, inputs, targets_lengths, condition_lengths) + total_logdet += logdet + return z, total_logdet diff --git a/model/posterior.py b/model/posterior.py index 1ef19e6..b0154e3 100644 --- a/model/posterior.py +++ b/model/posterior.py @@ -1,14 +1,15 @@ import torch import torch.nn as nn import torch.nn.functional as F + import numpy as np import math - from .utils import LinearNorm, PreNet, PositionalEncoding -from .attention import CrossAttentionBLK +from .attention import CrossAttentionBlock from utils.tools import get_mask_from_lengths -device = torch.device("cuda" if torch.cuda.is_available() else "cpu") + +# device = torch.device("cuda" if torch.cuda.is_available() else "cpu") class BasePosterior(nn.Module): @@ -18,7 +19,7 @@ def __init__(self): super(BasePosterior, self).__init__() def forward(self, inputs, src_enc, src_lengths=None, target_lengths=None - ): + ): raise NotImplementedError @staticmethod @@ -34,9 +35,9 @@ def reparameterize(mu, logvar, nsamples=1, random=True): batch, max_time, dim = mu.shape std = torch.exp(0.5 * logvar) if random: - eps = torch.normal(0.0, 1.0, [batch, nsamples, max_time, dim]).to(device) + eps = torch.normal(0.0, 1.0, [batch, nsamples, max_time, dim]).to(mu.device) else: - eps = torch.zeros([batch, nsamples, max_time, dim], device=device) + eps = torch.zeros([batch, nsamples, max_time, dim], device=mu.device) samples = eps * std.unsqueeze(1) + mu.unsqueeze(1) return samples, eps @@ -53,22 +54,22 @@ def log_probability(mu, logvar, z=None, eps=None, seq_lengths=None, epsilon=1e-8 """ # print('tracing back at posterior log-probability') batch, max_time, dim = mu.shape - std = torch.exp(0.5 * logvar) + + # random noise + # std = torch.exp(0.5 * logvar) normalized_samples = (eps if eps is not None else (z - mu.unsqueeze(1)) - / (std.unsqueeze(1) + epsilon)) + / (torch.exp(0.5 * logvar).unsqueeze(1) + epsilon)) + expanded_logvar = logvar.unsqueeze(1) # time_level_log_probs [batch, nsamples, max_time] - time_level_log_probs = -0.5 * ( - float(dim) * math.log(2 * np.pi) - + torch.sum(expanded_logvar + normalized_samples ** 2., - dim=3)) + time_level_log_probs = -0.5 * (float(dim) * math.log(2 * np.pi) + + torch.sum(expanded_logvar + normalized_samples ** 2, dim=3)) seq_mask = (get_mask_from_lengths(seq_lengths, max_time) if seq_lengths is not None - else torch.ones([batch, max_time], device=device)) + else torch.ones([batch, max_time], device=mu.device)) seq_mask = seq_mask.unsqueeze(1) # [batch, 1, max_time] - sample_level_log_probs = torch.sum(seq_mask * time_level_log_probs, - dim=2) # [batch, nsamples] + sample_level_log_probs = torch.sum(seq_mask * time_level_log_probs, dim=2) # [batch, nsamples] return sample_level_log_probs def sample(self, inputs, src_enc, input_lengths, src_lengths, @@ -92,35 +93,36 @@ def __init__(self, num_mels, embd_dim, pre_hidden, pre_drop_rate, pre_activation pos_drop_rate, nblk, attention_dim, attention_heads, temperature, ffn_hidden, latent_dim): super(TransformerPosterior, self).__init__() - self.pos_weight = nn.Parameter(torch.tensor(1.0, device=device)) + # self.pos_weight = nn.Parameter(torch.tensor(1.0, device=device)) + self.register_parameter("pos_weight", nn.Parameter(torch.tensor(1.0))) self.prenet = PreNet(in_features=num_mels, units=pre_hidden, drop_rate=pre_drop_rate, activation=pre_activation) self.pe = PositionalEncoding() self.pe_dropout = nn.Dropout(p=pos_drop_rate) self.attentions = nn.ModuleList( [ - CrossAttentionBLK(input_dim=pre_hidden, - memory_dim=embd_dim, - attention_dim=attention_dim, - attention_heads=attention_heads, - attention_temperature=temperature, - ffn_hidden=ffn_hidden) + CrossAttentionBlock(input_dim=pre_hidden, + memory_dim=embd_dim, + attention_dim=attention_dim, + attention_heads=attention_heads, + attention_temperature=temperature, + ffn_hidden=ffn_hidden) for i in range(nblk) ] ) self.mu_projection = LinearNorm(attention_dim, latent_dim, - kernel_initializer='zeros') + kernel_initializer='none') self.logvar_projection = LinearNorm(attention_dim, latent_dim, - kernel_initializer='zeros') + kernel_initializer='none') def forward(self, inputs, src_enc, src_lengths=None, target_lengths=None): # print('tracing back at posterior call') prenet_outs = self.prenet(inputs) max_time = prenet_outs.shape[1] dim = prenet_outs.shape[2] - pos = self.pe.positional_encoding(max_time, dim, device) + pos = self.pe.positional_encoding(max_time, dim, inputs.device) pos_embs = prenet_outs + self.pos_weight * pos pos_embs = self.pe_dropout(pos_embs) att_outs = pos_embs @@ -128,12 +130,14 @@ def forward(self, inputs, src_enc, src_lengths=None, target_lengths=None): att_outs, alignments = att( inputs=att_outs, memory=src_enc, query_lengths=target_lengths, memory_lengths=src_lengths) + + # [batch, target_lengths, latent_dim] mu = self.mu_projection(att_outs) logvar = self.logvar_projection(att_outs) return mu, logvar, None def sample(self, inputs, src_enc, input_lengths, src_lengths, - nsamples=1, random=True, training=None): + nsamples=1, random=True): mu, logvar, _ = self.forward(inputs, src_enc, input_lengths, src_lengths) samples, eps = self.reparameterize(mu, logvar, nsamples, random) log_probs = self.log_probability(mu, logvar, eps, input_lengths) diff --git a/model/prior.py b/model/prior.py index 7e01aa0..b630569 100644 --- a/model/prior.py +++ b/model/prior.py @@ -4,30 +4,18 @@ import math from numpy import pi as PI -from .flow import InvertibleLinearFlow, ActNormFlow, TransformerCoupling +from .glow import Glow from utils.tools import get_mask_from_lengths -device = torch.device("cuda" if torch.cuda.is_available() else "cpu") - -class BasePrior(nn.Module): +class Prior(nn.Module): """ P(z|x): prior that generate the latent variables conditioned on x """ def __init__(self, channels): - super(BasePrior, self).__init__() + super(Prior, self).__init__() self.channels = channels - def forward(self, inputs, targets_lengths, condition_lengths - ): - """ - :param targets_lengths: [batch, ] - :param inputs: condition_inputs - :param condition_lengths: - :return: tensor1: outputs, tensor2: log_probabilities - """ - raise NotImplementedError - def _initial_sample(self, targets_lengths, temperature=1.0): """ :param targets_lengths: [batch,] @@ -36,13 +24,23 @@ def _initial_sample(self, targets_lengths, temperature=1.0): log-probabilities: [batch, ] """ batch_size = targets_lengths.shape[0] - length = torch.max(targets_lengths).type(torch.int32) - epsilon = torch.normal(0.0, temperature, [batch_size, length, self.channels]).to(device) - logprobs = -0.5 * (math.log(2. * PI) + epsilon ** 2) + length = torch.max(targets_lengths).long() + epsilon = torch.normal(0.0, temperature, [batch_size, length, self.channels]).to(targets_lengths.device) + + logprobs = -0.5 * (epsilon ** 2 + math.log(2 * PI)) seq_mask = get_mask_from_lengths(targets_lengths).unsqueeze(-1) # [batch, max_time, 1] logprobs = torch.sum(seq_mask * logprobs, dim=[1, 2]) # [batch, ] return epsilon, logprobs + def forward(self, inputs, targets_lengths, condition_lengths): + """ + :param targets_lengths: [batch, ] + :param inputs: condition_inputs + :param condition_lengths: + :return: tensor1: outputs, tensor2: log_probabilities + """ + raise NotImplementedError + def log_probability(self, z, condition_inputs, z_lengths=None, condition_lengths=None ): """ @@ -65,69 +63,24 @@ def init(self, *inputs, **kwargs): """ raise NotImplementedError - def sample(self, targets_lengths, n_samples, condition_inputs, condition_lengths=None - ): - """ - :param targets_lengths: - :param n_samples: - :param condition_inputs: - :param condition_lengths: - :return: tensor1: samples: [batch, n_samples, max_lengths, dim] - tensor2: log-probabilities: [batch, n_samples] - """ - raise NotImplementedError - -class TransformerPrior(BasePrior): +class TransformerPrior(Prior): def __init__(self, n_blk, channels, n_transformer_blk, embd_dim, attention_dim, - attention_heads, temperature, ffn_hidden, inverse=False): + attention_heads, temperature, ffn_hidden): super(TransformerPrior, self).__init__(channels) - orders = ['upper', 'lower'] - self.actnorms = nn.ModuleList( - [ - ActNormFlow(channels, inverse) - for i in range(n_blk) - ] - ) - self.linears = nn.ModuleList( - [ - InvertibleLinearFlow(channels, inverse) - for i in range(n_blk) - ] - ) - self.affine_couplings = nn.ModuleList( - [ - TransformerCoupling(channels=channels, inverse=inverse, - nblk=n_transformer_blk, - embd_dim=embd_dim, - attention_dim=attention_dim, - attention_heads=attention_heads, - temperature=temperature, - ffn_hidden=ffn_hidden, - order=orders[i % 2]) - for i in range(n_blk) - ] - ) - - def forward(self, inputs, targets_lengths, condition_lengths, temperature=1.0 - ): - # 1. get initial noise + self.glow = Glow(n_blk, channels, n_transformer_blk, embd_dim, attention_dim, + attention_heads, temperature, ffn_hidden) + + def forward(self, targets_lengths, conditional_inputs: torch.Tensor, condition_lengths, temperature=1.0): + # get initial noise epsilon, logprobs = self._initial_sample(targets_lengths, temperature=temperature) + z = epsilon - for _, (actnorm, linear, affine_coupling) in enumerate(zip(self.actnorms, self.linears, self.affine_couplings)): - z, logdet = actnorm(z, targets_lengths) - logprobs -= logdet - z, logdet = linear(z, targets_lengths) - logprobs -= logdet - z, logdet = affine_coupling(inputs=z, condition_inputs=inputs, - inputs_lengths=targets_lengths, - condition_lengths=condition_lengths) - logprobs -= logdet + z, logdet = self.glow(z, conditional_inputs, targets_lengths, condition_lengths) + logprobs += logdet return z, logprobs - def log_probability(self, z, condition_inputs, z_lengths=None, - condition_lengths=None - ): + def log_probability(self, z, condition_inputs, z_lengths=None, condition_lengths=None): """ :param z: [batch, max_time, dim] :param condition_inputs: @@ -135,53 +88,19 @@ def log_probability(self, z, condition_inputs, z_lengths=None, :param condition_lengths: :return: log-probabilities of z, [batch] """ - # print('tracing back at prior log-probability') - epsilon = z - batch_size = z.shape[0] + epsilon, logdet = self.glow.inverse(z, condition_inputs, z_lengths, condition_lengths) + + logprobs = -0.5 * (epsilon ** 2 + math.log(2 * PI)) max_time = z.shape[1] - accum_logdet = torch.zeros([batch_size, ], dtype=torch.float32, device=device) - for _, (actnorm, linear, affine_coupling) in enumerate(zip(self.actnorms, self.linears, self.affine_couplings)): - epsilon, logdet = affine_coupling.bwd_pass(inputs=epsilon, - condition_inputs=condition_inputs, - inputs_lengths=z_lengths, - condition_lengths=condition_lengths) - accum_logdet += logdet - epsilon, logdet = linear.bwd_pass(epsilon, z_lengths) - accum_logdet += logdet - epsilon, logdet = actnorm.bwd_pass(epsilon, z_lengths) - accum_logdet += logdet - logprobs = -0.5 * (math.log(2. * PI) + epsilon ** 2) seq_mask = get_mask_from_lengths(z_lengths, max_time).unsqueeze(-1) # [batch, max_time] logprobs = torch.sum(seq_mask * logprobs, dim=[1, 2]) # [batch, ] - logprobs += accum_logdet + logprobs += logdet return logprobs - def sample(self, targets_lengths, condition_inputs, condition_lengths=None, temperature=1.0): - # 1. get initial noise - epsilon, logprobs = self._initial_sample(targets_lengths, temperature=temperature) # [batch*n_samples, ] - z = epsilon - for _, (actnorm, linear, affine_coupling) in enumerate(zip(self.actnorms, self.linears, self.affine_couplings)): - z, logdet = actnorm(z, targets_lengths) - logprobs -= logdet - z, logdet = linear(z, targets_lengths) - logprobs -= logdet - z, logdet = affine_coupling.fwd_pass(inputs=z, condition_inputs=condition_inputs, - inputs_lengths=targets_lengths, - condition_lengths=condition_lengths) - logprobs -= logdet - return z, logprobs - - def init(self, conditions, targets_lengths, condition_lengths): - # 1. get initial noise + def init(self, inputs: torch.Tensor, targets_lengths, condition_lengths): + # get initial noise epsilon, logprobs = self._initial_sample(targets_lengths) + z = epsilon - for _, (actnorm, linear, affine_coupling) in enumerate(zip(self.actnorms, self.linears, self.affine_couplings)): - z, logdet = actnorm.init(z, targets_lengths) - logprobs -= logdet - z, logdet = linear(z, targets_lengths) - logprobs -= logdet - z, logdet = affine_coupling.init(inputs=z, condition_inputs=conditions, - inputs_lengths=targets_lengths, - condition_lengths=condition_lengths) - logprobs -= logdet - return z, logprobs \ No newline at end of file + z, logdet = self.glow.init(z, inputs, targets_lengths, condition_lengths) + return z, logprobs diff --git a/model/transform.py b/model/transform.py index b797b73..7857e16 100644 --- a/model/transform.py +++ b/model/transform.py @@ -3,9 +3,10 @@ import torch.nn.functional as F from .utils import LinearNorm, PositionalEncoding -from .attention import CrossAttentionBLK +from .attention import CrossAttentionBlock -device = torch.device("cuda" if torch.cuda.is_available() else "cpu") + +# device = torch.device("cuda" if torch.cuda.is_available() else "cpu") class BaseTransform(nn.Module): @@ -13,12 +14,12 @@ def __init__(self, in_dim, out_dim): super(BaseTransform, self).__init__() self.out_dim = out_dim self.log_scale_proj = LinearNorm(in_dim, self.out_dim, - kernel_initializer='zeros') + kernel_initializer='none') self.shift_proj = LinearNorm(in_dim, self.out_dim, - kernel_initializer='zeros') + kernel_initializer='none') def forward(self, inputs, condition_inputs, condition_lengths=None - ): + ): """ :param inputs: xa inputs :param condition_inputs: @@ -33,26 +34,27 @@ def __init__(self, nblk, channels, embd_dim, attention_dim, attention_heads, tem ffn_hidden, out_dim): super(TransformerTransform, self).__init__(in_dim=attention_dim, out_dim=out_dim) self.pos_emb_layer = PositionalEncoding() - self.pos_weight = nn.Parameter(torch.tensor(1.0, device=device)) + # self.pos_weight = nn.Parameter(torch.tensor(1.0, device=device)) + self.register_parameter("pos_weight", nn.Parameter(torch.tensor(1.0))) self.pre_projection = LinearNorm(channels // 2, attention_dim) self.attentions = nn.ModuleList( [ - CrossAttentionBLK(input_dim=attention_dim, - memory_dim=embd_dim, - attention_dim=attention_dim, - attention_heads=attention_heads, - attention_temperature=temperature, - ffn_hidden=ffn_hidden) + CrossAttentionBlock(input_dim=attention_dim, + memory_dim=embd_dim, + attention_dim=attention_dim, + attention_heads=attention_heads, + attention_temperature=temperature, + ffn_hidden=ffn_hidden) for i in range(nblk) ] ) def forward(self, inputs, condition_inputs, condition_lengths=None, - target_lengths=None): + target_lengths=None): att_outs = self.pre_projection(inputs) max_time = att_outs.shape[1] dim = att_outs.shape[2] - pos_embd = self.pos_emb_layer.positional_encoding(max_time, dim, device) + pos_embd = self.pos_emb_layer.positional_encoding(max_time, dim, inputs.device) att_outs += self.pos_weight * pos_embd for att in self.attentions: att_outs, _ = att(inputs=att_outs, memory=condition_inputs, diff --git a/model/utils.py b/model/utils.py index 26d945c..e12b302 100644 --- a/model/utils.py +++ b/model/utils.py @@ -4,25 +4,24 @@ class LinearNorm(nn.Module): - def __init__(self, in_features, out_features, activation=None, - use_bias=True, kernel_initializer='glorot_uniform', bias_initializer='zeros'): + def __init__(self, in_features, out_features, activation=None, + use_bias=True, kernel_initializer='glorot_uniform', bias_initializer='zeros'): super(LinearNorm, self).__init__() self.linear = nn.Linear( in_features=in_features, out_features=out_features, bias=use_bias) - # init weight if kernel_initializer == 'glorot_uniform': nn.init.xavier_uniform_(self.linear.weight) elif kernel_initializer == 'zeros': nn.init.zeros_(self.linear.weight) - # init bias if use_bias: if bias_initializer == 'zeros': nn.init.constant_(self.linear.bias, 0.0) - + else: + raise NotImplementedError self.activation = activation if activation is not None else nn.Identity() def forward(self, x): @@ -32,14 +31,14 @@ def forward(self, x): class ConvNorm(nn.Module): def __init__( - self, in_channels, out_channels, kernel_size=1, stride=1, - padding=None, dilation=1, activation=None, + self, in_channels, out_channels, kernel_size=1, stride=1, + padding=None, dilation=1, activation=None, use_bias=True, kernel_initializer='glorot_uniform', bias_initializer='zeros'): super(ConvNorm, self).__init__() self.conv = nn.Conv1d(in_channels, out_channels, kernel_size=kernel_size, - stride=stride, padding=padding, dilation=dilation, bias=use_bias - ) + stride=stride, padding=padding, dilation=dilation, bias=use_bias + ) # init weight if kernel_initializer == 'glorot_uniform': @@ -82,7 +81,7 @@ def __init__(self, nconv, hidden, conv_kernel, drop_rate, self.conv_stack = nn.ModuleList( [ Conv1D(in_channels=hidden, out_channels=hidden, kernel_size=conv_kernel, activation=activation, - drop_rate=drop_rate, bn_before_act=bn_before_act) + drop_rate=drop_rate, bn_before_act=bn_before_act) for i in range(nconv) ] ) @@ -118,12 +117,12 @@ def __init__(self, in_channels, out_channels, kernel_size, activation, drop_rate bn_before_act=False, strides=1): super(Conv1D, self).__init__() self.conv1d = ConvNorm(in_channels=in_channels, - out_channels=out_channels, - kernel_size=kernel_size, - stride=strides, - padding=int((kernel_size - 1) / 2), - dilation=1, - activation=None) + out_channels=out_channels, + kernel_size=kernel_size, + stride=strides, + padding=int((kernel_size - 1) / 2), + dilation=1, + activation=None) self.activation = activation if activation is not None else nn.Identity() self.bn = nn.BatchNorm1d(out_channels) self.dropout = nn.Dropout(p=drop_rate) @@ -152,8 +151,9 @@ def __init__(self, n_conv, hidden, conv_filters, conv_kernel, activations = [nn.Tanh()] * (n_conv - 1) + [nn.Identity()] self.conv_stack = nn.ModuleList( [ - Conv1D(in_channels=hidden if i==0 else conv_filters, out_channels=conv_filters, kernel_size=conv_kernel, - activation=activations[i], drop_rate=drop_rate) + Conv1D(in_channels=hidden if i == 0 else conv_filters, out_channels=conv_filters, + kernel_size=conv_kernel, + activation=activations[i], drop_rate=drop_rate) for i in range(n_conv) ] ) diff --git a/synthesize.py b/synthesize.py index 7f3c12e..b9d5cd0 100644 --- a/synthesize.py +++ b/synthesize.py @@ -43,26 +43,18 @@ def synthesize(model, step, configs, vocoder, audio_processor, batchs, temperatu for batch in batchs: batch = to_device(batch, device) with torch.no_grad(): - t, t_l = batch[3], batch[4] - text_pos_step = model.mel_text_len_ratio / float(final_reduction_factor) - text_embd = model.text_encoder(t, t_l, pos_step=text_pos_step) - predicted_lengths = model.length_predictor( - text_embd.detach(), t_l) - predicted_m_l = predicted_lengths.type(torch.int32) - reduced_pred_ml = (predicted_m_l + 80 + final_reduction_factor - 1 - ) // final_reduction_factor - prior_latents, prior_logprobs = model.prior.sample( - reduced_pred_ml, text_embd, t_l, temperature=temperature) - _, prior_dec_outs, prior_dec_alignments = model.decoder( - prior_latents, text_embd, reduced_pred_ml, t_l) + texts, text_lengths = batch[3], batch[4] + + mel, mel_lengths, reduced_mel_lengths, alignments, *_ = model.inference( + inputs=texts, text_lengths=text_lengths, reduction_factor=final_reduction_factor) synth_samples( batch, - prior_dec_outs, - predicted_m_l + 80, - reduced_pred_ml, - t_l, - prior_dec_alignments, + mel, + mel_lengths, + reduced_mel_lengths, + text_lengths, + alignments, vocoder, audio_processor, model_config, diff --git a/train.py b/train.py index 7c83421..8f0d1b6 100644 --- a/train.py +++ b/train.py @@ -101,7 +101,7 @@ def _get_reduction_factor(ep): model.init(text_inputs=batch[2:][1], mel_lengths=batch[2:][5], text_lengths=batch[2:][2]) # Forward - predictions, mel_l2, kl_divergence, length_l2, dec_alignments, reduced_mel_lens = model( + (predictions, mel_l2, kl_divergence, length_l2, dec_alignments, reduced_mel_lens, *_) = model( *(batch[2:]), reduce_loss=True, reduction_factor=reduction_factor @@ -186,7 +186,8 @@ def _get_reduction_factor(ep): val_logger, vocoder, audio_processor, - len(losses) + len(losses), + device, ) with open(os.path.join(val_log_path, "log.txt"), "a") as f: f.write(message + "\n") diff --git a/utils/tools.py b/utils/tools.py index 79927d6..eab626d 100644 --- a/utils/tools.py +++ b/utils/tools.py @@ -207,7 +207,7 @@ def synth_samples( src_len = text_lens[i].item() mel_len = pred_lens[i].item() reduced_mel_len = reduced_pred_lens[i].item() - mel_prediction = predictions[i, :mel_len].transpose(0, 1) + mel_prediction = np.transpose(predictions[i, :mel_len], [1, 0]) attn_keys, attn_values = list(), list() for key, value in sorted(dec_alignments.items()): @@ -256,6 +256,7 @@ def plot_mel(data, titles, save_dir=None): for i in range(len(data)): mel = data[i] + print(mel.shape) axes[i][0].imshow(mel, origin="lower") axes[i][0].set_aspect(2.5, adjustable="box") axes[i][0].set_ylim(0, mel.shape[0])