forked from castorini/castor
-
Notifications
You must be signed in to change notification settings - Fork 0
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Neural Document Classification (castorini#159)
* Add ReutersTrainer, ReutersEvaluator options in Factory classes * Add Reuters to Kim-CNN command line arguments * Fix SST dataset path according to changes in Kim-CNN args The dataset path in args.py was made to point at the dataset folder rather than dataset/SST folder. Hence SST folder was added to paths in the SST dataset class * Add Reuters dataset class, and support in __main__ * Add Reuters dataset trainers and evaluators * Remove debug print statement in reuters_evaluator * Fix rounding bug in reuters_trainer and reuters_evaluator * Add LSTM for baseline text classification measurements * Add eval metrics for lstm_baseline * Set batch_first param in lstm_baseline * Remove onnx args from lstm_baseline * Pack padded sequences in LSTM_baseline * Add TensorBoardX support for Reuters trainer * Add Arxiv Academic Paper Dataset (AAPD) * Add Hidden Bottleneck Layer to BiLSTM * Fix packing of padded tensors in Reuters * Add cmdline args for Hidden Bottleneck Layer for BiLSTM * Include pre-padding lengths in AAPD dataset * Remove duplication of preprocessing code in AAPD * Remove batch_size condition in ReutersTrainer * Add ignore_lengths option to ReutersTrainer and ReutersEvaluator * Add AAPDCharQuantized and ReutersCharQuantized * Rename Reuters_hierarchical to ReutersHierarchical * Add CharacterCNN for document classification * Update README.md for CharacterCNN * Fix table in README.md for CharacterCNN * Add AAPDHierarchical for HAN * Update HAN for changes in Reuters dataset endpoints * Fix bug in CharCNN when running on CPU * Add AAPD dataset support for KimCNN * Fix dataset paths for SST-1 * Fix dimensions of FC1 in CharCNN * Add model checkpointing for Reuters based on F1 * Refactor LSTM baseline __main__ * Add precision, recall and F1 to Reuters evaluator * Checkpoint only at the end of an epoch for ReutersTrainer Add detailed log printing for dev evaluations * Fix log_template and dev_log_template in ReutersTrainer * Add IMDB dataset * Add support for single_label datasets in ReutersTrainer * Add support for IMDB dataset in lstm_baseline and lstm_reg
- Loading branch information
1 parent
3dd3ced
commit 951df4a
Showing
10 changed files
with
190 additions
and
87 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,89 @@ | ||
import numpy as np | ||
import os | ||
import re | ||
import torch | ||
from datasets.reuters import clean_string, clean_string_fl, split_sents | ||
from torchtext.data import NestedField, Field, TabularDataset | ||
from torchtext.data.iterator import BucketIterator | ||
from torchtext.vocab import Vectors | ||
|
||
|
||
def char_quantize(string, max_length=1000): | ||
identity = np.identity(len(IMDBCharQuantized.ALPHABET)) | ||
quantized_string = np.array([identity[IMDBCharQuantized.ALPHABET[char]] for char in list(string.lower()) if char in IMDBCharQuantized.ALPHABET], dtype=np.float32) | ||
if len(quantized_string) > max_length: | ||
return quantized_string[:max_length] | ||
else: | ||
return np.concatenate((quantized_string, np.zeros((max_length - len(quantized_string), len(IMDBCharQuantized.ALPHABET)), dtype=np.float32))) | ||
|
||
|
||
def process_labels(string): | ||
""" | ||
Returns the label string as a list of integers | ||
:param string: | ||
:return: | ||
""" | ||
return [float(x) for x in string] | ||
|
||
|
||
class IMDB(TabularDataset): | ||
NAME = 'IMDB' | ||
NUM_CLASSES = 10 | ||
TEXT_FIELD = Field(batch_first=True, tokenize=clean_string, include_lengths=True) | ||
LABEL_FIELD = Field(sequential=False, use_vocab=False, batch_first=True, preprocessing=process_labels) | ||
|
||
@staticmethod | ||
def sort_key(ex): | ||
return len(ex.text) | ||
|
||
@classmethod | ||
def splits(cls, path, train=os.path.join('IMDB', 'data', 'imdb_train.tsv'), | ||
validation=os.path.join('IMDB', 'data', 'imdb_validation.tsv'), | ||
test=os.path.join('IMDB', 'data', 'imdb_test.tsv'), **kwargs): | ||
return super(IMDB, cls).splits( | ||
path, train=train, validation=validation, test=test, | ||
format='tsv', fields=[('label', cls.LABEL_FIELD), ('text', cls.TEXT_FIELD)] | ||
) | ||
|
||
@classmethod | ||
def iters(cls, path, vectors_name, vectors_cache, batch_size=64, shuffle=True, device=0, vectors=None, | ||
unk_init=torch.Tensor.zero_): | ||
""" | ||
:param path: directory containing train, test, dev files | ||
:param vectors_name: name of word vectors file | ||
:param vectors_cache: path to directory containing word vectors file | ||
:param batch_size: batch size | ||
:param device: GPU device | ||
:param vectors: custom vectors - either predefined torchtext vectors or your own custom Vector classes | ||
:param unk_init: function used to generate vector for OOV words | ||
:return: | ||
""" | ||
if vectors is None: | ||
vectors = Vectors(name=vectors_name, cache=vectors_cache, unk_init=unk_init) | ||
|
||
train, val, test = cls.splits(path) | ||
cls.TEXT_FIELD.build_vocab(train, val, test, vectors=vectors) | ||
return BucketIterator.splits((train, val, test), batch_size=batch_size, repeat=False, shuffle=shuffle, | ||
sort_within_batch=True, device=device) | ||
|
||
|
||
class IMDBCharQuantized(IMDB): | ||
ALPHABET = dict(map(lambda t: (t[1], t[0]), enumerate(list("""abcdefghijklmnopqrstuvwxyz0123456789,;.!?:'\"/\\|_@#$%^&*~`+-=<>()[]{}""")))) | ||
TEXT_FIELD = Field(sequential=False, use_vocab=False, batch_first=True, preprocessing=char_quantize) | ||
|
||
@classmethod | ||
def iters(cls, path, vectors_name, vectors_cache, batch_size=64, shuffle=True, device=0, vectors=None, | ||
unk_init=torch.Tensor.zero_): | ||
""" | ||
:param path: directory containing train, test, dev files | ||
:param batch_size: batch size | ||
:param device: GPU device | ||
:return: | ||
""" | ||
train, val, test = cls.splits(path) | ||
return BucketIterator.splits((train, val, test), batch_size=batch_size, repeat=False, shuffle=shuffle, device=device) | ||
|
||
|
||
class IMDBHierarchical(IMDB): | ||
In_FIELD = Field(batch_first=True, tokenize=clean_string) | ||
TEXT_FIELD = NestedField(In_FIELD, tokenize=split_sents) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Oops, something went wrong.