Skip to content

Commit

Permalink
model.py: renaming variables, removing dropout from lstm cell state, …
Browse files Browse the repository at this point in the history
…removing conversions now handled by amp
  • Loading branch information
rafaelvalle committed Apr 3, 2019
1 parent 087c867 commit 1480f82
Showing 1 changed file with 5 additions and 17 deletions.
22 changes: 5 additions & 17 deletions model.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,6 @@
from torch.nn import functional as F
from layers import ConvNorm, LinearNorm
from utils import to_gpu, get_mask_from_lengths
from fp16_optimizer import fp32_to_fp16, fp16_to_fp32


class LocationLayer(nn.Module):
Expand Down Expand Up @@ -355,8 +354,6 @@ def decode(self, decoder_input):
cell_input, (self.attention_hidden, self.attention_cell))
self.attention_hidden = F.dropout(
self.attention_hidden, self.p_attention_dropout, self.training)
self.attention_cell = F.dropout(
self.attention_cell, self.p_attention_dropout, self.training)

attention_weights_cat = torch.cat(
(self.attention_weights.unsqueeze(1),
Expand All @@ -372,8 +369,6 @@ def decode(self, decoder_input):
decoder_input, (self.decoder_hidden, self.decoder_cell))
self.decoder_hidden = F.dropout(
self.decoder_hidden, self.p_decoder_dropout, self.training)
self.decoder_cell = F.dropout(
self.decoder_cell, self.p_decoder_dropout, self.training)

decoder_hidden_attention_context = torch.cat(
(self.decoder_hidden, self.attention_context), dim=1)
Expand Down Expand Up @@ -489,10 +484,6 @@ def parse_batch(self, batch):
(text_padded, input_lengths, mel_padded, max_len, output_lengths),
(mel_padded, gate_padded))

def parse_input(self, inputs):
inputs = fp32_to_fp16(inputs) if self.fp16_run else inputs
return inputs

def parse_output(self, outputs, output_lengths=None):
if self.mask_padding and output_lengths is not None:
mask = ~get_mask_from_lengths(output_lengths)
Expand All @@ -503,20 +494,18 @@ def parse_output(self, outputs, output_lengths=None):
outputs[1].data.masked_fill_(mask, 0.0)
outputs[2].data.masked_fill_(mask[:, 0, :], 1e3) # gate energies

outputs = fp16_to_fp32(outputs) if self.fp16_run else outputs
return outputs

def forward(self, inputs):
inputs, input_lengths, targets, max_len, \
output_lengths = self.parse_input(inputs)
input_lengths, output_lengths = input_lengths.data, output_lengths.data
text_inputs, text_lengths, mels, max_len, output_lengths = inputs
text_lengths, output_lengths = text_lengths.data, output_lengths.data

embedded_inputs = self.embedding(inputs).transpose(1, 2)
embedded_inputs = self.embedding(text_inputs).transpose(1, 2)

encoder_outputs = self.encoder(embedded_inputs, input_lengths)
encoder_outputs = self.encoder(embedded_inputs, text_lengths)

mel_outputs, gate_outputs, alignments = self.decoder(
encoder_outputs, targets, memory_lengths=input_lengths)
encoder_outputs, mels, memory_lengths=text_lengths)

mel_outputs_postnet = self.postnet(mel_outputs)
mel_outputs_postnet = mel_outputs + mel_outputs_postnet
Expand All @@ -526,7 +515,6 @@ def forward(self, inputs):
output_lengths)

def inference(self, inputs):
inputs = self.parse_input(inputs)
embedded_inputs = self.embedding(inputs).transpose(1, 2)
encoder_outputs = self.encoder.inference(embedded_inputs)
mel_outputs, gate_outputs, alignments = self.decoder.inference(
Expand Down

2 comments on commit 1480f82

@zhanglq95
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

why remove dropout in the lstm unit?

@rafaelvalle
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

the dropout is only removed from the cell state, not the hidden state. It is removed because at every decoder iteration it scales the cell state by 1/ (1-p), creating an exponential grown on the cell state value that disturbs FP16 training.

Please sign in to comment.