forked from castorini/castor
-
Notifications
You must be signed in to change notification settings - Fork 0
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Add CharacterCNN for Document Classification (castorini#155)
* Add ReutersTrainer, ReutersEvaluator options in Factory classes * Add Reuters to Kim-CNN command line arguments * Fix SST dataset path according to changes in Kim-CNN args The dataset path in args.py was made to point at the dataset folder rather than dataset/SST folder. Hence SST folder was added to paths in the SST dataset class * Add Reuters dataset class, and support in __main__ * Add Reuters dataset trainers and evaluators * Remove debug print statement in reuters_evaluator * Fix rounding bug in reuters_trainer and reuters_evaluator * Add LSTM for baseline text classification measurements * Add eval metrics for lstm_baseline * Set batch_first param in lstm_baseline * Remove onnx args from lstm_baseline * Pack padded sequences in LSTM_baseline * Add TensorBoardX support for Reuters trainer * Add Arxiv Academic Paper Dataset (AAPD) * Add Hidden Bottleneck Layer to BiLSTM * Fix packing of padded tensors in Reuters * Add cmdline args for Hidden Bottleneck Layer for BiLSTM * Include pre-padding lengths in AAPD dataset * Remove duplication of preprocessing code in AAPD * Remove batch_size condition in ReutersTrainer * Add ignore_lengths option to ReutersTrainer and ReutersEvaluator * Add AAPDCharQuantized and ReutersCharQuantized * Rename Reuters_hierarchical to ReutersHierarchical * Add CharacterCNN for document classification * Update README.md for CharacterCNN * Fix table in README.md for CharacterCNN * Add AAPDHierarchical for HAN * Update HAN for changes in Reuters dataset endpoints * Fix bug in CharCNN when running on CPU
- Loading branch information
1 parent
6daa5a1
commit cc275f6
Showing
12 changed files
with
391 additions
and
18 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,53 @@ | ||
## Character-level Convolutional Network | ||
|
||
Implementation of Char-CNN from Character-level Convolutional Networks for Text Classification (http://papers.nips.cc/paper/5782-character-level-convolutional-networks-for-text-classification.pdf) | ||
|
||
## Quick Start | ||
|
||
To run the model on Reuters dataset, just run the following from the Castor working directory: | ||
|
||
``` | ||
python -m char_cnn --dataset Reuters --gpu 1 --batch_size 128 --lr 0.001 | ||
``` | ||
|
||
To test the model, you can use the following command. | ||
|
||
``` | ||
python -m char_cnn --trained_model kim_cnn/saves/Reuters/best_model.pt | ||
``` | ||
|
||
## Dataset | ||
|
||
We experiment the model on the following datasets. | ||
|
||
- Reuters Newswire (RCV-1) | ||
- Arxiv Academic Paper Dataset (AAPD) | ||
|
||
## Settings | ||
|
||
Adam is used for training. | ||
|
||
## Dataset Results | ||
|
||
### RCV-1 | ||
``` | ||
python -m char_cnn --dataset Reuters --gpu 1 --batch_size 128 --lr 0.001 | ||
``` | ||
| Accuracy | Avg. Precision | Avg. Recall | Avg. F1 | ||
-- | -- | -- | -- | -- | ||
Char-CNN (Dev) | 0.585 | 0.702 | 0.569 | 0.628 | ||
Char-CNN (Test) | 0.589 | 0.691 | 0.552 | 0.614 | ||
|
||
### AAPD | ||
``` | ||
python -m char_cnn --dataset AAPD --gpu 1 --batch_size 128 --lr 0.001 | ||
``` | ||
| Accuracy | Avg. Precision | Avg. Recall | Avg. F1 | ||
-- | -- | -- | -- | -- | ||
Char-CNN (Dev) | 0.305 | 0.681 | 0.537 | 0.600 | ||
Char-CNN (Test) | 0.294 | 0.681 | 0.526 | 0.593 | ||
|
||
## TODO | ||
- Support ONNX export. Currently throws a ONNX export failed (Couldn't export Python operator forward_flattened_wrapper) exception. | ||
- Parameters tuning | ||
|
Empty file.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,167 @@ | ||
from sklearn import metrics | ||
|
||
import logging | ||
import numpy as np | ||
import random | ||
import torch | ||
import torch.nn.functional as F | ||
from copy import deepcopy | ||
|
||
from common.evaluation import EvaluatorFactory | ||
from common.train import TrainerFactory | ||
from datasets.aapd import AAPDCharQuantized as AAPD | ||
from datasets.reuters import ReutersCharQuantized as Reuters | ||
from char_cnn.args import get_args | ||
from char_cnn.model import CharCNN | ||
|
||
|
||
class UnknownWordVecCache(object): | ||
""" | ||
Caches the first randomly generated word vector for a certain size to make it is reused. | ||
""" | ||
cache = {} | ||
|
||
@classmethod | ||
def unk(cls, tensor): | ||
size_tup = tuple(tensor.size()) | ||
if size_tup not in cls.cache: | ||
cls.cache[size_tup] = torch.Tensor(tensor.size()) | ||
# choose 0.25 so unknown vectors have approximately same variance as pre-trained ones | ||
# same as original implementation: https://github.com/yoonkim/CNN_sentence/blob/0a626a048757d5272a7e8ccede256a434a6529be/process_data.py#L95 | ||
cls.cache[size_tup].uniform_(-0.25, 0.25) | ||
return cls.cache[size_tup] | ||
|
||
|
||
def get_logger(): | ||
logger = logging.getLogger(__name__) | ||
logger.setLevel(logging.INFO) | ||
|
||
ch = logging.StreamHandler() | ||
ch.setLevel(logging.DEBUG) | ||
formatter = logging.Formatter('%(levelname)s - %(message)s') | ||
ch.setFormatter(formatter) | ||
logger.addHandler(ch) | ||
|
||
return logger | ||
|
||
|
||
def evaluate_dataset(split_name, dataset_cls, model, embedding, loader, batch_size, device): | ||
saved_model_evaluator = EvaluatorFactory.get_evaluator(dataset_cls, model, embedding, loader, batch_size, device) | ||
saved_model_evaluator.ignore_lengths = True | ||
scores, metric_names = saved_model_evaluator.get_scores() | ||
logger.info('Evaluation metrics for {}'.format(split_name)) | ||
logger.info('\t'.join([' '] + metric_names)) | ||
logger.info('\t'.join([split_name] + list(map(str, scores)))) | ||
|
||
|
||
if __name__ == '__main__': | ||
# Set default configuration in : args.py | ||
args = get_args() | ||
|
||
# Set random seed for reproducibility | ||
torch.manual_seed(args.seed) | ||
torch.backends.cudnn.deterministic = True | ||
if not args.cuda: | ||
args.gpu = -1 | ||
if torch.cuda.is_available() and args.cuda: | ||
print('Note: You are using GPU for training') | ||
torch.cuda.set_device(args.gpu) | ||
torch.cuda.manual_seed(args.seed) | ||
if torch.cuda.is_available() and not args.cuda: | ||
print('Warning: You have Cuda but not use it. You are using CPU for training.') | ||
np.random.seed(args.seed) | ||
random.seed(args.seed) | ||
logger = get_logger() | ||
|
||
# Set up the data for training SST-1 | ||
if args.dataset == 'Reuters': | ||
train_iter, dev_iter, test_iter = Reuters.iters(args.data_dir, args.word_vectors_file, args.word_vectors_dir, batch_size=args.batch_size, device=args.gpu, unk_init=UnknownWordVecCache.unk) | ||
elif args.dataset == 'AAPD': | ||
train_iter, dev_iter, test_iter = AAPD.iters(args.data_dir, args.word_vectors_file, args.word_vectors_dir, batch_size=args.batch_size, device=args.gpu, unk_init=UnknownWordVecCache.unk) | ||
else: | ||
raise ValueError('Unrecognized dataset') | ||
|
||
config = deepcopy(args) | ||
config.dataset = train_iter.dataset | ||
config.target_class = train_iter.dataset.NUM_CLASSES | ||
|
||
print('LABEL.target_class:', train_iter.dataset.NUM_CLASSES) | ||
print('Train instance', len(train_iter.dataset)) | ||
print('Dev instance', len(dev_iter.dataset)) | ||
print('Test instance', len(test_iter.dataset)) | ||
|
||
if args.resume_snapshot: | ||
if args.cuda: | ||
model = torch.load(args.resume_snapshot, map_location=lambda storage, location: storage.cuda(args.gpu)) | ||
else: | ||
model = torch.load(args.resume_snapshot, map_location=lambda storage, location: storage) | ||
else: | ||
model = CharCNN(config) | ||
if args.cuda: | ||
model.cuda() | ||
print('Shift model to GPU') | ||
|
||
parameter = filter(lambda p: p.requires_grad, model.parameters()) | ||
optimizer = torch.optim.Adam(parameter, lr=args.lr, weight_decay=args.weight_decay) | ||
|
||
if args.dataset == 'Reuters': | ||
train_evaluator = EvaluatorFactory.get_evaluator(Reuters, model, None, train_iter, args.batch_size, args.gpu) | ||
test_evaluator = EvaluatorFactory.get_evaluator(Reuters, model, None, test_iter, args.batch_size, args.gpu) | ||
dev_evaluator = EvaluatorFactory.get_evaluator(Reuters, model, None, dev_iter, args.batch_size, args.gpu) | ||
elif args.dataset == 'AAPD': | ||
train_evaluator = EvaluatorFactory.get_evaluator(AAPD, model, None, train_iter, args.batch_size, args.gpu) | ||
test_evaluator = EvaluatorFactory.get_evaluator(AAPD, model, None, test_iter, args.batch_size, args.gpu) | ||
dev_evaluator = EvaluatorFactory.get_evaluator(AAPD, model, None, dev_iter, args.batch_size, args.gpu) | ||
else: | ||
raise ValueError('Unrecognized dataset') | ||
|
||
dev_evaluator.ignore_lengths = True | ||
test_evaluator.ignore_lengths = True | ||
trainer_config = { | ||
'optimizer': optimizer, | ||
'batch_size': args.batch_size, | ||
'log_interval': args.log_every, | ||
'dev_log_interval': args.dev_every, | ||
'patience': args.patience, | ||
'model_outfile': args.save_path, # actually a directory, using model_outfile to conform to Trainer naming convention | ||
'logger': logger, | ||
'ignore_lengths': True | ||
} | ||
trainer = TrainerFactory.get_trainer(args.dataset, model, None, train_iter, trainer_config, train_evaluator, test_evaluator, dev_evaluator) | ||
|
||
if not args.trained_model: | ||
trainer.train(args.epochs) | ||
else: | ||
if args.cuda: | ||
model = torch.load(args.trained_model, map_location=lambda storage, location: storage.cuda(args.gpu)) | ||
else: | ||
model = torch.load(args.trained_model, map_location=lambda storage, location: storage) | ||
|
||
if args.dataset == 'Reuters': | ||
evaluate_dataset('dev', Reuters, model, None, dev_iter, args.batch_size, args.gpu) | ||
evaluate_dataset('test', Reuters, model, None, test_iter, args.batch_size, args.gpu) | ||
elif args.dataset == 'AAPD': | ||
evaluate_dataset('dev', AAPD, model, None, dev_iter, args.batch_size, args.gpu) | ||
evaluate_dataset('test', AAPD, model, None, test_iter, args.batch_size, args.gpu) | ||
else: | ||
raise ValueError('Unrecognized dataset') | ||
|
||
# Calculate dev and test metrics | ||
for data_loader in [dev_iter, test_iter]: | ||
predicted_labels = list() | ||
target_labels = list() | ||
for batch_idx, batch in enumerate(data_loader): | ||
scores_rounded = F.sigmoid(model(batch.text)).round().long() | ||
predicted_labels.extend(scores_rounded.cpu().detach().numpy()) | ||
target_labels.extend(batch.label.cpu().detach().numpy()) | ||
predicted_labels = np.array(predicted_labels) | ||
target_labels = np.array(target_labels) | ||
accuracy = metrics.accuracy_score(target_labels, predicted_labels) | ||
precision = metrics.precision_score(target_labels, predicted_labels, average='micro') | ||
recall = metrics.recall_score(target_labels, predicted_labels, average='micro') | ||
f1 = metrics.f1_score(target_labels, predicted_labels, average='micro') | ||
if data_loader == dev_iter: | ||
print("Dev metrics:") | ||
else: | ||
print("Test metrics:") | ||
print(accuracy, precision, recall, f1) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,34 @@ | ||
import os | ||
|
||
from argparse import ArgumentParser | ||
|
||
|
||
def get_args(): | ||
parser = ArgumentParser(description="Kim CNN") | ||
parser.add_argument('--no_cuda', action='store_false', help='do not use cuda', dest='cuda') | ||
parser.add_argument('--gpu', type=int, default=0) # Use -1 for CPU | ||
parser.add_argument('--epochs', type=int, default=50) | ||
parser.add_argument('--batch_size', type=int, default=128) | ||
parser.add_argument('--lr', type=float, default=0.001) | ||
parser.add_argument('--seed', type=int, default=3435) | ||
parser.add_argument('--dataset', type=str, default='Reuters', choices=['Reuters', 'AAPD']) | ||
parser.add_argument('--resume_snapshot', type=str, default=None) | ||
parser.add_argument('--dev_every', type=int, default=30) | ||
parser.add_argument('--log_every', type=int, default=10) | ||
parser.add_argument('--patience', type=int, default=100) | ||
parser.add_argument('--save_path', type=str, default='kim_cnn/saves') | ||
parser.add_argument('--num_conv_filters', type=int, default=256) | ||
parser.add_argument('--num_affine_neurons', type=int, default=1024) | ||
parser.add_argument('--output_channel', type=int, default=256) | ||
parser.add_argument('--dropout', type=float, default=0.5) | ||
parser.add_argument('--epoch_decay', type=int, default=15) | ||
parser.add_argument('--data_dir', help='word vectors directory', | ||
default=os.path.join(os.pardir, 'Castor-data', 'datasets')) | ||
parser.add_argument('--word_vectors_dir', help='word vectors directory', | ||
default=os.path.join(os.pardir, 'Castor-data', 'embeddings', 'word2vec')) | ||
parser.add_argument('--word_vectors_file', help='word vectors filename', default='GoogleNews-vectors-negative300.txt') | ||
parser.add_argument('--trained_model', type=str, default="") | ||
parser.add_argument('--weight_decay', type=float, default=0) | ||
|
||
args = parser.parse_args() | ||
return args |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,45 @@ | ||
import torch | ||
import torch.nn as nn | ||
|
||
import torch.nn.functional as F | ||
|
||
|
||
class CharCNN(nn.Module): | ||
def __init__(self, config): | ||
super(CharCNN, self).__init__() | ||
self.is_cuda_enabled = config.cuda | ||
dataset = config.dataset | ||
num_conv_filters = config.num_conv_filters | ||
output_channel = config.output_channel | ||
num_affine_neurons = config.num_affine_neurons | ||
target_class = config.target_class | ||
input_channel = 68 | ||
|
||
self.conv1 = nn.Conv1d(input_channel, num_conv_filters, kernel_size=7) # Default padding=0 | ||
self.conv2 = nn.Conv1d(num_conv_filters, num_conv_filters, kernel_size=7) | ||
self.conv3 = nn.Conv1d(num_conv_filters, num_conv_filters, kernel_size=3) | ||
self.conv4 = nn.Conv1d(num_conv_filters, num_conv_filters, kernel_size=3) | ||
self.conv5 = nn.Conv1d(num_conv_filters, num_conv_filters, kernel_size=3) | ||
self.conv6 = nn.Conv1d(num_conv_filters, output_channel, kernel_size=3) | ||
self.dropout = nn.Dropout(config.dropout) | ||
self.fc1 = nn.Linear(num_conv_filters, num_affine_neurons) | ||
self.fc2 = nn.Linear(num_affine_neurons, num_affine_neurons) | ||
self.fc3 = nn.Linear(num_affine_neurons, target_class) | ||
|
||
def forward(self, x, **kwargs): | ||
if torch.cuda.is_available() and self.is_cuda_enabled: | ||
x = x.transpose(1, 2).type(torch.cuda.FloatTensor) | ||
else: | ||
x = x.transpose(1, 2).type(torch.FloatTensor) | ||
x = F.max_pool1d(F.relu(self.conv1(x)), 3) | ||
x = F.max_pool1d(F.relu(self.conv2(x)), 3) | ||
x = F.relu(self.conv3(x)) | ||
x = F.relu(self.conv4(x)) | ||
x = F.relu(self.conv5(x)) | ||
x = F.relu(self.conv6(x)) | ||
x = F.max_pool1d(x, x.size(2)).squeeze(2) | ||
x = F.relu(self.fc1(x.view(x.size(0), -1))) | ||
x = self.dropout(x) | ||
x = F.relu(self.fc2(x)) | ||
x = self.dropout(x) | ||
return self.fc3(x) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Oops, something went wrong.