Skip to content

Commit

Permalink
Fix some bugs.
Browse files Browse the repository at this point in the history
  • Loading branch information
RowitZou committed Jun 16, 2019
1 parent 4663517 commit 4590b98
Show file tree
Hide file tree
Showing 9 changed files with 401 additions and 1,198 deletions.
559 changes: 181 additions & 378 deletions main.py

Large diffs are not rendered by default.

75 changes: 37 additions & 38 deletions model/LGN.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,7 @@
# -*- coding: utf-8 -*-
# @Author: Yicheng Zou
# @Last Modified by: Yicheng Zou, Contact: yczou18@fudan.edu.cn

import torch
import torch.nn as nn
import numpy as np
Expand Down Expand Up @@ -189,43 +192,43 @@ def forward(self, h, x):


class Graph(nn.Module):
def __init__(self, data):
def __init__(self, data, args):
super(Graph, self).__init__()

self.gpu = data.HP_gpu
self.gpu = args.use_gpu
self.word_alphabet = data.word_alphabet
self.char_emb_dim = data.char_emb_dim
self.word_emb_dim = data.word_emb_dim
self.gaz_emb_dim = data.gaz_emb_dim
self.hidden_dim = 200
self.num_head = 10 # 5 10 20
self.head_dim = 20 # 10 20
self.tf_dropout_rate = 0.1
self.iters = 4
self.hidden_dim = args.hidden_dim
self.num_head = args.num_head # 5 10 20
self.head_dim = args.head_dim # 10 20
self.tf_dropout_rate = args.tf_drop_rate
self.iters = args.iters
self.bmes_dim = 10
self.length_dim = 10
self.max_gaz_length = 5
self.emb_dropout_rate = 0.5
self.cell_dropout_rate = 0.2
self.emb_dropout_rate = args.emb_drop_rate
self.cell_dropout_rate = args.cell_drop_rate

# char embedding
self.char_embedding = nn.Embedding(data.char_alphabet.size(), self.char_emb_dim)
if data.pretrain_char_embedding is not None:
self.char_embedding.weight.data.copy_(torch.from_numpy(data.pretrain_char_embedding))

# word embedding
self.word_embedding = nn.Embedding(data.word_alphabet.size(), self.word_emb_dim)
assert data.pretrain_word_embedding is not None
self.word_embedding.weight.data.copy_(torch.from_numpy(data.pretrain_word_embedding))

# gaz embedding
self.gaz_embedding = nn.Embedding(data.gaz_alphabet.size(), self.gaz_emb_dim)
assert data.pretrain_gaz_embedding is not None
scale = np.sqrt(3.0 / self.gaz_emb_dim)
data.pretrain_gaz_embedding[0, :] = np.random.uniform(-scale, scale, [1, self.gaz_emb_dim])
self.gaz_embedding.weight.data.copy_(torch.from_numpy(data.pretrain_gaz_embedding))
if data.pretrain_word_embedding is not None:
scale = np.sqrt(3.0 / self.word_emb_dim)
data.pretrain_word_embedding[0, :] = np.random.uniform(-scale, scale, [1, self.word_emb_dim])
self.word_embedding.weight.data.copy_(torch.from_numpy(data.pretrain_word_embedding))

# position embedding
#self.pos_embedding = nn.Embedding(data.posi_alphabet_size, self.hidden_dim)
# lstm
self.emb_rnn_f = nn.LSTM(self.word_emb_dim, self.hidden_dim, batch_first=True)
self.emb_rnn_b = nn.LSTM(self.word_emb_dim, self.hidden_dim, batch_first=True)
self.emb_rnn_f = nn.LSTM(self.char_emb_dim, self.hidden_dim, batch_first=True)
self.emb_rnn_b = nn.LSTM(self.char_emb_dim, self.hidden_dim, batch_first=True)
self.edge_emb_linear = nn.Sequential(
nn.Linear(self.gaz_emb_dim, self.hidden_dim),
nn.Linear(self.word_emb_dim, self.hidden_dim),
nn.ELU()
)

Expand Down Expand Up @@ -279,12 +282,12 @@ def __init__(self, data):
self.glo_rnn_b = GLobal_Cell(self.hidden_dim, self.cell_dropout_rate)

self.layer_att_W = nn.Linear(self.hidden_dim * 2, 1)
self.hidden2tag = nn.Linear(self.hidden_dim * 2, data.label_alphabet_size + 2)
self.crf = CRF(data.label_alphabet_size, self.gpu)
self.hidden2tag = nn.Linear(self.hidden_dim * 2, data.label_alphabet.size() + 2)
self.crf = CRF(data.label_alphabet.size(), self.gpu)

