-
Notifications
You must be signed in to change notification settings - Fork 18
/
evaluate_retrieval.py
225 lines (179 loc) · 8.78 KB
/
evaluate_retrieval.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
#
# For licensing see accompanying LICENSE file.
# Copyright (C) 2021 Apple Inc. All Rights Reserved.
#
from argparse import ArgumentParser
import json
import logging
from multiprocessing import Pool
from pathlib import Path
from typing import Dict, List, Tuple
from evaluate_qa import compute_exact, compute_f1
from span_heuristic import find_closest_span_match
"""
Functions for evaluating passage retrieval.
This is used to compute MRR (mean reciprocal rank), Recall@10, and Recall@100 in Table 5 of the paper.
"""
RELEVANCE_THRESHOLD = 0.8
def compute_f1_for_retrieved_passage(line: str) -> dict:
"""
Given a serialized JSON line, with fields 'content' and 'answer', find the closest span matching answer,
update the deserialized dict with the span and F1 score, and return the dict.
"""
data = json.loads(line)
content, answer = data['content'], data['answer']
# If there is no answer, although the closest extractive answer is '', in the MRR and recall@k functions below
# we do not count any passage for these questions as relevant.
if len(answer) < 1:
data['heuristic_answer'] = ''
data['f1'] = compute_f1(answer, '')
return data
best_span, best_f1 = find_closest_span_match(content, answer)
data['heuristic_answer'] = best_span
data['f1'] = best_f1
return data
def compute_mean_reciprocal_rank(
question_id_to_docs: Dict[str, List[dict]], relevance_threshold: float
) -> float:
"""Given a dictionary mapping a question id to a list of docs, find the mean reciprocal rank."""
recip_rank_sum = 0
for qid, docs in question_id_to_docs.items():
top_rank = float('inf')
for doc in docs:
if len(doc['answer']) > 0 and doc['f1'] >= relevance_threshold:
top_rank = min(top_rank, doc['rank'])
recip_rank = 1 / top_rank if top_rank != float('inf') else 0
recip_rank_sum += recip_rank
return recip_rank_sum / len(question_id_to_docs)
def compute_recall_at_k(
question_id_to_docs: Dict[str, List[dict]], k: int, relevance_threshold: float
) -> float:
"""
Given a dictionary mapping a question id to a list of docs, find the recall@k.
We define recall@k = 1.0 if any document in the top-k is relevant, and 0 otherwise.
"""
relevant_doc_found_total = 0
for qid, docs in question_id_to_docs.items():
relevant_doc_found = 0
for doc in docs:
if len(doc['answer']) > 0 and doc['f1'] >= relevance_threshold and doc['rank'] <= k:
relevant_doc_found = 1
break
relevant_doc_found_total += relevant_doc_found
return relevant_doc_found_total / len(question_id_to_docs)
def compute_extractive_upper_bounds(
question_id_to_docs: Dict[str, List[dict]], temp_files_directory: Path
) -> Tuple[float, float]:
"""Given a dictionary mapping a question id to a list of docs, find the extractive upper bounds of (EM, F1)."""
total_em, total_f1 = 0, 0.0
with open(temp_files_directory / 'retrieved-passages-relevant-f1.jsonl', 'w') as outfile:
for qid, docs in question_id_to_docs.items():
best_em, best_f1 = 0, 0.0
best_doc = docs[0]
for doc in docs:
em = compute_exact(doc['answer'], doc['heuristic_answer'])
f1 = compute_f1(doc['answer'], doc['heuristic_answer'])
if f1 > best_f1:
best_doc = doc
best_em = max(best_em, em)
best_f1 = max(best_f1, f1)
if best_em == 1 and best_f1 == 1.0:
break
total_em += best_em
total_f1 += best_f1
outfile.write(json.dumps(best_doc) + '\n')
return (
total_em / len(question_id_to_docs),
total_f1 / len(question_id_to_docs),
)
def get_unique_relevant_docs_count(
question_id_to_docs: Dict[str, List[dict]], relevance_threshold: float
) -> float:
"""Given a dictionary mapping a question id to a list of docs, find the number of unique relevant docs."""
unique_relevant_docs = set()
for qid, docs in question_id_to_docs.items():
for doc in docs:
if len(doc['answer']) > 0 and doc['f1'] >= relevance_threshold:
unique_relevant_docs.add(doc['docid'])
return len(unique_relevant_docs)
def get_average_relevant_docs_per_question(
question_id_to_docs: Dict[str, List[dict]], relevance_threshold: float
) -> float:
"""Given a dictionary mapping a question id to a list of docs, find the average number of relevant docs per question."""
relevant_docs = 0
for qid, docs in question_id_to_docs.items():
for doc in docs:
if len(doc['answer']) > 0 and doc['f1'] >= relevance_threshold:
relevant_docs += 1
return relevant_docs / len(question_id_to_docs)
def main(retrieved_passages_pattern: str, temp_files_directory: str, workers: int):
retrieved_passages_files = Path().glob(retrieved_passages_pattern)
temp_files_directory = Path(temp_files_directory)
temp_files_directory.mkdir(exist_ok=True, parents=True)
question_id_to_docs = {}
for retrieved_passages_file in retrieved_passages_files:
with open(retrieved_passages_file) as infile:
with Pool(workers) as p:
for i, passage_results in enumerate(
p.imap(compute_f1_for_retrieved_passage, infile)
):
if (i + 1) % 5000 == 0:
logging.info(
f'Processing {retrieved_passages_file.name}, {i + 1} lines done...'
)
qid = f"{passage_results['Conversation-ID']}_{passage_results['Turn-ID']}"
if qid not in question_id_to_docs:
question_id_to_docs[qid] = []
question_id_to_docs[qid].append(
{
'Conversation-ID': passage_results['Conversation-ID'],
'Turn-ID': passage_results['Turn-ID'],
'docid': passage_results['docid'],
'content': passage_results['content'],
'rank': passage_results['rank'],
'answer': passage_results['answer'],
'heuristic_answer': passage_results['heuristic_answer'],
'f1': passage_results['f1'],
}
)
print('Final metrics:')
unique_relevant_docs = get_unique_relevant_docs_count(question_id_to_docs, RELEVANCE_THRESHOLD)
unique_docs_perfect_f1 = get_unique_relevant_docs_count(question_id_to_docs, 1.0)
avg_relevant_docs_per_question = get_average_relevant_docs_per_question(
question_id_to_docs, 1.0
)
print(f'Total number of unique queries: {len(question_id_to_docs)}')
print(f'Total number of unique relevant docs: {unique_relevant_docs}')
print(f'Total number of unique docs with F1=1.0: {unique_docs_perfect_f1}')
print(f'Average number of relevant docs per query: {avg_relevant_docs_per_question}')
mrr = compute_mean_reciprocal_rank(question_id_to_docs, RELEVANCE_THRESHOLD)
recall_at_10 = compute_recall_at_k(question_id_to_docs, 10, RELEVANCE_THRESHOLD)
recall_at_100 = compute_recall_at_k(question_id_to_docs, 100, RELEVANCE_THRESHOLD)
print(f'Mean Reciprocal Rank (MRR): {mrr:.4f}')
print(f'Recall@10: {recall_at_10 * 100:.2f}%')
print(f'Recall@100: {recall_at_100 * 100:.2f}%')
em_upper_bound, f1_upper_bound = compute_extractive_upper_bounds(
question_id_to_docs, temp_files_directory
)
print(f'Extractive Upper Bound for EM (100 point scale): {em_upper_bound * 100:.2f}')
print(f'Extractive Upper Bound for F1 (100 point scale): {f1_upper_bound * 100:.2f}')
if __name__ == '__main__':
parser = ArgumentParser(description='Passage retrieval evaluation')
parser.add_argument(
'--retrieved-passages-pattern',
required=True,
help="""A globbing pattern to select .jsonl files containing retrieved passages.
Each json line should contain the fields 'Conversation-ID', 'Turn-ID', 'docid', 'content', 'answer', 'rank'.
'answer' is the gold answer given in the QReCC dataset and rank is the rank of the document starting from 1.""",
)
parser.add_argument(
'--temp-files-directory',
default='/tmp/qrecc-retrieval-eval',
help='Directory to store temporary files containing F1 scores, which can be used for debugging and analysis',
)
parser.add_argument(
'--workers', default=8, type=int, help='Number of workers for parallel processing',
)
args = parser.parse_args()
logging.basicConfig(level=logging.INFO)
main(args.retrieved_passages_pattern, args.temp_files_directory, args.workers)