Skip to content

Commit

Permalink
Return hidden with lengths and masks
Browse files Browse the repository at this point in the history
  • Loading branch information
uoo723 committed Nov 23, 2020
1 parent 64c13b3 commit 11c4d91
Show file tree
Hide file tree
Showing 2 changed files with 4 additions and 4 deletions.
4 changes: 2 additions & 2 deletions deepxml/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -94,9 +94,9 @@ def train_step(self, train_x: torch.Tensor, train_y: torch.Tensor):
emb, train_y = self.mixup_fn(emb, train_y)
scores = self.model((emb, lengths, masks), pass_emb=True)
else:
hidden = self.model(train_x, return_hidden=True)
hidden, lengths, masks = self.model(train_x, return_hidden=True)
hidden, train_y = self.mixup_fn(hidden, train_y)
scores = self.model(hidden, pass_hidden=True)
scores = self.model((hidden, lengths, masks), pass_hidden=True)
else:
scores = self.model(train_x)

Expand Down
4 changes: 2 additions & 2 deletions deepxml/networks.py
Original file line number Diff line number Diff line change
Expand Up @@ -65,10 +65,10 @@ def forward(self, inputs, return_emb=False, pass_emb=False,
if not pass_hidden:
rnn_out = self.lstm(emb_out, lengths) # N, L, hidden_size * 2
else:
rnn_out = inputs
rnn_out, lengths, masks = inputs

if return_hidden:
return rnn_out
return rnn_out, lengths, masks

attn_out = self.attention(rnn_out, masks) # N, labels_num, hidden_size * 2

Expand Down

0 comments on commit 11c4d91

Please sign in to comment.