-
Notifications
You must be signed in to change notification settings - Fork 88
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
- Loading branch information
0 parents
commit f032314
Showing
29 changed files
with
2,132 additions
and
0 deletions.
There are no files selected for viewing
Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.
Oops, something went wrong.
Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.
Oops, something went wrong.
Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.
Oops, something went wrong.
Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.
Oops, something went wrong.
Large diffs are not rendered by default.
Oops, something went wrong.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,50 @@ | ||
Attention Guided Graph Convolutional Networks for Relation Extraction | ||
========== | ||
|
||
This repo contains the *PyTorch* code for the paper. | ||
|
||
This paper/code introduces the Attention Guided Graph Convolutional graph convolutional networks (AGGCNs) over dependency trees for the large scale sentence-level relation extraction task (TACRED). | ||
|
||
## Requirements | ||
|
||
- Python 3 (tested on 3.6.5) | ||
- PyTorch (tested on 0.4.1) | ||
- tqdm | ||
- unzip, wget (for downloading only) | ||
|
||
## Preparation | ||
|
||
The code requires that you have access to the TACRED dataset (LDC license required). Once you have the TACRED data, please put the JSON files under the directory `dataset/tacred`. | ||
|
||
First, download and unzip GloVe vectors: | ||
``` | ||
chmod +x download.sh; ./download.sh | ||
``` | ||
|
||
Then prepare vocabulary and initial word vectors with: | ||
``` | ||
python prepare_vocab.py dataset/tacred dataset/vocab --glove_dir dataset/glove | ||
``` | ||
|
||
This will write vocabulary and word vectors as a numpy matrix into the dir `dataset/vocab`. | ||
|
||
## Training | ||
|
||
To train the AGGCN model, run: | ||
``` | ||
bash train_caggcn.sh 1 | ||
``` | ||
|
||
Model checkpoints and logs will be saved to `./saved_models/01`. | ||
|
||
For details on the use of other parameters, please refer to `train.py`. | ||
|
||
## Evaluation | ||
|
||
To run evaluation on the test set, run: | ||
``` | ||
python eval.py saved_models/01 --dataset test | ||
``` | ||
|
||
This will use the `best_model.pt` file by default. Use `--model checkpoint_epoch_10.pt` to specify a model checkpoint file. | ||
|
Binary file not shown.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,149 @@ | ||
""" | ||
Data loader for TACRED json files. | ||
""" | ||
|
||
import json | ||
import random | ||
import torch | ||
import numpy as np | ||
|
||
from utils import constant | ||
|
||
class DataLoader(object): | ||
""" | ||
Load data from json files, preprocess and prepare batches. | ||
""" | ||
def __init__(self, filename, batch_size, opt, vocab, evaluation=False): | ||
self.batch_size = batch_size | ||
self.opt = opt | ||
self.vocab = vocab | ||
self.eval = evaluation | ||
self.label2id = constant.LABEL_TO_ID | ||
|
||
with open(filename) as infile: | ||
data = json.load(infile) | ||
self.raw_data = data | ||
data = self.preprocess(data, vocab, opt) | ||
|
||
# shuffle for training | ||
if not evaluation: | ||
indices = list(range(len(data))) | ||
random.shuffle(indices) | ||
data = [data[i] for i in indices] | ||
self.id2label = dict([(v,k) for k,v in self.label2id.items()]) | ||
self.labels = [self.id2label[d[-1]] for d in data] | ||
self.num_examples = len(data) | ||
|
||
# chunk into batches | ||
data = [data[i:i+batch_size] for i in range(0, len(data), batch_size)] | ||
self.data = data | ||
print("{} batches created for {}".format(len(data), filename)) | ||
|
||
def preprocess(self, data, vocab, opt): | ||
""" Preprocess the data and convert to ids. """ | ||
processed = [] | ||
for d in data: | ||
tokens = list(d['token']) | ||
if opt['lower']: | ||
tokens = [t.lower() for t in tokens] | ||
# anonymize tokens | ||
ss, se = d['subj_start'], d['subj_end'] | ||
os, oe = d['obj_start'], d['obj_end'] | ||
tokens[ss:se+1] = ['SUBJ-'+d['subj_type']] * (se-ss+1) | ||
tokens[os:oe+1] = ['OBJ-'+d['obj_type']] * (oe-os+1) | ||
tokens = map_to_ids(tokens, vocab.word2id) | ||
pos = map_to_ids(d['stanford_pos'], constant.POS_TO_ID) | ||
ner = map_to_ids(d['stanford_ner'], constant.NER_TO_ID) | ||
deprel = map_to_ids(d['stanford_deprel'], constant.DEPREL_TO_ID) | ||
head = [int(x) for x in d['stanford_head']] | ||
assert any([x == 0 for x in head]) | ||
l = len(tokens) | ||
subj_positions = get_positions(d['subj_start'], d['subj_end'], l) | ||
obj_positions = get_positions(d['obj_start'], d['obj_end'], l) | ||
subj_type = [constant.SUBJ_NER_TO_ID[d['subj_type']]] | ||
obj_type = [constant.OBJ_NER_TO_ID[d['obj_type']]] | ||
relation = self.label2id[d['relation']] | ||
processed += [(tokens, pos, ner, deprel, head, subj_positions, obj_positions, subj_type, obj_type, relation)] | ||
return processed | ||
|
||
def gold(self): | ||
""" Return gold labels as a list. """ | ||
return self.labels | ||
|
||
def __len__(self): | ||
return len(self.data) | ||
|
||
def __getitem__(self, key): | ||
""" Get a batch with index. """ | ||
if not isinstance(key, int): | ||
raise TypeError | ||
if key < 0 or key >= len(self.data): | ||
raise IndexError | ||
batch = self.data[key] | ||
batch_size = len(batch) | ||
batch = list(zip(*batch)) | ||
assert len(batch) == 10 | ||
|
||
# sort all fields by lens for easy RNN operations | ||
lens = [len(x) for x in batch[0]] | ||
batch, orig_idx = sort_all(batch, lens) | ||
|
||
# word dropout | ||
if not self.eval: | ||
words = [word_dropout(sent, self.opt['word_dropout']) for sent in batch[0]] | ||
else: | ||
words = batch[0] | ||
|
||
# convert to tensors | ||
words = get_long_tensor(words, batch_size) | ||
masks = torch.eq(words, 0) | ||
pos = get_long_tensor(batch[1], batch_size) | ||
ner = get_long_tensor(batch[2], batch_size) | ||
deprel = get_long_tensor(batch[3], batch_size) | ||
head = get_long_tensor(batch[4], batch_size) | ||
subj_positions = get_long_tensor(batch[5], batch_size) | ||
obj_positions = get_long_tensor(batch[6], batch_size) | ||
subj_type = get_long_tensor(batch[7], batch_size) | ||
obj_type = get_long_tensor(batch[8], batch_size) | ||
|
||
rels = torch.LongTensor(batch[9]) | ||
|
||
return (words, masks, pos, ner, deprel, head, subj_positions, obj_positions, subj_type, obj_type, rels, orig_idx) | ||
|
||
def __iter__(self): | ||
for i in range(self.__len__()): | ||
yield self.__getitem__(i) | ||
|
||
|
||
def map_to_ids(tokens, vocab): | ||
ids = [vocab[t] if t in vocab else constant.UNK_ID for t in tokens] | ||
return ids | ||
|
||
|
||
def get_positions(start_idx, end_idx, length): | ||
""" Get subj/obj position sequence. """ | ||
return list(range(-start_idx, 0)) + [0]*(end_idx - start_idx + 1) + \ | ||
list(range(1, length-end_idx)) | ||
|
||
|
||
def get_long_tensor(tokens_list, batch_size): | ||
""" Convert list of list of tokens to a padded LongTensor. """ | ||
token_len = max(len(x) for x in tokens_list) | ||
tokens = torch.LongTensor(batch_size, token_len).fill_(constant.PAD_ID) | ||
for i, s in enumerate(tokens_list): | ||
tokens[i, :len(s)] = torch.LongTensor(s) | ||
return tokens | ||
|
||
|
||
def sort_all(batch, lens): | ||
""" Sort all fields by descending order of lens, and return the original indices. """ | ||
unsorted_all = [lens] + [range(len(lens))] + list(batch) | ||
sorted_all = [list(t) for t in zip(*sorted(zip(*unsorted_all), reverse=True))] | ||
return sorted_all[2:], sorted_all[1] | ||
|
||
|
||
def word_dropout(tokens, dropout): | ||
""" Randomly dropout tokens (IDs) and replace them with <UNK> tokens. """ | ||
return [constant.UNK_ID if x != constant.UNK_ID and np.random.random() < dropout \ | ||
else x for x in tokens] | ||
|
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,14 @@ | ||
#!/bin/bash | ||
|
||
cd dataset; mkdir glove | ||
cd glove | ||
|
||
echo "==> Downloading glove vectors..." | ||
wget http://nlp.stanford.edu/data/glove.840B.300d.zip | ||
|
||
echo "==> Unzipping glove vectors..." | ||
unzip glove.840B.300d.zip | ||
rm glove.840B.300d.zip | ||
|
||
echo "==> Done." | ||
|
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,67 @@ | ||
""" | ||
Run evaluation with saved models. | ||
""" | ||
import random | ||
import argparse | ||
from tqdm import tqdm | ||
import torch | ||
|
||
from data.loader import DataLoader | ||
from model.trainer import GCNTrainer | ||
from utils import torch_utils, scorer, constant, helper | ||
from utils.vocab import Vocab | ||
|
||
|
||
parser = argparse.ArgumentParser() | ||
parser.add_argument('model_dir', type=str, help='Directory of the model.') | ||
parser.add_argument('--model', type=str, default='best_model.pt', help='Name of the model file.') | ||
parser.add_argument('--data_dir', type=str, default='dataset/tacred') | ||
parser.add_argument('--dataset', type=str, default='test', help="Evaluate on dev or test.") | ||
|
||
parser.add_argument('--seed', type=int, default=1234) | ||
parser.add_argument('--cuda', type=bool, default=torch.cuda.is_available()) | ||
parser.add_argument('--cpu', action='store_true') | ||
args = parser.parse_args() | ||
|
||
torch.manual_seed(args.seed) | ||
random.seed(1234) | ||
if args.cpu: | ||
args.cuda = False | ||
elif args.cuda: | ||
torch.cuda.manual_seed(args.seed) | ||
|
||
# load opt | ||
model_file = args.model_dir + '/' + args.model | ||
print("Loading model from {}".format(model_file)) | ||
opt = torch_utils.load_config(model_file) | ||
trainer = GCNTrainer(opt) | ||
trainer.load(model_file) | ||
|
||
# load vocab | ||
vocab_file = args.model_dir + '/vocab.pkl' | ||
vocab = Vocab(vocab_file, load=True) | ||
assert opt['vocab_size'] == vocab.size, "Vocab size must match that in the saved model." | ||
|
||
# load data | ||
data_file = opt['data_dir'] + '/{}.json'.format(args.dataset) | ||
print("Loading data from {} with batch size {}...".format(data_file, opt['batch_size'])) | ||
batch = DataLoader(data_file, opt['batch_size'], opt, vocab, evaluation=True) | ||
|
||
helper.print_config(opt) | ||
label2id = constant.LABEL_TO_ID | ||
id2label = dict([(v,k) for k,v in label2id.items()]) | ||
|
||
predictions = [] | ||
all_probs = [] | ||
batch_iter = tqdm(batch) | ||
for i, b in enumerate(batch_iter): | ||
preds, probs, _ = trainer.predict(b) | ||
predictions += preds | ||
all_probs += probs | ||
|
||
predictions = [id2label[p] for p in predictions] | ||
p, r, f1 = scorer.score(batch.gold(), predictions, verbose=True) | ||
print("{} set evaluate result: {:.2f}\t{:.2f}\t{:.2f}".format(args.dataset,p,r,f1)) | ||
|
||
print("Evaluation ended.") | ||
|
Binary file not shown.
Binary file not shown.
Binary file not shown.
Oops, something went wrong.