Skip to content

Commit

Permalink
Performance optimization.
Browse files Browse the repository at this point in the history
  • Loading branch information
RowitZou committed Aug 14, 2019
1 parent eda450f commit 1d79d05
Show file tree
Hide file tree
Showing 2 changed files with 87 additions and 67 deletions.
19 changes: 9 additions & 10 deletions main.py
Original file line number Diff line number Diff line change
Expand Up @@ -315,12 +315,11 @@ def load_model_decode(model_dir, data, args, name):
model_dir = model_dir + "_best"
print("Load Model from file: ", model_dir)
model = Graph(data, args)
model.load_state_dict(torch.load(model_dir))

# load model need consider if the model trained in GPU and load in CPU, or vice versa
if not args.use_gpu:
model.load_state_dict(torch.load(model_dir, map_location=lambda storage, loc: storage))
else:
model.load_state_dict(torch.load(model_dir))
if args.use_gpu:
model = model.cuda()

print(("Decode %s data ..." % name))
start_time = time.time()
Expand All @@ -337,12 +336,12 @@ def load_model_decode(model_dir, data, args, name):
parser = argparse.ArgumentParser()
parser.add_argument('--status', choices=['train', 'test', 'decode'], help='Function status.', default='train')
parser.add_argument('--use_gpu', type=str2bool, default=True)
parser.add_argument('--train', help='Training set.', default='data/ontonote.cn/train.char.bmes')
parser.add_argument('--dev', help='Developing set.', default='data/ontonote.cn/dev.char.bmes')
parser.add_argument('--test', help='Testing set.', default='data/ontonote.cn/test.char.bmes')
parser.add_argument('--train', help='Training set.', default='data/onto4ner.cn/train.char.bmes')
parser.add_argument('--dev', help='Developing set.', default='data/onto4ner.cn/dev.char.bmes')
parser.add_argument('--test', help='Testing set.', default='data/onto4ner.cn/test.char.bmes')
parser.add_argument('--raw', help='Raw file for decoding.')
parser.add_argument('--output', help='Output results for decoding.')
parser.add_argument('--saved_set', help='Path of saved data set.', default='data/ontonote.cn/saved.dset')
parser.add_argument('--saved_set', help='Path of saved data set.', default='data/onto4ner.cn/saved.dset')
parser.add_argument('--saved_model', help='Path of saved model.', default="saved_model/model_ontonote")
parser.add_argument('--char_emb', help='Path of character embedding file.', default="data/gigaword_chn.all.a2b.uni.ite50.vec")
parser.add_argument('--word_emb', help='Path of word embedding file.', default="data/ctb.50d.vec")
Expand All @@ -353,7 +352,7 @@ def load_model_decode(model_dir, data, args, name):
parser.add_argument('--bidirectional', type=str2bool, default=True, help='If use bidirectional digraph.')

parser.add_argument('--seed', help='Random seed', default=1023, type=int)
parser.add_argument('--batch_size', help='Batch size. For now it only works when batch size is 1.', default=1, type=int)
parser.add_argument('--batch_size', help='Batch size. ', default=1, type=int)
parser.add_argument('--num_epoch',default=100, type=int, help="Epoch number.")
parser.add_argument('--iters', default=4, type=int, help='The number of Graph iterations.')
parser.add_argument('--hidden_dim', default=50, type=int, help='Hidden state size.')
Expand All @@ -379,6 +378,7 @@ def load_model_decode(model_dir, data, args, name):
torch.manual_seed(seed_num)
np.random.seed(seed_num)


train_file = args.train
dev_file = args.dev
test_file = args.test
Expand All @@ -390,7 +390,6 @@ def load_model_decode(model_dir, data, args, name):
word_file = args.word_emb

if status == 'train':
assert not (train_file is None or dev_file is None or test_file is None)
if os.path.exists(saved_set_path):
print('Loading saved data set...')
with open(saved_set_path, 'rb') as f:
Expand Down
135 changes: 78 additions & 57 deletions model/LGN.py
Original file line number Diff line number Diff line change
Expand Up @@ -168,59 +168,52 @@ def __init__(self, data, args):

def construct_graph(self, batch_size, seq_len, word_list):

assert batch_size == 1

if self.cuda:
device = 'cuda'
else:
device = 'cpu'
if self.use_edge:
unk_index = torch.tensor(0).cuda() if self.cuda else torch.tensor(0)
unk_index = torch.tensor(0, device=device)
unk_emb = self.word_embedding(unk_index)

