Skip to content

Commit

Permalink
Made models use the regular way of detecting training vs eval
Browse files Browse the repository at this point in the history
  • Loading branch information
TheButlah committed Jul 26, 2019
1 parent 62abd12 commit 90bb56d
Show file tree
Hide file tree
Showing 2 changed files with 8 additions and 17 deletions.
3 changes: 2 additions & 1 deletion models/fatchord_version.py
Original file line number Diff line number Diff line change
Expand Up @@ -167,11 +167,12 @@ def forward(self, x, mels):
return self.fc3(x)

def generate(self, mels, save_path: Union[str, Path], batched, target, overlap, mu_law):
self.eval()

device = next(self.parameters()).device # use same device as parameters

mu_law = mu_law if self.mode == 'RAW' else False

self.eval()
output = []
start = time.time()
rnn1 = self.get_gru_cell(self.rnn1)
Expand Down
22 changes: 6 additions & 16 deletions models/tacotron.py
Original file line number Diff line number Diff line change
Expand Up @@ -212,7 +212,6 @@ class Decoder(nn.Module):
def __init__(self, n_mels, decoder_dims, lstm_dims):
super().__init__()
self.register_buffer('r', torch.tensor(1, dtype=torch.int))
self.generating = False
self.n_mels = n_mels
self.prenet = PreNet(n_mels)
self.attn_net = LSA(decoder_dims)
Expand Down Expand Up @@ -257,15 +256,15 @@ def forward(self, encoder_seq, encoder_seq_proj, prenet_in,

# Compute first Residual RNN
rnn1_hidden_next, rnn1_cell = self.res_rnn1(x, (rnn1_hidden, rnn1_cell))
if not self.generating:
if self.training:
rnn1_hidden = self.zoneout(rnn1_hidden, rnn1_hidden_next)
else:
rnn1_hidden = rnn1_hidden_next
x = x + rnn1_hidden

# Compute second Residual RNN
rnn2_hidden_next, rnn2_cell = self.res_rnn2(x, (rnn2_hidden, rnn2_cell))
if not self.generating:
if self.training:
rnn2_hidden = self.zoneout(rnn2_hidden, rnn2_hidden_next)
else:
rnn2_hidden = rnn2_hidden_next
Expand Down Expand Up @@ -313,13 +312,9 @@ def forward(self, x, m, generate_gta=False):
self.step += 1

if generate_gta:
self.encoder.eval()
self.postnet.eval()
self.decoder.generating = True
self.eval()
else:
self.encoder.train()
self.postnet.train()
self.decoder.generating = False
self.train()

batch_size, _, steps = m.size()

Expand Down Expand Up @@ -372,11 +367,8 @@ def forward(self, x, m, generate_gta=False):
return mel_outputs, linear, attn_scores

def generate(self, x, steps=2000):
self.eval()
device = next(self.parameters()).device # use same device as parameters

self.encoder.eval()
self.postnet.eval()
self.decoder.generating = True

batch_size = 1
x = torch.as_tensor(x, dtype=torch.long, device=device).unsqueeze(0)
Expand Down Expand Up @@ -432,9 +424,7 @@ def generate(self, x, steps=2000):
attn_scores = torch.cat(attn_scores, 1)
attn_scores = attn_scores.cpu().data.numpy()[0]

self.encoder.train()
self.postnet.train()
self.decoder.generating = False
self.train()

return mel_outputs, linear, attn_scores

Expand Down

0 comments on commit 90bb56d

Please sign in to comment.