Skip to content

Commit

Permalink
Update trec_eval.py
Browse files Browse the repository at this point in the history
Update to use pytrec_eval
  • Loading branch information
sunnweiwei authored Feb 24, 2024
1 parent 2552334 commit 51dad78
Showing 1 changed file with 139 additions and 96 deletions.
235 changes: 139 additions & 96 deletions trec_eval.py
Original file line number Diff line number Diff line change
@@ -1,19 +1,131 @@
import os
import re
import subprocess
import sys
import platform
import pandas as pd
import tempfile
import os
import copy
from typing import Dict, Tuple
import pytrec_eval


def trec_eval(qrels: Dict[str, Dict[str, int]],
results: Dict[str, Dict[str, float]],
k_values: Tuple[int] = (10, 50, 100, 200, 1000)) -> Dict[str, float]:
ndcg, _map, recall = {}, {}, {}

for k in k_values:
ndcg[f"NDCG@{k}"] = 0.0
_map[f"MAP@{k}"] = 0.0
recall[f"Recall@{k}"] = 0.0

map_string = "map_cut." + ",".join([str(k) for k in k_values])
ndcg_string = "ndcg_cut." + ",".join([str(k) for k in k_values])
recall_string = "recall." + ",".join([str(k) for k in k_values])

evaluator = pytrec_eval.RelevanceEvaluator(qrels, {map_string, ndcg_string, recall_string})
scores = evaluator.evaluate(results)

for query_id in scores:
for k in k_values:
ndcg[f"NDCG@{k}"] += scores[query_id]["ndcg_cut_" + str(k)]
_map[f"MAP@{k}"] += scores[query_id]["map_cut_" + str(k)]
recall[f"Recall@{k}"] += scores[query_id]["recall_" + str(k)]

def _normalize(m: dict) -> dict:
return {k: round(v / len(scores), 5) for k, v in m.items()}

ndcg = _normalize(ndcg)
_map = _normalize(_map)
recall = _normalize(recall)

all_metrics = {}
for mt in [ndcg, _map, recall]:
all_metrics.update(mt)

return all_metrics


def get_qrels_file(name):
THE_TOPICS = {
'dl19': 'dl19-passage',
'dl20': 'dl20-passage',
'covid': 'beir-v1.0.0-trec-covid-test',
'arguana': 'beir-v1.0.0-arguana-test',
'touche': 'beir-v1.0.0-webis-touche2020-test',
'news': 'beir-v1.0.0-trec-news-test',
'scifact': 'beir-v1.0.0-scifact-test',
'fiqa': 'beir-v1.0.0-fiqa-test',
'scidocs': 'beir-v1.0.0-scidocs-test',
'nfc': 'beir-v1.0.0-nfcorpus-test',
'quora': 'beir-v1.0.0-quora-test',
'dbpedia': 'beir-v1.0.0-dbpedia-entity-test',
'fever': 'beir-v1.0.0-fever-test',
'robust04': 'beir-v1.0.0-robust04-test',
'signal': 'beir-v1.0.0-signal1m-test',
}
name = THE_TOPICS[name]
name = name.replace('-test', '.test')
name = 'data/label_file/qrels.' + name + '.txt'
return name


def remove_duplicate(response):
new_response = []
for c in response:
if c not in new_response:
new_response.append(c)
else:
print('duplicate')
return new_response

from pyserini.search import get_qrels_file
from pyserini.util import download_evaluation_script

def clean_response(response: str):
new_response = ''
for c in response:
if not c.isdigit():
new_response += ' '
else:
try:
new_response += str(int(c))
except:
new_response += ' '
new_response = new_response.strip()
return new_response


class EvalFunction:
@staticmethod
def receive_responses(rank_results, responses, cut_start=0, cut_end=100):
print('receive_responses', len(responses), len(rank_results))
for i in range(len(responses)):
response = responses[i]
response = clean_response(response)
response = [int(x) - 1 for x in response.split()]
response = remove_duplicate(response)
cut_range = copy.deepcopy(rank_results[i]['hits'][cut_start: cut_end])
original_rank = [tt for tt in range(len(cut_range))]
response = [ss for ss in response if ss in original_rank]
response = response + [tt for tt in original_rank if tt not in response]
for j, x in enumerate(response):
rank_results[i]['hits'][j + cut_start] = {
'content': cut_range[x]['content'], 'qid': cut_range[x]['qid'], 'docid': cut_range[x]['docid'],
'rank': cut_range[j]['rank'], 'score': cut_range[j]['score']}
return rank_results

