Skip to content

Commit

Permalink
Neural Document Classification (castorini#159)
Browse files Browse the repository at this point in the history
* Add ReutersTrainer, ReutersEvaluator options in Factory classes

* Add Reuters to Kim-CNN command line arguments

* Fix SST dataset path according to changes in Kim-CNN args

The dataset path in args.py was made to point at the dataset folder rather than dataset/SST folder. Hence SST folder was added to paths in the SST dataset class

* Add Reuters dataset class, and support in __main__

* Add Reuters dataset trainers and evaluators

* Remove debug print statement in reuters_evaluator

* Fix rounding bug in reuters_trainer and reuters_evaluator

* Add LSTM for baseline text classification measurements

* Add eval metrics for lstm_baseline

* Set batch_first param in lstm_baseline

* Remove onnx args from lstm_baseline

* Pack padded sequences in LSTM_baseline

* Add TensorBoardX support for Reuters trainer

* Add Arxiv Academic Paper Dataset (AAPD)

* Add Hidden Bottleneck Layer to BiLSTM

* Fix packing of padded tensors in Reuters

* Add cmdline args for Hidden Bottleneck Layer for BiLSTM

* Include pre-padding lengths in AAPD dataset

* Remove duplication of preprocessing code in AAPD

* Remove batch_size condition in ReutersTrainer

* Add ignore_lengths option to ReutersTrainer and ReutersEvaluator

* Add AAPDCharQuantized and ReutersCharQuantized

* Rename Reuters_hierarchical to ReutersHierarchical

* Add CharacterCNN for document classification

* Update README.md for CharacterCNN

* Fix table in README.md for CharacterCNN

* Add AAPDHierarchical for HAN

* Update HAN for changes in Reuters dataset endpoints

* Fix bug in CharCNN when running on CPU

* Add AAPD dataset support for KimCNN

* Fix dataset paths for SST-1

* Fix dimensions of FC1 in CharCNN

* Add model checkpointing for Reuters based on F1

* Refactor LSTM baseline __main__

* Add precision, recall and F1 to Reuters evaluator

* Checkpoint only at the end of an epoch for ReutersTrainer

Add detailed log printing for dev evaluations

* Fix log_template and dev_log_template in ReutersTrainer

* Add IMDB dataset

* Add support for single_label datasets in ReutersTrainer

* Add support for IMDB dataset in lstm_baseline and lstm_reg
  • Loading branch information
Achyudh Ram authored and Impavidity committed Nov 11, 2018
1 parent 3dd3ced commit 951df4a
Show file tree
Hide file tree
Showing 10 changed files with 190 additions and 87 deletions.
9 changes: 7 additions & 2 deletions common/dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,8 @@
from datasets.quora import Quora
from datasets.reuters import Reuters
from datasets.aapd import AAPD
from datasets.imdb import IMDB


class UnknownWordVecCache(object):
"""
Expand Down Expand Up @@ -71,8 +73,6 @@ def get_dataset(dataset_name, word_vectors_dir, word_vectors_file, batch_size, d
embedding = nn.Embedding.from_pretrained(PIT2015.TEXT_FIELD.vocab.vectors)
return PIT2015, embedding, train_loader, test_loader, dev_loader



elif dataset_name == 'snli':
dataset_root = os.path.join(castor_dir, os.pardir, 'Castor-data', 'datasets', 'snli_1.0/')
train_loader, dev_loader, test_loader = SNLI.iters(dataset_root, word_vectors_file, word_vectors_dir, batch_size, device=device, unk_init=UnknownWordVecCache.unk)
Expand All @@ -98,6 +98,11 @@ def get_dataset(dataset_name, word_vectors_dir, word_vectors_file, batch_size, d
train_loader, dev_loader, test_loader = AAPD.iters(dataset_root, word_vectors_file, word_vectors_dir, batch_size, device=device, unk_init=UnknownWordVecCache.unk)
embedding = nn.Embedding.from_pretrained(AAPD.TEXT_FIELD.vocab.vectors)
return AAPD, embedding, train_loader, test_loader, dev_loader
elif dataset_name == 'imdb':
dataset_root = os.path.join(castor_dir, os.pardir, 'Castor-data', 'datasets', 'IMDB/')
train_loader, dev_loader, test_loader = AAPD.iters(dataset_root, word_vectors_file, word_vectors_dir, batch_size, device=device, unk_init=UnknownWordVecCache.unk)
embedding = nn.Embedding.from_pretrained(AAPD.TEXT_FIELD.vocab.vectors)
return IMDB, embedding, train_loader, test_loader, dev_loader
else:
raise ValueError('{} is not a valid dataset.'.format(dataset_name))

1 change: 1 addition & 0 deletions common/evaluation.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,7 @@ class EvaluatorFactory(object):
'twitterurl': PIT2015Evaluator,
'Reuters': ReutersEvaluator,
'AAPD': ReutersEvaluator,
'IMDB': ReutersEvaluator,
'SNLI': SNLIEvaluator,
'sts2014': STS2014Evaluator,
'Quora': QuoraEvaluator
Expand Down
18 changes: 12 additions & 6 deletions common/evaluators/reuters_evaluator.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@ class ReutersEvaluator(Evaluator):
def __init__(self, dataset_cls, model, embedding, data_loader, batch_size, device, keep_results=False):
super().__init__(dataset_cls, model, embedding, data_loader, batch_size, device, keep_results)
self.ignore_lengths = False
self.single_label = False

def get_scores(self):
self.model.eval()
Expand All @@ -36,13 +37,18 @@ def get_scores(self):
else:
scores = self.model(batch.text[0], lengths=batch.text[1])

total_loss += F.binary_cross_entropy_with_logits(scores, batch.label.float(), size_average=False).item()
if hasattr(self.model, 'TAR') and self.model.TAR: # TAR condition
total_loss += (rnn_outs[1:]-rnn_outs[:-1]).pow(2).mean()
if self.single_label:
predicted_labels.extend(torch.argmax(scores, dim=1).cpu().detach().numpy())
target_labels.extend(torch.argmax(batch.label, dim=1).cpu().detach().numpy())
total_loss += F.cross_entropy(scores, torch.argmax(batch.label, dim=1), size_average=False).item()
else:
scores_rounded = F.sigmoid(scores).round().long()
predicted_labels.extend(scores_rounded.cpu().detach().numpy())
target_labels.extend(batch.label.cpu().detach().numpy())
total_loss += F.binary_cross_entropy_with_logits(scores, batch.label.float(), size_average=False).item()

scores_rounded = F.sigmoid(scores).round().long()
predicted_labels.extend(scores_rounded.cpu().detach().numpy())
target_labels.extend(batch.label.cpu().detach().numpy())
if hasattr(self.model, 'TAR') and self.model.TAR: # TAR condition
total_loss += (rnn_outs[1:] - rnn_outs[:-1]).pow(2).mean()

predicted_labels = np.array(predicted_labels)
target_labels = np.array(target_labels)
Expand Down
1 change: 1 addition & 0 deletions common/train.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,7 @@ class TrainerFactory(object):
'twitterurl': PIT2015Trainer,
'Reuters': ReutersTrainer,
'AAPD': ReutersTrainer,
'IMDB': ReutersTrainer,
'snli': SNLITrainer,
'sts2014': STS2014Trainer,
'quora': QuoraTrainer
Expand Down
96 changes: 51 additions & 45 deletions common/trainers/reuters_trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,8 +21,8 @@ def __init__(self, model, embedding, train_loader, trainer_config, train_evaluat
self.iters_not_improved = 0
self.start = None
self.log_template = ' '.join(
'{:>6.0f},{:>5.0f},{:>9.0f},{:>5.0f}/{:<5.0f} {:>7.0f}%,{:>8.6f},{},{:12.4f},{}'.split(','))
self.dev_log_template = ' '.join('{:>6.0f},{:>5.0f},{:>9.0f},{:>5.0f}/{:<5.0f} {:>7.0f}%,{:>8.6f},{:8.6f},{:12.4f},{:12.4f}'.split(','))
'{:>6.0f},{:>5.0f},{:>9.0f},{:>5.0f}/{:<5.0f} {:>7.0f}%,{:>8.6f},{:12.4f}'.split(','))
self.dev_log_template = ' '.join('{:>6.0f},{:>5.0f},{:>9.0f},{:>5.0f}/{:<5.0f} {:>7.4f},{:>8.4f},{:8.4f},{:12.4f},{:12.4f}'.split(','))
self.writer = SummaryWriter(log_dir="tensorboard_logs/" + datetime.datetime.now().strftime("%Y-%m-%d_%H-%M-%S"))
self.snapshot_path = os.path.join(self.model_outfile, self.train_loader.dataset.NAME, 'best_model.pt')

Expand All @@ -34,76 +34,82 @@ def train_epoch(self, epoch):
self.model.train()
self.optimizer.zero_grad()
if hasattr(self.model, 'TAR') and self.model.TAR:
if 'ignore_lengths' in self.config and self.config['ignore_lengths'] == True:
if 'ignore_lengths' in self.config and self.config['ignore_lengths']:
scores, rnn_outs = self.model(batch.text, lengths=batch.text)
else:
scores, rnn_outs = self.model(batch.text[0], lengths=batch.text[1])
else:
if 'ignore_lengths' in self.config and self.config['ignore_lengths'] == True:
if 'ignore_lengths' in self.config and self.config['ignore_lengths']:
scores = self.model(batch.text, lengths=batch.text)
else:
scores = self.model(batch.text[0], lengths=batch.text[1])
# Using binary accuracy
for tensor1, tensor2 in zip(F.sigmoid(scores).round().long(), batch.label):
if np.array_equal(tensor1, tensor2):
n_correct += 1
n_total += batch.batch_size
train_acc = 100. * n_correct / n_total
loss = F.binary_cross_entropy_with_logits(scores, batch.label.float())

if 'single_label' in self.config and self.config['single_label']:
for tensor1, tensor2 in zip(torch.argmax(scores, dim=1), torch.argmax(batch.label.data, dim=1)):
if np.array_equal(tensor1, tensor2):
n_correct += 1
loss = F.cross_entropy(scores, torch.argmax(batch.label.data, dim=1))
else:
predictions = F.sigmoid(scores).round().long()
# Computing binary accuracy
for tensor1, tensor2 in zip(predictions, batch.label):
if np.array_equal(tensor1, tensor2):
n_correct += 1
loss = F.binary_cross_entropy_with_logits(scores, batch.label.float())

if hasattr(self.model, 'TAR') and self.model.TAR:
loss = loss + (rnn_outs[1:] - rnn_outs[:-1]).pow(2).mean()
loss.backward()

n_total += batch.batch_size
train_acc = 100. * n_correct / n_total
loss.backward()
self.optimizer.step()

# Temp Ave
if hasattr(self.model, 'beta_ema') and self.model.beta_ema > 0:
self.model.update_ema()

# Evaluate performance on validation set
if self.iterations % self.dev_log_interval == 1:
dev_acc, dev_precision, dev_recall, dev_f1, dev_loss = self.dev_evaluator.get_scores()[0]
if self.iterations % self.log_interval == 1:
niter = epoch * len(self.train_loader) + batch_idx
self.writer.add_scalar('Train/Loss', loss.data[0], niter)
self.writer.add_scalar('Dev/Loss', dev_loss, niter)
self.writer.add_scalar('Train/Loss', loss.data.item(), niter)
self.writer.add_scalar('Train/Accuracy', train_acc, niter)
self.writer.add_scalar('Dev/Accuracy', dev_acc, niter)
self.writer.add_scalar('Dev/Precision', dev_precision, niter)
self.writer.add_scalar('Dev/Recall', dev_recall, niter)
self.writer.add_scalar('Dev/F-measure', dev_f1, niter)
print(self.dev_log_template.format(time.time() - self.start,
epoch, self.iterations, 1 + batch_idx, len(self.train_loader),
100. * (1 + batch_idx) / len(self.train_loader), loss.item(),
dev_loss, train_acc, dev_acc))

# Update validation results
if dev_f1 > self.best_dev_f1:
self.iters_not_improved = 0
self.best_dev_f1 = dev_f1
torch.save(self.model.state_dict(), self.snapshot_path)
else:
self.iters_not_improved += 1
if self.iters_not_improved >= self.patience:
self.early_stop = True
break

if self.iterations % self.log_interval == 1:
# print progress message
print(self.log_template.format(time.time() - self.start,
epoch, self.iterations, 1 + batch_idx, len(self.train_loader),
100. * (1 + batch_idx) / len(self.train_loader), loss.item(), ' ' * 8,
train_acc, ' ' * 12))
100. * (1 + batch_idx) / len(self.train_loader), loss.item(),
train_acc))

def train(self, epochs):
self.start = time.time()
header = ' Time Epoch Iteration Progress (%Epoch) Loss Dev/Loss Accuracy Dev/Accuracy'
header = ' Time Epoch Iteration Progress (%Epoch) Loss Accuracy'
dev_header = ' Time Epoch Iteration Progress Dev/Acc. Dev/Pr. Dev/Recall Dev/F1 Dev/Loss'
# model_outfile is actually a directory, using model_outfile to conform to Trainer naming convention
os.makedirs(self.model_outfile, exist_ok=True)
os.makedirs(os.path.join(self.model_outfile, self.train_loader.dataset.NAME), exist_ok=True)
print(header)

for epoch in range(1, epochs + 1):
if self.early_stop:
print("Early Stopping. Epoch: {}, Best Dev F1: {}".format(epoch, self.best_dev_f1))
break
self.train_epoch(epoch)

# Evaluate performance on validation set
dev_acc, dev_precision, dev_recall, dev_f1, dev_loss = self.dev_evaluator.get_scores()[0]
self.writer.add_scalar('Dev/Loss', dev_loss, epoch)
self.writer.add_scalar('Dev/Accuracy', dev_acc, epoch)
self.writer.add_scalar('Dev/Precision', dev_precision, epoch)
self.writer.add_scalar('Dev/Recall', dev_recall, epoch)
self.writer.add_scalar('Dev/F-measure', dev_f1, epoch)
print('\n' + dev_header)
print(self.dev_log_template.format(time.time() - self.start, epoch, self.iterations, epoch, epochs,
dev_acc, dev_precision, dev_recall, dev_f1, dev_loss))
print('\n' + header)

# Update validation results
if dev_f1 > self.best_dev_f1:
self.iters_not_improved = 0
self.best_dev_f1 = dev_f1
torch.save(self.model.state_dict(), self.snapshot_path)
else:
self.iters_not_improved += 1
if self.iters_not_improved >= self.patience:
self.early_stop = True
print("Early Stopping. Epoch: {}, Best Dev F1: {}".format(epoch, self.best_dev_f1))
break
89 changes: 89 additions & 0 deletions datasets/imdb.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,89 @@
import numpy as np
import os
import re
import torch
from datasets.reuters import clean_string, clean_string_fl, split_sents
from torchtext.data import NestedField, Field, TabularDataset
from torchtext.data.iterator import BucketIterator
from torchtext.vocab import Vectors


def char_quantize(string, max_length=1000):
identity = np.identity(len(IMDBCharQuantized.ALPHABET))
quantized_string = np.array([identity[IMDBCharQuantized.ALPHABET[char]] for char in list(string.lower()) if char in IMDBCharQuantized.ALPHABET], dtype=np.float32)
if len(quantized_string) > max_length:
return quantized_string[:max_length]
else:
return np.concatenate((quantized_string, np.zeros((max_length - len(quantized_string), len(IMDBCharQuantized.ALPHABET)), dtype=np.float32)))


def process_labels(string):
"""
Returns the label string as a list of integers
:param string:
:return:
"""
return [float(x) for x in string]


class IMDB(TabularDataset):
NAME = 'IMDB'
NUM_CLASSES = 10
TEXT_FIELD = Field(batch_first=True, tokenize=clean_string, include_lengths=True)
LABEL_FIELD = Field(sequential=False, use_vocab=False, batch_first=True, preprocessing=process_labels)

@staticmethod
def sort_key(ex):
return len(ex.text)

@classmethod
def splits(cls, path, train=os.path.join('IMDB', 'data', 'imdb_train.tsv'),
validation=os.path.join('IMDB', 'data', 'imdb_validation.tsv'),
test=os.path.join('IMDB', 'data', 'imdb_test.tsv'), **kwargs):
return super(IMDB, cls).splits(
path, train=train, validation=validation, test=test,
format='tsv', fields=[('label', cls.LABEL_FIELD), ('text', cls.TEXT_FIELD)]
)

@classmethod
def iters(cls, path, vectors_name, vectors_cache, batch_size=64, shuffle=True, device=0, vectors=None,
unk_init=torch.Tensor.zero_):
"""
:param path: directory containing train, test, dev files
:param vectors_name: name of word vectors file
:param vectors_cache: path to directory containing word vectors file
:param batch_size: batch size
:param device: GPU device
:param vectors: custom vectors - either predefined torchtext vectors or your own custom Vector classes
:param unk_init: function used to generate vector for OOV words
:return:
"""
if vectors is None:
vectors = Vectors(name=vectors_name, cache=vectors_cache, unk_init=unk_init)

train, val, test = cls.splits(path)
cls.TEXT_FIELD.build_vocab(train, val, test, vectors=vectors)
return BucketIterator.splits((train, val, test), batch_size=batch_size, repeat=False, shuffle=shuffle,
sort_within_batch=True, device=device)


class IMDBCharQuantized(IMDB):
ALPHABET = dict(map(lambda t: (t[1], t[0]), enumerate(list("""abcdefghijklmnopqrstuvwxyz0123456789,;.!?:'\"/\\|_@#$%^&*~`+-=<>()[]{}"""))))
TEXT_FIELD = Field(sequential=False, use_vocab=False, batch_first=True, preprocessing=char_quantize)

@classmethod
def iters(cls, path, vectors_name, vectors_cache, batch_size=64, shuffle=True, device=0, vectors=None,
unk_init=torch.Tensor.zero_):
"""
:param path: directory containing train, test, dev files
:param batch_size: batch size
:param device: GPU device
:return:
"""
train, val, test = cls.splits(path)
return BucketIterator.splits((train, val, test), batch_size=batch_size, repeat=False, shuffle=shuffle, device=device)


class IMDBHierarchical(IMDB):
In_FIELD = Field(batch_first=True, tokenize=clean_string)
TEXT_FIELD = NestedField(In_FIELD, tokenize=split_sents)
10 changes: 8 additions & 2 deletions lstm_baseline/__main__.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
from datasets.sst import SST1
from datasets.sst import SST2
from datasets.reuters import Reuters
from datasets.imdb import IMDB
from datasets.aapd import AAPD
from lstm_baseline.args import get_args
from lstm_baseline.model import LSTMBaseline
Expand Down Expand Up @@ -78,7 +79,8 @@ def evaluate_dataset(split_name, dataset_cls, model, embedding, loader, batch_si
'SST-1': SST1,
'SST-2': SST2,
'Reuters': Reuters,
'AAPD': AAPD
'AAPD': AAPD,
'IMDB': IMDB
}

if args.dataset not in dataset_map:
Expand Down Expand Up @@ -118,6 +120,9 @@ def evaluate_dataset(split_name, dataset_cls, model, embedding, loader, batch_si
train_evaluator = EvaluatorFactory.get_evaluator(dataset_map[args.dataset], model, None, train_iter, args.batch_size, args.gpu)
test_evaluator = EvaluatorFactory.get_evaluator(dataset_map[args.dataset], model, None, test_iter, args.batch_size, args.gpu)
dev_evaluator = EvaluatorFactory.get_evaluator(dataset_map[args.dataset], model, None, dev_iter, args.batch_size, args.gpu)
train_evaluator.single_label = args.single_label
test_evaluator.single_label = args.single_label
dev_evaluator.single_label = args.single_label

trainer_config = {
'optimizer': optimizer,
Expand All @@ -126,7 +131,8 @@ def evaluate_dataset(split_name, dataset_cls, model, embedding, loader, batch_si
'dev_log_interval': args.dev_every,
'patience': args.patience,
'model_outfile': args.save_path, # actually a directory, using model_outfile to conform to Trainer naming convention
'logger': logger
'logger': logger,
'single_label': args.single_label
}
trainer = TrainerFactory.get_trainer(args.dataset, model, None, train_iter, trainer_config, train_evaluator, test_evaluator, dev_evaluator)

Expand Down
3 changes: 2 additions & 1 deletion lstm_baseline/args.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,12 +11,13 @@ def get_args():
parser.add_argument('--batch_size', type=int, default=1024)
parser.add_argument('--bidirectional', action='store_true'),
parser.add_argument('--bottleneck_layer', action='store_true'),
parser.add_argument('--single_label', action='store_true'),
parser.add_argument('--num_layers', type=int, default=2)
parser.add_argument('--hidden_dim', type=int, default=256)
parser.add_argument('--mode', type=str, default='static', choices=['rand', 'static', 'non-static'])
parser.add_argument('--lr', type=float, default=0.001)
parser.add_argument('--seed', type=int, default=3435)
parser.add_argument('--dataset', type=str, default='Reuters', choices=['SST-1', 'SST-2', 'Reuters', 'AAPD'])
parser.add_argument('--dataset', type=str, default='Reuters', choices=['SST-1', 'SST-2', 'Reuters', 'AAPD', 'IMDB'])
parser.add_argument('--resume_snapshot', type=str, default=None)
parser.add_argument('--dev_every', type=int, default=30)
parser.add_argument('--log_every', type=int, default=10)
Expand Down
Loading

0 comments on commit 951df4a

Please sign in to comment.