Skip to content

Commit

Permalink
Upload essential files.
Browse files Browse the repository at this point in the history
  • Loading branch information
RowitZou committed Jun 13, 2019
1 parent 8d72a2a commit 4663517
Show file tree
Hide file tree
Showing 17 changed files with 5,888 additions and 0 deletions.
567 changes: 567 additions & 0 deletions main.py

Large diffs are not rendered by default.

516 changes: 516 additions & 0 deletions model/LGN.py

Large diffs are not rendered by default.

260 changes: 260 additions & 0 deletions model/crf.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,260 @@
# -*- coding: utf-8 -*-
# @Author: Jie Yang
# @Date: 2017-12-04 23:19:38
# @Last Modified by: Yicheng Zou, Contact: yczou18@fudan.edu.cn
# @Last Modified time: 2018-12-13 22:48:17

import torch
import torch.autograd as autograd
import torch.nn as nn
START_TAG = -2
STOP_TAG = -1


# Compute log sum exp in a numerically stable way for the forward algorithm
def log_sum_exp(vec, m_size):
"""
calculate log of exp sum
args:
vec (batch_size, vanishing_dim, hidden_dim) : input tensor
m_size : hidden_dim
return:
batch_size, hidden_dim
"""
_, idx = torch.max(vec, 1) # B * 1 * M
max_score = torch.gather(vec, 1, idx.view(-1, 1, m_size)).view(-1, 1, m_size) # B * M
return max_score.view(-1, m_size) + torch.log(torch.sum(torch.exp(vec - max_score.expand_as(vec)), 1)).view(-1, m_size) # B * M

class CRF(nn.Module):

def __init__(self, tagset_size, gpu):
super(CRF, self).__init__()
print ("build batched crf...")
self.gpu = gpu
# Matrix of transition parameters. Entry i,j is the score of transitioning *to* i *from* j.
self.average_batch = False
self.tagset_size = tagset_size
# # We add 2 here, because of START_TAG and STOP_TAG
# # transitions (f_tag_size, t_tag_size), transition value from f_tag to t_tag
init_transitions = torch.zeros(self.tagset_size+2, self.tagset_size+2)
# init_transitions = torch.zeros(self.tagset_size+2, self.tagset_size+2)
# init_transitions[:,START_TAG] = -1000.0
# init_transitions[STOP_TAG,:] = -1000.0
# init_transitions[:,0] = -1000.0
# init_transitions[0,:] = -1000.0
if self.gpu:
init_transitions = init_transitions.cuda()
self.transitions = nn.Parameter(init_transitions) #(t+2,t+2)

# self.transitions = nn.Parameter(torch.Tensor(self.tagset_size+2, self.tagset_size+2))
# self.transitions.data.zero_()

def _calculate_PZ(self, feats, mask):
"""
input:
feats: (batch, seq_len, self.tag_size+2) (b,m,t+2)
masks: (batch, seq_len) (b,m)
"""
batch_size = feats.size(0)
seq_len = feats.size(1)
tag_size = feats.size(2)
# print feats.view(seq_len, tag_size)
assert(tag_size == self.tagset_size+2)
mask = mask.transpose(1,0).contiguous() #(m,b)
ins_num = seq_len * batch_size
## be careful the view shape, it is .view(ins_num, 1, tag_size) but not .view(ins_num, tag_size, 1)
feats = feats.transpose(1,0).contiguous().view(ins_num,1, tag_size).expand(ins_num, tag_size, tag_size) #(i,t+2,t+2) 第2维t+2的每一个是一样的
## need to consider start
scores = feats + self.transitions.view(1,tag_size,tag_size).expand(ins_num, tag_size, tag_size)
scores = scores.view(seq_len, batch_size, tag_size, tag_size)
# build iter
seq_iter = enumerate(scores)
_, inivalues = seq_iter.__next__() # bat_size * from_target_size * to_target_size (b,t,t) inivalues是每个句子的第一个字
# only need start from start_tag
partition = inivalues[:, START_TAG, :].clone().view(batch_size, tag_size, 1) # bat_size * to_target_size (b,t,1)

## add start score (from start to all tag, duplicate to batch_size)
# partition = partition + self.transitions[START_TAG,:].view(1, tag_size, 1).expand(batch_size, tag_size, 1)
# iter over last scores
for idx, cur_values in seq_iter:
# previous to_target is current from_target
# partition: previous results log(exp(from_target)), #(batch_size * from_target)
# cur_values: bat_size * from_target * to_target

cur_values = cur_values + partition.contiguous().view(batch_size, tag_size, 1).expand(batch_size, tag_size, tag_size)
cur_partition = log_sum_exp(cur_values, tag_size) #(b,t)
# print cur_partition.data

