Skip to content
This repository has been archived by the owner on Aug 31, 2021. It is now read-only.

Commit

Permalink
Constrained decoding and tree acc code
Browse files Browse the repository at this point in the history
  • Loading branch information
litesaber15 committed Aug 5, 2019
1 parent 553e538 commit 5dedfe0
Show file tree
Hide file tree
Showing 9 changed files with 2,214 additions and 2 deletions.
12 changes: 10 additions & 2 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -41,7 +41,7 @@ Dataset | BLEU | TreeAcc(whole) | TreeAcc(no-discourse) | TreeAcc(disc
------------|-------|------------------|-------------------------|-----------
S2S-Tree | - | 94.00 | 96.66 | 86.59
S2S-Constr | - | 97.15 | 98.76 | 94.45

##### Weather Challenge Dataset
Dataset | BLEU | TreeAcc(whole) | TreeAcc(no-discourse) | TreeAcc(discourse)
------------|-------|------------------|-------------------------|-----------
Expand All @@ -54,7 +54,15 @@ Dataset | BLEU | TreeAcc(whole) | TreeAcc(no-discourse) | TreeAcc(disc
S2S-Tree | 74.58 | 97.06 | 99.68 | 95.28
S2S-Constr | 74.69 | 99.25 | 99.89 | 97.78

#### We are currently preparing code for release, and plan to add it to this repository as soon as possible. Stay tuned for updates!
## Code

Computing tree accuracy:

```
python compute_tree_acc.py -tsv ~/seq2seq_out.tsv
```

Output file should be tab-separated with columns `id, input, pred, target`.

### License
TreeNLG is released under [CC-BY-NC-4.0](https://creativecommons.org/licenses/by-nc/4.0/legalcode), see [LICENSE](LICENSE.md) for details.
27 changes: 27 additions & 0 deletions compute_tree_acc.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,27 @@
#!/usr/bin/env python3
# Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved

import argparse

from tree_accuracy.tree_accuracy import compare_trees, scenario_to_tree, sequence_to_tree


if __name__ == "__main__":
parser = argparse.ArgumentParser(description="Compute tree accuracy")
# TSV file expected in format id, input, pred, target
parser.add_argument("-tsv", type=str)
args = parser.parse_args()
with open(args.tsv, "r") as f:
lines = [l.strip().split("\t") for l in f.readlines()]
print("Loaded {} lines".format(len(lines)))
correct = 0
for line in lines:
scenario_tree = scenario_to_tree(line[1].split(" "))
pred_tree = sequence_to_tree(line[2].split(" "))
if compare_trees(scenario_tree, pred_tree):
correct += 1
print(
"Tree accuracy: {:.2f} ({} / {})".format(
correct / len(lines) * 100, correct, len(lines)
)
)
Empty file.
138 changes: 138 additions & 0 deletions constrained_decoding/constrained_sequence_generator.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,138 @@
#!/usr/bin/env python3
# Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved

import re
from copy import deepcopy

import torch
from pytorch_translate import utils as pytorch_translate_utils

from constrained_decoding.constraint_checking import TreeConstraints
from constrained_decoding.sequence_generator import SequenceGenerator


# non-terminal prefix
NT_PREFIX = "b__"


def bracketize(s):
"""
Change the prefix of non-terminal tokens b__ to [__, i.e.,
b__dg_inform__ to [__dg_inform__.
"""
tokens = s.split()
if len(tokens) <= 1:
return re.sub(r"^%s" % NT_PREFIX, "[__", s)
else:
return " ".join([bracketize(t) for t in tokens])


class NLGFairseqSequenceGenerator(SequenceGenerator):
def __init__(self, models, src_dict, tgt_dict, config):
super().__init__(models, tgt_dict, **config._asdict())
self.src_dict = src_dict
self.tgt_dict = tgt_dict

def generate_hypo(self, repacked_inputs, maxlen_a=0.0, maxlen_b=None):
if maxlen_b is None:
maxlen_b = self.maxlen
src_tokens = repacked_inputs["src_tokens"]
srclen = pytorch_translate_utils.get_source_tokens_tensor(src_tokens).size(1)
hypos = self.generate(
repacked_inputs,
beam_size=self.beam_size,
maxlen=int(maxlen_a * srclen + maxlen_b),
# If we need to generate predictions with teacher forcing, this
# won't work. Right now this is fine.
prefix_tokens=None,
)
return self._pick_hypothesis_unpack_output(hypos)

@staticmethod
def _pack_input_for_fairseq(src_tokens, src_lengths):
return {"src_tokens": src_tokens, "src_lengths": src_lengths}

@staticmethod
def _pick_hypothesis_unpack_output(all_hypos):
"""
For now, we just pick the first hypothesis returned by fairseq and we
return just the "tokens" as output
"""
results = []
for hypo in all_hypos:
beam_results = []
for prediction in hypo:
beam_results.append(prediction["tokens"])
results.append(beam_results)
return results

def _build_constraints(self, src_tokens, beam_size):
"""
Returns list of constraint objects of size (bsz * beam_size, )
"""
srcs = [" ".join([self.src_dict[tok] for tok in row]) for row in src_tokens]
srcs = [s.replace(self.tgt_dict[self.tgt_dict.bos()], "") for s in srcs]
srcs = [s.replace(self.tgt_dict[self.tgt_dict.eos()], "") for s in srcs]
constraints = [TreeConstraints(bracketize(t)) for t in srcs]
bbeam_constraints = []
for constraint in constraints:
bbeam_constraints.extend([deepcopy(constraint) for i in range(beam_size)])
self.constraint_penalty = [0.0] * len(bbeam_constraints)
return bbeam_constraints

def _apply_constraint_penalty(self, scores):
"""
Penalize unmet constraints
"""
assert len(self.constraint_penalty) == scores.size(0)
scores += torch.tensor(self.constraint_penalty, device=scores.device).unsqueeze(
1
)

def _update_constraints(self, constraints, next_tokens, idx):
"""
Based on tokens consumed, update constraints and penalties for next step
"""
assert len(constraints) == len(next_tokens)
self.constraint_penalty = [
0.0
if constraint.next_token(bracketize(self.tgt_dict[token]), idx)
else float("-Inf")
for constraint, token in zip(constraints, next_tokens)
]

def _reorder_constraints(self, constraints, new_indices):
"""
Equivalent to constraints[new_indices] if both were Tensors.
"""
# deepcopy is needed since the same candidate can appear in
# multiple locations
return [deepcopy(constraints[idx]) for idx in new_indices]

def _apply_eos_constraints(self, constraints, eos_bbsz_idx, eos_scores):
"""
Only allow EOS for candidates that satisfy all constraints
Returns filters eos indices and scores
"""
eos_constraints = self._reorder_constraints(constraints, eos_bbsz_idx)
meets_constraints = []
for i, con in enumerate(eos_constraints):
if con.meets_all():
meets_constraints.append(i)
meets_constraints = torch.tensor(
meets_constraints, device=eos_bbsz_idx.device, dtype=torch.long
)
return eos_bbsz_idx[meets_constraints], eos_scores[meets_constraints]

def _finalize_constrained_results(self, finalized, device):
"""
Deal with potentially empty results after beam search
"""
for item in finalized:
if len(item) == 0:
item.append(
{
"tokens": torch.LongTensor([self.eos], device=device),
"score": -float("-Inf"),
}
)
Loading

0 comments on commit 5dedfe0

Please sign in to comment.