Skip to content

Commit

Permalink
Add ablation experiment settings.
Browse files Browse the repository at this point in the history
  • Loading branch information
RowitZou committed Jun 19, 2019
1 parent 651d3d2 commit c7b03c1
Show file tree
Hide file tree
Showing 9 changed files with 217 additions and 798 deletions.
298 changes: 189 additions & 109 deletions model/LGN.py

Large diffs are not rendered by default.

383 changes: 0 additions & 383 deletions model/LGN_no_edge.py

This file was deleted.

275 changes: 0 additions & 275 deletions model/LGN_no_glo_no_edge.py

This file was deleted.

3 changes: 1 addition & 2 deletions model/crf.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,6 @@
# -*- coding: utf-8 -*-
# @Author: Jie Yang
# @Date: 2017-12-04 23:19:38
# @Last Modified by: Yicheng Zou, Contact: yczou18@fudan.edu.cn
# @Last Modified time: 2018-12-13 22:48:17

import torch
import torch.autograd as autograd
Expand All @@ -25,6 +23,7 @@ def log_sum_exp(vec, m_size):
max_score = torch.gather(vec, 1, idx.view(-1, 1, m_size)).view(-1, 1, m_size) # B * M
return max_score.view(-1, m_size) + torch.log(torch.sum(torch.exp(vec - max_score.expand_as(vec)), 1)).view(-1, m_size) # B * M


class CRF(nn.Module):

def __init__(self, tagset_size, gpu):
Expand Down
38 changes: 15 additions & 23 deletions model/module.py
Original file line number Diff line number Diff line change
Expand Up @@ -106,19 +106,15 @@ def forward(self, x, y, mask=None):


class Nodes_Cell(nn.Module):
def __init__(self, hid_h, use_global=True, dropout=0.2):
def __init__(self, input_h, hid_h, use_global=True, dropout=0.2):
super(Nodes_Cell, self).__init__()

self.use_global = use_global
if self.use_global:
input_size = hid_h * 5
else:
input_size = hid_h * 3

self.Wix = nn.Linear(input_size, hid_h)
self.Wi2 = nn.Linear(input_size, hid_h)
self.Wf = nn.Linear(input_size, hid_h)
self.Wcx = nn.Linear(input_size, hid_h)
self.Wix = nn.Linear(input_h, hid_h)
self.Wi2 = nn.Linear(input_h, hid_h)
self.Wf = nn.Linear(input_h, hid_h)
self.Wcx = nn.Linear(input_h, hid_h)

self.drop = nn.Dropout(dropout)

Expand All @@ -144,18 +140,14 @@ def forward(self, h, h2, x, glo=None):


class Edges_Cell(nn.Module):
def __init__(self, hid_h, use_global=True, dropout=0.2):
def __init__(self, input_h, hid_h, use_global=True, dropout=0.2):
super(Edges_Cell, self).__init__()

self.use_global = use_global
if self.use_global:
input_size = hid_h * 4
else:
input_size = hid_h * 2

self.Wi = nn.Linear(input_size, hid_h)
self.Wf = nn.Linear(input_size, hid_h)
self.Wc = nn.Linear(input_size, hid_h)
self.Wi = nn.Linear(input_h, hid_h)
self.Wf = nn.Linear(input_h, hid_h)
self.Wc = nn.Linear(input_h, hid_h)

self.drop = nn.Dropout(dropout)

Expand All @@ -179,13 +171,13 @@ def forward(self, h, x, glo=None):
return output


class GLobal_Cell(nn.Module):
def __init__(self, hid_h, dropout=0.2):
super(GLobal_Cell, self).__init__()
class Global_Cell(nn.Module):
def __init__(self, input_h, hid_h, dropout=0.2):
super(Global_Cell, self).__init__()

self.Wi = nn.Linear(hid_h*3, hid_h)
self.Wf = nn.Linear(hid_h*3, hid_h)
self.Wc = nn.Linear(hid_h*3, hid_h)
self.Wi = nn.Linear(input_h, hid_h)
self.Wf = nn.Linear(input_h, hid_h)
self.Wc = nn.Linear(input_h, hid_h)

self.drop = nn.Dropout(dropout)

Expand Down
7 changes: 4 additions & 3 deletions utils/data.py
Original file line number Diff line number Diff line change
@@ -1,12 +1,13 @@
# -*- coding: utf-8 -*-
# @Author: Jie
# @Author: Jie Yang
# @Last Modified by: Yicheng Zou, Contact: yczou18@fudan.edu.cn

import sys
from utils.alphabet import Alphabet
from utils.functions import *
from utils.word_trie import Word_Trie


class Data:
def __init__(self):
self.MAX_SENTENCE_LENGTH = 250
Expand Down Expand Up @@ -75,7 +76,7 @@ def build_alphabet(self, input_file):
self.char_alphabet.close()

def build_word_file(self, word_file):
## build word file,initial word embedding file
# build word file,initial word embedding file
with open(word_file, 'r', encoding="utf-8") as f:
for line in f:
word = line.strip().split()[0]
Expand Down Expand Up @@ -146,7 +147,7 @@ def write_decoded_results(self, output_file, predict_results, name):
for idx in range(sent_num):
sent_length = len(predict_results[idx])
for idy in range(sent_length):
## content_list[idx] is a list with [word, char, label]
# content_list[idx] is a list with [word, char, label]
fout.write(content_list[idx][0][idy] + " " + predict_results[idx][idy] + '\n')
fout.write('\n')
fout.close()
Expand Down
Loading

0 comments on commit c7b03c1

Please sign in to comment.