bmes_index_b = torch.tensor(0).cuda() if self.cuda else torch.tensor(0)
bmes_index_m = torch.tensor(1).cuda() if self.cuda else torch.tensor(1)
bmes_index_e = torch.tensor(2).cuda() if self.cuda else torch.tensor(2)
bmes_index_s = torch.tensor(3).cuda() if self.cuda else torch.tensor(3)
bmes_emb_b = self.bmes_embedding(torch.tensor(0, device=device))
bmes_emb_m = self.bmes_embedding(torch.tensor(1, device=device))
bmes_emb_e = self.bmes_embedding(torch.tensor(2, device=device))
bmes_emb_s = self.bmes_embedding(torch.tensor(3, device=device))

bmes_emb_b = self.bmes_embedding(bmes_index_b)
bmes_emb_m = self.bmes_embedding(bmes_index_m)
bmes_emb_e = self.bmes_embedding(bmes_index_e)
bmes_emb_s = self.bmes_embedding(bmes_index_s)
sen_nodes_mask_list = []
sen_words_length_list =[]
sen_words_mask_f_list = []
sen_words_mask_b_list = []
sen_word_embed_list = []
sen_bmes_embed_list = []
max_edge_num = -1

for sen in range(batch_size):
sen_nodes_mask = torch.zeros([1, seq_len]).byte()
sen_words_length = torch.zeros([1, self.length_dim])
sen_words_mask_f = torch.zeros([1, seq_len]).byte()
sen_words_mask_b = torch.zeros([1, seq_len]).byte()
if self.gpu:
sen_nodes_mask = sen_nodes_mask.cuda()
sen_words_length = sen_words_length.cuda()
sen_words_mask_f = sen_words_mask_f.cuda()
sen_words_mask_b = sen_words_mask_b.cuda()
sen_nodes_mask = torch.zeros([1, seq_len], device=device).byte()
sen_words_length = torch.zeros([1, self.length_dim], device=device)
sen_words_mask_f = torch.zeros([1, seq_len], device=device).byte()
sen_words_mask_b = torch.zeros([1, seq_len], device=device).byte()

if self.use_edge:
sen_word_embed = unk_emb[None, :]
sen_bmes_embed = torch.zeros([1, seq_len, self.bmes_dim])
if self.gpu:
sen_word_embed = sen_word_embed.cuda()
sen_bmes_embed = sen_bmes_embed.cuda()
sen_bmes_embed = torch.zeros([1, seq_len, self.bmes_dim], device=device)

for w in range(seq_len):
if w < len(word_list[sen]) and word_list[sen][w]:
for word, word_len in zip(word_list[sen][w][0], word_list[sen][w][1]):

if word_len <= self.max_word_length:
word_length_index = torch.tensor(word_len-1, device=sen_words_length.device)
word_length_index = torch.tensor(word_len-1, device=device)
else:
word_length_index = torch.tensor(self.max_word_length - 1, device=sen_words_length.device)
word_length_index = torch.tensor(self.max_word_length - 1, device=device)
word_length = self.length_embedding(word_length_index)
sen_words_length = torch.cat([sen_words_length, word_length[None, :]], 0)

# mask: Masked elements are marked by 1, batch_size * word_num * seq_len
nodes_mask = torch.ones([1, seq_len]).byte()
words_mask_f = torch.ones([1, seq_len]).byte()
words_mask_b = torch.ones([1, seq_len]).byte()
if self.gpu:
nodes_mask = nodes_mask.cuda()
words_mask_f = words_mask_f.cuda()
words_mask_b = words_mask_b.cuda()
nodes_mask = torch.ones([1, seq_len], device=device).byte()
words_mask_f = torch.ones([1, seq_len], device=device).byte()
words_mask_b = torch.ones([1, seq_len], device=device).byte()

words_mask_f[0, w + word_len - 1] = 0
sen_words_mask_f = torch.cat([sen_words_mask_f, words_mask_f], 0)
Expand All @@ -229,13 +222,11 @@ def construct_graph(self, batch_size, seq_len, word_list):
sen_words_mask_b = torch.cat([sen_words_mask_b, words_mask_b], 0)

if self.use_edge:
word_index = torch.tensor(word, device=sen_word_embed.device)
word_index = torch.tensor(word, device=device)
word_embedding = self.word_embedding(word_index)
sen_word_embed = torch.cat([sen_word_embed, word_embedding[None, :]], 0)

bmes_embed = torch.zeros([1, seq_len, self.bmes_dim])
if self.gpu:
bmes_embed = bmes_embed.cuda()
bmes_embed = torch.zeros([1, seq_len, self.bmes_dim], device=device)

