-
Notifications
You must be signed in to change notification settings - Fork 56
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Add convolutional RNN for sentence classification (#57)
* Add SST data preprocessing * Add ConvRNN model * Add LR scheduler * Add grid search on hyperparameters * Add random search * Add CLI options * Add usage to README.md * Refactor code * Fix randomized search parameters * Update README.md with results * Use Dataset and DataLoader
- Loading branch information
1 parent
511f29a
commit e7ca33d
Showing
12 changed files
with
2,556 additions
and
0 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,23 @@ | ||
## Convolutional RNN | ||
|
||
Implementation based on [[1]](http://dl.acm.org/citation.cfm?id=3098140). | ||
|
||
### Usage | ||
|
||
Run `./getData.sh` to fetch the data. The project structure should now look like this: | ||
|
||
``` | ||
├── conv_rnn/ | ||
│ ├── data/ | ||
│ ├── saves/ | ||
│ └── *.* | ||
``` | ||
You may then run `python train.py` and `python test.py` for training and testing, respectively. For more options, add the `-h` switch. | ||
|
||
### Empirical results | ||
Best dev | Test | ||
-- | -- | ||
51.1 | 50.7 | ||
|
||
### References | ||
[1] Chenglong Wang, Feijun Jiang, and Hongxia Yang. 2017. A Hybrid Framework for Text Modeling with Convolutional RNN. In Proceedings of the 23rd ACM SIGKDD International Conference on Knowledge Discovery and Data Mining (KDD '17). |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,61 @@ | ||
import os | ||
import re | ||
|
||
import numpy as np | ||
import torch.utils.data as data | ||
|
||
def sst_tokenize(sentence): | ||
extraneous_pattern = re.compile(r"^(--lrb--|--rrb--|``|''|--|\.)$") | ||
words = [] | ||
for word in sentence.split(): | ||
if re.match(extraneous_pattern, word): | ||
continue | ||
words.append(word) | ||
return words | ||
|
||
class SSTEmbeddingLoader(object): | ||
def __init__(self, dirname, fmt="stsa.fine.{}", word2vec_file="word2vec.sst-1"): | ||
self.dirname = dirname | ||
self.fmt = fmt | ||
self.word2vec_file = word2vec_file | ||
|
||
def load_embed_data(self): | ||
weights = [] | ||
id_dict = {} | ||
unk_vocab_set = set() | ||
with open(os.path.join(self.dirname, self.word2vec_file)) as f: | ||
for i, line in enumerate(f.readlines()): | ||
word, vec = line.replace("\n", "").split(" ", 1) | ||
word = word.replace("#", "") | ||
vec = np.array([float(v) for v in vec.split(" ")]) | ||
weights.append(vec) | ||
id_dict[word] = i | ||
with open(os.path.join(self.dirname, self.fmt.format("phrases.train"))) as f: | ||
for line in f.readlines(): | ||
for word in sst_tokenize(line): | ||
if word not in id_dict and word not in unk_vocab_set: | ||
unk_vocab_set.add(word) | ||
return (id_dict, np.array(weights), list(unk_vocab_set)) | ||
|
||
class SSTDataset(data.Dataset): | ||
def __init__(self, sentences): | ||
super().__init__() | ||
self.sentences = sentences | ||
|
||
def __len__(self): | ||
return len(self.sentences) | ||
|
||
def __getitem__(self, index): | ||
return self.sentences[index] | ||
|
||
@classmethod | ||
def load_sst_sets(cls, dirname, fmt="stsa.fine.{}"): | ||
set_names = ["phrases.train", "dev", "test"] | ||
def read_set(name): | ||
data_set = [] | ||
with open(os.path.join(dirname, fmt.format(name))) as f: | ||
for line in f.readlines(): | ||
sentiment, sentence = line.replace("\n", "").split(" ", 1) | ||
data_set.append((sentiment, sentence)) | ||
return np.array(data_set) | ||
return [cls(read_set(name)) for name in set_names] |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,7 @@ | ||
#!/bin/sh | ||
mkdir -p data | ||
mkdir -p saves | ||
wget http://ocp59jkku.bkt.clouddn.com/sst-1.zip -P data/ | ||
wget http://ocp59jkku.bkt.clouddn.com/sst-2.zip -P data/ | ||
unzip data/sst-1.zip -d data/ | ||
unzip data/sst-2.zip -d data/ |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,141 @@ | ||
import random | ||
|
||
import numpy as np | ||
import torch | ||
import torch.nn as nn | ||
import torch.nn.functional as nn_func | ||
|
||
import data | ||
|
||
class ConvRNNModel(nn.Module): | ||
def __init__(self, word_model, **config): | ||
super().__init__() | ||
embedding_dim = word_model.dim | ||
self.word_model = word_model | ||
self.hidden_size = config["hidden_size"] | ||
fc_size = config["fc_size"] | ||
self.batch_size = config["mbatch_size"] | ||
dropout = config["dropout_prob"] | ||
n_fmaps = config["n_feature_maps"] | ||
self.rnn_type = config["rnn_type"] | ||
|
||
self.h_0_cache = torch.autograd.Variable(torch.zeros(2, self.batch_size, self.hidden_size)) | ||
self.c_0_cache = torch.autograd.Variable(torch.zeros(2, self.batch_size, self.hidden_size)) | ||
|
||
self.no_cuda = config["no_cuda"] | ||
if not self.no_cuda: | ||
self.h_0_cache = self.h_0_cache.cuda() | ||
self.c_0_cache = self.c_0_cache.cuda() | ||
|
||
if self.rnn_type.upper() == "LSTM": | ||
self.bi_rnn = nn.LSTM(embedding_dim, self.hidden_size, 1, batch_first=True, bidirectional=True) | ||
elif self.rnn_type.upper() == "GRU": | ||
self.bi_rnn = nn.GRU(embedding_dim, self.hidden_size, 1, batch_first=True, bidirectional=True) | ||
else: | ||
raise ValueError("RNN type must be one of LSTM or GRU") | ||
self.conv = nn.Conv2d(1, n_fmaps, (1, self.hidden_size * 2)) | ||
if dropout: | ||
self.dropout = nn.Dropout(dropout) | ||
self.fc1 = nn.Linear(n_fmaps + 2 * self.hidden_size, fc_size) | ||
self.fc2 = nn.Linear(fc_size, config["n_labels"]) | ||
|
||
def convert_dataset(self, dataset): | ||
dataset = np.stack(dataset) | ||
model_in = dataset[:, 1].reshape(-1) | ||
model_out = dataset[:, 0].flatten().astype(np.int) | ||
model_out = torch.autograd.Variable(torch.from_numpy(model_out)) | ||
model_in = self.preprocess(model_in) | ||
model_in = torch.autograd.Variable(model_in) | ||
if not self.no_cuda: | ||
model_out = model_out.cuda() | ||
model_in = model_in.cuda() | ||
return (model_in, model_out) | ||
|
||
def preprocess(self, sentences): | ||
return torch.from_numpy(np.array(self.word_model.lookup(sentences))) | ||
|
||
def forward(self, x): | ||
x = self.word_model(x) # shape: (batch, max sent, embed dim) | ||
if x.size(0) == self.batch_size: | ||
h_0 = self.h_0_cache | ||
c_0 = self.c_0_cache | ||
else: | ||
h_0 = torch.autograd.Variable(torch.zeros(2, x.size(0), self.hidden_size)) | ||
c_0 = torch.autograd.Variable(torch.zeros(2, x.size(0), self.hidden_size)) | ||
if not self.no_cuda: | ||
h_0 = h_0.cuda() | ||
c_0 = c_0.cuda() | ||
if self.rnn_type.upper() == "LSTM": | ||
rnn_seq, rnn_out = self.bi_rnn(x, (h_0, c_0)) # shape: (batch, seq len, 2 * hidden_size), (2, batch, hidden_size) | ||
rnn_out = rnn_out[0] # (h_0, c_0) | ||
else: | ||
rnn_seq, rnn_out = self.bi_rnn(x, h_0) # shape: (batch, 2, hidden_size) | ||
rnn_out.data = rnn_out.data.permute(1, 0, 2) | ||
x = self.conv(rnn_seq.unsqueeze(1)).squeeze(3) # shape: (batch, channels, seq len) | ||
x = nn_func.relu(x) # shape: (batch, channels, seq len) | ||
x = nn_func.max_pool1d(x, x.size(2)) # shape: (batch, channels) | ||
out = [t.squeeze(1) for t in rnn_out.chunk(2, 1)] | ||
out.append(x) | ||
x = torch.cat(out, 1).squeeze(2) | ||
if hasattr(self, "dropout"): | ||
x = self.dropout(x) | ||
x = nn_func.relu(self.fc1(x)) | ||
return self.fc2(x) | ||
|
||
class WordEmbeddingModel(nn.Module): | ||
def __init__(self, id_dict, weights, unknown_vocab=[], static=True, padding_idx=0): | ||
super().__init__() | ||
vocab_size = len(id_dict) + len(unknown_vocab) | ||
self.lookup_table = id_dict | ||
last_id = max(id_dict.values()) | ||
for word in unknown_vocab: | ||
last_id += 1 | ||
self.lookup_table[word] = last_id | ||
self.dim = weights.shape[1] | ||
self.weights = np.concatenate((weights, np.random.rand(len(unknown_vocab), self.dim) / 2 - 0.25)) | ||
self.padding_idx = padding_idx | ||
self.embedding = nn.Embedding(vocab_size, self.dim, padding_idx=padding_idx) | ||
self.embedding.weight.data.copy_(torch.from_numpy(self.weights)) | ||
if static: | ||
self.embedding.weight.requires_grad = False | ||
|
||
@classmethod | ||
def make_random_model(cls, id_dict, unknown_vocab=[], dim=300): | ||
weights = np.random.rand(len(id_dict), dim) - 0.5 | ||
return cls(id_dict, weights, unknown_vocab, static=False) | ||
|
||
def forward(self, x): | ||
return self.embedding(x) | ||
|
||
def lookup(self, sentences): | ||
raise NotImplementedError | ||
|
||
class SSTWordEmbeddingModel(WordEmbeddingModel): | ||
def __init__(self, id_dict, weights, unknown_vocab=[]): | ||
super().__init__(id_dict, weights, unknown_vocab, padding_idx=16259) | ||
|
||
def lookup(self, sentences): | ||
indices_list = [] | ||
max_len = 0 | ||
for sentence in sentences: | ||
indices = [] | ||
for word in data.sst_tokenize(sentence): | ||
try: | ||
index = self.lookup_table[word] | ||
indices.append(index) | ||
except KeyError: | ||
continue | ||
indices_list.append(indices) | ||
if len(indices) > max_len: | ||
max_len = len(indices) | ||
for indices in indices_list: | ||
indices.extend([self.padding_idx] * (max_len - len(indices))) | ||
return indices_list | ||
|
||
def set_seed(seed=0, no_cuda=False): | ||
np.random.seed(seed) | ||
if not no_cuda: | ||
torch.cuda.manual_seed(seed) | ||
torch.cuda.manual_seed_all(seed) | ||
torch.manual_seed(seed) | ||
random.seed(seed) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,36 @@ | ||
import argparse | ||
import os | ||
import random | ||
|
||
import numpy as np | ||
import torch | ||
import torch.nn as nn | ||
|
||
import data | ||
import model | ||
|
||
def main(): | ||
parser = argparse.ArgumentParser() | ||
parser.add_argument("--no_cuda", action="store_true", default=False) | ||
parser.add_argument("--input_file", default="saves/model.pt", type=str) | ||
parser.add_argument("--data_dir", default="data", type=str) | ||
parser.add_argument("--gpu_number", default=0, type=int) | ||
args = parser.parse_args() | ||
|
||
model.set_seed(5, no_cuda=args.no_cuda) | ||
data_loader = data.SSTDataLoader(args.data_dir) | ||
conv_rnn = torch.load(args.input_file) | ||
if not args.no_cuda: | ||
torch.cuda.set_device(args.gpu_number) | ||
conv_rnn.cuda() | ||
_, _, test_set = data_loader.load_sst_sets() | ||
|
||
conv_rnn.eval() | ||
test_in, test_out = conv_rnn.convert_dataset(test_set) | ||
scores = conv_rnn(test_in) | ||
n_correct = (torch.max(scores, 1)[1].view(len(test_set)).data == test_out.data).sum() | ||
accuracy = n_correct / len(test_set) | ||
print("Test set accuracy: {}".format(accuracy)) | ||
|
||
if __name__ == "__main__": | ||
main() |
Oops, something went wrong.