Skip to content

Commit

Permalink
Add SSE model (castorini#168)
Browse files Browse the repository at this point in the history
* runnable

* add util file

* update readme

* update final layer and add model name

* update argument

* update readme, delete useless args

* fix comments

* fix more comments
  • Loading branch information
Victor0118 authored Dec 18, 2018
1 parent 7c3c156 commit 57f53a8
Show file tree
Hide file tree
Showing 5 changed files with 370 additions and 0 deletions.
68 changes: 68 additions & 0 deletions sse/README.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,68 @@
# SSE

This is a PyTorch reimplementation of the following paper:

```
@InProceedings{nie-bansal:2017:RepEval,
author = {Nie, Yixin and Bansal, Mohit},
title = {Shortcut-Stacked Sentence Encoders for Multi-Domain Inference},
booktitle = {Proceedings of the 2nd Workshop on Evaluating Vector Space Representations for NLP},
year = {2017}
}
```


Please ensure you have followed instructions in the main [README](../README.md) doc before running any further commands in this doc.
The commands in this doc assume you are under the root directory of the Castor repo.

## SICK Dataset

To run SSE on the SICK dataset, use the following command. `--dropout 0` is for mimicking the original paper, although adding dropout can improve results. If you have any problems running it check the Troubleshooting section below.

```
python -m sse sse.sick.model.castor --dataset sick --epochs 19 --dropout 0.5 --lr 0.0002 --regularization 1e-4
```

| Implementation and config | Pearson's r | Spearman's p | MSE |
| -------------------------------- |:-------------:|:-------------:|:----------:|
| PyTorch using above config | 0.8812158 | 0.8292130938075161 | 0.22950001060962677 |

## TrecQA Dataset

To run SSE on the TrecQA dataset, use the following command:
```
python -m sse sse.trecqa.model --dataset trecqa --epochs 5 --holistic-filters 200 --lr 0.00018 --regularization 0.0006405 --dropout 0
```

| Implementation and config | map | mrr |
| -------------------------------- |:------:|:------:|
| PyTorch using above config | | |

This are the TrecQA raw dataset results. The paper results are reported in [Noise-Contrastive Estimation for Answer Selection with Deep Neural Networks](https://dl.acm.org/citation.cfm?id=2983872).

## WikiQA Dataset

You also need `trec_eval` for this dataset, similar to TrecQA.

Then, you can run:
```
python -m sse sse.wikiqa.model --epochs 10 --dataset wikiqa --epochs 5 --holistic-filters 100 --lr 0.00042 --regularization 0.0001683 --dropout 0
```
| Implementation and config | map | mrr |
| -------------------------------- |:------:|:------:|
| PyTorch using above config | | |


To see all options available, use
```
python -m sse --help
```

## Optional Dependencies

To optionally visualize the learning curve during training, we make use of https://github.com/lanpa/tensorboard-pytorch to connect to [TensorBoard](https://github.com/tensorflow/tensorboard). These projects require TensorFlow as a dependency, so you need to install TensorFlow before running the commands below. After these are installed, just add `--tensorboard` when running the training commands and open TensorBoard in the browser.

```sh
pip install tensorboardX
pip install tensorflow-tensorboard
```
Empty file added sse/__init__.py
Empty file.
144 changes: 144 additions & 0 deletions sse/__main__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,144 @@
import argparse
import logging
import os
import pprint
import random

import numpy as np
import torch
import torch.optim as optim

from common.dataset import DatasetFactory
from common.evaluation import EvaluatorFactory
from common.train import TrainerFactory
from utils.serialization import load_checkpoint
from .model import StackBiLSTMMaxout


def get_logger():
logger = logging.getLogger(__name__)
logger.setLevel(logging.INFO)

ch = logging.StreamHandler()
ch.setLevel(logging.DEBUG)
formatter = logging.Formatter('%(levelname)s - %(message)s')
ch.setFormatter(formatter)
logger.addHandler(ch)

return logger


def evaluate_dataset(split_name, dataset_cls, model, embedding, loader, batch_size, device, keep_results=False):
saved_model_evaluator = EvaluatorFactory.get_evaluator(dataset_cls, model, embedding, loader, batch_size, device,
keep_results=keep_results)
scores, metric_names = saved_model_evaluator.get_scores()
logger.info('Evaluation metrics for {}'.format(split_name))
logger.info('\t'.join([' '] + metric_names))
logger.info('\t'.join([split_name] + list(map(str, scores))))


if __name__ == '__main__':
parser = argparse.ArgumentParser(description='PyTorch implementation of Multi-Perspective CNN')
parser.add_argument('model_outfile', help='file to save final model')
parser.add_argument('--dataset', help='dataset to use, one of [sick, msrvid, trecqa, wikiqa]', default='sick')
parser.add_argument('--word-vectors-dir', help='word vectors directory',
default=os.path.join(os.pardir, 'Castor-data', 'embeddings', 'GloVe'))
parser.add_argument('--word-vectors-file', help='word vectors filename', default='glove.840B.300d.txt')
parser.add_argument('--word-vectors-dim', type=int, default=300,
help='number of dimensions of word vectors (default: 300)')
parser.add_argument('--skip-training', help='will load pre-trained model', action='store_true')
parser.add_argument('--device', type=int, default=0, help='GPU device, -1 for CPU (default: 0)')
parser.add_argument('--wide-conv', action='store_true', default=False,
help='use wide convolution instead of narrow convolution (default: false)')
parser.add_argument('--sparse-features', action='store_true',
default=False, help='use sparse features (default: false)')
parser.add_argument('--batch-size', type=int, default=64, help='input batch size for training (default: 64)')
parser.add_argument('--epochs', type=int, default=10, help='number of epochs to train (default: 10)')
parser.add_argument('--optimizer', type=str, default='adam', help='optimizer to use: adam or sgd (default: adam)')
parser.add_argument('--lr', type=float, default=0.001, help='learning rate (default: 0.001)')
parser.add_argument('--lr-reduce-factor', type=float, default=0.3,
help='learning rate reduce factor after plateau (default: 0.3)')
parser.add_argument('--patience', type=float, default=2,
help='learning rate patience after seeing plateau (default: 2)')
parser.add_argument('--momentum', type=float, default=0, help='momentum (default: 0)')
parser.add_argument('--epsilon', type=float, default=1e-8, help='Optimizer epsilon (default: 1e-8)')
parser.add_argument('--log-interval', type=int, default=10,
help='how many batches to wait before logging training status (default: 10)')
parser.add_argument('--regularization', type=float, default=0.0001,
help='Regularization for the optimizer (default: 0.0001)')
parser.add_argument('--mlpD', type=int, default=1600, help='MLP dimension (default: 1600)')
parser.add_argument('--dropout', type=float, default=0.1, help='dropout probability (default: 0.1)')
parser.add_argument('--maxlen', type=int, default=30, help='maximum length of text (default: 60)')
parser.add_argument('--seed', type=int, default=1234, help='random seed (default: 1234)')
parser.add_argument('--tensorboard', action='store_true', default=False,
help='use TensorBoard to visualize training (default: false)')
parser.add_argument('--run-label', type=str, help='label to describe run')
parser.add_argument('--keep-results', action='store_true',
help='store the output score and qrel files into disk for the test set')

args = parser.parse_args()

device = torch.device(f'cuda:{args.device}' if torch.cuda.is_available() and args.device >= 0 else 'cpu')

random.seed(args.seed)
np.random.seed(args.seed)
torch.manual_seed(args.seed)
if args.device != -1:
torch.cuda.manual_seed(args.seed)

logger = get_logger()
logger.info(pprint.pformat(vars(args)))

dataset_cls, embedding, train_loader, test_loader, dev_loader \
= DatasetFactory.get_dataset(args.dataset, args.word_vectors_dir, args.word_vectors_file, args.batch_size, args.device)

ext_feats = dataset_cls.EXT_FEATS if args.sparse_features else 0

model = StackBiLSTMMaxout(d=args.word_vectors_dim, mlp_d=args.mlpD,
num_classes=dataset_cls.NUM_CLASSES, dropout_r=args.dropout, max_l=args.maxlen)

model = model.to(device)
embedding = embedding.to(device)

optimizer = None
if args.optimizer == 'adam':
optimizer = optim.Adam(model.parameters(), lr=args.lr, weight_decay=args.regularization, eps=args.epsilon)
elif args.optimizer == 'sgd':
optimizer = optim.SGD(model.parameters(), lr=args.lr, momentum=args.momentum, weight_decay=args.regularization)
else:
raise ValueError('optimizer not recognized: it should be either adam or sgd')

train_evaluator = EvaluatorFactory.get_evaluator(dataset_cls, model, embedding, train_loader, args.batch_size,
args.device)
test_evaluator = EvaluatorFactory.get_evaluator(dataset_cls, model, embedding, test_loader, args.batch_size,
args.device)
dev_evaluator = EvaluatorFactory.get_evaluator(dataset_cls, model, embedding, dev_loader, args.batch_size,
args.device)

trainer_config = {
'optimizer': optimizer,
'batch_size': args.batch_size,
'log_interval': args.log_interval,
'model_outfile': args.model_outfile,
'lr_reduce_factor': args.lr_reduce_factor,
'patience': args.patience,
'tensorboard': args.tensorboard,
'run_label': args.run_label,
'logger': logger
}
trainer = TrainerFactory.get_trainer(args.dataset, model, embedding, train_loader, trainer_config, train_evaluator, test_evaluator, dev_evaluator)

if not args.skip_training:
total_params = sum(param.numel() for param in model.parameters() if param.requires_grad)
logger.info('Total number of trainable parameters: %s', total_params)
trainer.train(args.epochs)

_, _, state_dict, _, _ = load_checkpoint(args.model_outfile)

for k, tensor in state_dict.items():
state_dict[k] = tensor.to(device)

model.load_state_dict(state_dict)
if dev_loader:
evaluate_dataset('dev', dataset_cls, model, embedding, dev_loader, args.batch_size, args.device)
evaluate_dataset('test', dataset_cls, model, embedding, test_loader, args.batch_size, args.device, args.keep_results)
87 changes: 87 additions & 0 deletions sse/model.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,87 @@
import numpy as np

import torch
import torch.nn as nn
import torch.nn.functional as F

from utils import torch_util


class StackBiLSTMMaxout(nn.Module):
def __init__(self, h_size=[512, 1024, 2048], d=300, mlp_d=1600, dropout_r=0.1, max_l=60, num_classes=3):
super().__init__()

self.arch = "SSE"
self.lstm = nn.LSTM(input_size=d, hidden_size=h_size[0],
num_layers=1, bidirectional=True)

self.lstm_1 = nn.LSTM(input_size=(d + h_size[0] * 2), hidden_size=h_size[1],
num_layers=1, bidirectional=True)

self.lstm_2 = nn.LSTM(input_size=(d + (h_size[0] + h_size[1]) * 2), hidden_size=h_size[2],
num_layers=1, bidirectional=True)

self.max_l = max_l
self.h_size = h_size

self.mlp_1 = nn.Linear(h_size[2] * 2 * 4, mlp_d)
self.mlp_2 = nn.Linear(mlp_d, mlp_d)
self.sm = nn.Linear(mlp_d, num_classes)

self.classifier = nn.Sequential(*[self.mlp_1, nn.ReLU(), nn.Dropout(dropout_r),
self.mlp_2, nn.ReLU(), nn.Dropout(dropout_r),
self.sm])

def display(self):
for param in self.parameters():
print(param.data.size())

def forward(self, sent1, sent2, ext_feats=None, word_to_doc_count=None, raw_sent1=None, raw_sent2=None):
sent1 = sent1.permute(2, 0, 1) # from [B * D * T] to [T * B * D]
sent2 = sent2.permute(2, 0, 1)
sent1_lengths = torch.tensor([len(s.split(" ")) for s in raw_sent1])
sent2_lengths = torch.tensor([len(s.split(" ")) for s in raw_sent2])
if self.max_l:
sent1_lengths = sent1_lengths.clamp(max=self.max_l)
sent2_lengths = sent2_lengths.clamp(max=self.max_l)
if sent1.size(0) > self.max_l:
sent1 = sent1[:self.max_l, :]
if sent2.size(0) > self.max_l:
sent2 = sent2[:self.max_l, :]
#p_sent1 = self.Embd(sent1)
#p_sent2 = self.Embd(sent2)
sent1_layer1_out = torch_util.auto_rnn_bilstm(self.lstm, sent1, sent1_lengths)
sent2_layer1_out = torch_util.auto_rnn_bilstm(self.lstm, sent2, sent2_lengths)

# Length truncate
len1 = sent1_layer1_out.size(0)
len2 = sent2_layer1_out.size(0)
p_sent1 = sent1[:len1, :, :] # [T, B, D]
p_sent2 = sent2[:len2, :, :] # [T, B, D]

# Using residual connection
sent1_layer2_in = torch.cat([p_sent1, sent1_layer1_out], dim=2)
sent2_layer2_in = torch.cat([p_sent2, sent2_layer1_out], dim=2)

sent1_layer2_out = torch_util.auto_rnn_bilstm(self.lstm_1, sent1_layer2_in, sent1_lengths)
sent2_layer2_out = torch_util.auto_rnn_bilstm(self.lstm_1, sent2_layer2_in, sent2_lengths)

sent1_layer3_in = torch.cat([p_sent1, sent1_layer1_out, sent1_layer2_out], dim=2)
sent2_layer3_in = torch.cat([p_sent2, sent2_layer1_out, sent2_layer2_out], dim=2)

sent1_layer3_out = torch_util.auto_rnn_bilstm(self.lstm_2, sent1_layer3_in, sent1_lengths)
sent2_layer3_out = torch_util.auto_rnn_bilstm(self.lstm_2, sent2_layer3_in, sent2_lengths)

sent1_layer3_maxout = torch_util.max_along_time(sent1_layer3_out, sent1_lengths)
sent2_layer3_maxout = torch_util.max_along_time(sent2_layer3_out, sent2_lengths)

# Only use the last layer
features = torch.cat([sent1_layer3_maxout, sent2_layer3_maxout,
torch.abs(sent1_layer3_maxout - sent2_layer3_maxout),
sent1_layer3_maxout * sent2_layer3_maxout],
dim=1)

out = self.classifier(features)
out = F.log_softmax(out, dim=1)
return out

71 changes: 71 additions & 0 deletions utils/torch_util.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,71 @@
import numpy as np

import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.autograd import Variable


def auto_rnn_bilstm(lstm: nn.LSTM, seqs, lengths):
batch_size = seqs.size(1)
state_shape = lstm.num_layers * 2, batch_size, lstm.hidden_size
h0 = c0 = Variable(seqs.data.new(*state_shape).zero_())

packed_pinputs, r_index = pack_for_rnn_seq(seqs, lengths)
output, (hn, cn) = lstm(packed_pinputs, (h0, c0))
output = unpack_from_rnn_seq(output, r_index)

return output

def pack_for_rnn_seq(inputs, lengths):
"""
:param inputs: [T * B * D]
:param lengths: [B]
:return:
"""
_, sorted_indices = lengths.sort()
'''
Reverse to decreasing order
'''
r_index = reversed(list(sorted_indices))
s_inputs_list = []
lengths_list = []
reverse_indices = np.zeros(lengths.size(0), dtype=np.int64)

for j, i in enumerate(r_index):
s_inputs_list.append(inputs[:, i, :].unsqueeze(1))
lengths_list.append(lengths[i])
reverse_indices[i] = j

reverse_indices = list(reverse_indices)

s_inputs = torch.cat(s_inputs_list, 1)
packed_seq = nn.utils.rnn.pack_padded_sequence(s_inputs, lengths_list)

return packed_seq, reverse_indices

def unpack_from_rnn_seq(packed_seq, reverse_indices):
unpacked_seq, _ = nn.utils.rnn.pad_packed_sequence(packed_seq)
s_inputs_list = []

for i in reverse_indices:
s_inputs_list.append(unpacked_seq[:, i, :].unsqueeze(1))
return torch.cat(s_inputs_list, 1)

def max_along_time(inputs, lengths):
"""
:param inputs: [T * B * D]
:param lengths: [B]
:return: [B * D] max_along_time
"""
ls = list(lengths)

b_seq_max_list = []
for i, l in enumerate(ls):
seq_i = inputs[:l, i, :]
seq_i_max, _ = seq_i.max(dim=0)
seq_i_max = seq_i_max.squeeze()
b_seq_max_list.append(seq_i_max)

return torch.stack(b_seq_max_list)

0 comments on commit 57f53a8

Please sign in to comment.