Skip to content

Commit

Permalink
Minor fixes
Browse files Browse the repository at this point in the history
  • Loading branch information
zeynepakkalyoncu committed Jul 1, 2019
1 parent 6f7fbcb commit d9c80af
Show file tree
Hide file tree
Showing 3 changed files with 4 additions and 12 deletions.
2 changes: 1 addition & 1 deletion eval_scripts/test.sh
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@ else
collection="robust04_2cv"
fi

if [ ${tune_params} == "True" ] ; then
if [ ${tune_params} ] ; then
declare -a sents=("a" "ab" "abc")

./eval_scripts/train.qqsh ${experiment} ${num_folds} ${anserini_path}
Expand Down
4 changes: 0 additions & 4 deletions src/eval_bert.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,8 +48,6 @@ def load_bert_scores(pred_file, query_dict, sent_dict):
def calc_q_doc_bert(score_dict, run_file, topics, top_doc_dict, bm25_dict,
topKSent, alpha, beta, gamma):
run_file = open(os.path.join('runs', run_file), "w")

doc_scores = {}
for q in topics:
doc_score_dict = {}
for d in top_doc_dict[q]:
Expand All @@ -64,11 +62,9 @@ def calc_q_doc_bert(score_dict, run_file, topics, top_doc_dict, bm25_dict,
score_list.append(s)
sum_score += s * w
doc_score_dict[d] = alpha * bm25_dict[q][d] + (1.0 - alpha) * sum_score
doc_scores[d] = (bm25_dict[q][d], sum_score, doc_score_dict[d]) # used only for interactive querying where len(topics) = 1

doc_score_dict = sorted(doc_score_dict.items(), key=operator.itemgetter(1), reverse=True)
rank = 1
# print(doc_score_dict)
for doc, score in doc_score_dict:
run_file.write("{} Q0 {} {} {} BERT\n".format(q, doc, rank, score))
rank += 1
Expand Down
10 changes: 3 additions & 7 deletions src/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -66,17 +66,14 @@ def main():

top_doc_dict, doc_bm25_dict, sent_dict, q_dict, doc_label_dict = eval_bm25(collection_path)
score_dict = load_bert_scores(predictions_path, q_dict, sent_dict)
topics = test_topics if not args.interactive else list(q_dict.keys())

print(topics)

if args.interactive:
top_rank_docs = visualize_scores(collection_path, score_dict)
with open(os.path.join(args.data_path, 'query_sent_scores.csv'), 'w') as scores_file:
for doc in top_rank_docs[:100]:
scores_file.write('{} | {} | {} | {} | {}\n'.format(doc[0], sentid2text[doc[0]], doc[1], doc[2], 'BM25' if doc[3] > 0 else 'BERT'))
scores_file.write('{}\t{}\t{}\t{}\t{}\n'.format(doc[0], sentid2text[doc[0]], doc[1], doc[2], 'BM25' if doc[3] > 0 else 'BERT'))
for doc in top_rank_docs[-100:]:
scores_file.write('{} | {} | {} | {} | {}\n'.format(doc[0], sentid2text[doc[0]], doc[1], doc[2], 'BM25' if doc[3] > 0 else 'BERT'))
scores_file.write('{}\t{}\t{}\t{}\t{}\n'.format(doc[0], sentid2text[doc[0]], doc[1], doc[2], 'BM25' if doc[3] > 0 else 'BERT'))

if not os.path.isdir('runs'):
os.mkdir('runs')
Expand All @@ -100,8 +97,7 @@ def main():
round(b, 2), round(g, 2), map_score)

elif mode == 'test':
topics = test_topics if not args.interactive else list(
q_dict.keys())
topics = test_topics if not args.interactive else list(q_dict.keys())
calc_q_doc_bert(score_dict,
'run.' + experiment + '.cv.test.' + str(test_folder_set),
topics, top_doc_dict, doc_bm25_dict, topK, alpha,
Expand Down

0 comments on commit d9c80af

Please sign in to comment.