Skip to content

Commit

Permalink
Tune Kim CNN for SST-2 and Improve SST-1 Results with Dataset/Initial…
Browse files Browse the repository at this point in the history
…ization Changes (#133)

* SST change min_freq and Kim CNN init distribution

* Add tuned results for SST-1 and SST-2

* Fix typo
  • Loading branch information
tuzhucheng authored Jul 12, 2018
1 parent 5ff980d commit 82bf90f
Show file tree
Hide file tree
Showing 4 changed files with 69 additions and 18 deletions.
18 changes: 9 additions & 9 deletions datasets/sst.py
Original file line number Diff line number Diff line change
Expand Up @@ -51,29 +51,29 @@ def iters(cls, path, vectors_name, vectors_cache, batch_size=64, shuffle=True, d

train, val, test = cls.splits(path)

cls.TEXT_FIELD.build_vocab(train, val, test, min_freq=2, vectors=vectors)
cls.TEXT_FIELD.build_vocab(train, val, test, vectors=vectors)

return BucketIterator.splits((train, val, test), batch_size=batch_size, repeat=False, shuffle=shuffle,
sort_within_batch=True, device=device)

class SST2(TabularDataset):
NAME = 'SST-2'
NUM_CLASSES = 5

TEXT_FIELD = Field(batch_first=True, tokenize=clean_str_sst)
LABEL_FIELD = Field(sequential=False, use_vocab=False, batch_first=True)

@staticmethod
def sort_key(ex):
return len(ex.text)

@classmethod
def splits(cls, path, train='stsa.binary.phrases.train', validation='stsa.binary.dev', test='stsa.binary.test', **kwargs):
return super(SST2, cls).splits(
path, train=train, validation=validation, test=test,
format='tsv', fields=[('label', cls.LABEL_FIELD), ('text', cls.TEXT_FIELD)]
)

@classmethod
def iters(cls, path, vectors_name, vectors_cache, batch_size=64, shuffle=True, device=0, vectors=None,
unk_init=torch.Tensor.zero_):
Expand All @@ -89,11 +89,11 @@ def iters(cls, path, vectors_name, vectors_cache, batch_size=64, shuffle=True, d
"""
if vectors is None:
vectors = Vectors(name=vectors_name, cache=vectors_cache, unk_init=unk_init)

train, val, test = cls.splits(path)
cls.TEXT_FIELD.build_vocab(train, val, test, min_freq=2, vectors=vectors)

cls.TEXT_FIELD.build_vocab(train, val, test, vectors=vectors)

return BucketIterator.splits((train, val, test), batch_size=batch_size, repeat=False, shuffle=shuffle,
sort_within_batch=True, device=device)

Expand Down
46 changes: 40 additions & 6 deletions kim_cnn/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,7 @@ python -m kim_cnn --trained_model kim_cnn/saves/SST-1/multichannel_best_model.pt
We experiment the model on the following datasets.

- SST-1: Keep the original splits and train with phrase level dataset and test on sentence level dataset.
- SST-2: Same as SST-1 but with neutral reviews removed and binary labels.

## Settings

Expand Down Expand Up @@ -62,35 +63,68 @@ but this will take ~6-7x training time.
**Random**

```
python -m kim_cnn --mode rand --lr 0.8337 --weight_decay 0.0008987 --dropout 0.4
python -m kim_cnn --dataset SST-1 --mode rand --lr 0.5777 --weight_decay 0.0007 --dropout 0
```

**Static**

```
python -m kim_cnn --mode static --lr 0.8641 --weight_decay 1.44e-05 --dropout 0.3
python -m kim_cnn --dataset SST-1 --mode static --lr 0.3213 --weight_decay 0.0002 --dropout 0.4
```

**Non-static**

```
python -m kim_cnn --mode non-static --lr 0.371 --weight_decay 1.84e-05 --dropout 0.4
python -m kim_cnn --dataset SST-1 --mode non-static --lr 0.388 --weight_decay 0.0004 --dropout 0.2
```

**Multichannel**

```
python -m kim_cnn --mode multichannel --lr 0.2532 --weight_decay 3.95e-05 --dropout 0.1
python -m kim_cnn --dataset SST-1 --mode multichannel --lr 0.3782 --weight_decay 0.0002 --dropout 0.4
```

Using deterministic algorithm for cuDNN.

| Test Accuracy on SST-1 | rand | static | non-static | multichannel |
|:------------------------------:|:----------:|:------------:|:--------------:|:---------------:|
| Paper | 45.0 | 45.5 | 48.0 | 47.4 |
| PyTorch using above configs | 41.5 | 44.7 | 47.4 | 47.5 |
| PyTorch using above configs | 44.3 | 47.9 | 48.6 | 49.2 |

## SST-2 Dataset Results

**Random**

```
python -m kim_cnn --dataset SST-2 --mode rand --lr 0.564 --weight_decay 0.0007 --dropout 0.5
```

**Static**

```
python -m kim_cnn --dataset SST-2 --mode static --lr 0.5589 --weight_decay 0.0004 --dropout 0.5
```

**Non-static**

```
python -m kim_cnn --dataset SST-2 --mode non-static --lr 0.5794 --weight_decay 0.0003 --dropout 0.3
```

**Multichannel**

```
python -m kim_cnn --dataset SST-2 --mode multichannel --lr 0.7373 --weight_decay 0.0001 --dropout 0.1
```

Using deterministic algorithm for cuDNN.

| Test Accuracy on SST-2 | rand | static | non-static | multichannel |
|:------------------------------:|:----------:|:------------:|:--------------:|:---------------:|
| Paper | 82.7 | 86.8 | 87.2 | 88.1 |
| PyTorch using above configs | 83.0 | 86.4 | 87.3 | 87.4 |

## TODO

- More experiments on SST-2 and subjectivity
- More experiments on subjectivity
- Parameters tuning
20 changes: 18 additions & 2 deletions kim_cnn/__main__.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,22 @@
from kim_cnn.args import get_args
from kim_cnn.model import KimCNN

class UnknownWordVecCache(object):
"""
Caches the first randomly generated word vector for a certain size to make it is reused.
"""
cache = {}

@classmethod
def unk(cls, tensor):
size_tup = tuple(tensor.size())
if size_tup not in cls.cache:
cls.cache[size_tup] = torch.Tensor(tensor.size())
# choose 0.25 so unknown vectors have approximately same variance as pre-trained ones
# same as original implementation: https://github.com/yoonkim/CNN_sentence/blob/0a626a048757d5272a7e8ccede256a434a6529be/process_data.py#L95
cls.cache[size_tup].uniform_(-0.25, 0.25)
return cls.cache[size_tup]


def get_logger():
logger = logging.getLogger(__name__)
Expand Down Expand Up @@ -55,10 +71,10 @@ def evaluate_dataset(split_name, dataset_cls, model, embedding, loader, batch_si

# Set up the data for training SST-1
if args.dataset == 'SST-1':
train_iter, dev_iter, test_iter = SST1.iters(args.data_dir, args.word_vectors_file, args.word_vectors_dir, batch_size=args.batch_size, device=args.gpu)
train_iter, dev_iter, test_iter = SST1.iters(args.data_dir, args.word_vectors_file, args.word_vectors_dir, batch_size=args.batch_size, device=args.gpu, unk_init=UnknownWordVecCache.unk)
# Set up the data for training SST-2
elif args.dataset == 'SST-2':
train_iter, dev_iter, test_iter = SST2.iters(args.data_dir, args.word_vectors_file, args.word_vectors_dir, batch_size=args.batch_size, device=args.gpu)
train_iter, dev_iter, test_iter = SST2.iters(args.data_dir, args.word_vectors_file, args.word_vectors_dir, batch_size=args.batch_size, device=args.gpu, unk_init=UnknownWordVecCache.unk)
else:
raise ValueError('Unrecognized dataset')

Expand Down
3 changes: 2 additions & 1 deletion kim_cnn/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,8 @@ def __init__(self, config):
input_channel = 2
else:
input_channel = 1
self.embed = nn.Embedding(words_num, words_dim)
rand_embed_init = torch.Tensor(words_num, words_dim).uniform_(-0.25, 0.25)
self.embed = nn.Embedding.from_pretrained(rand_embed_init, freeze=False)
self.static_embed = nn.Embedding.from_pretrained(dataset.TEXT_FIELD.vocab.vectors, freeze=True)
self.non_static_embed = nn.Embedding.from_pretrained(dataset.TEXT_FIELD.vocab.vectors, freeze=False)

Expand Down

0 comments on commit 82bf90f

Please sign in to comment.