Skip to content

Commit

Permalink
Fix KimCNN for SST, AAPD datasets (castorini#157)
Browse files Browse the repository at this point in the history
* Add ReutersTrainer, ReutersEvaluator options in Factory classes

* Add Reuters to Kim-CNN command line arguments

* Fix SST dataset path according to changes in Kim-CNN args

The dataset path in args.py was made to point at the dataset folder rather than dataset/SST folder. Hence SST folder was added to paths in the SST dataset class

* Add Reuters dataset class, and support in __main__

* Add Reuters dataset trainers and evaluators

* Remove debug print statement in reuters_evaluator

* Fix rounding bug in reuters_trainer and reuters_evaluator

* Add LSTM for baseline text classification measurements

* Add eval metrics for lstm_baseline

* Set batch_first param in lstm_baseline

* Remove onnx args from lstm_baseline

* Pack padded sequences in LSTM_baseline

* Add TensorBoardX support for Reuters trainer

* Add Arxiv Academic Paper Dataset (AAPD)

* Add Hidden Bottleneck Layer to BiLSTM

* Fix packing of padded tensors in Reuters

* Add cmdline args for Hidden Bottleneck Layer for BiLSTM

* Include pre-padding lengths in AAPD dataset

* Remove duplication of preprocessing code in AAPD

* Remove batch_size condition in ReutersTrainer

* Add ignore_lengths option to ReutersTrainer and ReutersEvaluator

* Add AAPDCharQuantized and ReutersCharQuantized

* Rename Reuters_hierarchical to ReutersHierarchical

* Add CharacterCNN for document classification

* Update README.md for CharacterCNN

* Fix table in README.md for CharacterCNN

* Add AAPDHierarchical for HAN

* Update HAN for changes in Reuters dataset endpoints

* Fix bug in CharCNN when running on CPU

* Add AAPD dataset support for KimCNN

* Fix dataset paths for SST-1
  • Loading branch information
Achyudh Ram authored and Impavidity committed Nov 5, 2018
1 parent cc275f6 commit f0a5c37
Show file tree
Hide file tree
Showing 4 changed files with 15 additions and 24 deletions.
4 changes: 3 additions & 1 deletion datasets/sst.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,8 @@ def sort_key(ex):
return len(ex.text)

@classmethod
def splits(cls, path, train='stsa.fine.phrases.train', validation='stsa.fine.dev', test='stsa.fine.test', **kwargs):
def splits(cls, path, train=os.path.join('SST', 'stsa.fine.phrases.train'),
validation=os.path.join('SST', 'stsa.fine.dev'), test= os.path.join('SST', 'stsa.fine.test'), **kwargs):
return super(SST1, cls).splits(
path, train=train, validation=validation, test=test,
format='tsv', fields=[('label', cls.LABEL_FIELD), ('text', cls.TEXT_FIELD)]
Expand Down Expand Up @@ -57,6 +58,7 @@ def iters(cls, path, vectors_name, vectors_cache, batch_size=64, shuffle=True, d
return BucketIterator.splits((train, val, test), batch_size=batch_size, repeat=False, shuffle=shuffle,
sort_within_batch=True, device=device)


class SST2(TabularDataset):
NAME = 'SST-2'
NUM_CLASSES = 5
Expand Down
31 changes: 10 additions & 21 deletions kim_cnn/__main__.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,11 +12,11 @@
from common.train import TrainerFactory
from datasets.sst import SST1
from datasets.sst import SST2
from datasets.aapd import AAPD
from datasets.reuters import Reuters
from kim_cnn.args import get_args
from kim_cnn.model import KimCNN


class UnknownWordVecCache(object):
"""
Caches the first randomly generated word vector for a certain size to make it is reused.
Expand Down Expand Up @@ -82,6 +82,8 @@ def evaluate_dataset(split_name, dataset_cls, model, embedding, loader, batch_si
train_iter, dev_iter, test_iter = SST2.iters(args.data_dir, args.word_vectors_file, args.word_vectors_dir, batch_size=args.batch_size, device=args.gpu, unk_init=UnknownWordVecCache.unk)
elif args.dataset == 'Reuters':
train_iter, dev_iter, test_iter = Reuters.iters(args.data_dir, args.word_vectors_file, args.word_vectors_dir, batch_size=args.batch_size, device=args.gpu, unk_init=UnknownWordVecCache.unk)
elif args.dataset == 'AAPD':
train_iter, dev_iter, test_iter = AAPD.iters(args.data_dir, args.word_vectors_file, args.word_vectors_dir, batch_size=args.batch_size, device=args.gpu, unk_init=UnknownWordVecCache.unk)
else:
raise ValueError('Unrecognized dataset')

Expand Down Expand Up @@ -123,6 +125,10 @@ def evaluate_dataset(split_name, dataset_cls, model, embedding, loader, batch_si
train_evaluator = EvaluatorFactory.get_evaluator(Reuters, model, None, train_iter, args.batch_size, args.gpu)
test_evaluator = EvaluatorFactory.get_evaluator(Reuters, model, None, test_iter, args.batch_size, args.gpu)
dev_evaluator = EvaluatorFactory.get_evaluator(Reuters, model, None, dev_iter, args.batch_size, args.gpu)
elif args.dataset == 'AAPD':
train_evaluator = EvaluatorFactory.get_evaluator(AAPD, model, None, train_iter, args.batch_size, args.gpu)
test_evaluator = EvaluatorFactory.get_evaluator(AAPD, model, None, test_iter, args.batch_size, args.gpu)
dev_evaluator = EvaluatorFactory.get_evaluator(AAPD, model, None, dev_iter, args.batch_size, args.gpu)
else:
raise ValueError('Unrecognized dataset')

Expand Down Expand Up @@ -154,29 +160,12 @@ def evaluate_dataset(split_name, dataset_cls, model, embedding, loader, batch_si
elif args.dataset == 'Reuters':
evaluate_dataset('dev', Reuters, model, None, dev_iter, args.batch_size, args.gpu)
evaluate_dataset('test', Reuters, model, None, test_iter, args.batch_size, args.gpu)
elif args.dataset == 'AAPD':
evaluate_dataset('dev', AAPD, model, None, dev_iter, args.batch_size, args.gpu)
evaluate_dataset('test', AAPD, model, None, test_iter, args.batch_size, args.gpu)
else:
raise ValueError('Unrecognized dataset')

# Calculate dev and test metrics
for data_loader in [dev_iter, test_iter]:
predicted_labels = list()
target_labels = list()
for batch_idx, batch in enumerate(data_loader):
scores_rounded = F.sigmoid(model(batch.text)).round().long()
predicted_labels.extend(scores_rounded.cpu().detach().numpy())
target_labels.extend(batch.label.cpu().detach().numpy())
predicted_labels = np.array(predicted_labels)
target_labels = np.array(target_labels)
accuracy = metrics.accuracy_score(target_labels, predicted_labels)
precision = metrics.precision_score(target_labels, predicted_labels, average='micro')
recall = metrics.recall_score(target_labels, predicted_labels, average='micro')
f1 = metrics.f1_score(target_labels, predicted_labels, average='micro')
if data_loader == dev_iter:
print("Dev metrics:")
else:
print("Test metrics:")
print(accuracy, precision, recall, f1)

if args.onnx:
device = torch.device('cuda') if torch.cuda.is_available() and args.cuda else torch.device('cpu')
dummy_input = torch.zeros(args.onnx_batch_size, args.onnx_sent_len, dtype=torch.long, device=device)
Expand Down
2 changes: 1 addition & 1 deletion kim_cnn/args.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@ def get_args():
parser.add_argument('--mode', type=str, default='multichannel', choices=['rand', 'static', 'non-static', 'multichannel'])
parser.add_argument('--lr', type=float, default=1.0)
parser.add_argument('--seed', type=int, default=3435)
parser.add_argument('--dataset', type=str, default='SST-1', choices=['SST-1', 'SST-2', 'Reuters'])
parser.add_argument('--dataset', type=str, default='SST-1', choices=['SST-1', 'SST-2', 'Reuters', 'AAPD'])
parser.add_argument('--resume_snapshot', type=str, default=None)
parser.add_argument('--dev_every', type=int, default=30)
parser.add_argument('--log_every', type=int, default=10)
Expand Down
2 changes: 1 addition & 1 deletion kim_cnn/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,7 +38,7 @@ def __init__(self, config):
self.dropout = nn.Dropout(config.dropout)
self.fc1 = nn.Linear(Ks * output_channel, target_class)

def forward(self, x):
def forward(self, x, **kwargs):
if self.mode == 'rand':
word_input = self.embed(x) # (batch, sent_len, embed_dim)
x = word_input.unsqueeze(1) # (batch, channel_input, sent_len, embed_dim)
Expand Down

0 comments on commit f0a5c37

Please sign in to comment.