Skip to content

Commit

Permalink
first commit
Browse files Browse the repository at this point in the history
  • Loading branch information
Cartus committed May 24, 2019
0 parents commit f032314
Show file tree
Hide file tree
Showing 29 changed files with 2,132 additions and 0 deletions.
11 changes: 11 additions & 0 deletions .idea/AGGCN_TACRED.iml

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

22 changes: 22 additions & 0 deletions .idea/deployment.xml

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

7 changes: 7 additions & 0 deletions .idea/misc.xml

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

8 changes: 8 additions & 0 deletions .idea/modules.xml

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

456 changes: 456 additions & 0 deletions .idea/workspace.xml

Large diffs are not rendered by default.

50 changes: 50 additions & 0 deletions README.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,50 @@
Attention Guided Graph Convolutional Networks for Relation Extraction
==========

This repo contains the *PyTorch* code for the paper.

This paper/code introduces the Attention Guided Graph Convolutional graph convolutional networks (AGGCNs) over dependency trees for the large scale sentence-level relation extraction task (TACRED).

## Requirements

- Python 3 (tested on 3.6.5)
- PyTorch (tested on 0.4.1)
- tqdm
- unzip, wget (for downloading only)

## Preparation

The code requires that you have access to the TACRED dataset (LDC license required). Once you have the TACRED data, please put the JSON files under the directory `dataset/tacred`.

First, download and unzip GloVe vectors:
```
chmod +x download.sh; ./download.sh
```

Then prepare vocabulary and initial word vectors with:
```
python prepare_vocab.py dataset/tacred dataset/vocab --glove_dir dataset/glove
```

This will write vocabulary and word vectors as a numpy matrix into the dir `dataset/vocab`.

## Training

To train the AGGCN model, run:
```
bash train_caggcn.sh 1
```

Model checkpoints and logs will be saved to `./saved_models/01`.

For details on the use of other parameters, please refer to `train.py`.

## Evaluation

To run evaluation on the test set, run:
```
python eval.py saved_models/01 --dataset test
```

This will use the `best_model.pt` file by default. Use `--model checkpoint_epoch_10.pt` to specify a model checkpoint file.

Binary file added data/__pycache__/loader.cpython-36.pyc
Binary file not shown.
149 changes: 149 additions & 0 deletions data/loader.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,149 @@
"""
Data loader for TACRED json files.
"""

import json
import random
import torch
import numpy as np

from utils import constant

class DataLoader(object):
"""
Load data from json files, preprocess and prepare batches.
"""
def __init__(self, filename, batch_size, opt, vocab, evaluation=False):
self.batch_size = batch_size
self.opt = opt
self.vocab = vocab
self.eval = evaluation
self.label2id = constant.LABEL_TO_ID

with open(filename) as infile:
data = json.load(infile)
self.raw_data = data
data = self.preprocess(data, vocab, opt)

# shuffle for training
if not evaluation:
indices = list(range(len(data)))
random.shuffle(indices)
data = [data[i] for i in indices]
self.id2label = dict([(v,k) for k,v in self.label2id.items()])
self.labels = [self.id2label[d[-1]] for d in data]
self.num_examples = len(data)

# chunk into batches
data = [data[i:i+batch_size] for i in range(0, len(data), batch_size)]
self.data = data
print("{} batches created for {}".format(len(data), filename))

def preprocess(self, data, vocab, opt):
""" Preprocess the data and convert to ids. """
processed = []
for d in data:
tokens = list(d['token'])
if opt['lower']:
tokens = [t.lower() for t in tokens]
# anonymize tokens
ss, se = d['subj_start'], d['subj_end']
os, oe = d['obj_start'], d['obj_end']
tokens[ss:se+1] = ['SUBJ-'+d['subj_type']] * (se-ss+1)
tokens[os:oe+1] = ['OBJ-'+d['obj_type']] * (oe-os+1)
tokens = map_to_ids(tokens, vocab.word2id)
pos = map_to_ids(d['stanford_pos'], constant.POS_TO_ID)
ner = map_to_ids(d['stanford_ner'], constant.NER_TO_ID)
deprel = map_to_ids(d['stanford_deprel'], constant.DEPREL_TO_ID)
head = [int(x) for x in d['stanford_head']]
assert any([x == 0 for x in head])
l = len(tokens)
subj_positions = get_positions(d['subj_start'], d['subj_end'], l)
obj_positions = get_positions(d['obj_start'], d['obj_end'], l)
subj_type = [constant.SUBJ_NER_TO_ID[d['subj_type']]]
obj_type = [constant.OBJ_NER_TO_ID[d['obj_type']]]
relation = self.label2id[d['relation']]
processed += [(tokens, pos, ner, deprel, head, subj_positions, obj_positions, subj_type, obj_type, relation)]
return processed

def gold(self):
""" Return gold labels as a list. """
return self.labels

def __len__(self):
return len(self.data)

def __getitem__(self, key):
""" Get a batch with index. """
if not isinstance(key, int):
raise TypeError
if key < 0 or key >= len(self.data):
raise IndexError
batch = self.data[key]
batch_size = len(batch)
batch = list(zip(*batch))
assert len(batch) == 10

# sort all fields by lens for easy RNN operations
lens = [len(x) for x in batch[0]]
batch, orig_idx = sort_all(batch, lens)

