diff --git a/eval_scripts/test.sh b/eval_scripts/test.sh index dee61de..b47e15e 100755 --- a/eval_scripts/test.sh +++ b/eval_scripts/test.sh @@ -13,7 +13,7 @@ else collection="robust04_2cv" fi -if [ ${tune_params} == "True" ] ; then +if [ ${tune_params} ] ; then declare -a sents=("a" "ab" "abc") ./eval_scripts/train.qqsh ${experiment} ${num_folds} ${anserini_path} diff --git a/src/eval_bert.py b/src/eval_bert.py index 262e015..9f2a9fe 100644 --- a/src/eval_bert.py +++ b/src/eval_bert.py @@ -48,8 +48,6 @@ def load_bert_scores(pred_file, query_dict, sent_dict): def calc_q_doc_bert(score_dict, run_file, topics, top_doc_dict, bm25_dict, topKSent, alpha, beta, gamma): run_file = open(os.path.join('runs', run_file), "w") - - doc_scores = {} for q in topics: doc_score_dict = {} for d in top_doc_dict[q]: @@ -64,11 +62,9 @@ def calc_q_doc_bert(score_dict, run_file, topics, top_doc_dict, bm25_dict, score_list.append(s) sum_score += s * w doc_score_dict[d] = alpha * bm25_dict[q][d] + (1.0 - alpha) * sum_score - doc_scores[d] = (bm25_dict[q][d], sum_score, doc_score_dict[d]) # used only for interactive querying where len(topics) = 1 doc_score_dict = sorted(doc_score_dict.items(), key=operator.itemgetter(1), reverse=True) rank = 1 - # print(doc_score_dict) for doc, score in doc_score_dict: run_file.write("{} Q0 {} {} {} BERT\n".format(q, doc, rank, score)) rank += 1 diff --git a/src/main.py b/src/main.py index 5edc019..774336a 100644 --- a/src/main.py +++ b/src/main.py @@ -66,17 +66,14 @@ def main(): top_doc_dict, doc_bm25_dict, sent_dict, q_dict, doc_label_dict = eval_bm25(collection_path) score_dict = load_bert_scores(predictions_path, q_dict, sent_dict) - topics = test_topics if not args.interactive else list(q_dict.keys()) - - print(topics) if args.interactive: top_rank_docs = visualize_scores(collection_path, score_dict) with open(os.path.join(args.data_path, 'query_sent_scores.csv'), 'w') as scores_file: for doc in top_rank_docs[:100]: - scores_file.write('{} | {} | {} | {} | {}\n'.format(doc[0], sentid2text[doc[0]], doc[1], doc[2], 'BM25' if doc[3] > 0 else 'BERT')) + scores_file.write('{}\t{}\t{}\t{}\t{}\n'.format(doc[0], sentid2text[doc[0]], doc[1], doc[2], 'BM25' if doc[3] > 0 else 'BERT')) for doc in top_rank_docs[-100:]: - scores_file.write('{} | {} | {} | {} | {}\n'.format(doc[0], sentid2text[doc[0]], doc[1], doc[2], 'BM25' if doc[3] > 0 else 'BERT')) + scores_file.write('{}\t{}\t{}\t{}\t{}\n'.format(doc[0], sentid2text[doc[0]], doc[1], doc[2], 'BM25' if doc[3] > 0 else 'BERT')) if not os.path.isdir('runs'): os.mkdir('runs') @@ -100,8 +97,7 @@ def main(): round(b, 2), round(g, 2), map_score) elif mode == 'test': - topics = test_topics if not args.interactive else list( - q_dict.keys()) + topics = test_topics if not args.interactive else list(q_dict.keys()) calc_q_doc_bert(score_dict, 'run.' + experiment + '.cv.test.' + str(test_folder_set), topics, top_doc_dict, doc_bm25_dict, topK, alpha,