Skip to content

Commit

Permalink
Merge pull request #1238 from sknadig/attcov2
Browse files Browse the repository at this point in the history
fixed att_to_numpy() function for AttCov, AttCovLoc
  • Loading branch information
kan-bayashi authored Oct 4, 2019
2 parents be24490 + 9480ffe commit 2b14d3b
Showing 1 changed file with 1 addition and 1 deletion.
2 changes: 1 addition & 1 deletion espnet/nets/pytorch_backend/rnn/attentions.py
Original file line number Diff line number Diff line change
@@ -1437,7 +1437,7 @@ def att_to_numpy(att_ws, att):
att_ws = torch.stack([aw[:, -1] for aw in att_ws], dim=1).cpu().numpy()
elif isinstance(att, (AttCov, AttCovLoc)):
# att_ws => list of list of previous attentions
att_ws = torch.stack([aw[-1] for aw in att_ws], dim=1).cpu().numpy()
att_ws = torch.stack([aw[idx] for idx, aw in enumerate(att_ws)], dim=1).cpu().numpy()
elif isinstance(att, AttLocRec):
# att_ws => list of tuple of attention and hidden states
att_ws = torch.stack([aw[0] for aw in att_ws], dim=1).cpu().numpy()

0 comments on commit 2b14d3b

Please sign in to comment.