# (bat_size * from_target * to_target) -> (bat_size * to_target)
# partition = utils.switch(partition, cur_partition, mask[idx].view(bat_size, 1).expand(bat_size, self.tagset_size)).view(bat_size, -1)
mask_idx = mask[idx, :].view(batch_size, 1).expand(batch_size, tag_size)

## effective updated partition part, only keep the partition value of mask value = 1
masked_cur_partition = cur_partition.masked_select(mask_idx)
## let mask_idx broadcastable, to disable warning
mask_idx = mask_idx.contiguous().view(batch_size, tag_size, 1)

## replace the partition where the maskvalue=1, other partition value keeps the same
partition.masked_scatter_(mask_idx, masked_cur_partition)
# until the last state, add transition score for all partition (and do log_sum_exp) then select the value in STOP_TAG
cur_values = self.transitions.view(1,tag_size, tag_size).expand(batch_size, tag_size, tag_size) + partition.contiguous().view(batch_size, tag_size, 1).expand(batch_size, tag_size, tag_size)
cur_partition = log_sum_exp(cur_values, tag_size) #(batch_size,hidden_dim)
final_partition = cur_partition[:, STOP_TAG] #(batch_size)
return final_partition.sum(), scores #scores: (seq_len, batch, tag_size, tag_size)


def _viterbi_decode(self, feats, mask):
"""
input:
feats: (batch, seq_len, self.tag_size+2)
mask: (batch, seq_len)
output:
decode_idx: (batch, seq_len) decoded sequence
path_score: (batch, 1) corresponding score for each sequence (to be implementated)
"""
batch_size = feats.size(0)
seq_len = feats.size(1)
tag_size = feats.size(2)
assert(tag_size == self.tagset_size+2)
## calculate sentence length for each sentence
length_mask = torch.sum(mask.long(), dim = 1).view(batch_size,1).long() #(batch_size,1) 每个句子的mask长度
## mask to (seq_len, batch_size)
mask = mask.transpose(1,0).contiguous() #(seq_len,b)
ins_num = seq_len * batch_size
## be careful the view shape, it is .view(ins_num, 1, tag_size) but not .view(ins_num, tag_size, 1)
feats = feats.transpose(1,0).contiguous().view(ins_num, 1, tag_size).expand(ins_num, tag_size, tag_size) #(ins_num, tag_size, tag_size)
## need to consider start
scores = feats + self.transitions.view(1,tag_size,tag_size).expand(ins_num, tag_size, tag_size)
scores = scores.view(seq_len, batch_size, tag_size, tag_size)

# build iter
seq_iter = enumerate(scores)
## record the position of best score
back_points = list()
partition_history = list()

## reverse mask (bug for mask = 1- mask, use this as alternative choice)
# mask = 1 + (-1)*mask
mask = (1 - mask.long()).byte()
_, inivalues = seq_iter.__next__() # bat_size * from_target_size * to_target_size
# only need start from start_tag
partition = inivalues[:, START_TAG, :].clone().view(batch_size, tag_size, 1) # bat_size * to_target_size
partition_history.append(partition) #(seqlen,batch_size,tag_size,1)
# iter over last scores
for idx, cur_values in seq_iter:
# previous to_target is current from_target
# partition: previous results log(exp(from_target)), #(batch_size * from_target)
# cur_values: batch_size * from_target * to_target
cur_values = cur_values + partition.contiguous().view(batch_size, tag_size, 1).expand(batch_size, tag_size, tag_size)
## forscores, cur_bp = torch.max(cur_values[:,:-2,:], 1) # do not consider START_TAG/STOP_TAG
partition, cur_bp = torch.max(cur_values,dim=1)
partition_history.append(partition.unsqueeze(2))
## cur_bp: (batch_size, tag_size) max source score position in current tag
## set padded label as 0, which will be filtered in post processing
cur_bp.masked_fill_(mask[idx].view(batch_size, 1).expand(batch_size, tag_size), 0)
back_points.append(cur_bp)
### add score to final STOP_TAG
partition_history = torch.cat(partition_history,dim=0).view(seq_len, batch_size,-1).transpose(1,0).contiguous() ## (batch_size, seq_len, tag_size)
### get the last position for each setences, and select the last partitions using gather()
last_position = length_mask.view(batch_size,1,1).expand(batch_size, 1, tag_size) -1
last_partition = torch.gather(partition_history, 1, last_position).view(batch_size,tag_size,1)
### calculate the score from last partition to end state (and then select the STOP_TAG from it)
last_values = last_partition.expand(batch_size, tag_size, tag_size) + self.transitions.view(1,tag_size, tag_size).expand(batch_size, tag_size, tag_size)
_, last_bp = torch.max(last_values, 1) #(batch_size,tag_size)
pad_zero = autograd.Variable(torch.zeros(batch_size, tag_size)).long()
if self.gpu:
pad_zero = pad_zero.cuda()
back_points.append(pad_zero)
back_points = torch.cat(back_points).view(seq_len, batch_size, tag_size)