@staticmethod
def write_file(rank_results, file):
print('write_file')
with open(file, 'w') as f:
for i in range(len(rank_results)):
rank = 1
hits = rank_results[i]['hits']
for hit in hits:
f.write(f"{hit['qid']} Q0 {hit['docid']} {rank} {hit['score']} rank\n")
rank += 1
return True

@staticmethod
def trunc(qrels, run):
qrels = get_qrels_file(qrels)
# print(qrels)
run = pd.read_csv(run, delim_whitespace=True, header=None)
qrels = pd.read_csv(qrels, delim_whitespace=True, header=None)
run[0] = run[0].astype(str)
Expand All @@ -25,92 +137,23 @@ def trunc(qrels, run):
return temp_file

@staticmethod
def eval(args, trunc=True):
script_path = download_evaluation_script('trec_eval')
cmd_prefix = ['java', '-jar', script_path]
# args = sys.argv

# Option to discard non-judged hits in run file
judged_docs_only = ''
judged_result = []
cutoffs = []

if '-remove-unjudged' in args:
judged_docs_only = args.pop(args.index('-remove-unjudged'))

if any([i.startswith('judged.') for i in args]):
# Find what position the arg is in.
idx = [i.startswith('judged.') for i in args].index(True)
cutoffs = args.pop(idx)
cutoffs = list(map(int, cutoffs[7:].split(',')))
# Get rid of the '-m' before the 'judged.xxx' option
args.pop(idx - 1)

temp_file = ''

if len(args) > 1:
if trunc:
args[-2] = EvalFunction.trunc(args[-2], args[-1])
print('Trunc', args[-2])

if not os.path.exists(args[-2]):
args[-2] = get_qrels_file(args[-2])
if os.path.exists(args[-1]):
# Convert run to trec if it's on msmarco
with open(args[-1]) as f:
first_line = f.readline()
if 'Q0' not in first_line:
temp_file = tempfile.NamedTemporaryFile(delete=False).name
print('msmarco run detected. Converting to trec...')
run = pd.read_csv(args[-1], delim_whitespace=True, header=None,
names=['query_id', 'doc_id', 'rank'])
run['score'] = 1 / run['rank']
run.insert(1, 'Q0', 'Q0')
run['name'] = 'TEMPRUN'
run.to_csv(temp_file, sep='\t', header=None, index=None)
args[-1] = temp_file

run = pd.read_csv(args[-1], delim_whitespace=True, header=None)
qrels = pd.read_csv(args[-2], delim_whitespace=True, header=None)

# cast doc_id column as string
run[0] = run[0].astype(str)
qrels[0] = qrels[0].astype(str)

# Discard non-judged hits

if judged_docs_only:
if not temp_file:
temp_file = tempfile.NamedTemporaryFile(delete=False).name
judged_indexes = pd.merge(run[[0, 2]].reset_index(), qrels[[0, 2]], on=[0, 2])['index']
run = run.loc[judged_indexes]
run.to_csv(temp_file, sep='\t', header=None, index=None)
args[-1] = temp_file
# Measure judged@cutoffs
for cutoff in cutoffs:
run_cutoff = run.groupby(0).head(cutoff)
judged = len(pd.merge(run_cutoff[[0, 2]], qrels[[0, 2]], on=[0, 2])) / len(run_cutoff)
metric_name = f'judged_{cutoff}'
judged_result.append(f'{metric_name:22}\tall\t{judged:.4f}')
cmd = cmd_prefix + args[1:]
else:
cmd = cmd_prefix

print(f'Running command: {cmd}')
shell = platform.system() == "Windows"
process = subprocess.Popen(cmd,
stdout=subprocess.PIPE,
stderr=subprocess.PIPE,
shell=shell)
stdout, stderr = process.communicate()
if stderr:
print(stderr.decode("utf-8"))

print('Results:')
print(stdout.decode("utf-8").rstrip())

for judged in judged_result:
print(judged)

if temp_file:
os.remove(temp_file)
def main(args_qrel, args_run):

args_qrel = EvalFunction.trunc(args_qrel, args_run)

assert os.path.exists(args_qrel)
assert os.path.exists(args_run)

with open(args_qrel, 'r') as f_qrel:
qrel = pytrec_eval.parse_qrel(f_qrel)

with open(args_run, 'r') as f_run:
run = pytrec_eval.parse_run(f_run)

all_metrics = trec_eval(qrel, run, k_values=(1, 5, 10))
print(all_metrics)
return all_metrics


if __name__ == '__main__':
EvalFunction.main('dl19', 'ranking_results_file')

0 comments on commit 51dad78

Please sign in to comment.