Skip to content

Commit

Permalink
Fix some bugs and modify default settings.
Browse files Browse the repository at this point in the history
  • Loading branch information
RowitZou committed Jun 21, 2019
1 parent 3a620c7 commit 3cf2f2d
Show file tree
Hide file tree
Showing 2 changed files with 8 additions and 8 deletions.
10 changes: 5 additions & 5 deletions main.py
Original file line number Diff line number Diff line change
Expand Up @@ -324,11 +324,11 @@ def load_model_decode(model_dir, data, args, name):

parser.add_argument('--seed', help='Random seed', default=47, type=int)
parser.add_argument('--batch_size', help='Batch size. For now it only works when batch size is 1.', default=1, type=int)
parser.add_argument('--num_epoch',default=2, type=int, help="Epoch number.")
parser.add_argument('--iters', default=3, type=int, help='The number of Graph iterations.')
parser.add_argument('--num_epoch',default=50, type=int, help="Epoch number.")
parser.add_argument('--iters', default=4, type=int, help='The number of Graph iterations.')
parser.add_argument('--hidden_dim', default=50, type=int, help='Hidden state size.')
parser.add_argument('--num_head', default=10, type=int, help='Number of transformer head.')
parser.add_argument('--head_dim', default=10, type=int, help='Head dimension of transformer.')
parser.add_argument('--head_dim', default=20, type=int, help='Head dimension of transformer.')
parser.add_argument('--tf_drop_rate', default=0.1, type=float, help='Transformer dropout rate.')
parser.add_argument('--emb_drop_rate', default=0.5, type=float, help='Embedding dropout rate.')
parser.add_argument('--cell_drop_rate', default=0.2, type=float, help='Aggregation module dropout rate.')
Expand All @@ -337,8 +337,8 @@ def load_model_decode(model_dir, data, args, name):
parser.add_argument('--label_alphabet_size', type=int, help='Label alphabet size.')
parser.add_argument('--char_dim', type=int, help='Char embedding size.')
parser.add_argument('--word_dim', type=int, help='Word embedding size.')
parser.add_argument('--lr', type=float, default=5e-04)
parser.add_argument('--weight_decay', type=float, default=1e-08)
parser.add_argument('--lr', type=float, default=5e-05)
parser.add_argument('--weight_decay', type=float, default=0)

args = parser.parse_args()

Expand Down
6 changes: 3 additions & 3 deletions model/LGN.py
Original file line number Diff line number Diff line change
Expand Up @@ -222,8 +222,6 @@ def construct_graph(self, batch_size, seq_len, word_list):
words_mask_f = words_mask_f.cuda()
words_mask_b = words_mask_b.cuda()

sen_nodes_mask = torch.cat([sen_nodes_mask, nodes_mask], 0)

words_mask_f[0, w + word_len - 1] = 0
sen_words_mask_f = torch.cat([sen_words_mask_f, words_mask_f], 0)

Expand Down Expand Up @@ -251,17 +249,19 @@ def construct_graph(self, batch_size, seq_len, word_list):
bmes_embed[0, w + index, :] = bmes_emb_m

sen_bmes_embed = torch.cat([sen_bmes_embed, bmes_embed], 0)
sen_nodes_mask = torch.cat([sen_nodes_mask, nodes_mask], 0)

batch_nodes_mask = sen_nodes_mask.unsqueeze(0)
batch_words_mask_f = sen_words_mask_f.unsqueeze(0)
batch_words_mask_b = sen_words_mask_b.unsqueeze(0)
batch_words_length = sen_words_length.unsqueeze(0)
if self.use_edge:
batch_nodes_mask = sen_nodes_mask.unsqueeze(0)
batch_word_embed = sen_word_embed.unsqueeze(0) # Only works when batch size is 1
batch_bmes_embed = sen_bmes_embed.unsqueeze(0)
else:
batch_word_embed = None
batch_bmes_embed = None
batch_nodes_mask = None

return batch_word_embed, batch_bmes_embed, batch_nodes_mask, batch_words_mask_f, batch_words_mask_b, batch_words_length

Expand Down

0 comments on commit 3cf2f2d

Please sign in to comment.