From 7dfe022474f57cbe78d1a6a385897019d39bf428 Mon Sep 17 00:00:00 2001 From: Kai Sheng Tai Date: Wed, 27 May 2015 14:33:45 -0700 Subject: [PATCH] Sample subtrees when training sentiment LSTMs --- sentiment/LSTMSentiment.lua | 73 ++++++++++++++++++++----------------- util/Tree.lua | 16 ++++++++ util/read_data.lua | 21 +++++++++++ 3 files changed, 77 insertions(+), 33 deletions(-) diff --git a/sentiment/LSTMSentiment.lua b/sentiment/LSTMSentiment.lua index d4e73a2..003d9a1 100644 --- a/sentiment/LSTMSentiment.lua +++ b/sentiment/LSTMSentiment.lua @@ -11,11 +11,12 @@ function LSTMSentiment:__init(config) self.learning_rate = config.learning_rate or 0.05 self.emb_learning_rate = config.emb_learning_rate or 0.1 self.num_layers = config.num_layers or 1 - self.batch_size = config.batch_size or 25 + self.batch_size = config.batch_size or 5 self.reg = config.reg or 1e-4 self.structure = config.structure or 'lstm' -- {lstm, bilstm} self.fine_grained = (config.fine_grained == nil) and true or config.fine_grained self.dropout = (config.dropout == nil) and true or config.dropout + self.train_subtrees = 4 -- number of subtrees to sample during training -- word embedding self.emb_dim = config.emb_vecs:size(2) @@ -118,41 +119,47 @@ function LSTMSentiment:train(dataset) local loss = 0 for j = 1, batch_size do local idx = indices[i + j - 1] + local tree = dataset.trees[idx] local sent = dataset.sents[idx] - local label = dataset.labels[idx] - local inputs = self.emb:forward(sent) - - -- get sentence representations - local rep - if self.structure == 'lstm' then - rep = self.lstm:forward(inputs) - elseif self.structure == 'bilstm' then - rep = { - self.lstm:forward(inputs), - self.lstm_b:forward(inputs, true), -- true => reverse - } + local subtrees = tree:depth_first_preorder() + for k = 1, self.train_subtrees + 1 do + local subtree = (k == 1) and tree or subtrees[math.ceil(torch.uniform(1, #subtrees))] + local span = sent[{{subtree.lo, subtree.hi}}] + local inputs = self.emb:forward(span) + + -- get sentence representations + local rep + if self.structure == 'lstm' then + rep = self.lstm:forward(inputs) + elseif self.structure == 'bilstm' then + rep = { + self.lstm:forward(inputs), + self.lstm_b:forward(inputs, true), -- true => reverse + } + end + + -- compute class log probabilities + local output = self.sentiment_module:forward(rep) + + -- compute loss and backpropagate + local example_loss = self.criterion:forward(output, subtree.gold_label) + loss = loss + example_loss + local obj_grad = self.criterion:backward(output, subtree.gold_label) + local rep_grad = self.sentiment_module:backward(rep, obj_grad) + local input_grads + if self.structure == 'lstm' then + input_grads = self:LSTM_backward(sent, inputs, rep_grad) + elseif self.structure == 'bilstm' then + input_grads = self:BiLSTM_backward(sent, inputs, rep_grad) + end + self.emb:backward(span, input_grads) end - - -- compute class log probabilities - local output = self.sentiment_module:forward(rep) - - -- compute loss and backpropagate - local example_loss = self.criterion:forward(output, label) - loss = loss + example_loss - local obj_grad = self.criterion:backward(output, label) - local rep_grad = self.sentiment_module:backward(rep, obj_grad) - local input_grads - if self.structure == 'lstm' then - input_grads = self:LSTM_backward(sent, inputs, rep_grad) - elseif self.structure == 'bilstm' then - input_grads = self:BiLSTM_backward(sent, inputs, rep_grad) - end - self.emb:backward(sent, input_grads) end - loss = loss / batch_size - self.grad_params:div(batch_size) - self.emb.gradWeight:div(batch_size) + local batch_subtrees = batch_size * (self.train_subtrees + 1) + loss = loss / batch_subtrees + self.grad_params:div(batch_subtrees) + self.emb.gradWeight:div(batch_subtrees) -- regularization loss = loss + 0.5 * self.reg * self.params:norm() ^ 2 @@ -266,7 +273,7 @@ function LSTMSentiment:print_config() printf('%-25s = %s\n', 'LSTM structure', self.structure) printf('%-25s = %d\n', 'LSTM layers', self.num_layers) printf('%-25s = %.2e\n', 'regularization strength', self.reg) - printf('%-25s = %d\n', 'minibatch size', self.batch_size) + printf('%-25s = %d\n', 'minibatch size', self.batch_size * (self.train_subtrees + 1)) printf('%-25s = %.2e\n', 'learning rate', self.learning_rate) printf('%-25s = %.2e\n', 'word vector learning rate', self.emb_learning_rate) printf('%-25s = %s\n', 'dropout', tostring(self.dropout)) diff --git a/util/Tree.lua b/util/Tree.lua index ded1e01..7c677ee 100644 --- a/util/Tree.lua +++ b/util/Tree.lua @@ -41,3 +41,19 @@ function Tree:depth() end return depth end + +local function depth_first_preorder(tree, nodes) + if tree == nil then + return + end + table.insert(nodes, tree) + for i = 1, tree.num_children do + depth_first_preorder(tree.children[i], nodes) + end +end + +function Tree:depth_first_preorder() + local nodes = {} + depth_first_preorder(self, nodes) + return nodes +end diff --git a/util/read_data.lua b/util/read_data.lua index 71b1374..20e91f2 100644 --- a/util/read_data.lua +++ b/util/read_data.lua @@ -142,6 +142,10 @@ function treelstm.read_sentiment_dataset(dir, vocab, fine_grained) dataset.vocab = vocab dataset.fine_grained = fine_grained local trees = treelstm.read_trees(dir .. 'parents.txt', dir .. 'labels.txt') + for _, tree in ipairs(trees) do + set_spans(tree) + end + local sents = treelstm.read_sentences(dir .. 'sents.txt', vocab) if not fine_grained then dataset.trees = {} @@ -166,6 +170,23 @@ function treelstm.read_sentiment_dataset(dir, vocab, fine_grained) return dataset end +function set_spans(tree) + if tree.num_children == 0 then + tree.lo, tree.hi = tree.leaf_idx, tree.leaf_idx + return + end + + for i = 1, tree.num_children do + set_spans(tree.children[i]) + end + + tree.lo, tree.hi = tree.children[1].lo, tree.children[1].hi + for i = 2, tree.num_children do + tree.lo = math.min(tree.lo, tree.children[i].lo) + tree.hi = math.max(tree.hi, tree.children[i].hi) + end +end + function remap_labels(tree, fine_grained) if fine_grained then tree.gold_label = tree.gold_label + 3