## select end ids in STOP_TAG
pointer = last_bp[:, STOP_TAG] #(batch_size)
insert_last = pointer.contiguous().view(batch_size,1,1).expand(batch_size,1, tag_size)
back_points = back_points.transpose(1,0).contiguous() #(batch_size,sq_len,tag_size)
## move the end ids(expand to tag_size) to the corresponding position of back_points to replace the 0 values
# print "lp:",last_position
# print "il:",insert_last
back_points.scatter_(1, last_position, insert_last) ##(batch_size,sq_len,tag_size)
# print "bp:",back_points
# exit(0)
back_points = back_points.transpose(1,0).contiguous() #(seq_len, batch_size, tag_size)
## decode from the end, padded position ids are 0, which will be filtered if following evaluation
decode_idx = autograd.Variable(torch.LongTensor(seq_len, batch_size))
if self.gpu:
decode_idx = decode_idx.cuda()
decode_idx[-1] = pointer.data
for idx in range(len(back_points)-2, -1, -1):
pointer = torch.gather(back_points[idx], 1, pointer.contiguous().view(batch_size, 1)) #pointer's size:(batch_size,1)
decode_idx[idx] = pointer.squeeze(1).data
path_score = None
decode_idx = decode_idx.transpose(1,0) #(batch_size, sent_len)
return path_score, decode_idx #


def forward(self, feats):
path_score, best_path = self._viterbi_decode(feats)
return path_score, best_path


def _score_sentence(self, scores, mask, tags):
"""
input:
scores: variable (seq_len, batch, tag_size, tag_size)
mask: (batch, seq_len)
tags: tensor (batch, seq_len)
output:
score: sum of score for gold sequences within whole batch
"""
# Gives the score of a provided tag sequence
batch_size = scores.size(1)
seq_len = scores.size(0)
tag_size = scores.size(2)
## convert tag value into a new format, recorded label bigram information to index
new_tags = autograd.Variable(torch.LongTensor(batch_size, seq_len))
if self.gpu:
new_tags = new_tags.cuda()
for idx in range(seq_len):
if idx == 0:
## start -> first score
new_tags[:,0] = (tag_size - 2)*tag_size + tags[:,0]

else:
new_tags[:,idx] = tags[:,idx-1]*tag_size + tags[:,idx]

## transition for label to STOP_TAG
end_transition = self.transitions[:,STOP_TAG].contiguous().view(1, tag_size).expand(batch_size, tag_size)
## length for batch, last word position = length - 1
length_mask = torch.sum(mask.long(), dim = 1).view(batch_size,1).long()
## index the label id of last word
end_ids = torch.gather(tags, 1, length_mask - 1)

## index the transition score for end_id to STOP_TAG
end_energy = torch.gather(end_transition, 1, end_ids)

## convert tag as (seq_len, batch_size, 1)
new_tags = new_tags.transpose(1,0).contiguous().view(seq_len, batch_size, 1)
### need convert tags id to search from 400 positions of scores
tg_energy = torch.gather(scores.view(seq_len, batch_size, -1), 2, new_tags).view(seq_len, batch_size) # seq_len * bat_size
## mask transpose to (seq_len, batch_size)
tg_energy = tg_energy.masked_select(mask.transpose(1,0))

# ## calculate the score from START_TAG to first label
# start_transition = self.transitions[START_TAG,:].view(1, tag_size).expand(batch_size, tag_size)
# start_energy = torch.gather(start_transition, 1, tags[0,:])

## add all score together
# gold_score = start_energy.sum() + tg_energy.sum() + end_energy.sum()
gold_score = tg_energy.sum() + end_energy.sum()
return gold_score

def neg_log_likelihood_loss(self, feats, mask, tags):
# nonegative log likelihood
batch_size = feats.size(0)
forward_score, scores = self._calculate_PZ(feats, mask) #forward_score:long, scores: (seq_len, batch, tag_size, tag_size)
gold_score = self._score_sentence(scores, mask, tags)
#print ("batch, f:", forward_score.data, " g:", gold_score.data, " dis:", forward_score.data - gold_score.data)
# exit(0)
if self.average_batch:
return (forward_score - gold_score)/batch_size
else:
return forward_score - gold_score
Loading

0 comments on commit 4663517

Please sign in to comment.