for index in range(word_len):
nodes_mask[0, w + index] = 0
Expand All @@ -251,26 +242,50 @@ def construct_graph(self, batch_size, seq_len, word_list):
sen_bmes_embed = torch.cat([sen_bmes_embed, bmes_embed], 0)
sen_nodes_mask = torch.cat([sen_nodes_mask, nodes_mask], 0)

batch_words_mask_f = sen_words_mask_f.unsqueeze(0)
batch_words_mask_b = sen_words_mask_b.unsqueeze(0)
batch_words_length = sen_words_length.unsqueeze(0)
if sen_words_mask_f.size(0) > max_edge_num:
max_edge_num = sen_words_mask_f.size(0)
sen_words_mask_f_list.append(sen_words_mask_f.unsqueeze_(0))
sen_words_mask_b_list.append(sen_words_mask_b.unsqueeze_(0))
sen_words_length_list.append(sen_words_length.unsqueeze_(0))
if self.use_edge:
sen_nodes_mask_list.append(sen_nodes_mask.unsqueeze_(0))
sen_word_embed_list.append(sen_word_embed.unsqueeze_(0))
sen_bmes_embed_list.append(sen_bmes_embed.unsqueeze_(0))

edges_mask = torch.zeros([batch_size, max_edge_num], device=device)
batch_words_mask_f = torch.ones([batch_size, max_edge_num, seq_len], device=device).byte()
batch_words_mask_b = torch.ones([batch_size, max_edge_num, seq_len], device=device).byte()
batch_words_length = torch.zeros([batch_size, max_edge_num, self.length_dim], device=device)
if self.use_edge:
batch_nodes_mask = sen_nodes_mask.unsqueeze(0)
batch_word_embed = sen_word_embed.unsqueeze(0) # Only works when batch size is 1
batch_bmes_embed = sen_bmes_embed.unsqueeze(0)
batch_nodes_mask = torch.zeros([batch_size, max_edge_num, seq_len], device=device).byte()
batch_word_embed = torch.zeros([batch_size, max_edge_num, self.word_emb_dim], device=device)
batch_bmes_embed = torch.zeros([batch_size, max_edge_num, seq_len, self.bmes_dim], device=device)
else:
batch_word_embed = None
batch_bmes_embed = None
batch_nodes_mask = None

return batch_word_embed, batch_bmes_embed, batch_nodes_mask, batch_words_mask_f, batch_words_mask_b, batch_words_length
for index in range(batch_size):
curr_edge_num = sen_words_mask_f_list[index].size(1)
edges_mask[index, 0:curr_edge_num] = 1.
batch_words_mask_f[index, 0:curr_edge_num, :] = sen_words_mask_f_list[index]
batch_words_mask_b[index, 0:curr_edge_num, :] = sen_words_mask_b_list[index]
batch_words_length[index, 0:curr_edge_num, :] = sen_words_length_list[index]
if self.use_edge:
batch_nodes_mask[index, 0:curr_edge_num, :] = sen_nodes_mask_list[index]
batch_word_embed[index, 0:curr_edge_num, :] = sen_word_embed_list[index]
batch_bmes_embed[index, 0:curr_edge_num, :, :] = sen_bmes_embed_list[index]

def update_graph(self, word_list, word_inputs):
return batch_word_embed, batch_bmes_embed, batch_nodes_mask, batch_words_mask_f, \
batch_words_mask_b, batch_words_length, edges_mask

def update_graph(self, word_list, word_inputs, mask):
mask = mask.float()
node_embeds = self.char_embedding(word_inputs) # batch_size, max_seq_len, embedding
B, L, _ = node_embeds.size()

edge_embs, bmes_embs, nodes_mask, words_mask_f, words_mask_b, words_length = self.construct_graph(B, L, word_list)
edge_embs, bmes_embs, nodes_mask, words_mask_f, words_mask_b, words_length, edges_mask = \
self.construct_graph(B, L, word_list)

node_embeds = self.dropout(node_embeds)

Expand All @@ -281,20 +296,22 @@ def update_graph(self, word_list, word_inputs):

# forward direction digraph
nodes_f, _ = self.emb_rnn_f(node_embeds)
nodes_f = nodes_f * mask.unsqueeze(2)
nodes_f_cat = nodes_f[:, None, :, :]
_, _, H = nodes_f.size()

if self.use_edge:
edges_f = edge_embs
edges_f = edge_embs * edges_mask.unsqueeze(2)
edges_f_cat = edges_f[:, None, :, :]

