-
Notifications
You must be signed in to change notification settings - Fork 56
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
* add prediction/qrel files dump option * fix comment * upgrade to torchtext 0.3 * remove ngram * update device * minor update * add pit2015 * revert update torchtext * revert update torchtext * revert update torchtext
- Loading branch information
1 parent
8a00f9c
commit 5afb845
Showing
9 changed files
with
199 additions
and
5 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,32 @@ | ||
import torch | ||
import torch.nn.functional as F | ||
|
||
from .evaluator import Evaluator | ||
|
||
class PIT2015Evaluator(Evaluator): | ||
|
||
def get_scores(self): | ||
self.model.eval() | ||
self.data_loader.init_epoch() | ||
n_dev_correct = 0 | ||
total_loss = 0 | ||
acc_total = 0 | ||
rel_total = 0 | ||
pre_total = 0 | ||
for batch_idx, batch in enumerate(self.data_loader): | ||
sent1, sent2 = self.get_sentence_embeddings(batch) | ||
scores = self.model(sent1, sent2, batch.ext_feats, batch.dataset.word_to_doc_cnt, batch.sentence_1_raw, batch.sentence_2_raw) | ||
prediction = torch.max(scores, 1)[1].view(batch.label.size()).data | ||
gold_label = batch.label.data | ||
n_dev_correct += (prediction == gold_label).sum().item() | ||
acc_total += ((prediction == batch.label.data) * (prediction == 1)).sum().item() | ||
total_loss += F.nll_loss(scores, batch.label, size_average=False).item() | ||
rel_total += batch.label.data.sum().item() | ||
pre_total += torch.max(scores, 1)[1].view(batch.label.size()).data.sum().item() | ||
|
||
precision = acc_total / pre_total | ||
recall = acc_total / rel_total | ||
f1 = 2 * precision * recall / (precision + recall) | ||
accuracy = 100. * n_dev_correct / len(self.data_loader.dataset.examples) | ||
avg_loss = total_loss / len(self.data_loader.dataset.examples) | ||
return [accuracy, avg_loss, precision, recall, f1], ['accuracy', 'cross_entropy_loss', 'precision', 'recall', 'f1'] |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,85 @@ | ||
import time | ||
|
||
import torch.nn.functional as F | ||
from torch.optim.lr_scheduler import ReduceLROnPlateau | ||
|
||
from .trainer import Trainer | ||
from utils.serialization import save_checkpoint | ||
|
||
|
||
class PIT2015Trainer(Trainer): | ||
|
||
def train_epoch(self, epoch): | ||
self.model.train() | ||
total_loss = 0 | ||
for batch_idx, batch in enumerate(self.train_loader): | ||
self.optimizer.zero_grad() | ||
|
||
# Select embedding | ||
sent1, sent2 = self.get_sentence_embeddings(batch) | ||
|
||
output = self.model(sent1, sent2, batch.ext_feats, batch.dataset.word_to_doc_cnt, batch.sentence_1_raw, batch.sentence_2_raw) | ||
loss = F.nll_loss(output, batch.label, size_average=False) | ||
total_loss += loss.item() | ||
loss.backward() | ||
self.optimizer.step() | ||
if batch_idx % self.log_interval == 0: | ||
self.logger.info('Train Epoch: {} [{}/{} ({:.0f}%)]\tLoss: {:.6f}'.format( | ||
epoch, min(batch_idx * self.batch_size, len(batch.dataset.examples)), | ||
len(batch.dataset.examples), | ||
100. * batch_idx / (len(self.train_loader)), loss.item() / len(batch)) | ||
) | ||
|
||
accuracy, avg_loss, precision, recall, f1 = self.evaluate(self.train_evaluator, 'train') | ||
|
||
if self.use_tensorboard: | ||
self.writer.add_scalar('{}/train/cross_entropy_loss'.format(self.train_loader.dataset.NAME), avg_loss, epoch) | ||
self.writer.add_scalar('{}/train/accuracy'.format(self.train_loader.dataset.NAME), accuracy, epoch) | ||
self.writer.add_scalar('{}/train/precision'.format(self.train_loader.dataset.NAME), precision, epoch) | ||
self.writer.add_scalar('{}/train/recall'.format(self.train_loader.dataset.NAME), recall, epoch) | ||
self.writer.add_scalar('{}/train/f1'.format(self.train_loader.dataset.NAME), f1, epoch) | ||
|
||
return total_loss | ||
|
||
def train(self, epochs): | ||
scheduler = None | ||
if self.lr_reduce_factor != 1 and self.lr_reduce_factor != None: | ||
scheduler = ReduceLROnPlateau(self.optimizer, mode='max', factor=self.lr_reduce_factor, patience=self.patience) | ||
epoch_times = [] | ||
prev_loss = -1 | ||
best_dev_score = -1 | ||
for epoch in range(1, epochs + 1): | ||
start = time.time() | ||
self.logger.info('Epoch {} started...'.format(epoch)) | ||
self.train_epoch(epoch) | ||
|
||
dev_scores = self.evaluate(self.dev_evaluator, 'dev') | ||
accuracy, avg_loss, precision, recall, f1 = dev_scores | ||
|
||
test_scores = self.evaluate(self.test_evaluator, 'test') | ||
if self.use_tensorboard: | ||
self.writer.add_scalar('{}/lr'.format(self.train_loader.dataset.NAME), self.optimizer.param_groups[0]['lr'], epoch) | ||
self.writer.add_scalar('{}/dev/cross_entropy_loss'.format(self.train_loader.dataset.NAME), avg_loss, epoch) | ||
self.writer.add_scalar('{}/dev/accuracy'.format(self.train_loader.dataset.NAME), accuracy, epoch) | ||
self.writer.add_scalar('{}/dev/precision'.format(self.train_loader.dataset.NAME), precision, epoch) | ||
self.writer.add_scalar('{}/dev/recall'.format(self.train_loader.dataset.NAME), recall, epoch) | ||
self.writer.add_scalar('{}/dev/f1'.format(self.train_loader.dataset.NAME), f1, epoch) | ||
|
||
end = time.time() | ||
duration = end - start | ||
self.logger.info('Epoch {} finished in {:.2f} minutes'.format(epoch, duration / 60)) | ||
epoch_times.append(duration) | ||
|
||
if f1 > best_dev_score: | ||
best_dev_score = f1 | ||
save_checkpoint(epoch, self.model.arch, self.model.state_dict(), self.optimizer.state_dict(), best_dev_score, self.model_outfile) | ||
|
||
if abs(prev_loss - avg_loss) <= 0.0002: | ||
self.logger.info('Early stopping. Loss changed by less than 0.0002.') | ||
break | ||
|
||
prev_loss = avg_loss | ||
if scheduler is not None: | ||
scheduler.step(f1) | ||
|
||
self.logger.info('Training took {:.2f} minutes overall...'.format(sum(epoch_times) / 60)) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,65 @@ | ||
import os | ||
|
||
import torch | ||
from torchtext.data.field import Field, RawField | ||
from torchtext.data.iterator import BucketIterator | ||
from torchtext.vocab import Vectors | ||
|
||
from datasets.castor_dataset import CastorPairDataset | ||
|
||
|
||
class PIT2015(CastorPairDataset): | ||
NAME = 'pit2015' | ||
NUM_CLASSES = 2 | ||
ID_FIELD = Field(sequential=False, tensor_type=torch.FloatTensor, use_vocab=False, batch_first=True) | ||
AID_FIELD = Field(sequential=False, use_vocab=False, batch_first=True) | ||
TEXT_FIELD = Field(batch_first=True, tokenize=lambda x: x) # tokenizer is identity since we already tokenized it to compute external features | ||
EXT_FEATS_FIELD = Field(tensor_type=torch.FloatTensor, use_vocab=False, batch_first=True, tokenize=lambda x: x) | ||
LABEL_FIELD = Field(sequential=False, use_vocab=False, batch_first=True) | ||
RAW_TEXT_FIELD = RawField() | ||
VOCAB_SIZE = 0 | ||
|
||
@staticmethod | ||
def sort_key(ex): | ||
return len(ex.sentence_1) | ||
|
||
def __init__(self, path): | ||
""" | ||
Create a PIT2015 dataset instance | ||
""" | ||
super(PIT2015, self).__init__(path) | ||
|
||
@classmethod | ||
def splits(cls, path, train='train', validation='dev', test='test', **kwargs): | ||
return super(PIT2015, cls).splits(path, train=train, validation=validation, test=test, **kwargs) | ||
|
||
@classmethod | ||
def iters(cls, path, vectors_name, vectors_dir, batch_size=64, shuffle=True, device=0, pt_file=False, vectors=None, | ||
unk_init=torch.Tensor.zero_): | ||
""" | ||
:param path: directory containing train, test, dev files | ||
:param vectors_name: name of word vectors file | ||
:param vectors_dir: directory containing word vectors file | ||
:param batch_size: batch size | ||
:param device: GPU device | ||
:param vectors: custom vectors - either predefined torchtext vectors or your own custom Vector classes | ||
:param pt_file: load cached embedding file from disk if it is true | ||
:param unk_init: function used to generate vector for OOV words | ||
:return: | ||
""" | ||
|
||
train, validation, test = cls.splits(path) | ||
if not pt_file: | ||
if vectors is None: | ||
vectors = Vectors(name=vectors_name, cache=vectors_dir, unk_init=unk_init) | ||
cls.TEXT_FIELD.build_vocab(train, validation, test, vectors=vectors) | ||
else: | ||
cls.TEXT_FIELD.build_vocab(train, validation, test) | ||
cls.TEXT_FIELD = cls.set_vectors(cls.TEXT_FIELD, os.path.join(vectors_dir, vectors_name)) | ||
|
||
cls.LABEL_FIELD.build_vocab(train, validation, test) | ||
|
||
cls.VOCAB_SIZE = len(cls.TEXT_FIELD.vocab) | ||
|
||
return BucketIterator.splits((train, validation, test), batch_size=batch_size, repeat=False, shuffle=shuffle, | ||
sort_within_batch=True, device=device) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters