Skip to content

Commit

Permalink
Update run_evaluation.py
Browse files Browse the repository at this point in the history
  • Loading branch information
sunnweiwei authored Oct 24, 2023
1 parent 6c2d724 commit 01e4100
Showing 1 changed file with 5 additions and 20 deletions.
25 changes: 5 additions & 20 deletions run_evaluation.py
Original file line number Diff line number Diff line change
Expand Up @@ -67,31 +67,22 @@

openai_key = os.environ.get("OPENAI_API_KEY", None)

# for data in ['dl19', 'dl20', 'covid', 'nfc', 'touche', 'dbpedia', 'scifact', 'signal', 'news', 'robust04']:
for data in ['signal', 'news', 'robust04']:

for data in ['dl19', 'dl20', 'covid', 'nfc', 'touche', 'dbpedia', 'scifact', 'signal', 'news', 'robust04']:
print('#' * 20)
print(f'Evaluation on {data}')
print('#' * 20)


# Retrieve passages using pyserini BM25.
# Get a specific doc:
# * searcher.num_docs
# * json.loads(searcher.object.reader.document(4).fields[1].fieldsData) -> {"id": "1", "contents": ""}
try:
searcher = LuceneSearcher.from_prebuilt_index(THE_INDEX[data])
topics = get_topics(THE_TOPICS[data] if data != 'dl20' else 'dl20')
qrels = get_qrels(THE_TOPICS[data])
rank_results = run_retriever(topics, searcher, qrels, k=100)

# Store JSON in rank_results to a file
with open(f'rank_results_{data}.json', 'w') as f:
json.dump(rank_results, f, indent=2)
# Store the QRELS of the dataset
with open(f'qrels_{data}.json', 'w') as f:
json.dump(qrels, f, indent=2)
except:
print(f'Failed to retrieve passages for {data}')
continue

# Run sliding window permutation generation
new_results = []
for item in tqdm(rank_results):
Expand All @@ -110,7 +101,6 @@
shutil.move(output_file, f'eval_{data}.txt')



for data in ['mrtydi-ar', 'mrtydi-bn', 'mrtydi-fi', 'mrtydi-id', 'mrtydi-ja', 'mrtydi-ko', 'mrtydi-ru', 'mrtydi-sw', 'mrtydi-te', 'mrtydi-th']:
print('#' * 20)
print(f'Evaluation on {data}')
Expand All @@ -124,14 +114,9 @@
rank_results = run_retriever(topics, searcher, qrels, k=100)
rank_results = rank_results[:100]

# Store JSON in rank_results to a file
with open(f'rank_results_{data}.json', 'w') as f:
json.dump(rank_results, f, indent=2)
# Store the QRELS of the dataset
with open(f'qrels_{data}.json', 'w') as f:
json.dump(qrels, f, indent=2)
except:
print(f'Failed to retrieve passages for {data}')
continue

# Run sliding window permutation generation
new_results = []
Expand Down

0 comments on commit 01e4100

Please sign in to comment.