if self.use_global:
glo_f = edges_f.mean(1, keepdim=True) + nodes_f.mean(1, keepdim=True)
glo_f = edges_f.sum(1, keepdim=True) / edges_mask.sum(1, keepdim=True).unsqueeze_(2) + \
nodes_f.sum(1, keepdim=True) / mask.sum(1, keepdim=True).unsqueeze_(2)
glo_f_cat = glo_f[:, None, :, :]

else:
if self.use_global:
glo_f = nodes_f.mean(1, keepdim=True)
glo_f = (nodes_f * mask.unsqueeze(2)).sum(1, keepdim=True) / mask.sum(1, keepdim=True).unsqueeze_(2)
glo_f_cat = glo_f[:, None, :, :]

for i in range(self.iters):
Expand All @@ -310,11 +327,12 @@ def update_graph(self, word_list, word_inputs):
if self.use_edge:
nodes_att_f = self.edge2node_f[i](nodes_f, torch.cat([edges_f, nodes_begin_f, words_length], -1).unsqueeze(2), words_mask_f)
if self.use_global:
glo_att_f = torch.cat([self.glo_att_f_node[i](glo_f, nodes_f), self.glo_att_f_edge[i](glo_f, edges_f)], -1)
glo_att_f = torch.cat([self.glo_att_f_node[i](glo_f, nodes_f, (1 - mask).byte()),
self.glo_att_f_edge[i](glo_f, edges_f, (1 - edges_mask).byte())], -1)
else:
nodes_att_f = self.edge2node_f[i](nodes_f, torch.cat([nodes_begin_f, words_length], -1).unsqueeze(2), words_mask_f)
if self.use_global:
glo_att_f = self.glo_att_f_node[i](glo_f, nodes_f)
glo_att_f = self.glo_att_f_node[i](glo_f, nodes_f, (1 - mask).byte())

# RNN-based update
if self.use_edge and N > 1:
Expand Down Expand Up @@ -348,18 +366,20 @@ def update_graph(self, word_list, word_inputs):
if self.bidirectional:
nodes_b, _ = self.emb_rnn_b(torch.flip(node_embeds, [1]))
nodes_b = torch.flip(nodes_b, [1])
nodes_b = nodes_b * mask.unsqueeze(2)
nodes_b_cat = nodes_b[:, None, :, :]

if self.use_edge:
edges_b = edge_embs
edges_b = edge_embs * edges_mask.unsqueeze(2)
edges_b_cat = edges_b[:, None, :, :]
if self.use_global:
glo_b = nodes_b.mean(1, keepdim=True) + edges_b.mean(1, keepdim=True)
glo_b = edges_b.sum(1, keepdim=True) / edges_mask.sum(1, keepdim=True).unsqueeze_(2) + \
nodes_b.sum(1, keepdim=True) / mask.sum(1, keepdim=True).unsqueeze_(2)
glo_b_cat = glo_b[:, None, :, :]

else:
if self.use_global:
glo_b = nodes_b.mean(1, keepdim=True)
glo_b = nodes_b.sum(1, keepdim=True) / mask.sum(1, keepdim=True).unsqueeze_(2)
glo_b_cat = glo_b[:, None, :, :]

for i in range(self.iters):
Expand All @@ -375,11 +395,12 @@ def update_graph(self, word_list, word_inputs):
if self.use_edge:
nodes_att_b = self.edge2node_b[i](nodes_b, torch.cat([edges_b, nodes_begin_b, words_length], -1).unsqueeze(2), words_mask_b)
if self.use_global:
glo_att_b = torch.cat([self.glo_att_b_node[i](glo_b, nodes_b), self.glo_att_b_edge[i](glo_b, edges_b)], -1)
glo_att_b = torch.cat([self.glo_att_b_node[i](glo_b, nodes_b, (1-mask).byte()),
self.glo_att_b_edge[i](glo_b, edges_b, (1-edges_mask).byte())], -1)
else:
nodes_att_b = self.edge2node_b[i](nodes_b, torch.cat([nodes_begin_b, words_length], -1).unsqueeze(2), words_mask_b)
if self.use_global:
glo_att_b = self.glo_att_b_node[i](glo_b, nodes_b)
glo_att_b = self.glo_att_b_node[i](glo_b, nodes_b, (1-mask).byte())

# RNN-based update
if self.use_edge and N > 1:
Expand Down Expand Up @@ -419,7 +440,7 @@ def update_graph(self, word_list, word_inputs):

def forward(self, word_list, batch_inputs, mask, batch_label=None):

tags = self.update_graph(word_list, batch_inputs)
tags = self.update_graph(word_list, batch_inputs, mask)

if batch_label is not None:
if self.use_crf:
Expand Down

0 comments on commit 1d79d05

Please sign in to comment.