Skip to content

Commit

Permalink
Sample subtrees when training sentiment LSTMs
Browse files Browse the repository at this point in the history
  • Loading branch information
kaishengtai committed May 27, 2015
1 parent 89ede02 commit 7dfe022
Show file tree
Hide file tree
Showing 3 changed files with 77 additions and 33 deletions.
73 changes: 40 additions & 33 deletions sentiment/LSTMSentiment.lua
Original file line number Diff line number Diff line change
Expand Up @@ -11,11 +11,12 @@ function LSTMSentiment:__init(config)
self.learning_rate = config.learning_rate or 0.05
self.emb_learning_rate = config.emb_learning_rate or 0.1
self.num_layers = config.num_layers or 1
self.batch_size = config.batch_size or 25
self.batch_size = config.batch_size or 5
self.reg = config.reg or 1e-4
self.structure = config.structure or 'lstm' -- {lstm, bilstm}
self.fine_grained = (config.fine_grained == nil) and true or config.fine_grained
self.dropout = (config.dropout == nil) and true or config.dropout
self.train_subtrees = 4 -- number of subtrees to sample during training

-- word embedding
self.emb_dim = config.emb_vecs:size(2)
Expand Down Expand Up @@ -118,41 +119,47 @@ function LSTMSentiment:train(dataset)
local loss = 0
for j = 1, batch_size do
local idx = indices[i + j - 1]
local tree = dataset.trees[idx]
local sent = dataset.sents[idx]
local label = dataset.labels[idx]
local inputs = self.emb:forward(sent)

-- get sentence representations
local rep
if self.structure == 'lstm' then
rep = self.lstm:forward(inputs)
elseif self.structure == 'bilstm' then
rep = {
self.lstm:forward(inputs),
self.lstm_b:forward(inputs, true), -- true => reverse
}
local subtrees = tree:depth_first_preorder()
for k = 1, self.train_subtrees + 1 do
local subtree = (k == 1) and tree or subtrees[math.ceil(torch.uniform(1, #subtrees))]
local span = sent[{{subtree.lo, subtree.hi}}]
local inputs = self.emb:forward(span)

-- get sentence representations
local rep
if self.structure == 'lstm' then
rep = self.lstm:forward(inputs)
elseif self.structure == 'bilstm' then
rep = {
self.lstm:forward(inputs),
self.lstm_b:forward(inputs, true), -- true => reverse
}
end

-- compute class log probabilities
local output = self.sentiment_module:forward(rep)

-- compute loss and backpropagate
local example_loss = self.criterion:forward(output, subtree.gold_label)
loss = loss + example_loss
local obj_grad = self.criterion:backward(output, subtree.gold_label)
local rep_grad = self.sentiment_module:backward(rep, obj_grad)
local input_grads
if self.structure == 'lstm' then
input_grads = self:LSTM_backward(sent, inputs, rep_grad)
elseif self.structure == 'bilstm' then
input_grads = self:BiLSTM_backward(sent, inputs, rep_grad)
end
self.emb:backward(span, input_grads)
end

-- compute class log probabilities
local output = self.sentiment_module:forward(rep)

-- compute loss and backpropagate
local example_loss = self.criterion:forward(output, label)
loss = loss + example_loss
local obj_grad = self.criterion:backward(output, label)
local rep_grad = self.sentiment_module:backward(rep, obj_grad)
local input_grads
if self.structure == 'lstm' then
input_grads = self:LSTM_backward(sent, inputs, rep_grad)
elseif self.structure == 'bilstm' then
input_grads = self:BiLSTM_backward(sent, inputs, rep_grad)
end
self.emb:backward(sent, input_grads)
end

loss = loss / batch_size
self.grad_params:div(batch_size)
self.emb.gradWeight:div(batch_size)
local batch_subtrees = batch_size * (self.train_subtrees + 1)
loss = loss / batch_subtrees
self.grad_params:div(batch_subtrees)
self.emb.gradWeight:div(batch_subtrees)

-- regularization
loss = loss + 0.5 * self.reg * self.params:norm() ^ 2
Expand Down Expand Up @@ -266,7 +273,7 @@ function LSTMSentiment:print_config()
printf('%-25s = %s\n', 'LSTM structure', self.structure)
printf('%-25s = %d\n', 'LSTM layers', self.num_layers)
printf('%-25s = %.2e\n', 'regularization strength', self.reg)
printf('%-25s = %d\n', 'minibatch size', self.batch_size)
printf('%-25s = %d\n', 'minibatch size', self.batch_size * (self.train_subtrees + 1))
printf('%-25s = %.2e\n', 'learning rate', self.learning_rate)
printf('%-25s = %.2e\n', 'word vector learning rate', self.emb_learning_rate)
printf('%-25s = %s\n', 'dropout', tostring(self.dropout))
Expand Down
16 changes: 16 additions & 0 deletions util/Tree.lua
Original file line number Diff line number Diff line change
Expand Up @@ -41,3 +41,19 @@ function Tree:depth()
end
return depth
end

local function depth_first_preorder(tree, nodes)
if tree == nil then
return
end
table.insert(nodes, tree)
for i = 1, tree.num_children do
depth_first_preorder(tree.children[i], nodes)
end
end

function Tree:depth_first_preorder()
local nodes = {}
depth_first_preorder(self, nodes)
return nodes
end
21 changes: 21 additions & 0 deletions util/read_data.lua
Original file line number Diff line number Diff line change
Expand Up @@ -142,6 +142,10 @@ function treelstm.read_sentiment_dataset(dir, vocab, fine_grained)
dataset.vocab = vocab
dataset.fine_grained = fine_grained
local trees = treelstm.read_trees(dir .. 'parents.txt', dir .. 'labels.txt')
for _, tree in ipairs(trees) do
set_spans(tree)
end

local sents = treelstm.read_sentences(dir .. 'sents.txt', vocab)
if not fine_grained then
dataset.trees = {}
Expand All @@ -166,6 +170,23 @@ function treelstm.read_sentiment_dataset(dir, vocab, fine_grained)
return dataset
end

function set_spans(tree)
if tree.num_children == 0 then
tree.lo, tree.hi = tree.leaf_idx, tree.leaf_idx
return
end

for i = 1, tree.num_children do
set_spans(tree.children[i])
end

tree.lo, tree.hi = tree.children[1].lo, tree.children[1].hi
for i = 2, tree.num_children do
tree.lo = math.min(tree.lo, tree.children[i].lo)
tree.hi = math.max(tree.hi, tree.children[i].hi)
end
end

function remap_labels(tree, fine_grained)
if fine_grained then
tree.gold_label = tree.gold_label + 3
Expand Down

0 comments on commit 7dfe022

Please sign in to comment.