Skip to content

Commit

Permalink
Add document classification models and datasets (#171)
Browse files Browse the repository at this point in the history
* Add ReutersTrainer, ReutersEvaluator options in Factory classes

* Add Reuters to Kim-CNN command line arguments

* Fix SST dataset path according to changes in Kim-CNN args

The dataset path in args.py was made to point at the dataset folder rather than dataset/SST folder. Hence SST folder was added to paths in the SST dataset class

* Add Reuters dataset class, and support in __main__

* Add Reuters dataset trainers and evaluators

* Remove debug print statement in reuters_evaluator

* Fix rounding bug in reuters_trainer and reuters_evaluator

* Add LSTM for baseline text classification measurements

* Add eval metrics for lstm_baseline

* Set batch_first param in lstm_baseline

* Remove onnx args from lstm_baseline

* Pack padded sequences in LSTM_baseline

* Add TensorBoardX support for Reuters trainer

* Add Arxiv Academic Paper Dataset (AAPD)

* Add Hidden Bottleneck Layer to BiLSTM

* Fix packing of padded tensors in Reuters

* Add cmdline args for Hidden Bottleneck Layer for BiLSTM

* Include pre-padding lengths in AAPD dataset

* Remove duplication of preprocessing code in AAPD

* Remove batch_size condition in ReutersTrainer

* Add ignore_lengths option to ReutersTrainer and ReutersEvaluator

* Add AAPDCharQuantized and ReutersCharQuantized

* Rename Reuters_hierarchical to ReutersHierarchical

* Add CharacterCNN for document classification

* Update README.md for CharacterCNN

* Fix table in README.md for CharacterCNN

* Add AAPDHierarchical for HAN

* Update HAN for changes in Reuters dataset endpoints

* Fix bug in CharCNN when running on CPU

* Add AAPD dataset support for KimCNN

* Fix dataset paths for SST-1

* Fix dimensions of FC1 in CharCNN

* Add model checkpointing for Reuters based on F1

* Refactor LSTM baseline __main__

* Add precision, recall and F1 to Reuters evaluator

* Checkpoint only at the end of an epoch for ReutersTrainer

Add detailed log printing for dev evaluations

* Fix log_template and dev_log_template in ReutersTrainer

* Add IMDB dataset

* Fix duplicate printing of header in ReutersTrainer

* Add support for single_label datasets in ReutersTrainer

* Add support for IMDB dataset in lstm_baseline and lstm_reg

* Fix evaluator call in main method of HAN

* Add IMDB for HAN

* Fix for single_label

* Fix evaluate_dataset method for single_label datasets

* Reduce default patience to 5 epochs before early stopping

* Revert change to save_state rather than the entire model

* Add Yelp 2018 dataset

* Integrate Yelp2018 with LSTM baseline

* Replace Yelp2018 with Yelp2014 dataset

* Add Yelp2014 to LSTM Baseline

* Integrate Yelp14 into LSTM Regularization

* Remove dropout in HBL for LSTM Baseline and Reg

* Add Yelp for HAN

* Fix the saving issue for HAN

* Fix loading for HAN

* Fix typo in ReutersEvaluator

* Print to STDOUT rather than logger

* Print XML-CNN eval to STDOUT rather than logger

* Update max_length for IMDB dataset

* Add single_label support for char_cnn

* Fix evaluation method for char_cnn

* Remove unwanted parameters from ReutersTrainer and ReutersEval

* Fix code formatting in lstm_reg/args

* Add support for IMDB and Yelp in KimCNN

* Fix single_label incorporation

* Remove unnecessary conditions

* Fix num_classes in Yelp2014

* Add single_label support for XML-CNN

* Fix call to evaluator in XML-CNN

* Address PEP8 issues

* Address PEP8 issues

* Address PEP8 issues

* Address PEP8 issues
  • Loading branch information
Ashutosh-Adhikari authored and daemon committed Jan 25, 2019
1 parent 57f53a8 commit dc086e8
Show file tree
Hide file tree
Showing 23 changed files with 311 additions and 257 deletions.
101 changes: 46 additions & 55 deletions char_cnn/__main__.py
Original file line number Diff line number Diff line change
@@ -1,18 +1,22 @@
from sklearn import metrics

from copy import deepcopy
import logging
import numpy as np
import random

import numpy as np
from sklearn import metrics
import torch
import torch.nn.functional as F
from copy import deepcopy

from char_cnn.args import get_args
from char_cnn.model import CharCNN
from common.evaluation import EvaluatorFactory
from common.train import TrainerFactory
from datasets.aapd import AAPDCharQuantized as AAPD
from datasets.imdb import IMDBCharQuantized as IMDB
from datasets.reuters import ReutersCharQuantized as Reuters
from char_cnn.args import get_args
from char_cnn.model import CharCNN
from datasets.yelp2014 import Yelp2014CharQuantized as Yelp2014




class UnknownWordVecCache(object):
Expand Down Expand Up @@ -45,13 +49,14 @@ def get_logger():
return logger


def evaluate_dataset(split_name, dataset_cls, model, embedding, loader, batch_size, device):
def evaluate_dataset(split_name, dataset_cls, model, embedding, loader, batch_size, device, single_label):
saved_model_evaluator = EvaluatorFactory.get_evaluator(dataset_cls, model, embedding, loader, batch_size, device)
saved_model_evaluator.ignore_lengths = True
saved_model_evaluator.single_label = single_label
scores, metric_names = saved_model_evaluator.get_scores()
logger.info('Evaluation metrics for {}'.format(split_name))
logger.info('\t'.join([' '] + metric_names))
logger.info('\t'.join([split_name] + list(map(str, scores))))
print('Evaluation metrics for', split_name)
print(metric_names)
print(scores)


if __name__ == '__main__':
Expand All @@ -73,13 +78,20 @@ def evaluate_dataset(split_name, dataset_cls, model, embedding, loader, batch_si
random.seed(args.seed)
logger = get_logger()

# Set up the data for training SST-1
if args.dataset == 'Reuters':
train_iter, dev_iter, test_iter = Reuters.iters(args.data_dir, args.word_vectors_file, args.word_vectors_dir, batch_size=args.batch_size, device=args.gpu, unk_init=UnknownWordVecCache.unk)
elif args.dataset == 'AAPD':
train_iter, dev_iter, test_iter = AAPD.iters(args.data_dir, args.word_vectors_file, args.word_vectors_dir, batch_size=args.batch_size, device=args.gpu, unk_init=UnknownWordVecCache.unk)
else:
dataset_map = {
'Reuters': Reuters,
'AAPD': AAPD,
'IMDB': IMDB,
'Yelp2014': Yelp2014
}

if args.dataset not in dataset_map:
raise ValueError('Unrecognized dataset')
else:
train_iter, dev_iter, test_iter = dataset_map[args.dataset].iters(args.data_dir, args.word_vectors_file,
args.word_vectors_dir,
batch_size=args.batch_size, device=args.gpu,
unk_init=UnknownWordVecCache.unk)

config = deepcopy(args)
config.dataset = train_iter.dataset
Expand All @@ -104,19 +116,18 @@ def evaluate_dataset(split_name, dataset_cls, model, embedding, loader, batch_si
parameter = filter(lambda p: p.requires_grad, model.parameters())
optimizer = torch.optim.Adam(parameter, lr=args.lr, weight_decay=args.weight_decay)

if args.dataset == 'Reuters':
train_evaluator = EvaluatorFactory.get_evaluator(Reuters, model, None, train_iter, args.batch_size, args.gpu)
test_evaluator = EvaluatorFactory.get_evaluator(Reuters, model, None, test_iter, args.batch_size, args.gpu)
dev_evaluator = EvaluatorFactory.get_evaluator(Reuters, model, None, dev_iter, args.batch_size, args.gpu)
elif args.dataset == 'AAPD':
train_evaluator = EvaluatorFactory.get_evaluator(AAPD, model, None, train_iter, args.batch_size, args.gpu)
test_evaluator = EvaluatorFactory.get_evaluator(AAPD, model, None, test_iter, args.batch_size, args.gpu)
dev_evaluator = EvaluatorFactory.get_evaluator(AAPD, model, None, dev_iter, args.batch_size, args.gpu)
else:
if args.dataset not in dataset_map:
raise ValueError('Unrecognized dataset')
else:
train_evaluator = EvaluatorFactory.get_evaluator(dataset_map[args.dataset], model, None, train_iter, args.batch_size, args.gpu)
test_evaluator = EvaluatorFactory.get_evaluator(dataset_map[args.dataset], model, None, test_iter, args.batch_size, args.gpu)
dev_evaluator = EvaluatorFactory.get_evaluator(dataset_map[args.dataset], model, None, dev_iter, args.batch_size, args.gpu)
train_evaluator.single_label = args.single_label
test_evaluator.single_label = args.single_label
dev_evaluator.single_label = args.single_label
dev_evaluator.ignore_lengths = True
test_evaluator.ignore_lengths = True

dev_evaluator.ignore_lengths = True
test_evaluator.ignore_lengths = True
trainer_config = {
'optimizer': optimizer,
'batch_size': args.batch_size,
Expand All @@ -125,7 +136,8 @@ def evaluate_dataset(split_name, dataset_cls, model, embedding, loader, batch_si
'patience': args.patience,
'model_outfile': args.save_path, # actually a directory, using model_outfile to conform to Trainer naming convention
'logger': logger,
'ignore_lengths': True
'ignore_lengths': True,
'single_label': args.single_label
}
trainer = TrainerFactory.get_trainer(args.dataset, model, None, train_iter, trainer_config, train_evaluator, test_evaluator, dev_evaluator)

Expand All @@ -137,31 +149,10 @@ def evaluate_dataset(split_name, dataset_cls, model, embedding, loader, batch_si
else:
model = torch.load(args.trained_model, map_location=lambda storage, location: storage)

if args.dataset == 'Reuters':
evaluate_dataset('dev', Reuters, model, None, dev_iter, args.batch_size, args.gpu)
evaluate_dataset('test', Reuters, model, None, test_iter, args.batch_size, args.gpu)
elif args.dataset == 'AAPD':
evaluate_dataset('dev', AAPD, model, None, dev_iter, args.batch_size, args.gpu)
evaluate_dataset('test', AAPD, model, None, test_iter, args.batch_size, args.gpu)
else:
raise ValueError('Unrecognized dataset')

# Calculate dev and test metrics
for data_loader in [dev_iter, test_iter]:
predicted_labels = list()
target_labels = list()
for batch_idx, batch in enumerate(data_loader):
scores_rounded = F.sigmoid(model(batch.text)).round().long()
predicted_labels.extend(scores_rounded.cpu().detach().numpy())
target_labels.extend(batch.label.cpu().detach().numpy())
predicted_labels = np.array(predicted_labels)
target_labels = np.array(target_labels)
accuracy = metrics.accuracy_score(target_labels, predicted_labels)
precision = metrics.precision_score(target_labels, predicted_labels, average='micro')
recall = metrics.recall_score(target_labels, predicted_labels, average='micro')
f1 = metrics.f1_score(target_labels, predicted_labels, average='micro')
if data_loader == dev_iter:
print("Dev metrics:")
else:
print("Test metrics:")
print(accuracy, precision, recall, f1)
model = torch.load(trainer.snapshot_path)
if args.dataset not in dataset_map:
raise ValueError('Unrecognized dataset')
else:
evaluate_dataset('dev', dataset_map[args.dataset], model, None, dev_iter, args.batch_size, args.gpu, args.single_label)
evaluate_dataset('test', dataset_map[args.dataset], model, None, test_iter, args.batch_size, args.gpu, args.single_label)
5 changes: 3 additions & 2 deletions char_cnn/args.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,12 +11,13 @@ def get_args():
parser.add_argument('--batch_size', type=int, default=128)
parser.add_argument('--lr', type=float, default=0.001)
parser.add_argument('--seed', type=int, default=3435)
parser.add_argument('--dataset', type=str, default='Reuters', choices=['Reuters', 'AAPD'])
parser.add_argument('--single_label', action='store_true'),
parser.add_argument('--dataset', type=str, default='Reuters', choices=['Reuters', 'AAPD', 'IMDB', 'Yelp2014'])
parser.add_argument('--resume_snapshot', type=str, default=None)
parser.add_argument('--dev_every', type=int, default=30)
parser.add_argument('--log_every', type=int, default=10)
parser.add_argument('--patience', type=int, default=100)
parser.add_argument('--save_path', type=str, default='kim_cnn/saves')
parser.add_argument('--save_path', type=str, default='char_cnn/saves')
parser.add_argument('--num_conv_filters', type=int, default=256)
parser.add_argument('--num_affine_neurons', type=int, default=1024)
parser.add_argument('--output_channel', type=int, default=256)
Expand Down
1 change: 1 addition & 0 deletions common/evaluation.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,7 @@ class EvaluatorFactory(object):
'Reuters': ReutersEvaluator,
'AAPD': ReutersEvaluator,
'IMDB': ReutersEvaluator,
'Yelp2014': ReutersEvaluator,
'SNLI': SNLIEvaluator,
'sts2014': STS2014Evaluator,
'Quora': QuoraEvaluator
Expand Down
6 changes: 3 additions & 3 deletions common/evaluators/reuters_evaluator.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,12 +28,12 @@ def get_scores(self):
for batch_idx, batch in enumerate(self.data_loader):
if hasattr(self.model, 'TAR') and self.model.TAR: # TAR condition
if self.ignore_lengths:
scores, rnn_outs = self.model(batch.text, lengths=batch.text)
scores, rnn_outs = self.model(batch.text)
else:
scores, rnn_outs = self.model(batch.text[0], lengths=batch.text[1])
else:
if self.ignore_lengths:
scores = self.model(batch.text, lengths=batch.text)
scores = self.model(batch.text)
else:
scores = self.model(batch.text[0], lengths=batch.text[1])

Expand Down Expand Up @@ -62,4 +62,4 @@ def get_scores(self):
if hasattr(self.model, 'beta_ema') and self.model.beta_ema > 0:
self.model.load_params(old_params)

return [accuracy, precision, recall, f1, avg_loss], ['accuracy', 'precision', 'recall', 'f1' 'cross_entropy_loss']
return [accuracy, precision, recall, f1, avg_loss], ['accuracy', 'precision', 'recall', 'f1', 'cross_entropy_loss']
1 change: 1 addition & 0 deletions common/train.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,7 @@ class TrainerFactory(object):
'Reuters': ReutersTrainer,
'AAPD': ReutersTrainer,
'IMDB': ReutersTrainer,
'Yelp2014': ReutersTrainer,
'snli': SNLITrainer,
'sts2014': STS2014Trainer,
'quora': QuoraTrainer
Expand Down
9 changes: 4 additions & 5 deletions common/trainers/reuters_trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,12 +35,12 @@ def train_epoch(self, epoch):
self.optimizer.zero_grad()
if hasattr(self.model, 'TAR') and self.model.TAR:
if 'ignore_lengths' in self.config and self.config['ignore_lengths']:
scores, rnn_outs = self.model(batch.text, lengths=batch.text)
scores, rnn_outs = self.model(batch.text)
else:
scores, rnn_outs = self.model(batch.text[0], lengths=batch.text[1])
else:
if 'ignore_lengths' in self.config and self.config['ignore_lengths']:
scores = self.model(batch.text, lengths=batch.text)
scores = self.model(batch.text)
else:
scores = self.model(batch.text[0], lengths=batch.text[1])

Expand Down Expand Up @@ -85,9 +85,9 @@ def train(self, epochs):
# model_outfile is actually a directory, using model_outfile to conform to Trainer naming convention
os.makedirs(self.model_outfile, exist_ok=True)
os.makedirs(os.path.join(self.model_outfile, self.train_loader.dataset.NAME), exist_ok=True)
print(header)

for epoch in range(1, epochs + 1):
print('\n' + header)
self.train_epoch(epoch)

# Evaluate performance on validation set
Expand All @@ -100,13 +100,12 @@ def train(self, epochs):
print('\n' + dev_header)
print(self.dev_log_template.format(time.time() - self.start, epoch, self.iterations, epoch, epochs,
dev_acc, dev_precision, dev_recall, dev_f1, dev_loss))
print('\n' + header)

# Update validation results
if dev_f1 > self.best_dev_f1:
self.iters_not_improved = 0
self.best_dev_f1 = dev_f1
torch.save(self.model.state_dict(), self.snapshot_path)
torch.save(self.model, self.snapshot_path)
else:
self.iters_not_improved += 1
if self.iters_not_improved >= self.patience:
Expand Down
10 changes: 6 additions & 4 deletions datasets/aapd.py
Original file line number Diff line number Diff line change
@@ -1,12 +1,14 @@
import re
import os
import re

import numpy as np
import torch
from datasets.reuters import clean_string, char_quantize, clean_string_fl, split_sents
from torchtext.data import NestedField, Field, TabularDataset
from torchtext.data.iterator import BucketIterator
from torchtext.vocab import Vectors

from datasets.reuters import clean_string, clean_string_fl, split_sents


def process_labels(string):
"""
Expand Down Expand Up @@ -76,5 +78,5 @@ def iters(cls, path, vectors_name, vectors_cache, batch_size=64, shuffle=True, d
return BucketIterator.splits((train, val, test), batch_size=batch_size, repeat=False, shuffle=shuffle, device=device)

class AAPDHierarchical(AAPD):
In_FIELD = Field(batch_first=True, tokenize=clean_string)
TEXT_FIELD = NestedField(In_FIELD, tokenize=split_sents)
NESTING_FIELD = Field(batch_first=True, tokenize=clean_string)
TEXT_FIELD = NestedField(NESTING_FIELD, tokenize=split_sents)
12 changes: 7 additions & 5 deletions datasets/imdb.py
Original file line number Diff line number Diff line change
@@ -1,14 +1,16 @@
import numpy as np
import os
import re

import numpy as np
import torch
from datasets.reuters import clean_string, clean_string_fl, split_sents
from torchtext.data import NestedField, Field, TabularDataset
from torchtext.data.iterator import BucketIterator
from torchtext.vocab import Vectors

from datasets.reuters import clean_string, clean_string_fl, split_sents


def char_quantize(string, max_length=1000):
def char_quantize(string, max_length=500):
identity = np.identity(len(IMDBCharQuantized.ALPHABET))
quantized_string = np.array([identity[IMDBCharQuantized.ALPHABET[char]] for char in list(string.lower()) if char in IMDBCharQuantized.ALPHABET], dtype=np.float32)
if len(quantized_string) > max_length:
Expand Down Expand Up @@ -85,5 +87,5 @@ def iters(cls, path, vectors_name, vectors_cache, batch_size=64, shuffle=True, d


class IMDBHierarchical(IMDB):
In_FIELD = Field(batch_first=True, tokenize=clean_string)
TEXT_FIELD = NestedField(In_FIELD, tokenize=split_sents)
NESTING_FIELD = Field(batch_first=True, tokenize=clean_string)
TEXT_FIELD = NestedField(NESTING_FIELD, tokenize=split_sents)
7 changes: 4 additions & 3 deletions datasets/reuters.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
import numpy as np
import os
import re

import numpy as np
import torch
from torchtext.data import NestedField, Field, TabularDataset
from torchtext.data.iterator import BucketIterator
Expand Down Expand Up @@ -109,5 +110,5 @@ def iters(cls, path, vectors_name, vectors_cache, batch_size=64, shuffle=True, d


class ReutersHierarchical(Reuters):
In_FIELD = Field(batch_first=True, tokenize=clean_string)
TEXT_FIELD = NestedField(In_FIELD, tokenize=split_sents)
NESTING_FIELD = Field(batch_first=True, tokenize=clean_string)
TEXT_FIELD = NestedField(NESTING_FIELD, tokenize=split_sents)
Loading

0 comments on commit dc086e8

Please sign in to comment.