Skip to content

Commit

Permalink
Resolve conflict
Browse files Browse the repository at this point in the history
  • Loading branch information
zeynepakkalyoncu committed Jun 29, 2019
2 parents a037292 + cc28c98 commit ce05ece
Show file tree
Hide file tree
Showing 2 changed files with 18 additions and 2 deletions.
9 changes: 9 additions & 0 deletions src/eval_bert.py
Original file line number Diff line number Diff line change
Expand Up @@ -49,6 +49,8 @@ def calc_q_doc_bert(score_dict, run_file, topics, top_doc_dict, bm25_dict,
topKSent, alpha, beta, gamma):
run_file = open(os.path.join('runs', run_file), "w")

top_docs = {}

for q in topics:
doc_score_dict = {}
for d in top_doc_dict[q]:
Expand All @@ -63,10 +65,17 @@ def calc_q_doc_bert(score_dict, run_file, topics, top_doc_dict, bm25_dict,
score_list.append(s)
sum_score += s * w
doc_score_dict[d] = alpha * bm25_dict[q][d] + (1.0 - alpha) * sum_score
top_docs[d] = (bm25_dict[q][d], sum_score, doc_score_dict[d]) # used only for interactive querying where len(topics) = 1

doc_score_dict = sorted(doc_score_dict.items(), key=operator.itemgetter(1), reverse=True)
rank = 1
# print(doc_score_dict)
for doc, score in doc_score_dict:
if rank <= 10:
print('doc id: {} | doc score: {} | BERT score: {} | overall score: {}'.format(doc, top_docs[doc][0], top_docs[doc][1], top_docs[doc][2]))
run_file.write("{} Q0 {} {} {} BERT\n".format(q, doc, rank, score))
rank += 1

run_file.close()


11 changes: 9 additions & 2 deletions src/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -77,12 +77,13 @@ def main():
os.mkdir('runs')

if mode == 'train':
topics = train_topics if not args.interactive else list(q_dict.keys())
# Grid search for best parameters
for a in np.arange(0.0, alpha, 0.1):
for b in np.arange(0.0, beta, 0.1):
for g in np.arange(0.0, gamma, 0.1):
calc_q_doc_bert(score_dict, 'run.' + experiment + '.cv.train',
train_topics, top_doc_dict, doc_bm25_dict,
topics, top_doc_dict, doc_bm25_dict,
topK, a, b, g)
base = 'runs/run.' + experiment + '.cv.train'
os.system('{}/eval/trec_eval.9.0.4/trec_eval -M1000 -m map {} {}> eval.base'.format(anserini_path, qrels_path, base))
Expand All @@ -94,12 +95,18 @@ def main():
round(b, 2), round(g, 2), map_score)

elif mode == 'test':
topics = test_topics if not args.interactive else list(
q_dict.keys())
calc_q_doc_bert(score_dict,
'run.' + experiment + '.cv.test.' + str(test_folder_set),
topics, top_doc_dict, doc_bm25_dict, topK, alpha,
beta, gamma)
else:
calc_q_doc_bert(score_dict, 'run.' + experiment + '.cv.all', all_topics,
topics = all_topics if not args.interactive else list(
q_dict.keys())
if args.interactive:
print('Top 10 documents for query: "{}"'.format(args.query))
calc_q_doc_bert(score_dict, 'run.' + experiment + '.cv.all', topics,
top_doc_dict, doc_bm25_dict, topK, alpha, beta, gamma)


Expand Down

0 comments on commit ce05ece

Please sign in to comment.