# # For licensing see accompanying LICENSE file. # Copyright (C) 2020 Apple Inc. All Rights Reserved. # """ Heuristic for finding a span in some passage that's close to the golden span. """ from difflib import SequenceMatcher as SM import re import string from typing import List, Tuple from nltk.util import ngrams from evaluate_qa import compute_f1, compute_f1_from_tokens, get_tokens, normalize_answer ARTICLES_RE = re.compile(r'\b(a|an|the)\b', re.UNICODE) EXCLUDED_PUNCTS = set(string.punctuation) def _find_approximate_matching_sequence(context: str, target: str) -> Tuple[str, float]: """Find some substring in the context which closely matches the target, returning this substring with a score.""" if target in context: return target, 1.0 target_length = len(target.split()) max_sim_val = 0 max_sim_string = '' seq_matcher = SM() seq_matcher.set_seq2(target) for ngram in ngrams(context.split(), target_length + int(0.05 * target_length)): candidate_ngram = ' '.join(ngram) seq_matcher.set_seq1(candidate_ngram) similarity = seq_matcher.quick_ratio() if similarity > max_sim_val: max_sim_val = similarity max_sim_string = candidate_ngram if similarity == 1.0: # early exiting break return max_sim_string, max_sim_val def _normalize_tokens(tokens: List[str], keep_empty_str=True) -> List[str]: """ Normalize individual tokens. If keep_empty_str is True, this keeps the overall number of tokens the same. A particular token could be normalized to an empty string. """ normalized_tokens = [] for token in tokens: token = token.lower() token = ''.join(ch for ch in token if ch not in EXCLUDED_PUNCTS) token = re.sub(ARTICLES_RE, '', token) if keep_empty_str or len(token): normalized_tokens.append(token) return normalized_tokens def find_closest_span_match(passage: str, gold_answer: str) -> Tuple[str, float]: """Heuristic for finding the closest span in a passage relative to some golden answer based on F1 score.""" closest_encompassing_span, closest_encompassing_span_score = _find_approximate_matching_sequence(passage, gold_answer) closest_encompassing_span_tok = closest_encompassing_span.split() gold_answer_tok = gold_answer.split() closest_encompassing_span_tok_normalized = _normalize_tokens(closest_encompassing_span_tok) gold_answer_tok_normalized = _normalize_tokens(gold_answer_tok, keep_empty_str=False) best_span, best_score, best_i, best_j = '', 0, None, None for i in range(0, len(closest_encompassing_span_tok_normalized)): for j in range(i + 1, len(closest_encompassing_span_tok_normalized) + 1): score = compute_f1_from_tokens( gold_answer_tok_normalized, [t for t in closest_encompassing_span_tok_normalized[i:j] if len(t)], ) if score > best_score: best_score = score best_i, best_j = i, j best_span = ' '.join(closest_encompassing_span_tok[best_i:best_j]) best_f1 = compute_f1(gold_answer, best_span) return best_span, best_f1