Skip to content

Commit

Permalink
fix input shape for new Embedding (PaddlePaddle#4037)
Browse files Browse the repository at this point in the history
test=develop
  • Loading branch information
songyouwei authored and phlrain committed Dec 6, 2019
1 parent 182d509 commit 7a50b68
Show file tree
Hide file tree
Showing 2 changed files with 5 additions and 4 deletions.
1 change: 1 addition & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -4,3 +4,4 @@
*.pyc
*~
*.vscode
*.idea
8 changes: 4 additions & 4 deletions dygraph/transformer/train.py
Original file line number Diff line number Diff line change
Expand Up @@ -54,12 +54,12 @@ def prepare_train_input(insts, src_pad_idx, trg_pad_idx, n_head):
"""
src_word, src_pos, src_slf_attn_bias, src_max_len = pad_batch_data(
[inst[0] for inst in insts], src_pad_idx, n_head, is_target=False)
src_word = src_word.reshape(-1, src_max_len, 1)
src_pos = src_pos.reshape(-1, src_max_len, 1)
src_word = src_word.reshape(-1, src_max_len)
src_pos = src_pos.reshape(-1, src_max_len)
trg_word, trg_pos, trg_slf_attn_bias, trg_max_len = pad_batch_data(
[inst[1] for inst in insts], trg_pad_idx, n_head, is_target=True)
trg_word = trg_word.reshape(-1, trg_max_len, 1)
trg_pos = trg_pos.reshape(-1, trg_max_len, 1)
trg_word = trg_word.reshape(-1, trg_max_len)
trg_pos = trg_pos.reshape(-1, trg_max_len)

trg_src_attn_bias = np.tile(src_slf_attn_bias[:, :, ::src_max_len, :],
[1, 1, trg_max_len, 1]).astype("float32")
Expand Down

0 comments on commit 7a50b68

Please sign in to comment.