-
Notifications
You must be signed in to change notification settings - Fork 18
/
span_heuristic.py
87 lines (69 loc) · 3.16 KB
/
span_heuristic.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
#
# For licensing see accompanying LICENSE file.
# Copyright (C) 2020 Apple Inc. All Rights Reserved.
#
"""
Heuristic for finding a span in some passage that's close to the golden span.
"""
from difflib import SequenceMatcher as SM
import re
import string
from typing import List, Tuple
from nltk.util import ngrams
from evaluate_qa import compute_f1, compute_f1_from_tokens, get_tokens, normalize_answer
ARTICLES_RE = re.compile(r'\b(a|an|the)\b', re.UNICODE)
EXCLUDED_PUNCTS = set(string.punctuation)
def _find_approximate_matching_sequence(context: str, target: str) -> Tuple[str, float]:
"""Find some substring in the context which closely matches the target, returning this substring with a score."""
if target in context:
return target, 1.0
target_length = len(target.split())
max_sim_val = 0
max_sim_string = ''
seq_matcher = SM()
seq_matcher.set_seq2(target)
for ngram in ngrams(context.split(), target_length + int(0.05 * target_length)):
candidate_ngram = ' '.join(ngram)
seq_matcher.set_seq1(candidate_ngram)
similarity = seq_matcher.quick_ratio()
if similarity > max_sim_val:
max_sim_val = similarity
max_sim_string = candidate_ngram
if similarity == 1.0:
# early exiting
break
return max_sim_string, max_sim_val
def _normalize_tokens(tokens: List[str], keep_empty_str=True) -> List[str]:
"""
Normalize individual tokens.
If keep_empty_str is True, this keeps the overall number of tokens the same.
A particular token could be normalized to an empty string.
"""
normalized_tokens = []
for token in tokens:
token = token.lower()
token = ''.join(ch for ch in token if ch not in EXCLUDED_PUNCTS)
token = re.sub(ARTICLES_RE, '', token)
if keep_empty_str or len(token):
normalized_tokens.append(token)
return normalized_tokens
def find_closest_span_match(passage: str, gold_answer: str) -> Tuple[str, float]:
"""Heuristic for finding the closest span in a passage relative to some golden answer based on F1 score."""
closest_encompassing_span, closest_encompassing_span_score = _find_approximate_matching_sequence(passage, gold_answer)
closest_encompassing_span_tok = closest_encompassing_span.split()
gold_answer_tok = gold_answer.split()
closest_encompassing_span_tok_normalized = _normalize_tokens(closest_encompassing_span_tok)
gold_answer_tok_normalized = _normalize_tokens(gold_answer_tok, keep_empty_str=False)
best_span, best_score, best_i, best_j = '', 0, None, None
for i in range(0, len(closest_encompassing_span_tok_normalized)):
for j in range(i + 1, len(closest_encompassing_span_tok_normalized) + 1):
score = compute_f1_from_tokens(
gold_answer_tok_normalized,
[t for t in closest_encompassing_span_tok_normalized[i:j] if len(t)],
)
if score > best_score:
best_score = score
best_i, best_j = i, j
best_span = ' '.join(closest_encompassing_span_tok[best_i:best_j])
best_f1 = compute_f1(gold_answer, best_span)
return best_span, best_f1