# word dropout
if not self.eval:
words = [word_dropout(sent, self.opt['word_dropout']) for sent in batch[0]]
else:
words = batch[0]

# convert to tensors
words = get_long_tensor(words, batch_size)
masks = torch.eq(words, 0)
pos = get_long_tensor(batch[1], batch_size)
ner = get_long_tensor(batch[2], batch_size)
deprel = get_long_tensor(batch[3], batch_size)
head = get_long_tensor(batch[4], batch_size)
subj_positions = get_long_tensor(batch[5], batch_size)
obj_positions = get_long_tensor(batch[6], batch_size)
subj_type = get_long_tensor(batch[7], batch_size)
obj_type = get_long_tensor(batch[8], batch_size)

rels = torch.LongTensor(batch[9])

return (words, masks, pos, ner, deprel, head, subj_positions, obj_positions, subj_type, obj_type, rels, orig_idx)

def __iter__(self):
for i in range(self.__len__()):
yield self.__getitem__(i)


def map_to_ids(tokens, vocab):
ids = [vocab[t] if t in vocab else constant.UNK_ID for t in tokens]
return ids


def get_positions(start_idx, end_idx, length):
""" Get subj/obj position sequence. """
return list(range(-start_idx, 0)) + [0]*(end_idx - start_idx + 1) + \
list(range(1, length-end_idx))


def get_long_tensor(tokens_list, batch_size):
""" Convert list of list of tokens to a padded LongTensor. """
token_len = max(len(x) for x in tokens_list)
tokens = torch.LongTensor(batch_size, token_len).fill_(constant.PAD_ID)
for i, s in enumerate(tokens_list):
tokens[i, :len(s)] = torch.LongTensor(s)
return tokens


def sort_all(batch, lens):
""" Sort all fields by descending order of lens, and return the original indices. """
unsorted_all = [lens] + [range(len(lens))] + list(batch)
sorted_all = [list(t) for t in zip(*sorted(zip(*unsorted_all), reverse=True))]
return sorted_all[2:], sorted_all[1]


def word_dropout(tokens, dropout):
""" Randomly dropout tokens (IDs) and replace them with <UNK> tokens. """
return [constant.UNK_ID if x != constant.UNK_ID and np.random.random() < dropout \
else x for x in tokens]

14 changes: 14 additions & 0 deletions download.sh
Original file line number Diff line number Diff line change
@@ -0,0 +1,14 @@
#!/bin/bash

cd dataset; mkdir glove
cd glove

echo "==> Downloading glove vectors..."
wget http://nlp.stanford.edu/data/glove.840B.300d.zip

echo "==> Unzipping glove vectors..."
unzip glove.840B.300d.zip
rm glove.840B.300d.zip

echo "==> Done."

67 changes: 67 additions & 0 deletions eval.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,67 @@
"""
Run evaluation with saved models.
"""
import random
import argparse
from tqdm import tqdm
import torch

from data.loader import DataLoader
from model.trainer import GCNTrainer
from utils import torch_utils, scorer, constant, helper
from utils.vocab import Vocab


parser = argparse.ArgumentParser()
parser.add_argument('model_dir', type=str, help='Directory of the model.')
parser.add_argument('--model', type=str, default='best_model.pt', help='Name of the model file.')
parser.add_argument('--data_dir', type=str, default='dataset/tacred')
parser.add_argument('--dataset', type=str, default='test', help="Evaluate on dev or test.")

parser.add_argument('--seed', type=int, default=1234)
parser.add_argument('--cuda', type=bool, default=torch.cuda.is_available())
parser.add_argument('--cpu', action='store_true')
args = parser.parse_args()

torch.manual_seed(args.seed)
random.seed(1234)
if args.cpu:
args.cuda = False
elif args.cuda:
torch.cuda.manual_seed(args.seed)

# load opt
model_file = args.model_dir + '/' + args.model
print("Loading model from {}".format(model_file))
opt = torch_utils.load_config(model_file)
trainer = GCNTrainer(opt)
trainer.load(model_file)

# load vocab
vocab_file = args.model_dir + '/vocab.pkl'
vocab = Vocab(vocab_file, load=True)
assert opt['vocab_size'] == vocab.size, "Vocab size must match that in the saved model."

# load data
data_file = opt['data_dir'] + '/{}.json'.format(args.dataset)
print("Loading data from {} with batch size {}...".format(data_file, opt['batch_size']))
batch = DataLoader(data_file, opt['batch_size'], opt, vocab, evaluation=True)

helper.print_config(opt)
label2id = constant.LABEL_TO_ID
id2label = dict([(v,k) for k,v in label2id.items()])

predictions = []
all_probs = []
batch_iter = tqdm(batch)
for i, b in enumerate(batch_iter):
preds, probs, _ = trainer.predict(b)
predictions += preds
all_probs += probs

predictions = [id2label[p] for p in predictions]
p, r, f1 = scorer.score(batch.gold(), predictions, verbose=True)
print("{} set evaluate result: {:.2f}\t{:.2f}\t{:.2f}".format(args.dataset,p,r,f1))

print("Evaluation ended.")

Binary file added model/__pycache__/layers.cpython-36.pyc
Binary file not shown.
Binary file added model/__pycache__/trainer.cpython-36.pyc
Binary file not shown.
Binary file added model/__pycache__/tree.cpython-36.pyc
Binary file not shown.
Loading

0 comments on commit f032314

Please sign in to comment.