From 651d3d2ab2718bb996ac70bfffc775680306eec5 Mon Sep 17 00:00:00 2001 From: RowitZou <1094074685@qq.com> Date: Mon, 17 Jun 2019 17:52:48 +0800 Subject: [PATCH] Add ablation experiment settings. --- main.py | 20 +- model/LGN.py | 530 ++++++------------ ...{transformer_no_edge.py => LGN_no_edge.py} | 0 ...o_glo_no_edge.py => LGN_no_glo_no_edge.py} | 0 model/module.py | 204 +++++++ model/transformer_light.py | 413 -------------- model/transformer_no_crf.py | 511 ----------------- model/transformer_no_glo.py | 398 ------------- model/transformer_no_lstm.py | 510 ----------------- model/transformer_only_sw.py | 512 ----------------- model/transformer_single.py | 443 --------------- utils/data.py | 1 - 12 files changed, 405 insertions(+), 3137 deletions(-) rename model/{transformer_no_edge.py => LGN_no_edge.py} (100%) rename model/{transformer_no_glo_no_edge.py => LGN_no_glo_no_edge.py} (100%) create mode 100644 model/module.py delete mode 100644 model/transformer_light.py delete mode 100644 model/transformer_no_crf.py delete mode 100644 model/transformer_no_glo.py delete mode 100644 model/transformer_no_lstm.py delete mode 100644 model/transformer_only_sw.py delete mode 100644 model/transformer_single.py diff --git a/main.py b/main.py index 44e62ee..5fef0dd 100644 --- a/main.py +++ b/main.py @@ -18,6 +18,17 @@ from utils.data import Data +def str2bool(v): + if isinstance(v, bool): + return v + if v.lower() in ('yes', 'true', 't', 'y', '1'): + return True + elif v.lower() in ('no', 'false', 'f', 'n', '0'): + return False + else: + raise argparse.ArgumentTypeError('Boolean value expected.') + + def data_initialization(data, word_file, train_file, dev_file, test_file): data.build_word_file(word_file) @@ -142,6 +153,8 @@ def train(data, args, saved_model_path): print( "Training model...") model = Graph(data, args) + if args.use_gpu: + model = model.cuda() print('# generated parameters:', sum(param.numel() for param in model.parameters())) print( "Finished built model.") @@ -293,7 +306,7 @@ def load_model_decode(model_dir, data, args, name): if __name__ == '__main__': parser = argparse.ArgumentParser() parser.add_argument('--status', choices=['train', 'test', 'decode'], help='Function status.', default='train') - parser.add_argument('--use_gpu', type=bool, default=True) + parser.add_argument('--use_gpu', type=str2bool, default=True) parser.add_argument('--train', help='Training set.') parser.add_argument('--dev', help='Developing set.') parser.add_argument('--test', help='Testing set.') @@ -304,6 +317,11 @@ def load_model_decode(model_dir, data, args, name): parser.add_argument('--char_emb', help='Path of character embedding file.', default="data/gigaword_chn.all.a2b.uni.ite50.vec") parser.add_argument('--word_emb', help='Path of word embedding file.', default="data/ctb.50d.vec") + parser.add_argument('--use_crf', type=str2bool, default=True) + parser.add_argument('--use_edge', type=str2bool, default=True, help='If use lexicon embeddings (edge embeddings).') + parser.add_argument('--use_global', type=str2bool, default=True, help='If use the global node.') + parser.add_argument('--bidirectional', type=str2bool, default=True, help='If use bidirectional digraph.') + parser.add_argument('--seed', help='Random seed', default=47, type=int) parser.add_argument('--batch_size', help='Batch size. For now it only works when batch size is 1.', default=1, type=int) parser.add_argument('--num_epoch',default=2, type=int, help="Epoch number.") diff --git a/model/LGN.py b/model/LGN.py index 19acbeb..d708155 100644 --- a/model/LGN.py +++ b/model/LGN.py @@ -2,194 +2,8 @@ # @Author: Yicheng Zou # @Last Modified by: Yicheng Zou, Contact: yczou18@fudan.edu.cn -import torch -import torch.nn as nn -import numpy as np -import torch.nn.functional as F from model.crf import CRF - - -class MultiHeadAtt(nn.Module): - def __init__(self, nhid, keyhid, nhead=10, head_dim=10, dropout=0.1, if_g=False): - super(MultiHeadAtt, self).__init__() - - if if_g: - self.WQ = nn.Conv2d(nhid * 3, nhead * head_dim, 1) - else: - self.WQ = nn.Conv2d(nhid, nhead * head_dim, 1) - self.WK = nn.Conv2d(keyhid, nhead * head_dim, 1) - self.WV = nn.Conv2d(keyhid, nhead * head_dim, 1) - self.WO = nn.Conv2d(nhead * head_dim, nhid, 1) - - self.drop = nn.Dropout(dropout) - - self.norm = nn.LayerNorm(nhid) - - self.nhid, self.nhead, self.head_dim = nhid, nhead, head_dim - - def forward(self, query_h, value, mask, query_g=None): - - if not (query_g is None): - query = torch.cat([query_h, query_g], -1) - else: - query = query_h - query = query.permute(0, 2, 1)[:, :, :, None] - value = value.permute(0, 3, 1, 2) - - residual = query_h - nhid, nhead, head_dim = self.nhid, self.nhead, self.head_dim - - B, QL, H = query_h.shape - - _, _, VL, VD = value.shape # VD = 1 or VD = QL - - assert VD == 1 or VD == QL - # q: (B, H, QL, 1) - # v: (B, H, VL, VD) - q, k, v = self.WQ(query), self.WK(value), self.WV(value) - - q = q.view(B, nhead, head_dim, 1, QL) - k = k.view(B, nhead, head_dim, VL, VD) - v = v.view(B, nhead, head_dim, VL, VD) - - alpha = (q * k).sum(2, keepdim=True) / np.sqrt(head_dim) - alpha = alpha.masked_fill(mask[:, None, None, :, :], -np.inf) - alpha = self.drop(F.softmax(alpha, 3)) - att = (alpha * v).sum(3).view(B, nhead * head_dim, QL, 1) - - output = F.leaky_relu(self.WO(att)).permute(0, 2, 3, 1).view(B, QL, H) - output = self.norm(output + residual) - - return output - - -class GloAtt(nn.Module): - def __init__(self, nhid, nhead=10, head_dim=10, dropout=0.1): - # Multi-head Self Attention Case 2, a broadcastable query for a sequence key and value - super(GloAtt, self).__init__() - self.WQ = nn.Conv2d(nhid, nhead * head_dim, 1) - self.WK = nn.Conv2d(nhid, nhead * head_dim, 1) - self.WV = nn.Conv2d(nhid, nhead * head_dim, 1) - self.WO = nn.Conv2d(nhead * head_dim, nhid, 1) - - self.drop = nn.Dropout(dropout) - - self.norm = nn.LayerNorm(nhid) - - # print('NUM_HEAD', nhead, 'DIM_HEAD', head_dim) - self.nhid, self.nhead, self.head_dim = nhid, nhead, head_dim - - def forward(self, x, y, mask=None): - # x: B, H, 1, 1, 1 y: B H L 1 - nhid, nhead, head_dim = self.nhid, self.nhead, self.head_dim - B, L, H = y.shape - - x = x.permute(0, 2, 1)[:, :, :, None] - y = y.permute(0, 2, 1)[:, :, :, None] - - residual = x - q, k, v = self.WQ(x), self.WK(y), self.WV(y) - - q = q.view(B, nhead, 1, head_dim) # B, H, 1, 1 -> B, N, 1, h - k = k.view(B, nhead, head_dim, L) # B, H, L, 1 -> B, N, h, L - v = v.view(B, nhead, head_dim, L).permute(0, 1, 3, 2) # B, H, L, 1 -> B, N, L, h - - pre_a = torch.matmul(q, k) / np.sqrt(head_dim) - if mask is not None: - pre_a = pre_a.masked_fill(mask[:, None, None, :], -float('inf')) - alphas = self.drop(F.softmax(pre_a, 3)) # B, N, 1, L - att = torch.matmul(alphas, v).view(B, -1, 1, 1) # B, N, 1, h -> B, N*h, 1, 1 - output = F.leaky_relu(self.WO(att)) + residual - output = self.norm(output.permute(0, 2, 3, 1)).view(B, 1, H) - - return output - - -class Nodes_Cell(nn.Module): - def __init__(self, hid_h, dropout=0.2): - super(Nodes_Cell, self).__init__() - - self.Wix = nn.Linear(hid_h*5, hid_h) - #self.Wig = nn.Linear(hid_h*4, hid_h) - self.Wi2 = nn.Linear(hid_h*5, hid_h) - self.Wf = nn.Linear(hid_h*5, hid_h) - self.Wcx = nn.Linear(hid_h*5, hid_h) - #self.Wcg = nn.Linear(hid_h, hid_h) - - self.drop = nn.Dropout(dropout) - - def forward(self, h, h2, x, glo): - - x = self.drop(x) - glo = self.drop(glo) - - cat_all = torch.cat([h, h2, x, glo], -1) - #cat_x = torch.cat([h, h2, x], -1) - #cat_g = torch.cat([glo], -1) - - ix = torch.sigmoid(self.Wix(cat_all)) - #ig = torch.sigmoid(self.Wig(cat_all)) - i2 = torch.sigmoid(self.Wi2(cat_all)) - f = torch.sigmoid(self.Wf(cat_all)) - cx = torch.tanh(self.Wcx(cat_all)) - #cg = torch.tanh(self.Wcg(cat_g)) - - alpha = F.softmax(torch.cat([ix.unsqueeze(1), i2.unsqueeze(1), f.unsqueeze(1)], 1), 1) - output = (alpha[:, 0] * cx) + (alpha[:, 1] * h2) + (alpha[:, 2] * h) - - return output - - -class Gazs_Cell(nn.Module): - def __init__(self, hid_h, dropout=0.2): - super(Gazs_Cell, self).__init__() - - self.Wi = nn.Linear(hid_h*4, hid_h) - self.Wf = nn.Linear(hid_h*4, hid_h) - self.Wc = nn.Linear(hid_h*4, hid_h) - - self.drop = nn.Dropout(dropout) - - def forward(self, h, x, glo): - - x = self.drop(x) - glo = self.drop(glo) - - cat_all = torch.cat([h, x, glo], -1) - i = torch.sigmoid(self.Wi(cat_all)) - f = torch.sigmoid(self.Wf(cat_all)) - c = torch.tanh(self.Wc(cat_all)) - - alpha = F.softmax(torch.cat([i.unsqueeze(1), f.unsqueeze(1)], 1), 1) - output = (alpha[:, 0] * c) + (alpha[:, 1] * h) - - return output - - -class GLobal_Cell(nn.Module): - def __init__(self, hid_h, dropout=0.2): - super(GLobal_Cell, self).__init__() - - self.Wi = nn.Linear(hid_h*3, hid_h) - self.Wf = nn.Linear(hid_h*3, hid_h) - self.Wc = nn.Linear(hid_h*3, hid_h) - - self.drop = nn.Dropout(dropout) - - def forward(self, h, x): - - x = self.drop(x) - - cat_all = torch.cat([h, x], -1) - i = torch.sigmoid(self.Wi(cat_all)) - f = torch.sigmoid(self.Wf(cat_all)) - c = torch.tanh(self.Wc(cat_all)) - - alpha = F.softmax(torch.cat([i.unsqueeze(1), f.unsqueeze(1)], 1), 1) - output = (alpha[:, 0] * c) + (alpha[:, 1] * h) - - return output - +from model.module import * class Graph(nn.Module): def __init__(self, data, args): @@ -205,9 +19,13 @@ def __init__(self, data, args): self.iters = args.iters self.bmes_dim = 10 self.length_dim = 10 - self.max_gaz_length = 5 + self.max_word_length = 5 self.emb_dropout_rate = args.emb_drop_rate self.cell_dropout_rate = args.cell_drop_rate + self.use_crf = args.use_crf + self.use_global = args.use_global + self.bidirectional = args.bidirectional + self.label_size = args.label_alphabet_size # char embedding self.char_embedding = nn.Embedding(args.char_alphabet_size, self.char_emb_dim) @@ -221,8 +39,6 @@ def __init__(self, data, args): data.pretrain_word_embedding[0, :] = np.random.uniform(-scale, scale, [1, self.word_emb_dim]) self.word_embedding.weight.data.copy_(torch.from_numpy(data.pretrain_word_embedding)) - # position embedding - # self.pos_embedding = nn.Embedding(data.posi_alphabet_size, self.hidden_dim) # lstm self.emb_rnn_f = nn.LSTM(self.char_emb_dim, self.hidden_dim, batch_first=True) self.emb_rnn_b = nn.LSTM(self.char_emb_dim, self.hidden_dim, batch_first=True) @@ -235,7 +51,7 @@ def __init__(self, data, args): self.bmes_embedding = nn.Embedding(4, self.bmes_dim) # length embedding - self.length_embedding = nn.Embedding(self.max_gaz_length, self.length_dim) + self.length_embedding = nn.Embedding(self.max_word_length, self.length_dim) self.dropout = nn.Dropout(self.emb_dropout_rate) self.norm = nn.LayerNorm(self.hidden_dim) @@ -248,71 +64,63 @@ def __init__(self, data, args): [MultiHeadAtt(self.hidden_dim, self.hidden_dim+self.bmes_dim, nhead=self.num_head, head_dim=self.head_dim, dropout=self.tf_dropout_rate) for _ in range(self.iters)]) - self.glo_att_f_node = nn.ModuleList( - [GloAtt(self.hidden_dim, nhead=self.num_head, head_dim=self.head_dim, dropout=self.tf_dropout_rate) - for _ in range(self.iters)]) + if self.use_global: + self.glo_att_f_node = nn.ModuleList( + [GloAtt(self.hidden_dim, nhead=self.num_head, head_dim=self.head_dim, dropout=self.tf_dropout_rate) + for _ in range(self.iters)]) - self.glo_att_f_edge = nn.ModuleList( - [GloAtt(self.hidden_dim, nhead=self.num_head, head_dim=self.head_dim, dropout=self.tf_dropout_rate) - for _ in range(self.iters)]) - - self.edge_rnn_f = Gazs_Cell(self.hidden_dim, dropout=self.cell_dropout_rate) - self.node_rnn_f = Nodes_Cell(self.hidden_dim, dropout=self.cell_dropout_rate) - self.glo_rnn_f = GLobal_Cell(self.hidden_dim, dropout=self.cell_dropout_rate) + self.glo_att_f_edge = nn.ModuleList( + [GloAtt(self.hidden_dim, nhead=self.num_head, head_dim=self.head_dim, dropout=self.tf_dropout_rate) + for _ in range(self.iters)]) + self.glo_rnn_f = GLobal_Cell(self.hidden_dim, dropout=self.cell_dropout_rate) + self.edge_rnn_f = Edges_Cell(self.hidden_dim, dropout=self.cell_dropout_rate) + self.node_rnn_f = Nodes_Cell(self.hidden_dim, dropout=self.cell_dropout_rate) - self.edge2node_b = nn.ModuleList( - [MultiHeadAtt(self.hidden_dim, self.hidden_dim*2+self.length_dim, - nhead=self.num_head, head_dim=self.head_dim, dropout=self.tf_dropout_rate) - for _ in range(self.iters)]) - self.node2edge_b = nn.ModuleList( - [MultiHeadAtt(self.hidden_dim, self.hidden_dim+self.bmes_dim, nhead=self.num_head, head_dim=self.head_dim, dropout=self.tf_dropout_rate) - for _ in range(self.iters)]) + else: + self.edge_rnn_f = Edges_Cell(self.hidden_dim, use_global=False, dropout=self.cell_dropout_rate) + self.node_rnn_f = Nodes_Cell(self.hidden_dim, use_global=False, dropout=self.cell_dropout_rate) + + if self.bidirectional: + self.edge2node_b = nn.ModuleList( + [MultiHeadAtt(self.hidden_dim, self.hidden_dim*2+self.length_dim, + nhead=self.num_head, head_dim=self.head_dim, dropout=self.tf_dropout_rate) + for _ in range(self.iters)]) + self.node2edge_b = nn.ModuleList( + [MultiHeadAtt(self.hidden_dim, self.hidden_dim+self.bmes_dim, nhead=self.num_head, head_dim=self.head_dim, dropout=self.tf_dropout_rate) + for _ in range(self.iters)]) + + if self.use_global: + self.glo_att_b_node = nn.ModuleList( + [GloAtt(self.hidden_dim, nhead=self.num_head, head_dim=self.head_dim, dropout=self.tf_dropout_rate) + for _ in range(self.iters)]) + + self.glo_att_b_edge = nn.ModuleList( + [GloAtt(self.hidden_dim, nhead=self.num_head, head_dim=self.head_dim, dropout=self.tf_dropout_rate) + for _ in range(self.iters)]) + + self.glo_rnn_b = GLobal_Cell(self.hidden_dim, self.cell_dropout_rate) + self.edge_rnn_b = Edges_Cell(self.hidden_dim, self.cell_dropout_rate) + self.node_rnn_b = Nodes_Cell(self.hidden_dim, self.cell_dropout_rate) + + else: + self.edge_rnn_b = Edges_Cell(self.hidden_dim, use_global=False, dropout=self.cell_dropout_rate) + self.node_rnn_b = Nodes_Cell(self.hidden_dim, use_global=False, dropout=self.cell_dropout_rate) + + if self.bidirectional: + output_dim = self.hidden_dim * 2 + else: + output_dim = self.hidden_dim - self.glo_att_b_node = nn.ModuleList( - [GloAtt(self.hidden_dim, nhead=self.num_head, head_dim=self.head_dim, dropout=self.tf_dropout_rate) - for _ in range(self.iters)]) + self.layer_att_W = nn.Linear(output_dim, 1) - self.glo_att_b_edge = nn.ModuleList( - [GloAtt(self.hidden_dim, nhead=self.num_head, head_dim=self.head_dim, dropout=self.tf_dropout_rate) - for _ in range(self.iters)]) + if self.use_crf: + self.hidden2tag = nn.Linear(output_dim, self.label_size + 2) + self.crf = CRF(self.label_size, self.gpu) + else: + self.hidden2tag = nn.Linear(output_dim, self.label_size) + self.criterion = nn.CrossEntropyLoss() - self.edge_rnn_b = Gazs_Cell(self.hidden_dim, self.cell_dropout_rate) - self.node_rnn_b = Nodes_Cell(self.hidden_dim, self.cell_dropout_rate) - self.glo_rnn_b = GLobal_Cell(self.hidden_dim, self.cell_dropout_rate) - - self.layer_att_W = nn.Linear(self.hidden_dim * 2, 1) - self.hidden2tag = nn.Linear(self.hidden_dim * 2, args.label_alphabet_size + 2) - self.crf = CRF(args.label_alphabet_size, self.gpu) - - if self.gpu: - self.char_embedding = self.char_embedding.cuda() - self.word_embedding = self.word_embedding.cuda() - self.bmes_embedding = self.bmes_embedding.cuda() - self.length_embedding = self.length_embedding.cuda() - self.norm = self.norm.cuda() - self.edge2node_f = self.edge2node_f.cuda() - self.node2edge_f = self.node2edge_f.cuda() - self.edge_rnn_f = self.edge_rnn_f.cuda() - self.node_rnn_f = self.node_rnn_f.cuda() - self.glo_rnn_f = self.glo_rnn_f.cuda() - self.glo_att_f_node = self.glo_att_f_node.cuda() - self.glo_att_f_edge = self.glo_att_f_edge.cuda() - self.edge2node_b = self.edge2node_b.cuda() - self.node2edge_b = self.node2edge_b.cuda() - self.edge_rnn_b = self.edge_rnn_b.cuda() - self.node_rnn_b = self.node_rnn_b.cuda() - self.glo_rnn_b = self.glo_rnn_b.cuda() - self.glo_att_b_node = self.glo_att_b_node.cuda() - self.glo_att_b_edge = self.glo_att_b_edge.cuda() - #self.pos_embedding = self.pos_embedding.cuda() - self.emb_rnn_f = self.emb_rnn_f.cuda() - self.emb_rnn_b = self.emb_rnn_b.cuda() - self.edge_emb_linear = self.edge_emb_linear.cuda() - self.layer_att_W = self.layer_att_W.cuda() - self.hidden2tag = self.hidden2tag.cuda() - self.crf = self.crf.cuda() - - def obtain_gaz_relation(self, batch_size, seq_len, gaz_list): + def construct_graph(self, batch_size, seq_len, word_list): assert batch_size == 1 @@ -330,59 +138,59 @@ def obtain_gaz_relation(self, batch_size, seq_len, gaz_list): bmes_emb_s = self.bmes_embedding(bmes_index_s) for sen in range(batch_size): - sen_gaz_embed = unk_emb[None, :] + sen_word_embed = unk_emb[None, :] sen_nodes_mask = torch.zeros([1, seq_len]).byte() - sen_gazs_length = torch.zeros([1, self.length_dim]) + sen_words_length = torch.zeros([1, self.length_dim]) sen_bmes_embed = torch.zeros([1, seq_len, self.bmes_dim]) - sen_gazs_mask_f = torch.zeros([1, seq_len]).byte() - sen_gazs_mask_b = torch.zeros([1, seq_len]).byte() + sen_words_mask_f = torch.zeros([1, seq_len]).byte() + sen_words_mask_b = torch.zeros([1, seq_len]).byte() if self.cuda: - sen_gaz_embed = sen_gaz_embed.cuda() + sen_word_embed = sen_word_embed.cuda() sen_nodes_mask = sen_nodes_mask.cuda() - sen_gazs_length = sen_gazs_length.cuda() + sen_words_length = sen_words_length.cuda() sen_bmes_embed = sen_bmes_embed.cuda() - sen_gazs_mask_f = sen_gazs_mask_f.cuda() - sen_gazs_mask_b = sen_gazs_mask_b.cuda() + sen_words_mask_f = sen_words_mask_f.cuda() + sen_words_mask_b = sen_words_mask_b.cuda() for w in range(seq_len): - if w < len(gaz_list[sen]) and gaz_list[sen][w]: - for gaz, gaz_len in zip(gaz_list[sen][w][0], gaz_list[sen][w][1]): + if w < len(word_list[sen]) and word_list[sen][w]: + for word, word_len in zip(word_list[sen][w][0], word_list[sen][w][1]): - gaz_index = torch.tensor(gaz, device=sen_gaz_embed.device) - gaz_embedding = self.word_embedding(gaz_index) - sen_gaz_embed = torch.cat([sen_gaz_embed, gaz_embedding[None, :]], 0) + word_index = torch.tensor(word, device=sen_word_embed.device) + word_embedding = self.word_embedding(word_index) + sen_word_embed = torch.cat([sen_word_embed, word_embedding[None, :]], 0) - if gaz_len <= self.max_gaz_length: - gaz_length_index = torch.tensor(gaz_len-1, device=sen_gazs_length.device) + if word_len <= self.max_word_length: + word_length_index = torch.tensor(word_len-1, device=sen_words_length.device) else: - gaz_length_index = torch.tensor(self.max_gaz_length-1, device=sen_gazs_length.device) - gaz_length = self.length_embedding(gaz_length_index) - sen_gazs_length = torch.cat([sen_gazs_length, gaz_length[None, :]], 0) + word_length_index = torch.tensor(self.max_word_length - 1, device=sen_words_length.device) + word_length = self.length_embedding(word_length_index) + sen_words_length = torch.cat([sen_words_length, word_length[None, :]], 0) - # mask: 需要mask的地方置为1, batch_size * gaz_num * seq_len + # mask: 需要mask的地方置为1, batch_size * word_num * seq_len nodes_mask = torch.ones([1, seq_len]).byte() bmes_embed = torch.zeros([1, seq_len, self.bmes_dim]) - gazs_mask_f = torch.ones([1, seq_len]).byte() - gazs_mask_b = torch.ones([1, seq_len]).byte() + words_mask_f = torch.ones([1, seq_len]).byte() + words_mask_b = torch.ones([1, seq_len]).byte() if self.cuda: nodes_mask = nodes_mask.cuda() bmes_embed = bmes_embed.cuda() - gazs_mask_f = gazs_mask_f.cuda() - gazs_mask_b = gazs_mask_b.cuda() + words_mask_f = words_mask_f.cuda() + words_mask_b = words_mask_b.cuda() - gazs_mask_f[0, w + gaz_len - 1] = 0 - sen_gazs_mask_f = torch.cat([sen_gazs_mask_f, gazs_mask_f], 0) + words_mask_f[0, w + word_len - 1] = 0 + sen_words_mask_f = torch.cat([sen_words_mask_f, words_mask_f], 0) - gazs_mask_b[0, w] = 0 - sen_gazs_mask_b = torch.cat([sen_gazs_mask_b, gazs_mask_b], 0) + words_mask_b[0, w] = 0 + sen_words_mask_b = torch.cat([sen_words_mask_b, words_mask_b], 0) - for index in range(gaz_len): + for index in range(word_len): nodes_mask[0, w + index] = 0 - if gaz_len == 1: + if word_len == 1: bmes_embed[0, w + index, :] = bmes_emb_s elif index == 0: bmes_embed[0, w + index, :] = bmes_emb_b - elif index == gaz_len - 1: + elif index == word_len - 1: bmes_embed[0, w + index, :] = bmes_emb_e else: bmes_embed[0, w + index, :] = bmes_emb_m @@ -390,44 +198,35 @@ def obtain_gaz_relation(self, batch_size, seq_len, gaz_list): sen_nodes_mask = torch.cat([sen_nodes_mask, nodes_mask], 0) sen_bmes_embed = torch.cat([sen_bmes_embed, bmes_embed], 0) - #sen_gazs_mask_f[0, (1-sen_gazs_mask_f).sum(dim=0) == 0] = 0 - #sen_gazs_mask_b[0, (1-sen_gazs_mask_b).sum(dim=0) == 0] = 0 - - batch_gaz_embed = sen_gaz_embed.unsqueeze(0) # 只有在batch_size=1时可以这么做 + batch_word_embed = sen_word_embed.unsqueeze(0) # Only works when batch size is 1 batch_nodes_mask = sen_nodes_mask.unsqueeze(0) batch_bmes_embed = sen_bmes_embed.unsqueeze(0) - batch_gazs_mask_f = sen_gazs_mask_f.unsqueeze(0) - batch_gazs_mask_b = sen_gazs_mask_b.unsqueeze(0) - batch_gazs_length = sen_gazs_length.unsqueeze(0) - return batch_gaz_embed, batch_bmes_embed, batch_nodes_mask, batch_gazs_mask_f, batch_gazs_mask_b, batch_gazs_length + batch_words_mask_f = sen_words_mask_f.unsqueeze(0) + batch_words_mask_b = sen_words_mask_b.unsqueeze(0) + batch_words_length = sen_words_length.unsqueeze(0) + return batch_word_embed, batch_bmes_embed, batch_nodes_mask, batch_words_mask_f, batch_words_mask_b, batch_words_length - def get_tags(self, gaz_list, word_inputs, mask): + def get_tags(self, word_list, word_inputs): - #mask = 1 - mask node_embeds = self.char_embedding(word_inputs) # batch_size, max_seq_len, embedding B, L, _ = node_embeds.size() - gaz_match = [] - edge_embs, bmes_embs, nodes_mask, gazs_mask_f, gazs_mask_b, gazs_length = self.obtain_gaz_relation(B, L, gaz_list) + edge_embs, bmes_embs, nodes_mask, words_mask_f, words_mask_b, words_length = self.construct_graph(B, L, word_list) _, N, _ = edge_embs.size() - #smask = torch.cat([torch.zeros(B, 1, ).byte().to(mask), mask], 1) - - #P = self.pos_embedding(torch.arange(L, dtype=torch.long, device=node_embeds.device).view(1, L)) - #node_embeds = node_embeds + P node_embeds = self.dropout(node_embeds) edge_embs = self.dropout(edge_embs) - #nodes_f = node_embeds + ## forward direction update edges_f = self.edge_emb_linear(edge_embs) nodes_f, _ = self.emb_rnn_f(node_embeds) - _, _, H = nodes_f.size() - - glo_f = edges_f.mean(1, keepdim=True) + nodes_f.mean(1, keepdim=True) nodes_f_cat = nodes_f[:, None, :, :] edges_f_cat = edges_f[:, None, :, :] - glo_f_cat = glo_f[:, None, :, :] - #ex_mask = mask[:, None, :, None].expand(B, H, L, 1) + _, _, H = nodes_f.size() + + if self.use_global: + glo_f = edges_f.mean(1, keepdim=True) + nodes_f.mean(1, keepdim=True) + glo_f_cat = glo_f[:, None, :, :] for i in range(self.iters): @@ -435,80 +234,115 @@ def get_tags(self, gaz_list, word_inputs, mask): bmes_nodes_f = torch.cat([nodes_f.unsqueeze(2).expand(B, L, N, H), bmes_embs.transpose(1, 2)], -1) edges_att_f = self.node2edge_f[i](edges_f, bmes_nodes_f, nodes_mask.transpose(1, 2)) - nodes_begin_f = torch.sum(nodes_f[:, None, :, :] * (1 - gazs_mask_b)[:, :, :, None].float(), 2) + nodes_begin_f = torch.sum(nodes_f[:, None, :, :] * (1 - words_mask_b)[:, :, :, None].float(), 2) nodes_begin_f = torch.cat([torch.zeros([B, 1, H], device=nodes_f.device), nodes_begin_f[:, 1:N, :]], 1) - nodes_att_f = self.edge2node_f[i](nodes_f, torch.cat([edges_f, nodes_begin_f, gazs_length], -1).unsqueeze(2), gazs_mask_f) + nodes_att_f = self.edge2node_f[i](nodes_f, torch.cat([edges_f, nodes_begin_f, words_length], -1).unsqueeze(2), words_mask_f) - glo_att_f = torch.cat([self.glo_att_f_node[i](glo_f, nodes_f), self.glo_att_f_edge[i](glo_f, edges_f)], -1) + if self.use_global: + glo_att_f = torch.cat([self.glo_att_f_node[i](glo_f, nodes_f), self.glo_att_f_edge[i](glo_f, edges_f)], -1) if N > 1: - edges_f = torch.cat([edges_f[:, 0:1, :], self.edge_rnn_f(edges_f[:, 1:N, :], - edges_att_f[:, 1:N, :], glo_att_f.expand(B, N-1, H*2))], 1) + if self.use_global: + edges_f = torch.cat([edges_f[:, 0:1, :], self.edge_rnn_f(edges_f[:, 1:N, :], + edges_att_f[:, 1:N, :], glo_att_f.expand(B, N-1, H*2))], 1) + else: + edges_f = torch.cat([edges_f[:, 0:1, :], self.edge_rnn_f(edges_f[:, 1:N, :], edges_att_f[:, 1:N, :])], 1) + edges_f_cat = torch.cat([edges_f_cat, edges_f[:, None, :, :]], 1) edges_f = torch.cat([edges_f[:, 0:1, :], self.norm(torch.sum(edges_f_cat[:, :, 1:N, :], 1))], 1) nodes_f_r = torch.cat([torch.zeros([B, 1, self.hidden_dim], device=nodes_f.device), nodes_f[:, 0:(L-1), :]], 1) - nodes_f = self.node_rnn_f(nodes_f, nodes_f_r, nodes_att_f, glo_att_f.expand(B, L, H*2)) + + if self.use_global: + nodes_f = self.node_rnn_f(nodes_f, nodes_f_r, nodes_att_f, glo_att_f.expand(B, L, H*2)) + else: + nodes_f = self.node_rnn_f(nodes_f, nodes_f_r, nodes_att_f) + nodes_f_cat = torch.cat([nodes_f_cat, nodes_f[:, None, :, :]], 1) nodes_f = self.norm(torch.sum(nodes_f_cat, 1)) - glo_f = self.glo_rnn_f(glo_f, glo_att_f) - glo_f_cat = torch.cat([glo_f_cat, glo_f[:, None, :, :]], 1) - glo_f = self.norm(torch.sum(glo_f_cat, 1)) - #nodes = nodes.masked_fill_(ex_mask, 0) + if self.use_global: + glo_f = self.glo_rnn_f(glo_f, glo_att_f) + glo_f_cat = torch.cat([glo_f_cat, glo_f[:, None, :, :]], 1) + glo_f = self.norm(torch.sum(glo_f_cat, 1)) - #nodes_b = node_embeds - edges_b = self.edge_emb_linear(edge_embs) - nodes_b, _ = self.emb_rnn_b(torch.flip(node_embeds, [1])) - nodes_b = torch.flip(nodes_b, [1]) + nodes_cat = nodes_f_cat - glo_b = nodes_b.mean(1, keepdim=True) + edges_b.mean(1, keepdim=True) - nodes_b_cat = nodes_b[:, None, :, :] - edges_b_cat = edges_b[:, None, :, :] - glo_b_cat = glo_b[:, None, :, :] + if self.bidirectional: + ## backward direction update + edges_b = self.edge_emb_linear(edge_embs) + nodes_b, _ = self.emb_rnn_b(torch.flip(node_embeds, [1])) + nodes_b = torch.flip(nodes_b, [1]) + nodes_b_cat = nodes_b[:, None, :, :] + edges_b_cat = edges_b[:, None, :, :] - for i in range(self.iters): + if self.use_global: + glo_b = nodes_b.mean(1, keepdim=True) + edges_b.mean(1, keepdim=True) + glo_b_cat = glo_b[:, None, :, :] - if N > 1: - bmes_nodes_b = torch.cat([nodes_b.unsqueeze(2).expand(B, L, N, H), bmes_embs.transpose(1, 2)], -1) - edges_att_b = self.node2edge_b[i](edges_b, bmes_nodes_b, nodes_mask.transpose(1, 2)) + for i in range(self.iters): - nodes_begin_b = torch.sum(nodes_b[:, None, :, :] * (1 - gazs_mask_f)[:, :, :, None].float(), 2) - nodes_begin_b = torch.cat([torch.zeros([B, 1, H], device=nodes_b.device), nodes_begin_b[:, 1:N, :]], 1) - nodes_att_b = self.edge2node_b[i](nodes_b, - torch.cat([edges_b, nodes_begin_b, gazs_length], -1).unsqueeze(2), gazs_mask_b) + if N > 1: + bmes_nodes_b = torch.cat([nodes_b.unsqueeze(2).expand(B, L, N, H), bmes_embs.transpose(1, 2)], -1) + edges_att_b = self.node2edge_b[i](edges_b, bmes_nodes_b, nodes_mask.transpose(1, 2)) - glo_att_b = torch.cat([self.glo_att_b_node[i](glo_b, nodes_b), self.glo_att_b_edge[i](glo_b, edges_b)], -1) + nodes_begin_b = torch.sum(nodes_b[:, None, :, :] * (1 - words_mask_f)[:, :, :, None].float(), 2) + nodes_begin_b = torch.cat([torch.zeros([B, 1, H], device=nodes_b.device), nodes_begin_b[:, 1:N, :]], 1) + nodes_att_b = self.edge2node_b[i](nodes_b, torch.cat([edges_b, nodes_begin_b, words_length], -1).unsqueeze(2), words_mask_b) - if N > 1: - edges_b = torch.cat([edges_b[:, 0:1, :], self.edge_rnn_b(edges_b[:, 1:N, :], edges_att_b[:, 1:N, :], - glo_att_b.expand(B, N-1, H*2))], 1) - edges_b_cat = torch.cat([edges_b_cat, edges_b[:, None, :, :]], 1) - edges_b = torch.cat([edges_b[:, 0:1, :], self.norm(torch.sum(edges_b_cat[:, :, 1:N, :], 1))], 1) + if self.use_global: + glo_att_b = torch.cat([self.glo_att_b_node[i](glo_b, nodes_b), self.glo_att_b_edge[i](glo_b, edges_b)], -1) + + if N > 1: + if self.use_global: + edges_b = torch.cat([edges_b[:, 0:1, :], self.edge_rnn_b(edges_b[:, 1:N, :], + edges_att_b[:, 1:N, :], glo_att_b.expand(B, N-1, H*2))], 1) + else: + edges_b = torch.cat([edges_b[:, 0:1, :], self.edge_rnn_b(edges_b[:, 1:N, :], edges_att_b[:, 1:N, :])], 1) + + edges_b_cat = torch.cat([edges_b_cat, edges_b[:, None, :, :]], 1) + edges_b = torch.cat([edges_b[:, 0:1, :], self.norm(torch.sum(edges_b_cat[:, :, 1:N, :], 1))], 1) + + nodes_b_r = torch.cat([nodes_b[:, 1:L, :], torch.zeros([B, 1, self.hidden_dim], device=nodes_b.device)], 1) + + if self.use_global: + nodes_b = self.node_rnn_b(nodes_b, nodes_b_r, nodes_att_b, glo_att_b.expand(B, L, H*2)) + else: + nodes_b = self.node_rnn_b(nodes_b, nodes_b_r, nodes_att_b) + + nodes_b_cat = torch.cat([nodes_b_cat, nodes_b[:, None, :, :]], 1) + nodes_b = self.norm(torch.sum(nodes_b_cat, 1)) - nodes_b_r = torch.cat([nodes_b[:, 1:L, :], torch.zeros([B, 1, self.hidden_dim], device=nodes_b.device)], 1) - nodes_b = self.node_rnn_b(nodes_b, nodes_b_r, nodes_att_b, glo_att_b.expand(B, L, H*2)) - nodes_b_cat = torch.cat([nodes_b_cat, nodes_b[:, None, :, :]], 1) - nodes_b = self.norm(torch.sum(nodes_b_cat, 1)) + if self.use_global: + glo_b = self.glo_rnn_b(glo_b, glo_att_b) + glo_b_cat = torch.cat([glo_b_cat, glo_b[:, None, :, :]], 1) + glo_b = self.norm(torch.sum(glo_b_cat, 1)) - glo_b = self.glo_rnn_b(glo_b, glo_att_b) - glo_b_cat = torch.cat([glo_b_cat, glo_b[:, None, :, :]], 1) - glo_b = self.norm(torch.sum(glo_b_cat, 1)) + nodes_cat = torch.cat([nodes_f_cat, nodes_b_cat], -1) - nodes_cat = torch.cat([nodes_f_cat, nodes_b_cat], -1) layer_att = torch.sigmoid(self.layer_att_W(nodes_cat)) layer_alpha = F.softmax(layer_att, 1) nodes = torch.sum(layer_alpha * nodes_cat, 1) tags = self.hidden2tag(nodes) - return tags, gaz_match + return tags + + def forward(self, word_list, batch_inputs, mask, batch_label=None): + + tags = self.get_tags(word_list, batch_inputs) - def forward(self, gaz_list, word_inputs, mask, batch_label=None): - tags, _ = self.get_tags(gaz_list, word_inputs, mask) if batch_label is not None: - total_loss = self.crf.neg_log_likelihood_loss(tags, mask, batch_label) + if self.use_crf: + total_loss = self.crf.neg_log_likelihood_loss(tags, mask, batch_label) + else: + total_loss = self.criterion(tags.view(-1, self.label_size), batch_label.view(-1)) else: total_loss = None - scores, tag_seq = self.crf._viterbi_decode(tags, mask) + + if self.use_crf: + _, tag_seq = self.crf._viterbi_decode(tags, mask) + else: + tag_seq = tags.argmax(-1) + return total_loss, tag_seq diff --git a/model/transformer_no_edge.py b/model/LGN_no_edge.py similarity index 100% rename from model/transformer_no_edge.py rename to model/LGN_no_edge.py diff --git a/model/transformer_no_glo_no_edge.py b/model/LGN_no_glo_no_edge.py similarity index 100% rename from model/transformer_no_glo_no_edge.py rename to model/LGN_no_glo_no_edge.py diff --git a/model/module.py b/model/module.py new file mode 100644 index 0000000..5314ffb --- /dev/null +++ b/model/module.py @@ -0,0 +1,204 @@ +# -*- coding: utf-8 -*- +# @Author: Yicheng Zou +# @Last Modified by: Yicheng Zou, Contact: yczou18@fudan.edu.cn + + +import torch +import torch.nn as nn +import numpy as np +import torch.nn.functional as F + + +class MultiHeadAtt(nn.Module): + def __init__(self, nhid, keyhid, nhead=10, head_dim=10, dropout=0.1, if_g=False): + super(MultiHeadAtt, self).__init__() + + if if_g: + self.WQ = nn.Conv2d(nhid * 3, nhead * head_dim, 1) + else: + self.WQ = nn.Conv2d(nhid, nhead * head_dim, 1) + self.WK = nn.Conv2d(keyhid, nhead * head_dim, 1) + self.WV = nn.Conv2d(keyhid, nhead * head_dim, 1) + self.WO = nn.Conv2d(nhead * head_dim, nhid, 1) + + self.drop = nn.Dropout(dropout) + + self.norm = nn.LayerNorm(nhid) + + self.nhid, self.nhead, self.head_dim = nhid, nhead, head_dim + + def forward(self, query_h, value, mask, query_g=None): + + if not (query_g is None): + query = torch.cat([query_h, query_g], -1) + else: + query = query_h + query = query.permute(0, 2, 1)[:, :, :, None] + value = value.permute(0, 3, 1, 2) + + residual = query_h + nhid, nhead, head_dim = self.nhid, self.nhead, self.head_dim + + B, QL, H = query_h.shape + + _, _, VL, VD = value.shape # VD = 1 or VD = QL + + assert VD == 1 or VD == QL + # q: (B, H, QL, 1) + # v: (B, H, VL, VD) + q, k, v = self.WQ(query), self.WK(value), self.WV(value) + + q = q.view(B, nhead, head_dim, 1, QL) + k = k.view(B, nhead, head_dim, VL, VD) + v = v.view(B, nhead, head_dim, VL, VD) + + alpha = (q * k).sum(2, keepdim=True) / np.sqrt(head_dim) + alpha = alpha.masked_fill(mask[:, None, None, :, :], -np.inf) + alpha = self.drop(F.softmax(alpha, 3)) + att = (alpha * v).sum(3).view(B, nhead * head_dim, QL, 1) + + output = F.leaky_relu(self.WO(att)).permute(0, 2, 3, 1).view(B, QL, H) + output = self.norm(output + residual) + + return output + + +class GloAtt(nn.Module): + def __init__(self, nhid, nhead=10, head_dim=10, dropout=0.1): + # Multi-head Self Attention Case 2, a broadcastable query for a sequence key and value + super(GloAtt, self).__init__() + self.WQ = nn.Conv2d(nhid, nhead * head_dim, 1) + self.WK = nn.Conv2d(nhid, nhead * head_dim, 1) + self.WV = nn.Conv2d(nhid, nhead * head_dim, 1) + self.WO = nn.Conv2d(nhead * head_dim, nhid, 1) + + self.drop = nn.Dropout(dropout) + + self.norm = nn.LayerNorm(nhid) + + # print('NUM_HEAD', nhead, 'DIM_HEAD', head_dim) + self.nhid, self.nhead, self.head_dim = nhid, nhead, head_dim + + def forward(self, x, y, mask=None): + # x: B, H, 1, 1, 1 y: B H L 1 + nhid, nhead, head_dim = self.nhid, self.nhead, self.head_dim + B, L, H = y.shape + + x = x.permute(0, 2, 1)[:, :, :, None] + y = y.permute(0, 2, 1)[:, :, :, None] + + residual = x + q, k, v = self.WQ(x), self.WK(y), self.WV(y) + + q = q.view(B, nhead, 1, head_dim) # B, H, 1, 1 -> B, N, 1, h + k = k.view(B, nhead, head_dim, L) # B, H, L, 1 -> B, N, h, L + v = v.view(B, nhead, head_dim, L).permute(0, 1, 3, 2) # B, H, L, 1 -> B, N, L, h + + pre_a = torch.matmul(q, k) / np.sqrt(head_dim) + if mask is not None: + pre_a = pre_a.masked_fill(mask[:, None, None, :], -float('inf')) + alphas = self.drop(F.softmax(pre_a, 3)) # B, N, 1, L + att = torch.matmul(alphas, v).view(B, -1, 1, 1) # B, N, 1, h -> B, N*h, 1, 1 + output = F.leaky_relu(self.WO(att)) + residual + output = self.norm(output.permute(0, 2, 3, 1)).view(B, 1, H) + + return output + + +class Nodes_Cell(nn.Module): + def __init__(self, hid_h, use_global=True, dropout=0.2): + super(Nodes_Cell, self).__init__() + + self.use_global = use_global + if self.use_global: + input_size = hid_h * 5 + else: + input_size = hid_h * 3 + + self.Wix = nn.Linear(input_size, hid_h) + self.Wi2 = nn.Linear(input_size, hid_h) + self.Wf = nn.Linear(input_size, hid_h) + self.Wcx = nn.Linear(input_size, hid_h) + + self.drop = nn.Dropout(dropout) + + def forward(self, h, h2, x, glo=None): + + x = self.drop(x) + + if self.use_global: + glo = self.drop(glo) + cat_all = torch.cat([h, h2, x, glo], -1) + else: + cat_all = torch.cat([h, h2, x], -1) + + ix = torch.sigmoid(self.Wix(cat_all)) + i2 = torch.sigmoid(self.Wi2(cat_all)) + f = torch.sigmoid(self.Wf(cat_all)) + cx = torch.tanh(self.Wcx(cat_all)) + + alpha = F.softmax(torch.cat([ix.unsqueeze(1), i2.unsqueeze(1), f.unsqueeze(1)], 1), 1) + output = (alpha[:, 0] * cx) + (alpha[:, 1] * h2) + (alpha[:, 2] * h) + + return output + + +class Edges_Cell(nn.Module): + def __init__(self, hid_h, use_global=True, dropout=0.2): + super(Edges_Cell, self).__init__() + + self.use_global = use_global + if self.use_global: + input_size = hid_h * 4 + else: + input_size = hid_h * 2 + + self.Wi = nn.Linear(input_size, hid_h) + self.Wf = nn.Linear(input_size, hid_h) + self.Wc = nn.Linear(input_size, hid_h) + + self.drop = nn.Dropout(dropout) + + def forward(self, h, x, glo=None): + + x = self.drop(x) + + if self.use_global: + glo = self.drop(glo) + cat_all = torch.cat([h, x, glo], -1) + else: + cat_all = torch.cat([h, x], -1) + + i = torch.sigmoid(self.Wi(cat_all)) + f = torch.sigmoid(self.Wf(cat_all)) + c = torch.tanh(self.Wc(cat_all)) + + alpha = F.softmax(torch.cat([i.unsqueeze(1), f.unsqueeze(1)], 1), 1) + output = (alpha[:, 0] * c) + (alpha[:, 1] * h) + + return output + + +class GLobal_Cell(nn.Module): + def __init__(self, hid_h, dropout=0.2): + super(GLobal_Cell, self).__init__() + + self.Wi = nn.Linear(hid_h*3, hid_h) + self.Wf = nn.Linear(hid_h*3, hid_h) + self.Wc = nn.Linear(hid_h*3, hid_h) + + self.drop = nn.Dropout(dropout) + + def forward(self, h, x): + + x = self.drop(x) + + cat_all = torch.cat([h, x], -1) + i = torch.sigmoid(self.Wi(cat_all)) + f = torch.sigmoid(self.Wf(cat_all)) + c = torch.tanh(self.Wc(cat_all)) + + alpha = F.softmax(torch.cat([i.unsqueeze(1), f.unsqueeze(1)], 1), 1) + output = (alpha[:, 0] * c) + (alpha[:, 1] * h) + + return output diff --git a/model/transformer_light.py b/model/transformer_light.py deleted file mode 100644 index e5e7c39..0000000 --- a/model/transformer_light.py +++ /dev/null @@ -1,413 +0,0 @@ -# -*- coding: utf-8 -*- -import torch -import torch.nn as nn -import numpy as np -import torch.nn.functional as F -from model.crf import CRF -#from model.layers import MultiHeadAttention, PositionwiseFeedForward - -def get_sinusoid_encoding_table(n_position, d_hid, padding_idx=None): - ''' Sinusoid position encoding table ''' - - def cal_angle(position, hid_idx): - return position / np.power(10000, 2 * (hid_idx // 2) / d_hid) - - def get_posi_angle_vec(position): - return [cal_angle(position, hid_j) for hid_j in range(d_hid)] - - sinusoid_table = np.array([get_posi_angle_vec(pos_i) for pos_i in range(n_position)]) - - sinusoid_table[:, 0::2] = np.sin(sinusoid_table[:, 0::2]) # dim 2i - sinusoid_table[:, 1::2] = np.cos(sinusoid_table[:, 1::2]) # dim 2i+1 - - if padding_idx is not None: - # zero vector for padding dimension - sinusoid_table[padding_idx] = 0. - - return torch.FloatTensor(sinusoid_table) - - -class Vertices_Cell(nn.Module): - def __init__(self, hid_1, hid_2): - super(Vertices_Cell, self).__init__() - - self.i_e = nn.Linear(hid_1, hid_2) - self.i_h = nn.Linear(hid_2, hid_2) - - self.f_e = nn.Linear(hid_1, hid_2) - self.f_h = nn.Linear(hid_2, hid_2) - - self.c_e = nn.Linear(hid_1, hid_2) - self.c_h = nn.Linear(hid_2, hid_2) - - nn.init.xavier_normal_(self.i_e.weight) - nn.init.xavier_normal_(self.i_h.weight) - nn.init.xavier_normal_(self.f_e.weight) - nn.init.xavier_normal_(self.f_h.weight) - nn.init.xavier_normal_(self.c_e.weight) - nn.init.xavier_normal_(self.c_h.weight) - - - def forward(self, input_1, input_2, vertices): - - i = torch.sigmoid(self.i_e(input_1) + self.i_h(input_2)) - f = torch.sigmoid(self.f_e(input_1) + self.f_h(input_2)) - c_new = torch.tanh(self.c_e(input_1) + self.c_h(input_2)) - - alpha = F.softmax(torch.cat([i.unsqueeze(1), f.unsqueeze(1)], dim=1), dim=1) - new_vertices = (alpha[:, 0] * c_new) + (alpha[:, 1] * vertices) - - return new_vertices - -class ScaledDotProductAttention(nn.Module): - ''' Scaled Dot-Product Attention ''' - - def __init__(self, temperature, attn_dropout=0.1): - super().__init__() - self.temperature = temperature - self.dropout = nn.Dropout(attn_dropout) - self.softmax = nn.Softmax(dim=2) - - def forward(self, q, k, v, mask=None): - - attn = torch.bmm(q, k.transpose(1, 2)) #(head*b,len,h)*(head*b,h,len) = (head*b,len,len) attn=len的每个对leni的权重,也就是每个字对第i个字的权重 - attn = attn / self.temperature - - if mask is not None: - attn = attn.masked_fill(mask, -np.inf) - - attn = self.softmax(attn) - attn = self.dropout(attn) - output = torch.bmm(attn, v) - - return output, attn - -class MultiHeadAttention(nn.Module): - ''' Multi-Head Attention module ''' - - def __init__(self, n_head, d_model, d_k, d_v, dropout=0.0): - super().__init__() - - self.n_head = n_head - self.d_k = d_k - self.d_v = d_v - - self.w_qs = nn.Linear(d_model, n_head * d_k) - self.w_ks = nn.Linear(d_model, n_head * d_k) - self.w_vs = nn.Linear(d_model, n_head * d_v) - nn.init.normal_(self.w_qs.weight, mean=0, std=np.sqrt(2.0 / (d_model + d_k))) - nn.init.normal_(self.w_ks.weight, mean=0, std=np.sqrt(2.0 / (d_model + d_k))) - nn.init.normal_(self.w_vs.weight, mean=0, std=np.sqrt(2.0 / (d_model + d_v))) - - self.attention = ScaledDotProductAttention(temperature=np.power(d_k, 0.5)) - self.layer_norm = nn.LayerNorm(d_model) - - self.fc = nn.Linear(n_head * d_v, d_model) - nn.init.xavier_normal_(self.fc.weight) - - self.dropout = nn.Dropout(dropout) - - - def forward(self, q, k, v, mask=None): - - d_k, d_v, n_head = self.d_k, self.d_v, self.n_head - - sz_b, len_q, _ = q.size() - sz_b, len_k, _ = k.size() - sz_b, len_v, _ = v.size() - - residual = q - - q = self.w_qs(q).view(sz_b, len_q, n_head, d_k) - k = self.w_ks(k).view(sz_b, len_k, n_head, d_k) - v = self.w_vs(v).view(sz_b, len_v, n_head, d_v) - - q = q.permute(2, 0, 1, 3).contiguous().view(-1, len_q, d_k) # (n*b) x lq x dk - k = k.permute(2, 0, 1, 3).contiguous().view(-1, len_k, d_k) # (n*b) x lk x dk - v = v.permute(2, 0, 1, 3).contiguous().view(-1, len_v, d_v) # (n*b) x lv x dv - - if mask is not None: - mask = mask.repeat(n_head, 1, 1) # (n*b) x .. x .. - output, attn = self.attention(q, k, v, mask=mask) - - output = output.view(n_head, sz_b, len_q, d_v) - output = output.permute(1, 2, 0, 3).contiguous().view(sz_b, len_q, -1) # b x lq x (n*dv) - - output = self.dropout(F.relu(self.fc(output))) - output = self.layer_norm(output + residual) - return output, attn - - -class Graph(nn.Module): - def __init__(self, data): - super(Graph, self).__init__() - self.data = data - self.gpu = data.HP_gpu - self.hidden_dim = 50 - self.num_layer = data.HP_num_layer - self.gaz_alphabet = data.gaz_alphabet - self.word_alphabet = data.word_alphabet - self.gaz_emb_dim = data.gaz_emb_dim - self.word_emb_dim = data.word_emb_dim - self.bmes_emb_dim = 25 - self.gaz_embedding = nn.Embedding(data.gaz_alphabet.size(), self.gaz_emb_dim) - self.word_embedding = nn.Embedding(data.word_alphabet.size(), self.word_emb_dim) - self.bmes_embedding = nn.Embedding(4, self.bmes_emb_dim) - - assert data.pretrain_gaz_embedding is not None - scale = np.sqrt(3.0 / self.gaz_emb_dim) - data.pretrain_gaz_embedding[0, :] = np.random.uniform(-scale, scale, [1, self.gaz_emb_dim]) - self.gaz_embedding.weight.data.copy_(torch.from_numpy(data.pretrain_gaz_embedding)) - - assert data.pretrain_word_embedding is not None - self.word_embedding.weight.data.copy_(torch.from_numpy(data.pretrain_word_embedding)) - - self.position_embedding = nn.Embedding.from_pretrained( - get_sinusoid_encoding_table(data.posi_alphabet_size, self.word_emb_dim, padding_idx=0), freeze=True) - - self.dropout = nn.Dropout(0.5) - self.rnn = nn.ModuleList([nn.LSTM(self.word_emb_dim, self.hidden_dim, batch_first=True) - for _ in range(self.num_layer)]).cuda() - - self.node_att = nn.ModuleList([MultiHeadAttention(4, self.hidden_dim, self.hidden_dim, self.hidden_dim) - for _ in range(self.num_layer)]) - #self.node_pooling = nn.ModuleList([PositionwiseFeedForward(self.hidden_dim, self.hidden_dim * 2, self.hidden_dim) - # for _ in range(self.num_layer)]) - #self.gaz_pooling = nn.ModuleList([MultiHeadAttention(1, self.hidden_dim, self.hidden_dim * 2, self.hidden_dim, self.hidden_dim) - # for _ in range(self.num_layer)]) - - self.glo_att = nn.ModuleList([MultiHeadAttention(4, self.hidden_dim, self.hidden_dim, self.hidden_dim) - for _ in range(self.num_layer)]) - - #self.node_cell = Vertices_Cell(self.hidden_dim, self.hidden_dim) - - self.hidden2tag = nn.Linear(self.hidden_dim, data.label_alphabet_size + 2) - self.crf = CRF(data.label_alphabet_size, self.gpu) - - if self.gpu: - self.gaz_embedding = self.gaz_embedding.cuda() - self.word_embedding = self.word_embedding.cuda() - self.position_embedding = self.position_embedding.cuda() - self.bmes_embedding = self.bmes_embedding.cuda() - self.node_att = self.node_att.cuda() - #self.node_pooling = self.node_pooling.cuda() - #self.gaz_pooling = self.gaz_pooling.cuda() - self.glo_att = self.glo_att.cuda() - #self.node_cell = self.node_cell.cuda() - self.hidden2tag = self.hidden2tag.cuda() - self.crf = self.crf.cuda() - - """ - def obtain_gaz_relation(self, batch_size, seq_len, gaz_list): - - assert batch_size == 1 - batch_gaz_embed = torch.tensor([]) - batch_nodes_mask = torch.tensor([], dtype=torch.uint8) - batch_bmes_embed = torch.tensor([]) - batch_gazs_mask = torch.tensor([]) - if self.cuda: - batch_gaz_embed = batch_gaz_embed.cuda() - batch_nodes_mask = batch_nodes_mask.cuda() - batch_bmes_embed = batch_bmes_embed.cuda() - batch_gazs_mask = batch_gazs_mask.cuda() - - for sen in range(batch_size): - sen_gaz_embed = torch.zeros([0, self.gaz_emb_dim]) - sen_nodes_mask = torch.zeros([0, seq_len], dtype=torch.uint8) - sen_bmes_embed = torch.zeros([0, seq_len, self.bmes_emb_dim]) - sen_gazs_mask = torch.zeros([0, seq_len], dtype=torch.uint8) - if self.cuda: - sen_gaz_embed = sen_gaz_embed.cuda() - sen_nodes_mask = sen_nodes_mask.cuda() - sen_bmes_embed = sen_bmes_embed.cuda() - sen_gazs_mask = sen_gazs_mask.cuda() - - for w in range(seq_len): - if w < len(gaz_list[sen]) and gaz_list[sen][w]: - for gaz, gaz_len in zip(gaz_list[sen][w][0], gaz_list[sen][w][1]): - - gaz_index = torch.tensor(gaz).cuda() if self.cuda else torch.tensor(gaz) - gaz_embedding = self.gaz_embedding(gaz_index) - sen_gaz_embed = torch.cat([sen_gaz_embed, gaz_embedding.unsqueeze(0)], dim=0) - - # mask: 需要mask的地方置为1, batch_size * gaz_num * seq_len - nodes_mask = torch.ones([1, seq_len], dtype=torch.uint8) - bmes_embed = torch.zeros([1, seq_len, self.bmes_emb_dim]) - - gazs_mask = torch.ones([1, seq_len], dtype=torch.uint8) - gazs_mask[0, w + gaz_len - 1] = 0 - sen_gazs_mask = torch.cat([sen_gazs_mask, gazs_mask.unsqueeze(0)], dim=0) - - if self.cuda: - nodes_mask = nodes_mask.cuda() - bmes_embed = bmes_embed.cuda() - - for index in range(gaz_len): - nodes_mask[0, w + index] = 0 - if gaz_len == 1: - bmes_index = torch.tensor(3).cuda() if self.cuda else torch.tensor(3) # S - elif index == 0: - bmes_index = torch.tensor(0).cuda() if self.cuda else torch.tensor(0) # B - elif index == gaz_len - 1: - bmes_index = torch.tensor(2).cuda() if self.cuda else torch.tensor(2) # E - else: - bmes_index = torch.tensor(1).cuda() if self.cuda else torch.tensor(1) # M - bmes_embed[0, w + index, :] = self.bmes_embedding(bmes_index) - - sen_nodes_mask = torch.cat([sen_nodes_mask, nodes_mask], dim=0) - sen_bmes_embed = torch.cat([sen_bmes_embed, bmes_embed], dim=0) - - sen_gazs_unk_mask = torch.ones([1, seq_len], dtype=torch.uint8) - sen_gazs_unk_mask[0, (1-sen_gazs_mask).sum(dim=0) < 0.5] = 0. - sen_gazs_mask = torch.cat([sen_gazs_unk_mask, sen_gazs_mask], dim=0) - - if sen_gaz_embed.size(0) != 0: - batch_gaz_embed = sen_gaz_embed.unsqueeze(0) # 只有在batch_size=1时可以这么做 - batch_nodes_mask = sen_nodes_mask.unsqueeze(0) - batch_bmes_embed = sen_bmes_embed.unsqueeze(0) - batch_gazs_mask = sen_gazs_mask.unsqueeze(0) - - return batch_gaz_embed, batch_nodes_mask, batch_bmes_embed, batch_gazs_mask - """ - - def obtain_gaz_relation(self, batch_size, seq_len, gaz_list, mask): - - assert batch_size == 1 - adjoin_index = torch.tensor(0).cuda() if self.cuda else torch.tensor(0) - adjoin_emb = self.gaz_embedding(adjoin_index) - - gazs_send_embeds = torch.zeros(batch_size, seq_len+1, seq_len, self.gaz_emb_dim) - gazs_rec_embeds = torch.zeros(batch_size, seq_len+1, seq_len, self.gaz_emb_dim) - rel_send_gaz = torch.ones([batch_size, seq_len+1, seq_len], dtype=torch.uint8) - rel_rec_gaz = torch.ones([batch_size, seq_len+1, seq_len], dtype=torch.uint8) - - if self.cuda: - gazs_send_embeds = gazs_send_embeds.cuda() - gazs_rec_embeds = gazs_rec_embeds.cuda() - rel_send_gaz = rel_send_gaz.cuda() - rel_rec_gaz = rel_rec_gaz.cuda() - - for sen in range(batch_size): - for w in range(seq_len): - if not mask[sen][w]: - break - if w < len(gaz_list[sen]): - gazs_send_embeds[sen][w][w] = adjoin_emb - gazs_rec_embeds[sen][w + 1][w] = adjoin_emb - rel_send_gaz[sen][w][w] = 0 - rel_rec_gaz[sen][w + 1][w] = 0 - - for sen in range(batch_size): - for w in range(seq_len): - if not mask[sen][w]: - break - if w < len(gaz_list[sen]) and gaz_list[sen][w]: - for gaz, gaz_len in zip(gaz_list[sen][w][0], gaz_list[sen][w][1]): - gaz_index = torch.tensor(gaz).cuda() if self.cuda else torch.tensor(gaz) - gaz_embedding = self.gaz_embedding(gaz_index) - - gazs_send_embeds[sen][w + 1][w + gaz_len - 1] = gaz_embedding - gazs_rec_embeds[sen][w + gaz_len - 1][w] = gaz_embedding - rel_send_gaz[sen][w + 1][w + gaz_len - 1] = 0 - rel_rec_gaz[sen][w + gaz_len - 1][w] = 0 - - return gazs_send_embeds, gazs_rec_embeds, rel_send_gaz, rel_rec_gaz - - def get_tags(self, gaz_list, word_inputs, mask): - - batch_size = word_inputs.size()[0] - seq_len = word_inputs.size()[1] - word_embs = self.word_embedding(word_inputs) # batch_size, max_seq_len, embedding - gaz_match = [] - - # position embedding - posi_inputs = torch.zeros(batch_size, seq_len).long() - for batch in range(batch_size): - posi_temp = torch.LongTensor([i + 1 for i in range(seq_len) if mask[batch][i]]) - posi_inputs[batch, 0:posi_temp.size(0)] = posi_temp - if self.gpu: - posi_inputs = posi_inputs.cuda() - position_embs = self.position_embedding(posi_inputs) - - # 节点的初始表示, batch_size * seq_len * emb_size - raw_nodes = self.dropout(word_embs + position_embs) - - # gaz的初始表示,batch_size * seq_len * seq_len * emb_size - #raw_send_embed, raw_rec_embed, rel_send_gaz, rel_rec_gaz = \ - # self.obtain_gaz_relation(batch_size, seq_len, gaz_list, mask) - - #gazs_send = self.dropout(raw_send_embed) - #gazs_rec = self.dropout(raw_rec_embed) - nodes_send = raw_nodes - #nodes_rec = raw_nodes - glo = torch.mean(nodes_send, dim=1) - - for layer in range(self.num_layer): - - layer_index = layer # layer_index = 0 表示只有一套参数 - - # SEND - nodes_att_list = [] - #gazs_pooling_list = [] - - # 前后padding的 nodes - padding_embed = torch.zeros([batch_size, 1, self.hidden_dim]).cuda() if self.cuda else \ - torch.zeros([batch_size, 1, self.hidden_dim]) - nodes_padding = torch.cat([padding_embed, nodes_send, padding_embed], dim=1) - #gazs_padding = torch.cat([nodes_padding.unsqueeze(2).expand(batch_size, seq_len+1, seq_len, self.hidden_dim), - # gazs_send], dim=-1) - - for i in range(seq_len): - - node_att_cell = torch.cat([nodes_padding[:, i, :].unsqueeze(1), - nodes_padding[:, (i+1), :].unsqueeze(1), - nodes_padding[:, (i+2), :].unsqueeze(1), - raw_nodes[:, i, :].unsqueeze(1), - glo.unsqueeze(1), - ], dim=1) - - node_att, _att = self.node_att[layer_index](nodes_send[:, i:(i+1), :], node_att_cell, node_att_cell) - nodes_att_list.append(node_att) - - #gaz_pooling_cell = gazs_padding[:, :, i, :] - #gaz_pooling_mask = rel_send_gaz[:, :, i].unsqueeze(1) - #gaz_pooling, _att = self.gaz_pooling[layer_index](nodes_send[:, i:(i+1), :], gaz_pooling_cell, gaz_pooling_cell, gaz_pooling_mask) - #gazs_pooling_list.append(gaz_pooling) - - nodes_send = torch.cat(nodes_att_list, dim=1) - #nodes_send = self.node_pooling[layer_index](nodes_att) - #gazs_pooling = torch.cat(gazs_pooling_list, dim=1) - - #nodes_send = self.node_cell(nodes_att, gazs_pooling, nodes_send) - - #nodes_send = self.dropout(nodes_att) - - glo_cell = torch.cat([glo.unsqueeze(1), - nodes_send], dim=1) - - glo, _att = self.glo_att[layer_index](glo.unsqueeze(1), glo_cell, glo_cell) - glo = glo.squeeze(1) - #nodes_send, _ = self.node_att[layer](nodes_send, nodes_send, nodes_send) - #nodes_send = self.node_pooling[layer](nodes_send) - #nodes_send, _ = self.rnn[layer](nodes_send) - - tags = self.hidden2tag(nodes_send) # (b,m,t) - - return tags, gaz_match - - def neg_log_likelihood_loss(self, gaz_list, word_inputs, word_seq_lengths, mask, batch_label): - - tags, _ = self.get_tags(gaz_list, word_inputs, mask) - total_loss = self.crf.neg_log_likelihood_loss(tags, mask, batch_label) - scores, tag_seq = self.crf._viterbi_decode(tags, mask) - - return total_loss, tag_seq # (batch_size,) ,(b,seqlen?) - - def forward(self, gaz_list, word_inputs, word_seq_lengths, mask): - tags, gaz_match = self.get_tags(gaz_list, word_inputs, mask) - # tags_ = tags.transpose(0,1).contiguous() - # mask_ = mask.transpose(0,1).contiguous() - scores, tag_seq = self.crf._viterbi_decode(tags, mask) - # tag_seq = self.crf_.decode(tags_, mask=mask_) - # tag_seq = self.crf_.decode(tags, mask=mask) - return tag_seq, gaz_match diff --git a/model/transformer_no_crf.py b/model/transformer_no_crf.py deleted file mode 100644 index ef7425f..0000000 --- a/model/transformer_no_crf.py +++ /dev/null @@ -1,511 +0,0 @@ -# -*- coding: utf-8 -*- -import torch -import torch.nn as nn -import numpy as np -import torch.nn.functional as F -from model.crf import CRF - - -class MultiHeadAtt(nn.Module): - def __init__(self, nhid, keyhid, nhead=10, head_dim=10, dropout=0.1, if_g=False): - super(MultiHeadAtt, self).__init__() - - if if_g: - self.WQ = nn.Conv2d(nhid * 3, nhead * head_dim, 1) - else: - self.WQ = nn.Conv2d(nhid, nhead * head_dim, 1) - self.WK = nn.Conv2d(keyhid, nhead * head_dim, 1) - self.WV = nn.Conv2d(keyhid, nhead * head_dim, 1) - self.WO = nn.Conv2d(nhead * head_dim, nhid, 1) - - self.drop = nn.Dropout(dropout) - - self.norm = nn.LayerNorm(nhid) - - self.nhid, self.nhead, self.head_dim = nhid, nhead, head_dim - - def forward(self, query_h, value, mask, query_g=None): - - if not (query_g is None): - query = torch.cat([query_h, query_g], -1) - else: - query = query_h - query = query.permute(0, 2, 1)[:, :, :, None] - value = value.permute(0, 3, 1, 2) - - residual = query_h - nhid, nhead, head_dim = self.nhid, self.nhead, self.head_dim - - B, QL, H = query_h.shape - - _, _, VL, VD = value.shape # VD = 1 or VD = QL - - assert VD == 1 or VD == QL - # q: (B, H, QL, 1) - # v: (B, H, VL, VD) - q, k, v = self.WQ(query), self.WK(value), self.WV(value) - - q = q.view(B, nhead, head_dim, 1, QL) - k = k.view(B, nhead, head_dim, VL, VD) - v = v.view(B, nhead, head_dim, VL, VD) - - alpha = (q * k).sum(2, keepdim=True) / np.sqrt(head_dim) - alpha = alpha.masked_fill(mask[:, None, None, :, :], -np.inf) - alpha = self.drop(F.softmax(alpha, 3)) - att = (alpha * v).sum(3).view(B, nhead * head_dim, QL, 1) - - output = F.leaky_relu(self.WO(att)).permute(0, 2, 3, 1).view(B, QL, H) - output = self.norm(output + residual) - - return output - - -class GloAtt(nn.Module): - def __init__(self, nhid, nhead=10, head_dim=10, dropout=0.1): - # Multi-head Self Attention Case 2, a broadcastable query for a sequence key and value - super(GloAtt, self).__init__() - self.WQ = nn.Conv2d(nhid, nhead * head_dim, 1) - self.WK = nn.Conv2d(nhid, nhead * head_dim, 1) - self.WV = nn.Conv2d(nhid, nhead * head_dim, 1) - self.WO = nn.Conv2d(nhead * head_dim, nhid, 1) - - self.drop = nn.Dropout(dropout) - - self.norm = nn.LayerNorm(nhid) - - # print('NUM_HEAD', nhead, 'DIM_HEAD', head_dim) - self.nhid, self.nhead, self.head_dim = nhid, nhead, head_dim - - def forward(self, x, y, mask=None): - # x: B, H, 1, 1, 1 y: B H L 1 - nhid, nhead, head_dim = self.nhid, self.nhead, self.head_dim - B, L, H = y.shape - - x = x.permute(0, 2, 1)[:, :, :, None] - y = y.permute(0, 2, 1)[:, :, :, None] - - residual = x - q, k, v = self.WQ(x), self.WK(y), self.WV(y) - - q = q.view(B, nhead, 1, head_dim) # B, H, 1, 1 -> B, N, 1, h - k = k.view(B, nhead, head_dim, L) # B, H, L, 1 -> B, N, h, L - v = v.view(B, nhead, head_dim, L).permute(0, 1, 3, 2) # B, H, L, 1 -> B, N, L, h - - pre_a = torch.matmul(q, k) / np.sqrt(head_dim) - if mask is not None: - pre_a = pre_a.masked_fill(mask[:, None, None, :], -float('inf')) - alphas = self.drop(F.softmax(pre_a, 3)) # B, N, 1, L - att = torch.matmul(alphas, v).view(B, -1, 1, 1) # B, N, 1, h -> B, N*h, 1, 1 - output = F.leaky_relu(self.WO(att)) + residual - output = self.norm(output.permute(0, 2, 3, 1)).view(B, 1, H) - - return output - - -class Nodes_Cell(nn.Module): - def __init__(self, hid_h, dropout=0.2): - super(Nodes_Cell, self).__init__() - - self.Wix = nn.Linear(hid_h*5, hid_h) - #self.Wig = nn.Linear(hid_h*4, hid_h) - self.Wi2 = nn.Linear(hid_h*5, hid_h) - self.Wf = nn.Linear(hid_h*5, hid_h) - self.Wcx = nn.Linear(hid_h*5, hid_h) - #self.Wcg = nn.Linear(hid_h, hid_h) - - self.drop = nn.Dropout(dropout) - - def forward(self, h, h2, x, glo): - - x = self.drop(x) - glo = self.drop(glo) - - cat_all = torch.cat([h, h2, x, glo], -1) - #cat_x = torch.cat([h, h2, x], -1) - #cat_g = torch.cat([glo], -1) - - ix = torch.sigmoid(self.Wix(cat_all)) - #ig = torch.sigmoid(self.Wig(cat_all)) - i2 = torch.sigmoid(self.Wi2(cat_all)) - f = torch.sigmoid(self.Wf(cat_all)) - cx = torch.tanh(self.Wcx(cat_all)) - #cg = torch.tanh(self.Wcg(cat_g)) - - alpha = F.softmax(torch.cat([ix.unsqueeze(1), i2.unsqueeze(1), f.unsqueeze(1)], 1), 1) - output = (alpha[:, 0] * cx) + (alpha[:, 1] * h2) + (alpha[:, 2] * h) - - return output - - -class Gazs_Cell(nn.Module): - def __init__(self, hid_h, dropout=0.2): - super(Gazs_Cell, self).__init__() - - self.Wi = nn.Linear(hid_h*4, hid_h) - self.Wf = nn.Linear(hid_h*4, hid_h) - self.Wc = nn.Linear(hid_h*4, hid_h) - - self.drop = nn.Dropout(dropout) - - def forward(self, h, x, glo): - - x = self.drop(x) - glo = self.drop(glo) - - cat_all = torch.cat([h, x, glo], -1) - i = torch.sigmoid(self.Wi(cat_all)) - f = torch.sigmoid(self.Wf(cat_all)) - c = torch.tanh(self.Wc(cat_all)) - - alpha = F.softmax(torch.cat([i.unsqueeze(1), f.unsqueeze(1)], 1), 1) - output = (alpha[:, 0] * c) + (alpha[:, 1] * h) - - return output - - -class GLobal_Cell(nn.Module): - def __init__(self, hid_h, dropout=0.2): - super(GLobal_Cell, self).__init__() - - self.Wi = nn.Linear(hid_h*3, hid_h) - self.Wf = nn.Linear(hid_h*3, hid_h) - self.Wc = nn.Linear(hid_h*3, hid_h) - - self.drop = nn.Dropout(dropout) - - def forward(self, h, x): - - x = self.drop(x) - - cat_all = torch.cat([h, x], -1) - i = torch.sigmoid(self.Wi(cat_all)) - f = torch.sigmoid(self.Wf(cat_all)) - c = torch.tanh(self.Wc(cat_all)) - - alpha = F.softmax(torch.cat([i.unsqueeze(1), f.unsqueeze(1)], 1), 1) - output = (alpha[:, 0] * c) + (alpha[:, 1] * h) - - return output - - -class Graph(nn.Module): - def __init__(self, data): - super(Graph, self).__init__() - - self.gpu = data.HP_gpu - self.word_alphabet = data.word_alphabet - self.label_size = data.label_alphabet_size - self.word_emb_dim = data.word_emb_dim - self.gaz_emb_dim = data.gaz_emb_dim - self.hidden_dim = 50 - self.num_head = 10 # 5 10 20 - self.head_dim = 20 # 10 20 - self.tf_dropout_rate = 0.1 - self.iters = 4 - self.bmes_dim = 10 - self.length_dim = 10 - self.max_gaz_length = 5 - self.emb_dropout_rate = 0.5 - self.cell_dropout_rate = 0.2 - - # word embedding - self.word_embedding = nn.Embedding(data.word_alphabet.size(), self.word_emb_dim) - assert data.pretrain_word_embedding is not None - self.word_embedding.weight.data.copy_(torch.from_numpy(data.pretrain_word_embedding)) - - # gaz embedding - self.gaz_embedding = nn.Embedding(data.gaz_alphabet.size(), self.gaz_emb_dim) - assert data.pretrain_gaz_embedding is not None - scale = np.sqrt(3.0 / self.gaz_emb_dim) - data.pretrain_gaz_embedding[0, :] = np.random.uniform(-scale, scale, [1, self.gaz_emb_dim]) - self.gaz_embedding.weight.data.copy_(torch.from_numpy(data.pretrain_gaz_embedding)) - - # position embedding - #self.pos_embedding = nn.Embedding(data.posi_alphabet_size, self.hidden_dim) - # lstm - self.emb_rnn_f = nn.LSTM(self.hidden_dim, self.hidden_dim, batch_first=True) - self.emb_rnn_b = nn.LSTM(self.hidden_dim, self.hidden_dim, batch_first=True) - - # bmes embedding - self.bmes_embedding = nn.Embedding(4, self.bmes_dim) - - # length embedding - self.length_embedding = nn.Embedding(self.max_gaz_length, self.length_dim) - - self.dropout = nn.Dropout(self.emb_dropout_rate) - self.norm = nn.LayerNorm(self.hidden_dim) - - self.edge2node_f = nn.ModuleList( - [MultiHeadAtt(self.hidden_dim, self.hidden_dim*2+self.length_dim, - nhead=self.num_head, head_dim=self.head_dim, dropout=self.tf_dropout_rate) - for _ in range(self.iters)]) - self.node2edge_f = nn.ModuleList( - [MultiHeadAtt(self.hidden_dim, self.hidden_dim+self.bmes_dim, nhead=self.num_head, head_dim=self.head_dim, dropout=self.tf_dropout_rate) - for _ in range(self.iters)]) - - self.glo_att_f_node = nn.ModuleList( - [GloAtt(self.hidden_dim, nhead=self.num_head, head_dim=self.head_dim, dropout=self.tf_dropout_rate) - for _ in range(self.iters)]) - - self.glo_att_f_edge = nn.ModuleList( - [GloAtt(self.hidden_dim, nhead=self.num_head, head_dim=self.head_dim, dropout=self.tf_dropout_rate) - for _ in range(self.iters)]) - - self.edge_rnn_f = Gazs_Cell(self.hidden_dim, dropout=self.cell_dropout_rate) - self.node_rnn_f = Nodes_Cell(self.hidden_dim, dropout=self.cell_dropout_rate) - self.glo_rnn_f = GLobal_Cell(self.hidden_dim, dropout=self.cell_dropout_rate) - - self.edge2node_b = nn.ModuleList( - [MultiHeadAtt(self.hidden_dim, self.hidden_dim*2+self.length_dim, - nhead=self.num_head, head_dim=self.head_dim, dropout=self.tf_dropout_rate) - for _ in range(self.iters)]) - self.node2edge_b = nn.ModuleList( - [MultiHeadAtt(self.hidden_dim, self.hidden_dim+self.bmes_dim, nhead=self.num_head, head_dim=self.head_dim, dropout=self.tf_dropout_rate) - for _ in range(self.iters)]) - - self.glo_att_b_node = nn.ModuleList( - [GloAtt(self.hidden_dim, nhead=self.num_head, head_dim=self.head_dim, dropout=self.tf_dropout_rate) - for _ in range(self.iters)]) - - self.glo_att_b_edge = nn.ModuleList( - [GloAtt(self.hidden_dim, nhead=self.num_head, head_dim=self.head_dim, dropout=self.tf_dropout_rate) - for _ in range(self.iters)]) - - self.edge_rnn_b = Gazs_Cell(self.hidden_dim, self.cell_dropout_rate) - self.node_rnn_b = Nodes_Cell(self.hidden_dim, self.cell_dropout_rate) - self.glo_rnn_b = GLobal_Cell(self.hidden_dim, self.cell_dropout_rate) - - self.layer_att_W = nn.Linear(self.hidden_dim * 2, 1) - self.hidden2tag = nn.Linear(self.hidden_dim * 2, self.label_size) - self.criterion = nn.CrossEntropyLoss() - - if self.gpu: - self.word_embedding = self.word_embedding.cuda() - self.gaz_embedding = self.gaz_embedding.cuda() - self.bmes_embedding = self.bmes_embedding.cuda() - self.length_embedding = self.length_embedding.cuda() - self.norm = self.norm.cuda() - self.edge2node_f = self.edge2node_f.cuda() - self.node2edge_f = self.node2edge_f.cuda() - self.edge_rnn_f = self.edge_rnn_f.cuda() - self.node_rnn_f = self.node_rnn_f.cuda() - self.glo_rnn_f = self.glo_rnn_f.cuda() - self.glo_att_f_node = self.glo_att_f_node.cuda() - self.glo_att_f_edge = self.glo_att_f_edge.cuda() - self.edge2node_b = self.edge2node_b.cuda() - self.node2edge_b = self.node2edge_b.cuda() - self.edge_rnn_b = self.edge_rnn_b.cuda() - self.node_rnn_b = self.node_rnn_b.cuda() - self.glo_rnn_b = self.glo_rnn_b.cuda() - self.glo_att_b_node = self.glo_att_b_node.cuda() - self.glo_att_b_edge = self.glo_att_b_edge.cuda() - #self.pos_embedding = self.pos_embedding.cuda() - self.emb_rnn_f = self.emb_rnn_f.cuda() - self.emb_rnn_b = self.emb_rnn_b.cuda() - self.layer_att_W = self.layer_att_W.cuda() - self.hidden2tag = self.hidden2tag.cuda() - - def obtain_gaz_relation(self, batch_size, seq_len, gaz_list): - - assert batch_size == 1 - - unk_index = torch.tensor(0).cuda() if self.cuda else torch.tensor(0) - unk_emb = self.gaz_embedding(unk_index) - - bmes_index_b = torch.tensor(0).cuda() if self.cuda else torch.tensor(0) - bmes_index_m = torch.tensor(1).cuda() if self.cuda else torch.tensor(1) - bmes_index_e = torch.tensor(2).cuda() if self.cuda else torch.tensor(2) - bmes_index_s = torch.tensor(3).cuda() if self.cuda else torch.tensor(3) - - bmes_emb_b = self.bmes_embedding(bmes_index_b) - bmes_emb_m = self.bmes_embedding(bmes_index_m) - bmes_emb_e = self.bmes_embedding(bmes_index_e) - bmes_emb_s = self.bmes_embedding(bmes_index_s) - - for sen in range(batch_size): - sen_gaz_embed = unk_emb[None, :] - sen_nodes_mask = torch.zeros([1, seq_len]).byte() - sen_gazs_length = torch.zeros([1, self.length_dim]) - sen_bmes_embed = torch.zeros([1, seq_len, self.bmes_dim]) - sen_gazs_mask_f = torch.zeros([1, seq_len]).byte() - sen_gazs_mask_b = torch.zeros([1, seq_len]).byte() - if self.cuda: - sen_gaz_embed = sen_gaz_embed.cuda() - sen_nodes_mask = sen_nodes_mask.cuda() - sen_gazs_length = sen_gazs_length.cuda() - sen_bmes_embed = sen_bmes_embed.cuda() - sen_gazs_mask_f = sen_gazs_mask_f.cuda() - sen_gazs_mask_b = sen_gazs_mask_b.cuda() - - for w in range(seq_len): - if w < len(gaz_list[sen]) and gaz_list[sen][w]: - for gaz, gaz_len in zip(gaz_list[sen][w][0], gaz_list[sen][w][1]): - - gaz_index = torch.tensor(gaz, device=sen_gaz_embed.device) - gaz_embedding = self.gaz_embedding(gaz_index) - sen_gaz_embed = torch.cat([sen_gaz_embed, gaz_embedding[None, :]], 0) - - if gaz_len <= self.max_gaz_length: - gaz_length_index = torch.tensor(gaz_len-1, device=sen_gazs_length.device) - else: - gaz_length_index = torch.tensor(self.max_gaz_length-1, device=sen_gazs_length.device) - gaz_length = self.length_embedding(gaz_length_index) - sen_gazs_length = torch.cat([sen_gazs_length, gaz_length[None, :]], 0) - - # mask: 需要mask的地方置为1, batch_size * gaz_num * seq_len - nodes_mask = torch.ones([1, seq_len]).byte() - bmes_embed = torch.zeros([1, seq_len, self.bmes_dim]) - gazs_mask_f = torch.ones([1, seq_len]).byte() - gazs_mask_b = torch.ones([1, seq_len]).byte() - if self.cuda: - nodes_mask = nodes_mask.cuda() - bmes_embed = bmes_embed.cuda() - gazs_mask_f = gazs_mask_f.cuda() - gazs_mask_b = gazs_mask_b.cuda() - - gazs_mask_f[0, w + gaz_len - 1] = 0 - sen_gazs_mask_f = torch.cat([sen_gazs_mask_f, gazs_mask_f], 0) - - gazs_mask_b[0, w] = 0 - sen_gazs_mask_b = torch.cat([sen_gazs_mask_b, gazs_mask_b], 0) - - for index in range(gaz_len): - nodes_mask[0, w + index] = 0 - if gaz_len == 1: - bmes_embed[0, w + index, :] = bmes_emb_s - elif index == 0: - bmes_embed[0, w + index, :] = bmes_emb_b - elif index == gaz_len - 1: - bmes_embed[0, w + index, :] = bmes_emb_e - else: - bmes_embed[0, w + index, :] = bmes_emb_m - - sen_nodes_mask = torch.cat([sen_nodes_mask, nodes_mask], 0) - sen_bmes_embed = torch.cat([sen_bmes_embed, bmes_embed], 0) - - #sen_gazs_mask_f[0, (1-sen_gazs_mask_f).sum(dim=0) == 0] = 0 - #sen_gazs_mask_b[0, (1-sen_gazs_mask_b).sum(dim=0) == 0] = 0 - - batch_gaz_embed = sen_gaz_embed.unsqueeze(0) # 只有在batch_size=1时可以这么做 - batch_nodes_mask = sen_nodes_mask.unsqueeze(0) - batch_bmes_embed = sen_bmes_embed.unsqueeze(0) - batch_gazs_mask_f = sen_gazs_mask_f.unsqueeze(0) - batch_gazs_mask_b = sen_gazs_mask_b.unsqueeze(0) - batch_gazs_length = sen_gazs_length.unsqueeze(0) - return batch_gaz_embed, batch_bmes_embed, batch_nodes_mask, batch_gazs_mask_f, batch_gazs_mask_b, batch_gazs_length - - def get_tags(self, gaz_list, word_inputs, mask): - - #mask = 1 - mask - node_embeds = self.word_embedding(word_inputs) # batch_size, max_seq_len, embedding - B, L, H = node_embeds.size() - gaz_match = [] - - edge_embs, bmes_embs, nodes_mask, gazs_mask_f, gazs_mask_b, gazs_length = self.obtain_gaz_relation(B, L, gaz_list) - _, N, _ = edge_embs.size() - #smask = torch.cat([torch.zeros(B, 1, ).byte().to(mask), mask], 1) - - #P = self.pos_embedding(torch.arange(L, dtype=torch.long, device=node_embeds.device).view(1, L)) - #node_embeds = node_embeds + P - - node_embeds = self.dropout(node_embeds) - edge_embs = self.dropout(edge_embs) - - #nodes_f = node_embeds - edges_f = edge_embs - nodes_f, _ = self.emb_rnn_f(node_embeds) - - glo_f = node_embeds.mean(1, keepdim=True) + edge_embs.mean(1, keepdim=True) - nodes_f_cat = nodes_f[:, None, :, :] - edges_f_cat = edges_f[:, None, :, :] - glo_f_cat = glo_f[:, None, :, :] - #ex_mask = mask[:, None, :, None].expand(B, H, L, 1) - - for i in range(self.iters): - - if N > 1: - bmes_nodes_f = torch.cat([nodes_f.unsqueeze(2).expand(B, L, N, H), bmes_embs.transpose(1, 2)], -1) - edges_att_f = self.node2edge_f[i](edges_f, bmes_nodes_f, nodes_mask.transpose(1, 2)) - - nodes_begin_f = torch.sum(nodes_f[:, None, :, :] * (1 - gazs_mask_b)[:, :, :, None].float(), 2) - nodes_begin_f = torch.cat([torch.zeros([B, 1, H], device=nodes_f.device), nodes_begin_f[:, 1:N, :]], 1) - nodes_att_f = self.edge2node_f[i](nodes_f, torch.cat([edges_f, nodes_begin_f, gazs_length], -1).unsqueeze(2), gazs_mask_f) - - glo_att_f = torch.cat([self.glo_att_f_node[i](glo_f, nodes_f), self.glo_att_f_edge[i](glo_f, edges_f)], -1) - - if N > 1: - edges_f = torch.cat([edges_f[:, 0:1, :], self.edge_rnn_f(edges_f[:, 1:N, :], - edges_att_f[:, 1:N, :], glo_att_f.expand(B, N-1, H*2))], 1) - edges_f_cat = torch.cat([edges_f_cat, edges_f[:, None, :, :]], 1) - edges_f = torch.cat([edges_f[:, 0:1, :], self.norm(torch.sum(edges_f_cat[:, :, 1:N, :], 1))], 1) - - nodes_f_r = torch.cat([torch.zeros([B, 1, self.hidden_dim], device=nodes_f.device), nodes_f[:, 0:(L-1), :]], 1) - nodes_f = self.node_rnn_f(nodes_f, nodes_f_r, nodes_att_f, glo_att_f.expand(B, L, H*2)) - nodes_f_cat = torch.cat([nodes_f_cat, nodes_f[:, None, :, :]], 1) - nodes_f = self.norm(torch.sum(nodes_f_cat, 1)) - - glo_f = self.glo_rnn_f(glo_f, glo_att_f) - glo_f_cat = torch.cat([glo_f_cat, glo_f[:, None, :, :]], 1) - glo_f = self.norm(torch.sum(glo_f_cat, 1)) - #nodes = nodes.masked_fill_(ex_mask, 0) - - #nodes_b = node_embeds - edges_b = edge_embs - nodes_b, _ = self.emb_rnn_b(torch.flip(node_embeds, [1])) - nodes_b = torch.flip(nodes_b, [1]) - - glo_b = node_embeds.mean(1, keepdim=True) + edge_embs.mean(1, keepdim=True) - nodes_b_cat = nodes_b[:, None, :, :] - edges_b_cat = edges_b[:, None, :, :] - glo_b_cat = glo_b[:, None, :, :] - - for i in range(self.iters): - - if N > 1: - bmes_nodes_b = torch.cat([nodes_b.unsqueeze(2).expand(B, L, N, H), bmes_embs.transpose(1, 2)], -1) - edges_att_b = self.node2edge_b[i](edges_b, bmes_nodes_b, nodes_mask.transpose(1, 2)) - - nodes_begin_b = torch.sum(nodes_b[:, None, :, :] * (1 - gazs_mask_f)[:, :, :, None].float(), 2) - nodes_begin_b = torch.cat([torch.zeros([B, 1, H], device=nodes_b.device), nodes_begin_b[:, 1:N, :]], 1) - nodes_att_b = self.edge2node_b[i](nodes_b, - torch.cat([edges_b, nodes_begin_b, gazs_length], -1).unsqueeze(2), gazs_mask_b) - - glo_att_b = torch.cat([self.glo_att_b_node[i](glo_b, nodes_b), self.glo_att_b_edge[i](glo_b, edges_b)], -1) - - if N > 1: - edges_b = torch.cat([edges_b[:, 0:1, :], self.edge_rnn_b(edges_b[:, 1:N, :], edges_att_b[:, 1:N, :], - glo_att_b.expand(B, N-1, H*2))], 1) - edges_b_cat = torch.cat([edges_b_cat, edges_b[:, None, :, :]], 1) - edges_b = torch.cat([edges_b[:, 0:1, :], self.norm(torch.sum(edges_b_cat[:, :, 1:N, :], 1))], 1) - - nodes_b_r = torch.cat([nodes_b[:, 1:L, :], torch.zeros([B, 1, self.hidden_dim], device=nodes_b.device)], 1) - nodes_b = self.node_rnn_b(nodes_b, nodes_b_r, nodes_att_b, glo_att_b.expand(B, L, H*2)) - nodes_b_cat = torch.cat([nodes_b_cat, nodes_b[:, None, :, :]], 1) - nodes_b = self.norm(torch.sum(nodes_b_cat, 1)) - - glo_b = self.glo_rnn_b(glo_b, glo_att_b) - glo_b_cat = torch.cat([glo_b_cat, glo_b[:, None, :, :]], 1) - glo_b = self.norm(torch.sum(glo_b_cat, 1)) - - nodes_cat = torch.cat([nodes_f_cat, nodes_b_cat], -1) - layer_att = torch.sigmoid(self.layer_att_W(nodes_cat)) - layer_alpha = F.softmax(layer_att, 1) - nodes = torch.sum(layer_alpha * nodes_cat, 1) - - tags = self.hidden2tag(nodes) - - return tags, gaz_match - - def neg_log_likelihood_loss(self, gaz_list, word_inputs, word_seq_lengths, mask, batch_label): - - tags, _ = self.get_tags(gaz_list, word_inputs, mask) - - total_loss = self.criterion(tags.view(-1, self.label_size), batch_label.view(-1)) - tag_seq = tags.argmax(-1) - - return total_loss, tag_seq # (batch_size,) ,(b,seqlen?) - - def forward(self, gaz_list, word_inputs, word_seq_lengths, mask): - tags, gaz_match = self.get_tags(gaz_list, word_inputs, mask) - tag_seq = tags.argmax(-1) - return tag_seq, gaz_match diff --git a/model/transformer_no_glo.py b/model/transformer_no_glo.py deleted file mode 100644 index 5c10ceb..0000000 --- a/model/transformer_no_glo.py +++ /dev/null @@ -1,398 +0,0 @@ -# -*- coding: utf-8 -*- -import torch -import torch.nn as nn -import numpy as np -import torch.nn.functional as F -from model.crf import CRF - - -class MultiHeadAtt(nn.Module): - def __init__(self, nhid, keyhid, nhead=10, head_dim=10, dropout=0.1, if_g=False): - super(MultiHeadAtt, self).__init__() - - if if_g: - self.WQ = nn.Conv2d(nhid * 3, nhead * head_dim, 1) - else: - self.WQ = nn.Conv2d(nhid, nhead * head_dim, 1) - self.WK = nn.Conv2d(keyhid, nhead * head_dim, 1) - self.WV = nn.Conv2d(keyhid, nhead * head_dim, 1) - self.WO = nn.Conv2d(nhead * head_dim, nhid, 1) - - self.drop = nn.Dropout(dropout) - - self.norm = nn.LayerNorm(nhid) - - self.nhid, self.nhead, self.head_dim = nhid, nhead, head_dim - - def forward(self, query_h, value, mask, query_g=None): - - if not (query_g is None): - query = torch.cat([query_h, query_g], -1) - else: - query = query_h - query = query.permute(0, 2, 1)[:, :, :, None] - value = value.permute(0, 3, 1, 2) - - residual = query_h - nhid, nhead, head_dim = self.nhid, self.nhead, self.head_dim - - B, QL, H = query_h.shape - - _, _, VL, VD = value.shape # VD = 1 or VD = QL - - assert VD == 1 or VD == QL - # q: (B, H, QL, 1) - # v: (B, H, VL, VD) - q, k, v = self.WQ(query), self.WK(value), self.WV(value) - - q = q.view(B, nhead, head_dim, 1, QL) - k = k.view(B, nhead, head_dim, VL, VD) - v = v.view(B, nhead, head_dim, VL, VD) - - alpha = (q * k).sum(2, keepdim=True) / np.sqrt(head_dim) - alpha = alpha.masked_fill(mask[:, None, None, :, :], -np.inf) - alpha = self.drop(F.softmax(alpha, 3)) - att = (alpha * v).sum(3).view(B, nhead * head_dim, QL, 1) - - output = F.leaky_relu(self.WO(att)).permute(0, 2, 3, 1).view(B, QL, H) - output = self.norm(output + residual) - - return output - - -class Nodes_Cell(nn.Module): - def __init__(self, hid_h, dropout=0.2): - super(Nodes_Cell, self).__init__() - - self.Wix = nn.Linear(hid_h*3, hid_h) - #self.Wig = nn.Linear(hid_h*4, hid_h) - self.Wi2 = nn.Linear(hid_h*3, hid_h) - self.Wf = nn.Linear(hid_h*3, hid_h) - self.Wcx = nn.Linear(hid_h*3, hid_h) - #self.Wcg = nn.Linear(hid_h, hid_h) - - self.drop = nn.Dropout(dropout) - - def forward(self, h, h2, x): - - x = self.drop(x) - - cat_all = torch.cat([h, h2, x], -1) - #cat_x = torch.cat([h, h2, x], -1) - #cat_g = torch.cat([glo], -1) - - ix = torch.sigmoid(self.Wix(cat_all)) - #ig = torch.sigmoid(self.Wig(cat_all)) - i2 = torch.sigmoid(self.Wi2(cat_all)) - f = torch.sigmoid(self.Wf(cat_all)) - cx = torch.tanh(self.Wcx(cat_all)) - #cg = torch.tanh(self.Wcg(cat_g)) - - alpha = F.softmax(torch.cat([ix.unsqueeze(1), i2.unsqueeze(1), f.unsqueeze(1)], 1), 1) - output = (alpha[:, 0] * cx) + (alpha[:, 1] * h2) + (alpha[:, 2] * h) - - return output - - -class Gazs_Cell(nn.Module): - def __init__(self, hid_h, dropout=0.2): - super(Gazs_Cell, self).__init__() - - self.Wi = nn.Linear(hid_h*2, hid_h) - self.Wf = nn.Linear(hid_h*2, hid_h) - self.Wc = nn.Linear(hid_h*2, hid_h) - - self.drop = nn.Dropout(dropout) - - def forward(self, h, x): - - x = self.drop(x) - - cat_all = torch.cat([h, x], -1) - i = torch.sigmoid(self.Wi(cat_all)) - f = torch.sigmoid(self.Wf(cat_all)) - c = torch.tanh(self.Wc(cat_all)) - - alpha = F.softmax(torch.cat([i.unsqueeze(1), f.unsqueeze(1)], 1), 1) - output = (alpha[:, 0] * c) + (alpha[:, 1] * h) - - return output - -class Graph(nn.Module): - def __init__(self, data): - super(Graph, self).__init__() - - self.gpu = data.HP_gpu - self.word_alphabet = data.word_alphabet - self.word_emb_dim = data.word_emb_dim - self.gaz_emb_dim = data.gaz_emb_dim - self.hidden_dim = 50 - self.num_head = 10 # 5 10 20 - self.head_dim = 20 # 10 20 - self.tf_dropout_rate = 0.1 - self.iters = 4 - self.bmes_dim = 10 - self.length_dim = 10 - self.max_gaz_length = 5 - self.emb_dropout_rate = 0.5 - self.cell_dropout_rate = 0.2 - - # word embedding - self.word_embedding = nn.Embedding(data.word_alphabet.size(), self.word_emb_dim) - assert data.pretrain_word_embedding is not None - self.word_embedding.weight.data.copy_(torch.from_numpy(data.pretrain_word_embedding)) - - # gaz embedding - self.gaz_embedding = nn.Embedding(data.gaz_alphabet.size(), self.gaz_emb_dim) - assert data.pretrain_gaz_embedding is not None - scale = np.sqrt(3.0 / self.gaz_emb_dim) - data.pretrain_gaz_embedding[0, :] = np.random.uniform(-scale, scale, [1, self.gaz_emb_dim]) - self.gaz_embedding.weight.data.copy_(torch.from_numpy(data.pretrain_gaz_embedding)) - - # lstm - self.emb_rnn_f = nn.LSTM(self.hidden_dim, self.hidden_dim, batch_first=True) - self.emb_rnn_b = nn.LSTM(self.hidden_dim, self.hidden_dim, batch_first=True) - - # bmes embedding - self.bmes_embedding = nn.Embedding(4, self.bmes_dim) - - # length embedding - self.length_embedding = nn.Embedding(self.max_gaz_length, self.length_dim) - - self.dropout = nn.Dropout(self.emb_dropout_rate) - self.norm = nn.LayerNorm(self.hidden_dim) - - self.edge2node_f = nn.ModuleList( - [MultiHeadAtt(self.hidden_dim, self.hidden_dim*2+self.length_dim, - nhead=self.num_head, head_dim=self.head_dim, dropout=self.tf_dropout_rate) - for _ in range(self.iters)]) - self.node2edge_f = nn.ModuleList( - [MultiHeadAtt(self.hidden_dim, self.hidden_dim+self.bmes_dim, nhead=self.num_head, head_dim=self.head_dim, dropout=self.tf_dropout_rate) - for _ in range(self.iters)]) - - self.edge_rnn_f = Gazs_Cell(self.hidden_dim, dropout=self.cell_dropout_rate) - self.node_rnn_f = Nodes_Cell(self.hidden_dim, dropout=self.cell_dropout_rate) - - self.edge2node_b = nn.ModuleList( - [MultiHeadAtt(self.hidden_dim, self.hidden_dim*2+self.length_dim, - nhead=self.num_head, head_dim=self.head_dim, dropout=self.tf_dropout_rate) - for _ in range(self.iters)]) - self.node2edge_b = nn.ModuleList( - [MultiHeadAtt(self.hidden_dim, self.hidden_dim+self.bmes_dim, nhead=self.num_head, head_dim=self.head_dim, dropout=self.tf_dropout_rate) - for _ in range(self.iters)]) - - self.edge_rnn_b = Gazs_Cell(self.hidden_dim, self.cell_dropout_rate) - self.node_rnn_b = Nodes_Cell(self.hidden_dim, self.cell_dropout_rate) - - self.layer_att_W = nn.Linear(self.hidden_dim * 2, 1) - self.hidden2tag = nn.Linear(self.hidden_dim * 2, data.label_alphabet_size + 2) - self.crf = CRF(data.label_alphabet_size, self.gpu) - - if self.gpu: - self.word_embedding = self.word_embedding.cuda() - self.gaz_embedding = self.gaz_embedding.cuda() - self.bmes_embedding = self.bmes_embedding.cuda() - self.length_embedding = self.length_embedding.cuda() - self.norm = self.norm.cuda() - self.edge2node_f = self.edge2node_f.cuda() - self.node2edge_f = self.node2edge_f.cuda() - self.edge_rnn_f = self.edge_rnn_f.cuda() - self.node_rnn_f = self.node_rnn_f.cuda() - self.edge2node_b = self.edge2node_b.cuda() - self.node2edge_b = self.node2edge_b.cuda() - self.edge_rnn_b = self.edge_rnn_b.cuda() - self.node_rnn_b = self.node_rnn_b.cuda() - #self.pos_embedding = self.pos_embedding.cuda() - self.emb_rnn_f = self.emb_rnn_f.cuda() - self.emb_rnn_b = self.emb_rnn_b.cuda() - self.layer_att_W = self.layer_att_W.cuda() - self.hidden2tag = self.hidden2tag.cuda() - self.crf = self.crf.cuda() - - def obtain_gaz_relation(self, batch_size, seq_len, gaz_list): - - assert batch_size == 1 - - unk_index = torch.tensor(0).cuda() if self.cuda else torch.tensor(0) - unk_emb = self.gaz_embedding(unk_index) - - bmes_index_b = torch.tensor(0).cuda() if self.cuda else torch.tensor(0) - bmes_index_m = torch.tensor(1).cuda() if self.cuda else torch.tensor(1) - bmes_index_e = torch.tensor(2).cuda() if self.cuda else torch.tensor(2) - bmes_index_s = torch.tensor(3).cuda() if self.cuda else torch.tensor(3) - - bmes_emb_b = self.bmes_embedding(bmes_index_b) - bmes_emb_m = self.bmes_embedding(bmes_index_m) - bmes_emb_e = self.bmes_embedding(bmes_index_e) - bmes_emb_s = self.bmes_embedding(bmes_index_s) - - for sen in range(batch_size): - sen_gaz_embed = unk_emb[None, :] - sen_nodes_mask = torch.zeros([1, seq_len]).byte() - sen_gazs_length = torch.zeros([1, self.length_dim]) - sen_bmes_embed = torch.zeros([1, seq_len, self.bmes_dim]) - sen_gazs_mask_f = torch.zeros([1, seq_len]).byte() - sen_gazs_mask_b = torch.zeros([1, seq_len]).byte() - if self.cuda: - sen_gaz_embed = sen_gaz_embed.cuda() - sen_nodes_mask = sen_nodes_mask.cuda() - sen_gazs_length = sen_gazs_length.cuda() - sen_bmes_embed = sen_bmes_embed.cuda() - sen_gazs_mask_f = sen_gazs_mask_f.cuda() - sen_gazs_mask_b = sen_gazs_mask_b.cuda() - - for w in range(seq_len): - if w < len(gaz_list[sen]) and gaz_list[sen][w]: - for gaz, gaz_len in zip(gaz_list[sen][w][0], gaz_list[sen][w][1]): - - gaz_index = torch.tensor(gaz, device=sen_gaz_embed.device) - gaz_embedding = self.gaz_embedding(gaz_index) - sen_gaz_embed = torch.cat([sen_gaz_embed, gaz_embedding[None, :]], 0) - - if gaz_len <= self.max_gaz_length: - gaz_length_index = torch.tensor(gaz_len-1, device=sen_gazs_length.device) - else: - gaz_length_index = torch.tensor(self.max_gaz_length-1, device=sen_gazs_length.device) - gaz_length = self.length_embedding(gaz_length_index) - sen_gazs_length = torch.cat([sen_gazs_length, gaz_length[None, :]], 0) - - # mask: 需要mask的地方置为1, batch_size * gaz_num * seq_len - nodes_mask = torch.ones([1, seq_len]).byte() - bmes_embed = torch.zeros([1, seq_len, self.bmes_dim]) - gazs_mask_f = torch.ones([1, seq_len]).byte() - gazs_mask_b = torch.ones([1, seq_len]).byte() - if self.cuda: - nodes_mask = nodes_mask.cuda() - bmes_embed = bmes_embed.cuda() - gazs_mask_f = gazs_mask_f.cuda() - gazs_mask_b = gazs_mask_b.cuda() - - gazs_mask_f[0, w + gaz_len - 1] = 0 - sen_gazs_mask_f = torch.cat([sen_gazs_mask_f, gazs_mask_f], 0) - - gazs_mask_b[0, w] = 0 - sen_gazs_mask_b = torch.cat([sen_gazs_mask_b, gazs_mask_b], 0) - - for index in range(gaz_len): - nodes_mask[0, w + index] = 0 - if gaz_len == 1: - bmes_embed[0, w + index, :] = bmes_emb_s - elif index == 0: - bmes_embed[0, w + index, :] = bmes_emb_b - elif index == gaz_len - 1: - bmes_embed[0, w + index, :] = bmes_emb_e - else: - bmes_embed[0, w + index, :] = bmes_emb_m - - sen_nodes_mask = torch.cat([sen_nodes_mask, nodes_mask], 0) - sen_bmes_embed = torch.cat([sen_bmes_embed, bmes_embed], 0) - - #sen_gazs_mask_f[0, (1-sen_gazs_mask_f).sum(dim=0) == 0] = 0 - #sen_gazs_mask_b[0, (1-sen_gazs_mask_b).sum(dim=0) == 0] = 0 - - batch_gaz_embed = sen_gaz_embed.unsqueeze(0) # 只有在batch_size=1时可以这么做 - batch_nodes_mask = sen_nodes_mask.unsqueeze(0) - batch_bmes_embed = sen_bmes_embed.unsqueeze(0) - batch_gazs_mask_f = sen_gazs_mask_f.unsqueeze(0) - batch_gazs_mask_b = sen_gazs_mask_b.unsqueeze(0) - batch_gazs_length = sen_gazs_length.unsqueeze(0) - return batch_gaz_embed, batch_bmes_embed, batch_nodes_mask, batch_gazs_mask_f, batch_gazs_mask_b, batch_gazs_length - - def get_tags(self, gaz_list, word_inputs, mask): - - #mask = 1 - mask - node_embeds = self.word_embedding(word_inputs) # batch_size, max_seq_len, embedding - B, L, H = node_embeds.size() - gaz_match = [] - - edge_embs, bmes_embs, nodes_mask, gazs_mask_f, gazs_mask_b, gazs_length = self.obtain_gaz_relation(B, L, gaz_list) - _, N, _ = edge_embs.size() - #smask = torch.cat([torch.zeros(B, 1, ).byte().to(mask), mask], 1) - - #P = self.pos_embedding(torch.arange(L, dtype=torch.long, device=node_embeds.device).view(1, L)) - #node_embeds = node_embeds + P - - node_embeds = self.dropout(node_embeds) - edge_embs = self.dropout(edge_embs) - - #nodes_f = node_embeds - edges_f = edge_embs - nodes_f, _ = self.emb_rnn_f(node_embeds) - - nodes_f_cat = nodes_f[:, None, :, :] - edges_f_cat = edges_f[:, None, :, :] - #ex_mask = mask[:, None, :, None].expand(B, H, L, 1) - - for i in range(self.iters): - - if N > 1: - bmes_nodes_f = torch.cat([nodes_f.unsqueeze(2).expand(B, L, N, H), bmes_embs.transpose(1, 2)], -1) - edges_att_f = self.node2edge_f[i](edges_f, bmes_nodes_f, nodes_mask.transpose(1, 2)) - - nodes_begin_f = torch.sum(nodes_f[:, None, :, :] * (1 - gazs_mask_b)[:, :, :, None].float(), 2) - nodes_begin_f = torch.cat([torch.zeros([B, 1, H], device=nodes_f.device), nodes_begin_f[:, 1:N, :]], 1) - nodes_att_f = self.edge2node_f[i](nodes_f, torch.cat([edges_f, nodes_begin_f, gazs_length], -1).unsqueeze(2), gazs_mask_f) - - if N > 1: - edges_f = torch.cat([edges_f[:, 0:1, :], self.edge_rnn_f(edges_f[:, 1:N, :], - edges_att_f[:, 1:N, :])], 1) - edges_f_cat = torch.cat([edges_f_cat, edges_f[:, None, :, :]], 1) - edges_f = torch.cat([edges_f[:, 0:1, :], self.norm(torch.sum(edges_f_cat[:, :, 1:N, :], 1))], 1) - - nodes_f_r = torch.cat([torch.zeros([B, 1, self.hidden_dim], device=nodes_f.device), nodes_f[:, 0:(L-1), :]], 1) - nodes_f = self.node_rnn_f(nodes_f, nodes_f_r, nodes_att_f) - nodes_f_cat = torch.cat([nodes_f_cat, nodes_f[:, None, :, :]], 1) - nodes_f = self.norm(torch.sum(nodes_f_cat, 1)) - - #nodes = nodes.masked_fill_(ex_mask, 0) - - #nodes_b = node_embeds - edges_b = edge_embs - nodes_b, _ = self.emb_rnn_b(torch.flip(node_embeds, [1])) - nodes_b = torch.flip(nodes_b, [1]) - - nodes_b_cat = nodes_b[:, None, :, :] - edges_b_cat = edges_b[:, None, :, :] - - for i in range(self.iters): - - if N > 1: - bmes_nodes_b = torch.cat([nodes_b.unsqueeze(2).expand(B, L, N, H), bmes_embs.transpose(1, 2)], -1) - edges_att_b = self.node2edge_b[i](edges_b, bmes_nodes_b, nodes_mask.transpose(1, 2)) - - nodes_begin_b = torch.sum(nodes_b[:, None, :, :] * (1 - gazs_mask_f)[:, :, :, None].float(), 2) - nodes_begin_b = torch.cat([torch.zeros([B, 1, H], device=nodes_b.device), nodes_begin_b[:, 1:N, :]], 1) - nodes_att_b = self.edge2node_b[i](nodes_b, - torch.cat([edges_b, nodes_begin_b, gazs_length], -1).unsqueeze(2), gazs_mask_b) - - if N > 1: - edges_b = torch.cat([edges_b[:, 0:1, :], self.edge_rnn_b(edges_b[:, 1:N, :], edges_att_b[:, 1:N, :])], 1) - edges_b_cat = torch.cat([edges_b_cat, edges_b[:, None, :, :]], 1) - edges_b = torch.cat([edges_b[:, 0:1, :], self.norm(torch.sum(edges_b_cat[:, :, 1:N, :], 1))], 1) - - nodes_b_r = torch.cat([nodes_b[:, 1:L, :], torch.zeros([B, 1, self.hidden_dim], device=nodes_b.device)], 1) - nodes_b = self.node_rnn_b(nodes_b, nodes_b_r, nodes_att_b) - nodes_b_cat = torch.cat([nodes_b_cat, nodes_b[:, None, :, :]], 1) - nodes_b = self.norm(torch.sum(nodes_b_cat, 1)) - - nodes_cat = torch.cat([nodes_f_cat, nodes_b_cat], -1) - layer_att = torch.sigmoid(self.layer_att_W(nodes_cat)) - layer_alpha = F.softmax(layer_att, 1) - nodes = torch.sum(layer_alpha * nodes_cat, 1) - - tags = self.hidden2tag(nodes) - - return tags, gaz_match - - def neg_log_likelihood_loss(self, gaz_list, word_inputs, word_seq_lengths, mask, batch_label): - - tags, _ = self.get_tags(gaz_list, word_inputs, mask) - total_loss = self.crf.neg_log_likelihood_loss(tags, mask, batch_label) - scores, tag_seq = self.crf._viterbi_decode(tags, mask) - - return total_loss, tag_seq # (batch_size,) ,(b,seqlen?) - - def forward(self, gaz_list, word_inputs, word_seq_lengths, mask): - tags, gaz_match = self.get_tags(gaz_list, word_inputs, mask) - scores, tag_seq = self.crf._viterbi_decode(tags, mask) - return tag_seq, gaz_match diff --git a/model/transformer_no_lstm.py b/model/transformer_no_lstm.py deleted file mode 100644 index b15a29c..0000000 --- a/model/transformer_no_lstm.py +++ /dev/null @@ -1,510 +0,0 @@ -# -*- coding: utf-8 -*- -import torch -import torch.nn as nn -import numpy as np -import torch.nn.functional as F -from model.crf import CRF - - -class MultiHeadAtt(nn.Module): - def __init__(self, nhid, keyhid, nhead=10, head_dim=10, dropout=0.1, if_g=False): - super(MultiHeadAtt, self).__init__() - - if if_g: - self.WQ = nn.Conv2d(nhid * 3, nhead * head_dim, 1) - else: - self.WQ = nn.Conv2d(nhid, nhead * head_dim, 1) - self.WK = nn.Conv2d(keyhid, nhead * head_dim, 1) - self.WV = nn.Conv2d(keyhid, nhead * head_dim, 1) - self.WO = nn.Conv2d(nhead * head_dim, nhid, 1) - - self.drop = nn.Dropout(dropout) - - self.norm = nn.LayerNorm(nhid) - - self.nhid, self.nhead, self.head_dim = nhid, nhead, head_dim - - def forward(self, query_h, value, mask, query_g=None): - - if not (query_g is None): - query = torch.cat([query_h, query_g], -1) - else: - query = query_h - query = query.permute(0, 2, 1)[:, :, :, None] - value = value.permute(0, 3, 1, 2) - - residual = query_h - nhid, nhead, head_dim = self.nhid, self.nhead, self.head_dim - - B, QL, H = query_h.shape - - _, _, VL, VD = value.shape # VD = 1 or VD = QL - - assert VD == 1 or VD == QL - # q: (B, H, QL, 1) - # v: (B, H, VL, VD) - q, k, v = self.WQ(query), self.WK(value), self.WV(value) - - q = q.view(B, nhead, head_dim, 1, QL) - k = k.view(B, nhead, head_dim, VL, VD) - v = v.view(B, nhead, head_dim, VL, VD) - - alpha = (q * k).sum(2, keepdim=True) / np.sqrt(head_dim) - alpha = alpha.masked_fill(mask[:, None, None, :, :], -np.inf) - alpha = self.drop(F.softmax(alpha, 3)) - att = (alpha * v).sum(3).view(B, nhead * head_dim, QL, 1) - - output = F.leaky_relu(self.WO(att)).permute(0, 2, 3, 1).view(B, QL, H) - output = self.norm(output + residual) - - return output - - -class GloAtt(nn.Module): - def __init__(self, nhid, nhead=10, head_dim=10, dropout=0.1): - # Multi-head Self Attention Case 2, a broadcastable query for a sequence key and value - super(GloAtt, self).__init__() - self.WQ = nn.Conv2d(nhid, nhead * head_dim, 1) - self.WK = nn.Conv2d(nhid, nhead * head_dim, 1) - self.WV = nn.Conv2d(nhid, nhead * head_dim, 1) - self.WO = nn.Conv2d(nhead * head_dim, nhid, 1) - - self.drop = nn.Dropout(dropout) - - self.norm = nn.LayerNorm(nhid) - - # print('NUM_HEAD', nhead, 'DIM_HEAD', head_dim) - self.nhid, self.nhead, self.head_dim = nhid, nhead, head_dim - - def forward(self, x, y, mask=None): - # x: B, H, 1, 1, 1 y: B H L 1 - nhid, nhead, head_dim = self.nhid, self.nhead, self.head_dim - B, L, H = y.shape - - x = x.permute(0, 2, 1)[:, :, :, None] - y = y.permute(0, 2, 1)[:, :, :, None] - - residual = x - q, k, v = self.WQ(x), self.WK(y), self.WV(y) - - q = q.view(B, nhead, 1, head_dim) # B, H, 1, 1 -> B, N, 1, h - k = k.view(B, nhead, head_dim, L) # B, H, L, 1 -> B, N, h, L - v = v.view(B, nhead, head_dim, L).permute(0, 1, 3, 2) # B, H, L, 1 -> B, N, L, h - - pre_a = torch.matmul(q, k) / np.sqrt(head_dim) - if mask is not None: - pre_a = pre_a.masked_fill(mask[:, None, None, :], -float('inf')) - alphas = self.drop(F.softmax(pre_a, 3)) # B, N, 1, L - att = torch.matmul(alphas, v).view(B, -1, 1, 1) # B, N, 1, h -> B, N*h, 1, 1 - output = F.leaky_relu(self.WO(att)) + residual - output = self.norm(output.permute(0, 2, 3, 1)).view(B, 1, H) - - return output - - -class Nodes_Cell(nn.Module): - def __init__(self, hid_h, dropout=0.2): - super(Nodes_Cell, self).__init__() - - self.Wix = nn.Linear(hid_h*5, hid_h) - #self.Wig = nn.Linear(hid_h*4, hid_h) - self.Wi2 = nn.Linear(hid_h*5, hid_h) - self.Wf = nn.Linear(hid_h*5, hid_h) - self.Wcx = nn.Linear(hid_h*5, hid_h) - #self.Wcg = nn.Linear(hid_h, hid_h) - - self.drop = nn.Dropout(dropout) - - def forward(self, h, h2, x, glo): - - x = self.drop(x) - glo = self.drop(glo) - - cat_all = torch.cat([h, h2, x, glo], -1) - #cat_x = torch.cat([h, h2, x], -1) - #cat_g = torch.cat([glo], -1) - - ix = torch.sigmoid(self.Wix(cat_all)) - #ig = torch.sigmoid(self.Wig(cat_all)) - i2 = torch.sigmoid(self.Wi2(cat_all)) - f = torch.sigmoid(self.Wf(cat_all)) - cx = torch.tanh(self.Wcx(cat_all)) - #cg = torch.tanh(self.Wcg(cat_g)) - - alpha = F.softmax(torch.cat([ix.unsqueeze(1), i2.unsqueeze(1), f.unsqueeze(1)], 1), 1) - output = (alpha[:, 0] * cx) + (alpha[:, 1] * h2) + (alpha[:, 2] * h) - - return output - - -class Gazs_Cell(nn.Module): - def __init__(self, hid_h, dropout=0.2): - super(Gazs_Cell, self).__init__() - - self.Wi = nn.Linear(hid_h*4, hid_h) - self.Wf = nn.Linear(hid_h*4, hid_h) - self.Wc = nn.Linear(hid_h*4, hid_h) - - self.drop = nn.Dropout(dropout) - - def forward(self, h, x, glo): - - x = self.drop(x) - glo = self.drop(glo) - - cat_all = torch.cat([h, x, glo], -1) - i = torch.sigmoid(self.Wi(cat_all)) - f = torch.sigmoid(self.Wf(cat_all)) - c = torch.tanh(self.Wc(cat_all)) - - alpha = F.softmax(torch.cat([i.unsqueeze(1), f.unsqueeze(1)], 1), 1) - output = (alpha[:, 0] * c) + (alpha[:, 1] * h) - - return output - - -class GLobal_Cell(nn.Module): - def __init__(self, hid_h, dropout=0.2): - super(GLobal_Cell, self).__init__() - - self.Wi = nn.Linear(hid_h*3, hid_h) - self.Wf = nn.Linear(hid_h*3, hid_h) - self.Wc = nn.Linear(hid_h*3, hid_h) - - self.drop = nn.Dropout(dropout) - - def forward(self, h, x): - - x = self.drop(x) - - cat_all = torch.cat([h, x], -1) - i = torch.sigmoid(self.Wi(cat_all)) - f = torch.sigmoid(self.Wf(cat_all)) - c = torch.tanh(self.Wc(cat_all)) - - alpha = F.softmax(torch.cat([i.unsqueeze(1), f.unsqueeze(1)], 1), 1) - output = (alpha[:, 0] * c) + (alpha[:, 1] * h) - - return output - - -class Graph(nn.Module): - def __init__(self, data): - super(Graph, self).__init__() - - self.gpu = data.HP_gpu - self.word_alphabet = data.word_alphabet - self.word_emb_dim = data.word_emb_dim - self.gaz_emb_dim = data.gaz_emb_dim - self.hidden_dim = 50 - self.num_head = 10 # 5 10 20 - self.head_dim = 20 # 10 20 - self.tf_dropout_rate = 0.1 - self.iters = 4 - self.bmes_dim = 10 - self.length_dim = 10 - self.max_gaz_length = 5 - self.emb_dropout_rate = 0.5 - self.cell_dropout_rate = 0.2 - - # word embedding - self.word_embedding = nn.Embedding(data.word_alphabet.size(), self.word_emb_dim) - assert data.pretrain_word_embedding is not None - self.word_embedding.weight.data.copy_(torch.from_numpy(data.pretrain_word_embedding)) - - # gaz embedding - self.gaz_embedding = nn.Embedding(data.gaz_alphabet.size(), self.gaz_emb_dim) - assert data.pretrain_gaz_embedding is not None - scale = np.sqrt(3.0 / self.gaz_emb_dim) - data.pretrain_gaz_embedding[0, :] = np.random.uniform(-scale, scale, [1, self.gaz_emb_dim]) - self.gaz_embedding.weight.data.copy_(torch.from_numpy(data.pretrain_gaz_embedding)) - - # position embedding - self.pos_embedding = nn.Embedding(data.posi_alphabet_size, self.hidden_dim) - # lstm - #self.emb_rnn_f = nn.LSTM(self.hidden_dim, self.hidden_dim, batch_first=True) - #self.emb_rnn_b = nn.LSTM(self.hidden_dim, self.hidden_dim, batch_first=True) - - # bmes embedding - self.bmes_embedding = nn.Embedding(4, self.bmes_dim) - - # length embedding - self.length_embedding = nn.Embedding(self.max_gaz_length, self.length_dim) - - self.dropout = nn.Dropout(self.emb_dropout_rate) - self.norm = nn.LayerNorm(self.hidden_dim) - - self.edge2node_f = nn.ModuleList( - [MultiHeadAtt(self.hidden_dim, self.hidden_dim*2+self.length_dim, - nhead=self.num_head, head_dim=self.head_dim, dropout=self.tf_dropout_rate) - for _ in range(self.iters)]) - self.node2edge_f = nn.ModuleList( - [MultiHeadAtt(self.hidden_dim, self.hidden_dim+self.bmes_dim, nhead=self.num_head, head_dim=self.head_dim, dropout=self.tf_dropout_rate) - for _ in range(self.iters)]) - - self.glo_att_f_node = nn.ModuleList( - [GloAtt(self.hidden_dim, nhead=self.num_head, head_dim=self.head_dim, dropout=self.tf_dropout_rate) - for _ in range(self.iters)]) - - self.glo_att_f_edge = nn.ModuleList( - [GloAtt(self.hidden_dim, nhead=self.num_head, head_dim=self.head_dim, dropout=self.tf_dropout_rate) - for _ in range(self.iters)]) - - self.edge_rnn_f = Gazs_Cell(self.hidden_dim, dropout=self.cell_dropout_rate) - self.node_rnn_f = Nodes_Cell(self.hidden_dim, dropout=self.cell_dropout_rate) - self.glo_rnn_f = GLobal_Cell(self.hidden_dim, dropout=self.cell_dropout_rate) - - self.edge2node_b = nn.ModuleList( - [MultiHeadAtt(self.hidden_dim, self.hidden_dim*2+self.length_dim, - nhead=self.num_head, head_dim=self.head_dim, dropout=self.tf_dropout_rate) - for _ in range(self.iters)]) - self.node2edge_b = nn.ModuleList( - [MultiHeadAtt(self.hidden_dim, self.hidden_dim+self.bmes_dim, nhead=self.num_head, head_dim=self.head_dim, dropout=self.tf_dropout_rate) - for _ in range(self.iters)]) - - self.glo_att_b_node = nn.ModuleList( - [GloAtt(self.hidden_dim, nhead=self.num_head, head_dim=self.head_dim, dropout=self.tf_dropout_rate) - for _ in range(self.iters)]) - - self.glo_att_b_edge = nn.ModuleList( - [GloAtt(self.hidden_dim, nhead=self.num_head, head_dim=self.head_dim, dropout=self.tf_dropout_rate) - for _ in range(self.iters)]) - - self.edge_rnn_b = Gazs_Cell(self.hidden_dim, self.cell_dropout_rate) - self.node_rnn_b = Nodes_Cell(self.hidden_dim, self.cell_dropout_rate) - self.glo_rnn_b = GLobal_Cell(self.hidden_dim, self.cell_dropout_rate) - - self.layer_att_W = nn.Linear(self.hidden_dim * 2, 1) - self.hidden2tag = nn.Linear(self.hidden_dim * 2, data.label_alphabet_size + 2) - self.crf = CRF(data.label_alphabet_size, self.gpu) - - if self.gpu: - self.word_embedding = self.word_embedding.cuda() - self.gaz_embedding = self.gaz_embedding.cuda() - self.bmes_embedding = self.bmes_embedding.cuda() - self.length_embedding = self.length_embedding.cuda() - self.norm = self.norm.cuda() - self.edge2node_f = self.edge2node_f.cuda() - self.node2edge_f = self.node2edge_f.cuda() - self.edge_rnn_f = self.edge_rnn_f.cuda() - self.node_rnn_f = self.node_rnn_f.cuda() - self.glo_rnn_f = self.glo_rnn_f.cuda() - self.glo_att_f_node = self.glo_att_f_node.cuda() - self.glo_att_f_edge = self.glo_att_f_edge.cuda() - self.edge2node_b = self.edge2node_b.cuda() - self.node2edge_b = self.node2edge_b.cuda() - self.edge_rnn_b = self.edge_rnn_b.cuda() - self.node_rnn_b = self.node_rnn_b.cuda() - self.glo_rnn_b = self.glo_rnn_b.cuda() - self.glo_att_b_node = self.glo_att_b_node.cuda() - self.glo_att_b_edge = self.glo_att_b_edge.cuda() - self.pos_embedding = self.pos_embedding.cuda() - #self.emb_rnn_f = self.emb_rnn_f.cuda() - #self.emb_rnn_b = self.emb_rnn_b.cuda() - self.layer_att_W = self.layer_att_W.cuda() - self.hidden2tag = self.hidden2tag.cuda() - self.crf = self.crf.cuda() - - def obtain_gaz_relation(self, batch_size, seq_len, gaz_list): - - assert batch_size == 1 - - unk_index = torch.tensor(0).cuda() if self.cuda else torch.tensor(0) - unk_emb = self.gaz_embedding(unk_index) - - bmes_index_b = torch.tensor(0).cuda() if self.cuda else torch.tensor(0) - bmes_index_m = torch.tensor(1).cuda() if self.cuda else torch.tensor(1) - bmes_index_e = torch.tensor(2).cuda() if self.cuda else torch.tensor(2) - bmes_index_s = torch.tensor(3).cuda() if self.cuda else torch.tensor(3) - - bmes_emb_b = self.bmes_embedding(bmes_index_b) - bmes_emb_m = self.bmes_embedding(bmes_index_m) - bmes_emb_e = self.bmes_embedding(bmes_index_e) - bmes_emb_s = self.bmes_embedding(bmes_index_s) - - for sen in range(batch_size): - sen_gaz_embed = unk_emb[None, :] - sen_nodes_mask = torch.zeros([1, seq_len]).byte() - sen_gazs_length = torch.zeros([1, self.length_dim]) - sen_bmes_embed = torch.zeros([1, seq_len, self.bmes_dim]) - sen_gazs_mask_f = torch.zeros([1, seq_len]).byte() - sen_gazs_mask_b = torch.zeros([1, seq_len]).byte() - if self.cuda: - sen_gaz_embed = sen_gaz_embed.cuda() - sen_nodes_mask = sen_nodes_mask.cuda() - sen_gazs_length = sen_gazs_length.cuda() - sen_bmes_embed = sen_bmes_embed.cuda() - sen_gazs_mask_f = sen_gazs_mask_f.cuda() - sen_gazs_mask_b = sen_gazs_mask_b.cuda() - - for w in range(seq_len): - if w < len(gaz_list[sen]) and gaz_list[sen][w]: - for gaz, gaz_len in zip(gaz_list[sen][w][0], gaz_list[sen][w][1]): - - gaz_index = torch.tensor(gaz, device=sen_gaz_embed.device) - gaz_embedding = self.gaz_embedding(gaz_index) - sen_gaz_embed = torch.cat([sen_gaz_embed, gaz_embedding[None, :]], 0) - - if gaz_len <= self.max_gaz_length: - gaz_length_index = torch.tensor(gaz_len-1, device=sen_gazs_length.device) - else: - gaz_length_index = torch.tensor(self.max_gaz_length-1, device=sen_gazs_length.device) - gaz_length = self.length_embedding(gaz_length_index) - sen_gazs_length = torch.cat([sen_gazs_length, gaz_length[None, :]], 0) - - # mask: 需要mask的地方置为1, batch_size * gaz_num * seq_len - nodes_mask = torch.ones([1, seq_len]).byte() - bmes_embed = torch.zeros([1, seq_len, self.bmes_dim]) - gazs_mask_f = torch.ones([1, seq_len]).byte() - gazs_mask_b = torch.ones([1, seq_len]).byte() - if self.cuda: - nodes_mask = nodes_mask.cuda() - bmes_embed = bmes_embed.cuda() - gazs_mask_f = gazs_mask_f.cuda() - gazs_mask_b = gazs_mask_b.cuda() - - gazs_mask_f[0, w + gaz_len - 1] = 0 - sen_gazs_mask_f = torch.cat([sen_gazs_mask_f, gazs_mask_f], 0) - - gazs_mask_b[0, w] = 0 - sen_gazs_mask_b = torch.cat([sen_gazs_mask_b, gazs_mask_b], 0) - - for index in range(gaz_len): - nodes_mask[0, w + index] = 0 - if gaz_len == 1: - bmes_embed[0, w + index, :] = bmes_emb_s - elif index == 0: - bmes_embed[0, w + index, :] = bmes_emb_b - elif index == gaz_len - 1: - bmes_embed[0, w + index, :] = bmes_emb_e - else: - bmes_embed[0, w + index, :] = bmes_emb_m - - sen_nodes_mask = torch.cat([sen_nodes_mask, nodes_mask], 0) - sen_bmes_embed = torch.cat([sen_bmes_embed, bmes_embed], 0) - - #sen_gazs_mask_f[0, (1-sen_gazs_mask_f).sum(dim=0) == 0] = 0 - #sen_gazs_mask_b[0, (1-sen_gazs_mask_b).sum(dim=0) == 0] = 0 - - batch_gaz_embed = sen_gaz_embed.unsqueeze(0) # 只有在batch_size=1时可以这么做 - batch_nodes_mask = sen_nodes_mask.unsqueeze(0) - batch_bmes_embed = sen_bmes_embed.unsqueeze(0) - batch_gazs_mask_f = sen_gazs_mask_f.unsqueeze(0) - batch_gazs_mask_b = sen_gazs_mask_b.unsqueeze(0) - batch_gazs_length = sen_gazs_length.unsqueeze(0) - return batch_gaz_embed, batch_bmes_embed, batch_nodes_mask, batch_gazs_mask_f, batch_gazs_mask_b, batch_gazs_length - - def get_tags(self, gaz_list, word_inputs, mask): - - #mask = 1 - mask - node_embeds = self.word_embedding(word_inputs) # batch_size, max_seq_len, embedding - B, L, H = node_embeds.size() - gaz_match = [] - - edge_embs, bmes_embs, nodes_mask, gazs_mask_f, gazs_mask_b, gazs_length = self.obtain_gaz_relation(B, L, gaz_list) - _, N, _ = edge_embs.size() - #smask = torch.cat([torch.zeros(B, 1, ).byte().to(mask), mask], 1) - - P = self.pos_embedding(torch.arange(L, dtype=torch.long, device=node_embeds.device).view(1, L)) - node_embeds = node_embeds + P - - node_embeds = self.dropout(node_embeds) - edge_embs = self.dropout(edge_embs) - - nodes_f = node_embeds - edges_f = edge_embs - #nodes_f, _ = self.emb_rnn_f(node_embeds) - - glo_f = node_embeds.mean(1, keepdim=True) + edge_embs.mean(1, keepdim=True) - nodes_f_cat = nodes_f[:, None, :, :] - edges_f_cat = edges_f[:, None, :, :] - glo_f_cat = glo_f[:, None, :, :] - #ex_mask = mask[:, None, :, None].expand(B, H, L, 1) - - for i in range(self.iters): - - if N > 1: - bmes_nodes_f = torch.cat([nodes_f.unsqueeze(2).expand(B, L, N, H), bmes_embs.transpose(1, 2)], -1) - edges_att_f = self.node2edge_f[i](edges_f, bmes_nodes_f, nodes_mask.transpose(1, 2)) - - nodes_begin_f = torch.sum(nodes_f[:, None, :, :] * (1 - gazs_mask_b)[:, :, :, None].float(), 2) - nodes_begin_f = torch.cat([torch.zeros([B, 1, H], device=nodes_f.device), nodes_begin_f[:, 1:N, :]], 1) - nodes_att_f = self.edge2node_f[i](nodes_f, torch.cat([edges_f, nodes_begin_f, gazs_length], -1).unsqueeze(2), gazs_mask_f) - - glo_att_f = torch.cat([self.glo_att_f_node[i](glo_f, nodes_f), self.glo_att_f_edge[i](glo_f, edges_f)], -1) - - if N > 1: - edges_f = torch.cat([edges_f[:, 0:1, :], self.edge_rnn_f(edges_f[:, 1:N, :], - edges_att_f[:, 1:N, :], glo_att_f.expand(B, N-1, H*2))], 1) - edges_f_cat = torch.cat([edges_f_cat, edges_f[:, None, :, :]], 1) - edges_f = torch.cat([edges_f[:, 0:1, :], self.norm(torch.sum(edges_f_cat[:, :, 1:N, :], 1))], 1) - - nodes_f_r = torch.cat([torch.zeros([B, 1, self.hidden_dim], device=nodes_f.device), nodes_f[:, 0:(L-1), :]], 1) - nodes_f = self.node_rnn_f(nodes_f, nodes_f_r, nodes_att_f, glo_att_f.expand(B, L, H*2)) - nodes_f_cat = torch.cat([nodes_f_cat, nodes_f[:, None, :, :]], 1) - nodes_f = self.norm(torch.sum(nodes_f_cat, 1)) - - glo_f = self.glo_rnn_f(glo_f, glo_att_f) - glo_f_cat = torch.cat([glo_f_cat, glo_f[:, None, :, :]], 1) - glo_f = self.norm(torch.sum(glo_f_cat, 1)) - #nodes = nodes.masked_fill_(ex_mask, 0) - - nodes_b = node_embeds - edges_b = edge_embs - #nodes_b, _ = self.emb_rnn_b(torch.flip(node_embeds, [1])) - #nodes_b = torch.flip(nodes_b, [1]) - - glo_b = node_embeds.mean(1, keepdim=True) + edge_embs.mean(1, keepdim=True) - nodes_b_cat = nodes_b[:, None, :, :] - edges_b_cat = edges_b[:, None, :, :] - glo_b_cat = glo_b[:, None, :, :] - - for i in range(self.iters): - - if N > 1: - bmes_nodes_b = torch.cat([nodes_b.unsqueeze(2).expand(B, L, N, H), bmes_embs.transpose(1, 2)], -1) - edges_att_b = self.node2edge_b[i](edges_b, bmes_nodes_b, nodes_mask.transpose(1, 2)) - - nodes_begin_b = torch.sum(nodes_b[:, None, :, :] * (1 - gazs_mask_f)[:, :, :, None].float(), 2) - nodes_begin_b = torch.cat([torch.zeros([B, 1, H], device=nodes_b.device), nodes_begin_b[:, 1:N, :]], 1) - nodes_att_b = self.edge2node_b[i](nodes_b, - torch.cat([edges_b, nodes_begin_b, gazs_length], -1).unsqueeze(2), gazs_mask_b) - - glo_att_b = torch.cat([self.glo_att_b_node[i](glo_b, nodes_b), self.glo_att_b_edge[i](glo_b, edges_b)], -1) - - if N > 1: - edges_b = torch.cat([edges_b[:, 0:1, :], self.edge_rnn_b(edges_b[:, 1:N, :], edges_att_b[:, 1:N, :], - glo_att_b.expand(B, N-1, H*2))], 1) - edges_b_cat = torch.cat([edges_b_cat, edges_b[:, None, :, :]], 1) - edges_b = torch.cat([edges_b[:, 0:1, :], self.norm(torch.sum(edges_b_cat[:, :, 1:N, :], 1))], 1) - - nodes_b_r = torch.cat([nodes_b[:, 1:L, :], torch.zeros([B, 1, self.hidden_dim], device=nodes_b.device)], 1) - nodes_b = self.node_rnn_b(nodes_b, nodes_b_r, nodes_att_b, glo_att_b.expand(B, L, H*2)) - nodes_b_cat = torch.cat([nodes_b_cat, nodes_b[:, None, :, :]], 1) - nodes_b = self.norm(torch.sum(nodes_b_cat, 1)) - - glo_b = self.glo_rnn_b(glo_b, glo_att_b) - glo_b_cat = torch.cat([glo_b_cat, glo_b[:, None, :, :]], 1) - glo_b = self.norm(torch.sum(glo_b_cat, 1)) - - nodes_cat = torch.cat([nodes_f_cat, nodes_b_cat], -1) - layer_att = torch.sigmoid(self.layer_att_W(nodes_cat)) - layer_alpha = F.softmax(layer_att, 1) - nodes = torch.sum(layer_alpha * nodes_cat, 1) - - tags = self.hidden2tag(nodes) - - return tags, gaz_match - - def neg_log_likelihood_loss(self, gaz_list, word_inputs, word_seq_lengths, mask, batch_label): - - tags, _ = self.get_tags(gaz_list, word_inputs, mask) - total_loss = self.crf.neg_log_likelihood_loss(tags, mask, batch_label) - scores, tag_seq = self.crf._viterbi_decode(tags, mask) - - return total_loss, tag_seq # (batch_size,) ,(b,seqlen?) - - def forward(self, gaz_list, word_inputs, word_seq_lengths, mask): - tags, gaz_match = self.get_tags(gaz_list, word_inputs, mask) - scores, tag_seq = self.crf._viterbi_decode(tags, mask) - return tag_seq, gaz_match diff --git a/model/transformer_only_sw.py b/model/transformer_only_sw.py deleted file mode 100644 index 87740ef..0000000 --- a/model/transformer_only_sw.py +++ /dev/null @@ -1,512 +0,0 @@ -# -*- coding: utf-8 -*- -import torch -import torch.nn as nn -import numpy as np -import torch.nn.functional as F -from model.crf import CRF - - -class MultiHeadAtt(nn.Module): - def __init__(self, nhid, keyhid, nhead=10, head_dim=10, dropout=0.1, if_g=False): - super(MultiHeadAtt, self).__init__() - - if if_g: - self.WQ = nn.Conv2d(nhid * 3, nhead * head_dim, 1) - else: - self.WQ = nn.Conv2d(nhid, nhead * head_dim, 1) - self.WK = nn.Conv2d(keyhid, nhead * head_dim, 1) - self.WV = nn.Conv2d(keyhid, nhead * head_dim, 1) - self.WO = nn.Conv2d(nhead * head_dim, nhid, 1) - - self.drop = nn.Dropout(dropout) - - self.norm = nn.LayerNorm(nhid) - - self.nhid, self.nhead, self.head_dim = nhid, nhead, head_dim - - def forward(self, query_h, value, mask, query_g=None): - - if not (query_g is None): - query = torch.cat([query_h, query_g], -1) - else: - query = query_h - query = query.permute(0, 2, 1)[:, :, :, None] - value = value.permute(0, 3, 1, 2) - - residual = query_h - nhid, nhead, head_dim = self.nhid, self.nhead, self.head_dim - - B, QL, H = query_h.shape - - _, _, VL, VD = value.shape # VD = 1 or VD = QL - - assert VD == 1 or VD == QL - # q: (B, H, QL, 1) - # v: (B, H, VL, VD) - q, k, v = self.WQ(query), self.WK(value), self.WV(value) - - q = q.view(B, nhead, head_dim, 1, QL) - k = k.view(B, nhead, head_dim, VL, VD) - v = v.view(B, nhead, head_dim, VL, VD) - - alpha = (q * k).sum(2, keepdim=True) / np.sqrt(head_dim) - alpha = alpha.masked_fill(mask[:, None, None, :, :], -np.inf) - alpha = self.drop(F.softmax(alpha, 3)) - att = (alpha * v).sum(3).view(B, nhead * head_dim, QL, 1) - - output = F.leaky_relu(self.WO(att)).permute(0, 2, 3, 1).view(B, QL, H) - output = self.norm(output + residual) - - return output - - -class GloAtt(nn.Module): - def __init__(self, nhid, nhead=10, head_dim=10, dropout=0.1): - # Multi-head Self Attention Case 2, a broadcastable query for a sequence key and value - super(GloAtt, self).__init__() - self.WQ = nn.Conv2d(nhid, nhead * head_dim, 1) - self.WK = nn.Conv2d(nhid, nhead * head_dim, 1) - self.WV = nn.Conv2d(nhid, nhead * head_dim, 1) - self.WO = nn.Conv2d(nhead * head_dim, nhid, 1) - - self.drop = nn.Dropout(dropout) - - self.norm = nn.LayerNorm(nhid) - - # print('NUM_HEAD', nhead, 'DIM_HEAD', head_dim) - self.nhid, self.nhead, self.head_dim = nhid, nhead, head_dim - - def forward(self, x, y, mask=None): - # x: B, H, 1, 1, 1 y: B H L 1 - nhid, nhead, head_dim = self.nhid, self.nhead, self.head_dim - B, L, H = y.shape - - x = x.permute(0, 2, 1)[:, :, :, None] - y = y.permute(0, 2, 1)[:, :, :, None] - - residual = x - q, k, v = self.WQ(x), self.WK(y), self.WV(y) - - q = q.view(B, nhead, 1, head_dim) # B, H, 1, 1 -> B, N, 1, h - k = k.view(B, nhead, head_dim, L) # B, H, L, 1 -> B, N, h, L - v = v.view(B, nhead, head_dim, L).permute(0, 1, 3, 2) # B, H, L, 1 -> B, N, L, h - - pre_a = torch.matmul(q, k) / np.sqrt(head_dim) - if mask is not None: - pre_a = pre_a.masked_fill(mask[:, None, None, :], -float('inf')) - alphas = self.drop(F.softmax(pre_a, 3)) # B, N, 1, L - att = torch.matmul(alphas, v).view(B, -1, 1, 1) # B, N, 1, h -> B, N*h, 1, 1 - output = F.leaky_relu(self.WO(att)) + residual - output = self.norm(output.permute(0, 2, 3, 1)).view(B, 1, H) - - return output - - -class Nodes_Cell(nn.Module): - def __init__(self, hid_h, dropout=0.2): - super(Nodes_Cell, self).__init__() - - self.Wix = nn.Linear(hid_h*5, hid_h) - #self.Wig = nn.Linear(hid_h*4, hid_h) - self.Wi2 = nn.Linear(hid_h*5, hid_h) - self.Wf = nn.Linear(hid_h*5, hid_h) - self.Wcx = nn.Linear(hid_h*5, hid_h) - #self.Wcg = nn.Linear(hid_h, hid_h) - - self.drop = nn.Dropout(dropout) - - def forward(self, h, h2, x, glo): - - x = self.drop(x) - glo = self.drop(glo) - - cat_all = torch.cat([h, h2, x, glo], -1) - #cat_x = torch.cat([h, h2, x], -1) - #cat_g = torch.cat([glo], -1) - - ix = torch.sigmoid(self.Wix(cat_all)) - #ig = torch.sigmoid(self.Wig(cat_all)) - i2 = torch.sigmoid(self.Wi2(cat_all)) - f = torch.sigmoid(self.Wf(cat_all)) - cx = torch.tanh(self.Wcx(cat_all)) - #cg = torch.tanh(self.Wcg(cat_g)) - - alpha = F.softmax(torch.cat([ix.unsqueeze(1), i2.unsqueeze(1), f.unsqueeze(1)], 1), 1) - output = (alpha[:, 0] * cx) + (alpha[:, 1] * h2) + (alpha[:, 2] * h) - - return output - - -class Gazs_Cell(nn.Module): - def __init__(self, hid_h, dropout=0.2): - super(Gazs_Cell, self).__init__() - - self.Wi = nn.Linear(hid_h*4, hid_h) - self.Wf = nn.Linear(hid_h*4, hid_h) - self.Wc = nn.Linear(hid_h*4, hid_h) - - self.drop = nn.Dropout(dropout) - - def forward(self, h, x, glo): - - x = self.drop(x) - glo = self.drop(glo) - - cat_all = torch.cat([h, x, glo], -1) - i = torch.sigmoid(self.Wi(cat_all)) - f = torch.sigmoid(self.Wf(cat_all)) - c = torch.tanh(self.Wc(cat_all)) - - alpha = F.softmax(torch.cat([i.unsqueeze(1), f.unsqueeze(1)], 1), 1) - output = (alpha[:, 0] * c) + (alpha[:, 1] * h) - - return output - - -class GLobal_Cell(nn.Module): - def __init__(self, hid_h, dropout=0.2): - super(GLobal_Cell, self).__init__() - - self.Wi = nn.Linear(hid_h*3, hid_h) - self.Wf = nn.Linear(hid_h*3, hid_h) - self.Wc = nn.Linear(hid_h*3, hid_h) - - self.drop = nn.Dropout(dropout) - - def forward(self, h, x): - - x = self.drop(x) - - cat_all = torch.cat([h, x], -1) - i = torch.sigmoid(self.Wi(cat_all)) - f = torch.sigmoid(self.Wf(cat_all)) - c = torch.tanh(self.Wc(cat_all)) - - alpha = F.softmax(torch.cat([i.unsqueeze(1), f.unsqueeze(1)], 1), 1) - output = (alpha[:, 0] * c) + (alpha[:, 1] * h) - - return output - - -class Graph(nn.Module): - def __init__(self, data): - super(Graph, self).__init__() - - self.gpu = data.HP_gpu - self.word_alphabet = data.word_alphabet - self.word_emb_dim = data.word_emb_dim - self.gaz_emb_dim = data.gaz_emb_dim - self.hidden_dim = 50 - self.num_head = 10 # 5 10 20 - self.head_dim = 10 # 10 20 - self.tf_dropout_rate = 0.1 - self.iters = 5 - self.bmes_dim = 10 - self.length_dim = 10 - self.max_gaz_length = 5 - self.emb_dropout_rate = 0.5 - self.cell_dropout_rate = 0.2 - - # word embedding - self.word_embedding = nn.Embedding(data.word_alphabet.size(), self.word_emb_dim) - assert data.pretrain_word_embedding is not None - self.word_embedding.weight.data.copy_(torch.from_numpy(data.pretrain_word_embedding)) - - # gaz embedding - self.gaz_embedding = nn.Embedding(data.gaz_alphabet.size(), self.gaz_emb_dim) - assert data.pretrain_gaz_embedding is not None - scale = np.sqrt(3.0 / self.gaz_emb_dim) - data.pretrain_gaz_embedding[0, :] = np.random.uniform(-scale, scale, [1, self.gaz_emb_dim]) - self.gaz_embedding.weight.data.copy_(torch.from_numpy(data.pretrain_gaz_embedding)) - - # position embedding - #self.pos_embedding = nn.Embedding(data.posi_alphabet_size, self.hidden_dim) - # lstm - self.emb_rnn_f = nn.LSTM(self.hidden_dim, self.hidden_dim, batch_first=True) - self.emb_rnn_b = nn.LSTM(self.hidden_dim, self.hidden_dim, batch_first=True) - - # bmes embedding - self.bmes_embedding = nn.Embedding(4, self.bmes_dim) - - # length embedding - self.length_embedding = nn.Embedding(self.max_gaz_length, self.length_dim) - - self.dropout = nn.Dropout(self.emb_dropout_rate) - self.norm = nn.LayerNorm(self.hidden_dim) - - self.edge2node_f = nn.ModuleList( - [MultiHeadAtt(self.hidden_dim, self.hidden_dim*2+self.length_dim, - nhead=self.num_head, head_dim=self.head_dim, dropout=self.tf_dropout_rate) - for _ in range(self.iters)]) - self.node2edge_f = nn.ModuleList( - [MultiHeadAtt(self.hidden_dim, self.hidden_dim+self.bmes_dim, nhead=self.num_head, head_dim=self.head_dim, dropout=self.tf_dropout_rate) - for _ in range(self.iters)]) - - self.glo_att_f_node = nn.ModuleList( - [GloAtt(self.hidden_dim, nhead=self.num_head, head_dim=self.head_dim, dropout=self.tf_dropout_rate) - for _ in range(self.iters)]) - - self.glo_att_f_edge = nn.ModuleList( - [GloAtt(self.hidden_dim, nhead=self.num_head, head_dim=self.head_dim, dropout=self.tf_dropout_rate) - for _ in range(self.iters)]) - - self.edge_rnn_f = Gazs_Cell(self.hidden_dim, dropout=self.cell_dropout_rate) - self.node_rnn_f = Nodes_Cell(self.hidden_dim, dropout=self.cell_dropout_rate) - self.glo_rnn_f = GLobal_Cell(self.hidden_dim, dropout=self.cell_dropout_rate) - - self.edge2node_b = nn.ModuleList( - [MultiHeadAtt(self.hidden_dim, self.hidden_dim*2+self.length_dim, - nhead=self.num_head, head_dim=self.head_dim, dropout=self.tf_dropout_rate) - for _ in range(self.iters)]) - self.node2edge_b = nn.ModuleList( - [MultiHeadAtt(self.hidden_dim, self.hidden_dim+self.bmes_dim, nhead=self.num_head, head_dim=self.head_dim, dropout=self.tf_dropout_rate) - for _ in range(self.iters)]) - - self.glo_att_b_node = nn.ModuleList( - [GloAtt(self.hidden_dim, nhead=self.num_head, head_dim=self.head_dim, dropout=self.tf_dropout_rate) - for _ in range(self.iters)]) - - self.glo_att_b_edge = nn.ModuleList( - [GloAtt(self.hidden_dim, nhead=self.num_head, head_dim=self.head_dim, dropout=self.tf_dropout_rate) - for _ in range(self.iters)]) - - self.edge_rnn_b = Gazs_Cell(self.hidden_dim, self.cell_dropout_rate) - self.node_rnn_b = Nodes_Cell(self.hidden_dim, self.cell_dropout_rate) - self.glo_rnn_b = GLobal_Cell(self.hidden_dim, self.cell_dropout_rate) - - self.layer_att_W = nn.Linear(self.hidden_dim * 2, 1) - self.hidden2tag = nn.Linear(self.hidden_dim * 2, data.label_alphabet_size + 2) - self.crf = CRF(data.label_alphabet_size, self.gpu) - - if self.gpu: - self.word_embedding = self.word_embedding.cuda() - self.gaz_embedding = self.gaz_embedding.cuda() - self.bmes_embedding = self.bmes_embedding.cuda() - self.length_embedding = self.length_embedding.cuda() - self.norm = self.norm.cuda() - self.edge2node_f = self.edge2node_f.cuda() - self.node2edge_f = self.node2edge_f.cuda() - self.edge_rnn_f = self.edge_rnn_f.cuda() - self.node_rnn_f = self.node_rnn_f.cuda() - self.glo_rnn_f = self.glo_rnn_f.cuda() - self.glo_att_f_node = self.glo_att_f_node.cuda() - self.glo_att_f_edge = self.glo_att_f_edge.cuda() - self.edge2node_b = self.edge2node_b.cuda() - self.node2edge_b = self.node2edge_b.cuda() - self.edge_rnn_b = self.edge_rnn_b.cuda() - self.node_rnn_b = self.node_rnn_b.cuda() - self.glo_rnn_b = self.glo_rnn_b.cuda() - self.glo_att_b_node = self.glo_att_b_node.cuda() - self.glo_att_b_edge = self.glo_att_b_edge.cuda() - #self.pos_embedding = self.pos_embedding.cuda() - self.emb_rnn_f = self.emb_rnn_f.cuda() - self.emb_rnn_b = self.emb_rnn_b.cuda() - self.layer_att_W = self.layer_att_W.cuda() - self.hidden2tag = self.hidden2tag.cuda() - self.crf = self.crf.cuda() - - def obtain_gaz_relation(self, batch_size, seq_len, gaz_list): - - assert batch_size == 1 - - unk_index = torch.tensor(0).cuda() if self.cuda else torch.tensor(0) - unk_emb = self.gaz_embedding(unk_index) - - bmes_index_b = torch.tensor(0).cuda() if self.cuda else torch.tensor(0) - bmes_index_m = torch.tensor(1).cuda() if self.cuda else torch.tensor(1) - bmes_index_e = torch.tensor(2).cuda() if self.cuda else torch.tensor(2) - bmes_index_s = torch.tensor(3).cuda() if self.cuda else torch.tensor(3) - - bmes_emb_b = self.bmes_embedding(bmes_index_b) - bmes_emb_m = self.bmes_embedding(bmes_index_m) - bmes_emb_e = self.bmes_embedding(bmes_index_e) - bmes_emb_s = self.bmes_embedding(bmes_index_s) - - for sen in range(batch_size): - sen_gaz_embed = unk_emb[None, :] - sen_nodes_mask = torch.zeros([1, seq_len]).byte() - sen_gazs_length = torch.zeros([1, self.length_dim]) - sen_bmes_embed = torch.zeros([1, seq_len, self.bmes_dim]) - sen_gazs_mask_f = torch.zeros([1, seq_len]).byte() - sen_gazs_mask_b = torch.zeros([1, seq_len]).byte() - if self.cuda: - sen_gaz_embed = sen_gaz_embed.cuda() - sen_nodes_mask = sen_nodes_mask.cuda() - sen_gazs_length = sen_gazs_length.cuda() - sen_bmes_embed = sen_bmes_embed.cuda() - sen_gazs_mask_f = sen_gazs_mask_f.cuda() - sen_gazs_mask_b = sen_gazs_mask_b.cuda() - - for w in range(seq_len): - if w < len(gaz_list[sen]) and gaz_list[sen][w]: - for gaz, gaz_len in zip(gaz_list[sen][w][0], gaz_list[sen][w][1]): - - gaz_index = torch.tensor(gaz, device=sen_gaz_embed.device) - gaz_embedding = self.gaz_embedding(gaz_index) - sen_gaz_embed = torch.cat([sen_gaz_embed, gaz_embedding[None, :]], 0) - - if gaz_len <= self.max_gaz_length: - gaz_length_index = torch.tensor(gaz_len-1, device=sen_gazs_length.device) - else: - gaz_length_index = torch.tensor(self.max_gaz_length-1, device=sen_gazs_length.device) - gaz_length = self.length_embedding(gaz_length_index) - sen_gazs_length = torch.cat([sen_gazs_length, gaz_length[None, :]], 0) - - # mask: 需要mask的地方置为1, batch_size * gaz_num * seq_len - nodes_mask = torch.ones([1, seq_len]).byte() - bmes_embed = torch.zeros([1, seq_len, self.bmes_dim]) - gazs_mask_f = torch.ones([1, seq_len]).byte() - gazs_mask_b = torch.ones([1, seq_len]).byte() - if self.cuda: - nodes_mask = nodes_mask.cuda() - bmes_embed = bmes_embed.cuda() - gazs_mask_f = gazs_mask_f.cuda() - gazs_mask_b = gazs_mask_b.cuda() - - gazs_mask_f[0, w + gaz_len - 1] = 0 - sen_gazs_mask_f = torch.cat([sen_gazs_mask_f, gazs_mask_f], 0) - - gazs_mask_b[0, w] = 0 - sen_gazs_mask_b = torch.cat([sen_gazs_mask_b, gazs_mask_b], 0) - - nodes_mask[0, w] = 0 - nodes_mask[0, w + gaz_len - 1] = 0 - for index in range(gaz_len): - - if gaz_len == 1: - bmes_embed[0, w + index, :] = bmes_emb_s - elif index == 0: - bmes_embed[0, w + index, :] = bmes_emb_b - elif index == gaz_len - 1: - bmes_embed[0, w + index, :] = bmes_emb_e - else: - bmes_embed[0, w + index, :] = bmes_emb_m - - sen_nodes_mask = torch.cat([sen_nodes_mask, nodes_mask], 0) - sen_bmes_embed = torch.cat([sen_bmes_embed, bmes_embed], 0) - - #sen_gazs_mask_f[0, (1-sen_gazs_mask_f).sum(dim=0) == 0] = 0 - #sen_gazs_mask_b[0, (1-sen_gazs_mask_b).sum(dim=0) == 0] = 0 - - batch_gaz_embed = sen_gaz_embed.unsqueeze(0) # 只有在batch_size=1时可以这么做 - batch_nodes_mask = sen_nodes_mask.unsqueeze(0) - batch_bmes_embed = sen_bmes_embed.unsqueeze(0) - batch_gazs_mask_f = sen_gazs_mask_f.unsqueeze(0) - batch_gazs_mask_b = sen_gazs_mask_b.unsqueeze(0) - batch_gazs_length = sen_gazs_length.unsqueeze(0) - return batch_gaz_embed, batch_bmes_embed, batch_nodes_mask, batch_gazs_mask_f, batch_gazs_mask_b, batch_gazs_length - - def get_tags(self, gaz_list, word_inputs, mask): - - #mask = 1 - mask - node_embeds = self.word_embedding(word_inputs) # batch_size, max_seq_len, embedding - B, L, H = node_embeds.size() - gaz_match = [] - - edge_embs, bmes_embs, nodes_mask, gazs_mask_f, gazs_mask_b, gazs_length = self.obtain_gaz_relation(B, L, gaz_list) - _, N, _ = edge_embs.size() - #smask = torch.cat([torch.zeros(B, 1, ).byte().to(mask), mask], 1) - - #P = self.pos_embedding(torch.arange(L, dtype=torch.long, device=node_embeds.device).view(1, L)) - #node_embeds = node_embeds + P - - node_embeds = self.dropout(node_embeds) - edge_embs = self.dropout(edge_embs) - - #nodes_f = node_embeds - edges_f = edge_embs - nodes_f, _ = self.emb_rnn_f(node_embeds) - - glo_f = node_embeds.mean(1, keepdim=True) + edge_embs.mean(1, keepdim=True) - nodes_f_cat = nodes_f[:, None, :, :] - edges_f_cat = edges_f[:, None, :, :] - glo_f_cat = glo_f[:, None, :, :] - #ex_mask = mask[:, None, :, None].expand(B, H, L, 1) - - for i in range(self.iters): - - if N > 1: - bmes_nodes_f = torch.cat([nodes_f.unsqueeze(2).expand(B, L, N, H), bmes_embs.transpose(1, 2)], -1) - edges_att_f = self.node2edge_f[i](edges_f, bmes_nodes_f, nodes_mask.transpose(1, 2)) - - nodes_begin_f = torch.sum(nodes_f[:, None, :, :] * (1 - gazs_mask_b)[:, :, :, None].float(), 2) - nodes_begin_f = torch.cat([torch.zeros([B, 1, H], device=nodes_f.device), nodes_begin_f[:, 1:N, :]], 1) - nodes_att_f = self.edge2node_f[i](nodes_f, torch.cat([edges_f, nodes_begin_f, gazs_length], -1).unsqueeze(2), gazs_mask_f) - - glo_att_f = torch.cat([self.glo_att_f_node[i](glo_f, nodes_f), self.glo_att_f_edge[i](glo_f, edges_f)], -1) - - if N > 1: - edges_f = torch.cat([edges_f[:, 0:1, :], self.edge_rnn_f(edges_f[:, 1:N, :], - edges_att_f[:, 1:N, :], glo_att_f.expand(B, N-1, H*2))], 1) - edges_f_cat = torch.cat([edges_f_cat, edges_f[:, None, :, :]], 1) - edges_f = torch.cat([edges_f[:, 0:1, :], self.norm(torch.sum(edges_f_cat[:, :, 1:N, :], 1))], 1) - - nodes_f_r = torch.cat([torch.zeros([B, 1, self.hidden_dim], device=nodes_f.device), nodes_f[:, 0:(L-1), :]], 1) - nodes_f = self.node_rnn_f(nodes_f, nodes_f_r, nodes_att_f, glo_att_f.expand(B, L, H*2)) - nodes_f_cat = torch.cat([nodes_f_cat, nodes_f[:, None, :, :]], 1) - nodes_f = self.norm(torch.sum(nodes_f_cat, 1)) - - glo_f = self.glo_rnn_f(glo_f, glo_att_f) - glo_f_cat = torch.cat([glo_f_cat, glo_f[:, None, :, :]], 1) - glo_f = self.norm(torch.sum(glo_f_cat, 1)) - #nodes = nodes.masked_fill_(ex_mask, 0) - - #nodes_b = node_embeds - edges_b = edge_embs - nodes_b, _ = self.emb_rnn_b(torch.flip(node_embeds, [1])) - nodes_b = torch.flip(nodes_b, [1]) - - glo_b = node_embeds.mean(1, keepdim=True) + edge_embs.mean(1, keepdim=True) - nodes_b_cat = nodes_b[:, None, :, :] - edges_b_cat = edges_b[:, None, :, :] - glo_b_cat = glo_b[:, None, :, :] - - for i in range(self.iters): - - if N > 1: - bmes_nodes_b = torch.cat([nodes_b.unsqueeze(2).expand(B, L, N, H), bmes_embs.transpose(1, 2)], -1) - edges_att_b = self.node2edge_b[i](edges_b, bmes_nodes_b, nodes_mask.transpose(1, 2)) - - nodes_begin_b = torch.sum(nodes_b[:, None, :, :] * (1 - gazs_mask_f)[:, :, :, None].float(), 2) - nodes_begin_b = torch.cat([torch.zeros([B, 1, H], device=nodes_b.device), nodes_begin_b[:, 1:N, :]], 1) - nodes_att_b = self.edge2node_b[i](nodes_b, - torch.cat([edges_b, nodes_begin_b, gazs_length], -1).unsqueeze(2), gazs_mask_b) - - glo_att_b = torch.cat([self.glo_att_b_node[i](glo_b, nodes_b), self.glo_att_b_edge[i](glo_b, edges_b)], -1) - - if N > 1: - edges_b = torch.cat([edges_b[:, 0:1, :], self.edge_rnn_b(edges_b[:, 1:N, :], edges_att_b[:, 1:N, :], - glo_att_b.expand(B, N-1, H*2))], 1) - edges_b_cat = torch.cat([edges_b_cat, edges_b[:, None, :, :]], 1) - edges_b = torch.cat([edges_b[:, 0:1, :], self.norm(torch.sum(edges_b_cat[:, :, 1:N, :], 1))], 1) - - nodes_b_r = torch.cat([nodes_b[:, 1:L, :], torch.zeros([B, 1, self.hidden_dim], device=nodes_b.device)], 1) - nodes_b = self.node_rnn_b(nodes_b, nodes_b_r, nodes_att_b, glo_att_b.expand(B, L, H*2)) - nodes_b_cat = torch.cat([nodes_b_cat, nodes_b[:, None, :, :]], 1) - nodes_b = self.norm(torch.sum(nodes_b_cat, 1)) - - glo_b = self.glo_rnn_b(glo_b, glo_att_b) - glo_b_cat = torch.cat([glo_b_cat, glo_b[:, None, :, :]], 1) - glo_b = self.norm(torch.sum(glo_b_cat, 1)) - - nodes_cat = torch.cat([nodes_f_cat, nodes_b_cat], -1) - layer_att = torch.sigmoid(self.layer_att_W(nodes_cat)) - layer_alpha = F.softmax(layer_att, 1) - nodes = torch.sum(layer_alpha * nodes_cat, 1) - - tags = self.hidden2tag(nodes) - - return tags, gaz_match - - def neg_log_likelihood_loss(self, gaz_list, word_inputs, word_seq_lengths, mask, batch_label): - - tags, _ = self.get_tags(gaz_list, word_inputs, mask) - total_loss = self.crf.neg_log_likelihood_loss(tags, mask, batch_label) - scores, tag_seq = self.crf._viterbi_decode(tags, mask) - - return total_loss, tag_seq # (batch_size,) ,(b,seqlen?) - - def forward(self, gaz_list, word_inputs, word_seq_lengths, mask): - tags, gaz_match = self.get_tags(gaz_list, word_inputs, mask) - scores, tag_seq = self.crf._viterbi_decode(tags, mask) - return tag_seq, gaz_match diff --git a/model/transformer_single.py b/model/transformer_single.py deleted file mode 100644 index eb0bdb3..0000000 --- a/model/transformer_single.py +++ /dev/null @@ -1,443 +0,0 @@ -# -*- coding: utf-8 -*- -import torch -import torch.nn as nn -import numpy as np -import torch.nn.functional as F -from model.crf import CRF - - -class MultiHeadAtt(nn.Module): - def __init__(self, nhid, keyhid, nhead=10, head_dim=10, dropout=0.1, if_g=False): - super(MultiHeadAtt, self).__init__() - - if if_g: - self.WQ = nn.Conv2d(nhid * 3, nhead * head_dim, 1) - else: - self.WQ = nn.Conv2d(nhid, nhead * head_dim, 1) - self.WK = nn.Conv2d(keyhid, nhead * head_dim, 1) - self.WV = nn.Conv2d(keyhid, nhead * head_dim, 1) - self.WO = nn.Conv2d(nhead * head_dim, nhid, 1) - - self.drop = nn.Dropout(dropout) - - self.norm = nn.LayerNorm(nhid) - - self.nhid, self.nhead, self.head_dim = nhid, nhead, head_dim - - def forward(self, query_h, value, mask, query_g=None): - - if not (query_g is None): - query = torch.cat([query_h, query_g], -1) - else: - query = query_h - query = query.permute(0, 2, 1)[:, :, :, None] - value = value.permute(0, 3, 1, 2) - - residual = query_h - nhid, nhead, head_dim = self.nhid, self.nhead, self.head_dim - - B, QL, H = query_h.shape - - _, _, VL, VD = value.shape # VD = 1 or VD = QL - - assert VD == 1 or VD == QL - # q: (B, H, QL, 1) - # v: (B, H, VL, VD) - q, k, v = self.WQ(query), self.WK(value), self.WV(value) - - q = q.view(B, nhead, head_dim, 1, QL) - k = k.view(B, nhead, head_dim, VL, VD) - v = v.view(B, nhead, head_dim, VL, VD) - - alpha = (q * k).sum(2, keepdim=True) / np.sqrt(head_dim) - alpha = alpha.masked_fill(mask[:, None, None, :, :], -np.inf) - alpha = self.drop(F.softmax(alpha, 3)) - att = (alpha * v).sum(3).view(B, nhead * head_dim, QL, 1) - - output = F.leaky_relu(self.WO(att)).permute(0, 2, 3, 1).view(B, QL, H) - output = self.norm(output + residual) - - return output - - -class GloAtt(nn.Module): - def __init__(self, nhid, nhead=10, head_dim=10, dropout=0.1): - # Multi-head Self Attention Case 2, a broadcastable query for a sequence key and value - super(GloAtt, self).__init__() - self.WQ = nn.Conv2d(nhid, nhead * head_dim, 1) - self.WK = nn.Conv2d(nhid, nhead * head_dim, 1) - self.WV = nn.Conv2d(nhid, nhead * head_dim, 1) - self.WO = nn.Conv2d(nhead * head_dim, nhid, 1) - - self.drop = nn.Dropout(dropout) - - self.norm = nn.LayerNorm(nhid) - - # print('NUM_HEAD', nhead, 'DIM_HEAD', head_dim) - self.nhid, self.nhead, self.head_dim = nhid, nhead, head_dim - - def forward(self, x, y, mask=None): - # x: B, H, 1, 1, 1 y: B H L 1 - nhid, nhead, head_dim = self.nhid, self.nhead, self.head_dim - B, L, H = y.shape - - x = x.permute(0, 2, 1)[:, :, :, None] - y = y.permute(0, 2, 1)[:, :, :, None] - - residual = x - q, k, v = self.WQ(x), self.WK(y), self.WV(y) - - q = q.view(B, nhead, 1, head_dim) # B, H, 1, 1 -> B, N, 1, h - k = k.view(B, nhead, head_dim, L) # B, H, L, 1 -> B, N, h, L - v = v.view(B, nhead, head_dim, L).permute(0, 1, 3, 2) # B, H, L, 1 -> B, N, L, h - - pre_a = torch.matmul(q, k) / np.sqrt(head_dim) - if mask is not None: - pre_a = pre_a.masked_fill(mask[:, None, None, :], -float('inf')) - alphas = self.drop(F.softmax(pre_a, 3)) # B, N, 1, L - att = torch.matmul(alphas, v).view(B, -1, 1, 1) # B, N, 1, h -> B, N*h, 1, 1 - output = F.leaky_relu(self.WO(att)) + residual - output = self.norm(output.permute(0, 2, 3, 1)).view(B, 1, H) - - return output - - -class Nodes_Cell(nn.Module): - def __init__(self, hid_h, dropout=0.2): - super(Nodes_Cell, self).__init__() - - self.Wix = nn.Linear(hid_h*5, hid_h) - #self.Wig = nn.Linear(hid_h*4, hid_h) - self.Wi2 = nn.Linear(hid_h*5, hid_h) - self.Wf = nn.Linear(hid_h*5, hid_h) - self.Wcx = nn.Linear(hid_h*5, hid_h) - #self.Wcg = nn.Linear(hid_h, hid_h) - - self.drop = nn.Dropout(dropout) - - def forward(self, h, h2, x, glo): - - x = self.drop(x) - glo = self.drop(glo) - - cat_all = torch.cat([h, h2, x, glo], -1) - #cat_x = torch.cat([h, h2, x], -1) - #cat_g = torch.cat([glo], -1) - - ix = torch.sigmoid(self.Wix(cat_all)) - #ig = torch.sigmoid(self.Wig(cat_all)) - i2 = torch.sigmoid(self.Wi2(cat_all)) - f = torch.sigmoid(self.Wf(cat_all)) - cx = torch.tanh(self.Wcx(cat_all)) - #cg = torch.tanh(self.Wcg(cat_g)) - - alpha = F.softmax(torch.cat([ix.unsqueeze(1), i2.unsqueeze(1), f.unsqueeze(1)], 1), 1) - output = (alpha[:, 0] * cx) + (alpha[:, 1] * h2) + (alpha[:, 2] * h) - - return output - - -class Gazs_Cell(nn.Module): - def __init__(self, hid_h, dropout=0.2): - super(Gazs_Cell, self).__init__() - - self.Wi = nn.Linear(hid_h*4, hid_h) - self.Wf = nn.Linear(hid_h*4, hid_h) - self.Wc = nn.Linear(hid_h*4, hid_h) - - self.drop = nn.Dropout(dropout) - - def forward(self, h, x, glo): - - x = self.drop(x) - glo = self.drop(glo) - - cat_all = torch.cat([h, x, glo], -1) - i = torch.sigmoid(self.Wi(cat_all)) - f = torch.sigmoid(self.Wf(cat_all)) - c = torch.tanh(self.Wc(cat_all)) - - alpha = F.softmax(torch.cat([i.unsqueeze(1), f.unsqueeze(1)], 1), 1) - output = (alpha[:, 0] * c) + (alpha[:, 1] * h) - - return output - - -class GLobal_Cell(nn.Module): - def __init__(self, hid_h, dropout=0.2): - super(GLobal_Cell, self).__init__() - - self.Wi = nn.Linear(hid_h*3, hid_h) - self.Wf = nn.Linear(hid_h*3, hid_h) - self.Wc = nn.Linear(hid_h*3, hid_h) - - self.drop = nn.Dropout(dropout) - - def forward(self, h, x): - - x = self.drop(x) - - cat_all = torch.cat([h, x], -1) - i = torch.sigmoid(self.Wi(cat_all)) - f = torch.sigmoid(self.Wf(cat_all)) - c = torch.tanh(self.Wc(cat_all)) - - alpha = F.softmax(torch.cat([i.unsqueeze(1), f.unsqueeze(1)], 1), 1) - output = (alpha[:, 0] * c) + (alpha[:, 1] * h) - - return output - - -class Graph(nn.Module): - def __init__(self, data): - super(Graph, self).__init__() - - self.gpu = data.HP_gpu - self.word_alphabet = data.word_alphabet - self.word_emb_dim = data.word_emb_dim - self.gaz_emb_dim = data.gaz_emb_dim - self.hidden_dim = 50 - self.num_head = 10 # 5 10 20 - self.head_dim = 20 # 10 20 - self.tf_dropout_rate = 0.1 - self.iters = 4 - self.bmes_dim = 10 - self.length_dim = 10 - self.max_gaz_length = 5 - self.emb_dropout_rate = 0.5 - self.cell_dropout_rate = 0.2 - - # word embedding - self.word_embedding = nn.Embedding(data.word_alphabet.size(), self.word_emb_dim) - assert data.pretrain_word_embedding is not None - self.word_embedding.weight.data.copy_(torch.from_numpy(data.pretrain_word_embedding)) - - # gaz embedding - self.gaz_embedding = nn.Embedding(data.gaz_alphabet.size(), self.gaz_emb_dim) - assert data.pretrain_gaz_embedding is not None - scale = np.sqrt(3.0 / self.gaz_emb_dim) - data.pretrain_gaz_embedding[0, :] = np.random.uniform(-scale, scale, [1, self.gaz_emb_dim]) - self.gaz_embedding.weight.data.copy_(torch.from_numpy(data.pretrain_gaz_embedding)) - - # position embedding - #self.pos_embedding = nn.Embedding(data.posi_alphabet_size, self.hidden_dim) - # lstm - self.emb_rnn_f = nn.LSTM(self.hidden_dim, self.hidden_dim, batch_first=True) - - # bmes embedding - self.bmes_embedding = nn.Embedding(4, self.bmes_dim) - - # length embedding - self.length_embedding = nn.Embedding(self.max_gaz_length, self.length_dim) - - self.dropout = nn.Dropout(self.emb_dropout_rate) - self.norm = nn.LayerNorm(self.hidden_dim) - - self.edge2node_f = nn.ModuleList( - [MultiHeadAtt(self.hidden_dim, self.hidden_dim*2+self.length_dim, - nhead=self.num_head, head_dim=self.head_dim, dropout=self.tf_dropout_rate) - for _ in range(self.iters)]) - self.node2edge_f = nn.ModuleList( - [MultiHeadAtt(self.hidden_dim, self.hidden_dim+self.bmes_dim, nhead=self.num_head, head_dim=self.head_dim, dropout=self.tf_dropout_rate) - for _ in range(self.iters)]) - - self.glo_att_f_node = nn.ModuleList( - [GloAtt(self.hidden_dim, nhead=self.num_head, head_dim=self.head_dim, dropout=self.tf_dropout_rate) - for _ in range(self.iters)]) - - self.glo_att_f_edge = nn.ModuleList( - [GloAtt(self.hidden_dim, nhead=self.num_head, head_dim=self.head_dim, dropout=self.tf_dropout_rate) - for _ in range(self.iters)]) - - self.edge_rnn_f = Gazs_Cell(self.hidden_dim, dropout=self.cell_dropout_rate) - self.node_rnn_f = Nodes_Cell(self.hidden_dim, dropout=self.cell_dropout_rate) - self.glo_rnn_f = GLobal_Cell(self.hidden_dim, dropout=self.cell_dropout_rate) - - self.layer_att_W = nn.Linear(self.hidden_dim, 1) - self.hidden2tag = nn.Linear(self.hidden_dim, data.label_alphabet_size + 2) - self.crf = CRF(data.label_alphabet_size, self.gpu) - - if self.gpu: - self.word_embedding = self.word_embedding.cuda() - self.gaz_embedding = self.gaz_embedding.cuda() - self.bmes_embedding = self.bmes_embedding.cuda() - self.length_embedding = self.length_embedding.cuda() - self.norm = self.norm.cuda() - self.edge2node_f = self.edge2node_f.cuda() - self.node2edge_f = self.node2edge_f.cuda() - self.edge_rnn_f = self.edge_rnn_f.cuda() - self.node_rnn_f = self.node_rnn_f.cuda() - self.glo_rnn_f = self.glo_rnn_f.cuda() - self.glo_att_f_node = self.glo_att_f_node.cuda() - self.glo_att_f_edge = self.glo_att_f_edge.cuda() - #self.pos_embedding = self.pos_embedding.cuda() - self.emb_rnn_f = self.emb_rnn_f.cuda() - self.layer_att_W = self.layer_att_W.cuda() - self.hidden2tag = self.hidden2tag.cuda() - self.crf = self.crf.cuda() - - def obtain_gaz_relation(self, batch_size, seq_len, gaz_list): - - assert batch_size == 1 - - unk_index = torch.tensor(0).cuda() if self.cuda else torch.tensor(0) - unk_emb = self.gaz_embedding(unk_index) - - bmes_index_b = torch.tensor(0).cuda() if self.cuda else torch.tensor(0) - bmes_index_m = torch.tensor(1).cuda() if self.cuda else torch.tensor(1) - bmes_index_e = torch.tensor(2).cuda() if self.cuda else torch.tensor(2) - bmes_index_s = torch.tensor(3).cuda() if self.cuda else torch.tensor(3) - - bmes_emb_b = self.bmes_embedding(bmes_index_b) - bmes_emb_m = self.bmes_embedding(bmes_index_m) - bmes_emb_e = self.bmes_embedding(bmes_index_e) - bmes_emb_s = self.bmes_embedding(bmes_index_s) - - for sen in range(batch_size): - sen_gaz_embed = unk_emb[None, :] - sen_nodes_mask = torch.zeros([1, seq_len]).byte() - sen_gazs_length = torch.zeros([1, self.length_dim]) - sen_bmes_embed = torch.zeros([1, seq_len, self.bmes_dim]) - sen_gazs_mask_f = torch.zeros([1, seq_len]).byte() - sen_gazs_mask_b = torch.zeros([1, seq_len]).byte() - if self.cuda: - sen_gaz_embed = sen_gaz_embed.cuda() - sen_nodes_mask = sen_nodes_mask.cuda() - sen_gazs_length = sen_gazs_length.cuda() - sen_bmes_embed = sen_bmes_embed.cuda() - sen_gazs_mask_f = sen_gazs_mask_f.cuda() - sen_gazs_mask_b = sen_gazs_mask_b.cuda() - - for w in range(seq_len): - if w < len(gaz_list[sen]) and gaz_list[sen][w]: - for gaz, gaz_len in zip(gaz_list[sen][w][0], gaz_list[sen][w][1]): - - gaz_index = torch.tensor(gaz, device=sen_gaz_embed.device) - gaz_embedding = self.gaz_embedding(gaz_index) - sen_gaz_embed = torch.cat([sen_gaz_embed, gaz_embedding[None, :]], 0) - - if gaz_len <= self.max_gaz_length: - gaz_length_index = torch.tensor(gaz_len-1, device=sen_gazs_length.device) - else: - gaz_length_index = torch.tensor(self.max_gaz_length-1, device=sen_gazs_length.device) - gaz_length = self.length_embedding(gaz_length_index) - sen_gazs_length = torch.cat([sen_gazs_length, gaz_length[None, :]], 0) - - # mask: 需要mask的地方置为1, batch_size * gaz_num * seq_len - nodes_mask = torch.ones([1, seq_len]).byte() - bmes_embed = torch.zeros([1, seq_len, self.bmes_dim]) - gazs_mask_f = torch.ones([1, seq_len]).byte() - gazs_mask_b = torch.ones([1, seq_len]).byte() - if self.cuda: - nodes_mask = nodes_mask.cuda() - bmes_embed = bmes_embed.cuda() - gazs_mask_f = gazs_mask_f.cuda() - gazs_mask_b = gazs_mask_b.cuda() - - gazs_mask_f[0, w + gaz_len - 1] = 0 - sen_gazs_mask_f = torch.cat([sen_gazs_mask_f, gazs_mask_f], 0) - - gazs_mask_b[0, w] = 0 - sen_gazs_mask_b = torch.cat([sen_gazs_mask_b, gazs_mask_b], 0) - - for index in range(gaz_len): - nodes_mask[0, w + index] = 0 - if gaz_len == 1: - bmes_embed[0, w + index, :] = bmes_emb_s - elif index == 0: - bmes_embed[0, w + index, :] = bmes_emb_b - elif index == gaz_len - 1: - bmes_embed[0, w + index, :] = bmes_emb_e - else: - bmes_embed[0, w + index, :] = bmes_emb_m - - sen_nodes_mask = torch.cat([sen_nodes_mask, nodes_mask], 0) - sen_bmes_embed = torch.cat([sen_bmes_embed, bmes_embed], 0) - - #sen_gazs_mask_f[0, (1-sen_gazs_mask_f).sum(dim=0) == 0] = 0 - #sen_gazs_mask_b[0, (1-sen_gazs_mask_b).sum(dim=0) == 0] = 0 - - batch_gaz_embed = sen_gaz_embed.unsqueeze(0) # 只有在batch_size=1时可以这么做 - batch_nodes_mask = sen_nodes_mask.unsqueeze(0) - batch_bmes_embed = sen_bmes_embed.unsqueeze(0) - batch_gazs_mask_f = sen_gazs_mask_f.unsqueeze(0) - batch_gazs_mask_b = sen_gazs_mask_b.unsqueeze(0) - batch_gazs_length = sen_gazs_length.unsqueeze(0) - return batch_gaz_embed, batch_bmes_embed, batch_nodes_mask, batch_gazs_mask_f, batch_gazs_mask_b, batch_gazs_length - - def get_tags(self, gaz_list, word_inputs, mask): - - #mask = 1 - mask - node_embeds = self.word_embedding(word_inputs) # batch_size, max_seq_len, embedding - B, L, H = node_embeds.size() - gaz_match = [] - - edge_embs, bmes_embs, nodes_mask, gazs_mask_f, gazs_mask_b, gazs_length = self.obtain_gaz_relation(B, L, gaz_list) - _, N, _ = edge_embs.size() - #smask = torch.cat([torch.zeros(B, 1, ).byte().to(mask), mask], 1) - - #P = self.pos_embedding(torch.arange(L, dtype=torch.long, device=node_embeds.device).view(1, L)) - #node_embeds = node_embeds + P - - node_embeds = self.dropout(node_embeds) - edge_embs = self.dropout(edge_embs) - - #nodes_f = node_embeds - edges_f = edge_embs - nodes_f, _ = self.emb_rnn_f(node_embeds) - - glo_f = node_embeds.mean(1, keepdim=True) + edge_embs.mean(1, keepdim=True) - nodes_f_cat = nodes_f[:, None, :, :] - edges_f_cat = edges_f[:, None, :, :] - glo_f_cat = glo_f[:, None, :, :] - #ex_mask = mask[:, None, :, None].expand(B, H, L, 1) - - for i in range(self.iters): - - if N > 1: - bmes_nodes_f = torch.cat([nodes_f.unsqueeze(2).expand(B, L, N, H), bmes_embs.transpose(1, 2)], -1) - edges_att_f = self.node2edge_f[i](edges_f, bmes_nodes_f, nodes_mask.transpose(1, 2)) - - nodes_begin_f = torch.sum(nodes_f[:, None, :, :] * (1 - gazs_mask_b)[:, :, :, None].float(), 2) - nodes_begin_f = torch.cat([torch.zeros([B, 1, H], device=nodes_f.device), nodes_begin_f[:, 1:N, :]], 1) - nodes_att_f = self.edge2node_f[i](nodes_f, torch.cat([edges_f, nodes_begin_f, gazs_length], -1).unsqueeze(2), gazs_mask_f) - - glo_att_f = torch.cat([self.glo_att_f_node[i](glo_f, nodes_f), self.glo_att_f_edge[i](glo_f, edges_f)], -1) - - if N > 1: - edges_f = torch.cat([edges_f[:, 0:1, :], self.edge_rnn_f(edges_f[:, 1:N, :], - edges_att_f[:, 1:N, :], glo_att_f.expand(B, N-1, H*2))], 1) - edges_f_cat = torch.cat([edges_f_cat, edges_f[:, None, :, :]], 1) - edges_f = torch.cat([edges_f[:, 0:1, :], self.norm(torch.sum(edges_f_cat[:, :, 1:N, :], 1))], 1) - - nodes_f_r = torch.cat([torch.zeros([B, 1, self.hidden_dim], device=nodes_f.device), nodes_f[:, 0:(L-1), :]], 1) - nodes_f = self.node_rnn_f(nodes_f, nodes_f_r, nodes_att_f, glo_att_f.expand(B, L, H*2)) - nodes_f_cat = torch.cat([nodes_f_cat, nodes_f[:, None, :, :]], 1) - nodes_f = self.norm(torch.sum(nodes_f_cat, 1)) - - glo_f = self.glo_rnn_f(glo_f, glo_att_f) - glo_f_cat = torch.cat([glo_f_cat, glo_f[:, None, :, :]], 1) - glo_f = self.norm(torch.sum(glo_f_cat, 1)) - #nodes = nodes.masked_fill_(ex_mask, 0) - - nodes_cat = nodes_f_cat - layer_att = torch.sigmoid(self.layer_att_W(nodes_cat)) - layer_alpha = F.softmax(layer_att, 1) - nodes = torch.sum(layer_alpha * nodes_cat, 1) - - tags = self.hidden2tag(nodes) - - return tags, gaz_match - - def neg_log_likelihood_loss(self, gaz_list, word_inputs, word_seq_lengths, mask, batch_label): - - tags, _ = self.get_tags(gaz_list, word_inputs, mask) - total_loss = self.crf.neg_log_likelihood_loss(tags, mask, batch_label) - scores, tag_seq = self.crf._viterbi_decode(tags, mask) - - return total_loss, tag_seq # (batch_size,) ,(b,seqlen?) - - def forward(self, gaz_list, word_inputs, word_seq_lengths, mask): - tags, gaz_match = self.get_tags(gaz_list, word_inputs, mask) - scores, tag_seq = self.crf._viterbi_decode(tags, mask) - return tag_seq, gaz_match diff --git a/utils/data.py b/utils/data.py index 0258f89..5116364 100644 --- a/utils/data.py +++ b/utils/data.py @@ -3,7 +3,6 @@ # @Last Modified by: Yicheng Zou, Contact: yczou18@fudan.edu.cn import sys -import re from utils.alphabet import Alphabet from utils.functions import * from utils.word_trie import Word_Trie