Skip to content

Commit

Permalink
Change the word normalized implementation.
Browse files Browse the repository at this point in the history
  • Loading branch information
RowitZou committed Jun 24, 2019
1 parent 3cf2f2d commit 33aa30b
Show file tree
Hide file tree
Showing 4 changed files with 21 additions and 11 deletions.
8 changes: 4 additions & 4 deletions main.py
Original file line number Diff line number Diff line change
Expand Up @@ -307,12 +307,12 @@ def load_model_decode(model_dir, data, args, name):
parser = argparse.ArgumentParser()
parser.add_argument('--status', choices=['train', 'test', 'decode'], help='Function status.', default='train')
parser.add_argument('--use_gpu', type=str2bool, default=True)
parser.add_argument('--train', help='Training set.')
parser.add_argument('--dev', help='Developing set.')
parser.add_argument('--test', help='Testing set.')
parser.add_argument('--train', help='Training set.', default='data/onto4ner.cn/train.char.bmes')
parser.add_argument('--dev', help='Developing set.', default='data/onto4ner.cn/dev.char.bmes')
parser.add_argument('--test', help='Testing set.', default='data/onto4ner.cn/test.char.bmes')
parser.add_argument('--raw', help='Raw file for decoding.')
parser.add_argument('--output', help='Output results for decoding.')
parser.add_argument('--saved_set', help='Path of saved data set.')
parser.add_argument('--saved_set', help='Path of saved data set.', default='data/onto4ner.cn/saved.dset')
parser.add_argument('--saved_model', help='Path of saved model.', default="saved_model/model")
parser.add_argument('--char_emb', help='Path of character embedding file.', default="data/gigaword_chn.all.a2b.uni.ite50.vec")
parser.add_argument('--word_emb', help='Path of word embedding file.', default="data/ctb.50d.vec")
Expand Down
8 changes: 4 additions & 4 deletions model/LGN.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,12 +45,12 @@ def __init__(self, data, args):

# bmes embedding
self.bmes_embedding = nn.Embedding(4, self.bmes_dim)

"""
self.edge_emb_linear = nn.Sequential(
nn.Linear(self.word_emb_dim, self.hidden_dim),
nn.ELU()
)

"""
# lstm
self.emb_rnn_f = nn.LSTM(self.char_emb_dim, self.hidden_dim, batch_first=True)
self.emb_rnn_b = nn.LSTM(self.char_emb_dim, self.hidden_dim, batch_first=True)
Expand Down Expand Up @@ -285,7 +285,7 @@ def update_graph(self, word_list, word_inputs):
_, _, H = nodes_f.size()

if self.use_edge:
edges_f = self.edge_emb_linear(edge_embs)
edges_f = edge_embs
edges_f_cat = edges_f[:, None, :, :]

if self.use_global:
Expand Down Expand Up @@ -351,7 +351,7 @@ def update_graph(self, word_list, word_inputs):
nodes_b_cat = nodes_b[:, None, :, :]

if self.use_edge:
edges_b = self.edge_emb_linear(edge_embs)
edges_b = edge_embs
edges_b_cat = edges_b[:, None, :, :]
if self.use_global:
glo_b = nodes_b.mean(1, keepdim=True) + edges_b.mean(1, keepdim=True)
Expand Down
4 changes: 2 additions & 2 deletions utils/data.py
Original file line number Diff line number Diff line change
Expand Up @@ -67,7 +67,7 @@ def build_alphabet(self, input_file):
char = pair[0]
if self.number_normalized:
# Mapping numbers to 0
char = re.sub('[0-9]', '0', char)
char = normalize_word(char)
label = pair[-1]
self.label_alphabet.add(label)
self.char_alphabet.add(char)
Expand All @@ -94,7 +94,7 @@ def build_word_alphabet(self, input_file):
if len(line) > 0:
word = line.split()[0]
if self.number_normalized:
word = re.sub('[0-9]', '0', word)
word = normalize_word(word)
word_list.append(word)
else:
for idx in range(len(word_list)):
Expand Down
12 changes: 11 additions & 1 deletion utils/functions.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,16 @@
import re


def normalize_word(word):
new_word = ""
for char in word:
if char.isdigit():
new_word += '0'
else:
new_word += char
return new_word


def read_instance_with_gaz(input_file, word_dict, char_alphabet, word_alphabet, label_alphabet, number_normalized, max_sent_length):
instence_texts = []
instence_Ids = []
Expand All @@ -22,7 +32,7 @@ def read_instance_with_gaz(input_file, word_dict, char_alphabet, word_alphabet,
pairs = line.strip().split()
char = pairs[0]
if number_normalized:
char = re.sub('[0-9]', '0', char)
char = normalize_word(char)
label = pairs[-1]
chars.append(char)
labels.append(label)
Expand Down

0 comments on commit 33aa30b

Please sign in to comment.