if self.gpu:
self.char_embedding = self.char_embedding.cuda()
self.word_embedding = self.word_embedding.cuda()
self.gaz_embedding = self.gaz_embedding.cuda()
self.bmes_embedding = self.bmes_embedding.cuda()
self.length_embedding = self.length_embedding.cuda()
self.norm = self.norm.cuda()
Expand Down Expand Up @@ -315,7 +318,7 @@ def obtain_gaz_relation(self, batch_size, seq_len, gaz_list):
assert batch_size == 1

unk_index = torch.tensor(0).cuda() if self.cuda else torch.tensor(0)
unk_emb = self.gaz_embedding(unk_index)
unk_emb = self.word_embedding(unk_index)

bmes_index_b = torch.tensor(0).cuda() if self.cuda else torch.tensor(0)
bmes_index_m = torch.tensor(1).cuda() if self.cuda else torch.tensor(1)
Expand Down Expand Up @@ -347,7 +350,7 @@ def obtain_gaz_relation(self, batch_size, seq_len, gaz_list):
for gaz, gaz_len in zip(gaz_list[sen][w][0], gaz_list[sen][w][1]):

gaz_index = torch.tensor(gaz, device=sen_gaz_embed.device)
gaz_embedding = self.gaz_embedding(gaz_index)
gaz_embedding = self.word_embedding(gaz_index)
sen_gaz_embed = torch.cat([sen_gaz_embed, gaz_embedding[None, :]], 0)

if gaz_len <= self.max_gaz_length:
Expand Down Expand Up @@ -402,7 +405,7 @@ def obtain_gaz_relation(self, batch_size, seq_len, gaz_list):
def get_tags(self, gaz_list, word_inputs, mask):

#mask = 1 - mask
node_embeds = self.word_embedding(word_inputs) # batch_size, max_seq_len, embedding
node_embeds = self.char_embedding(word_inputs) # batch_size, max_seq_len, embedding
B, L, _ = node_embeds.size()
gaz_match = []

Expand Down Expand Up @@ -502,15 +505,11 @@ def get_tags(self, gaz_list, word_inputs, mask):

return tags, gaz_match

def neg_log_likelihood_loss(self, gaz_list, word_inputs, word_seq_lengths, mask, batch_label):

def forward(self, gaz_list, word_inputs, mask, batch_label=None):
tags, _ = self.get_tags(gaz_list, word_inputs, mask)
total_loss = self.crf.neg_log_likelihood_loss(tags, mask, batch_label)
scores, tag_seq = self.crf._viterbi_decode(tags, mask)

return total_loss, tag_seq # (batch_size,) ,(b,seqlen?)

def forward(self, gaz_list, word_inputs, word_seq_lengths, mask):
tags, gaz_match = self.get_tags(gaz_list, word_inputs, mask)
if not batch_label is None:
total_loss = self.crf.neg_log_likelihood_loss(tags, mask, batch_label)
else:
total_loss = None
scores, tag_seq = self.crf._viterbi_decode(tags, mask)
return tag_seq, gaz_match
return total_loss, tag_seq
15 changes: 6 additions & 9 deletions utils/alphabet.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,6 @@
# -*- coding: utf-8 -*-
# @Author: Max
# @Date: 2018-01-19 11:33:37
# @Last Modified by: Jie Yang, Contact: jieynlp@gmail.com
# @Last Modified time: 2018-01-19 11:33:56

# @Last Modified by: Yicheng Zou, Contact: yczou18@fudan.edu.cn

"""
Alphabet maps objects to integer ids. It provides two way mapping from the index to the objects.
Expand Down Expand Up @@ -64,10 +61,10 @@ def get_instance(self, index):
return self.instances[0]

def size(self):
# if self.label:
# return len(self.instances)
# else:
return len(self.instances) + 1
if self.label:
return len(self.instances)
else:
return len(self.instances) + 1

def iteritems(self):
return self.instance2index.items()
Expand Down Expand Up @@ -101,7 +98,7 @@ def save(self, output_directory, name=None):
try:
json.dump(self.get_content(), open(os.path.join(output_directory, saving_name + ".json"), 'w'))
except Exception as e:
print("Exception: Alphabet is not saved: " % repr(e))
print("Exception: Alphabet is not saved: " + repr(e))

def load(self, input_directory, name=None):
"""
Expand Down
Loading

0 comments on commit 4590b98

Please sign in to comment.