Skip to content

Commit

Permalink
Add interactive querying
Browse files Browse the repository at this point in the history
  • Loading branch information
zeynepakkalyoncu committed Jun 29, 2019
1 parent 7e0f759 commit fcb26d0
Show file tree
Hide file tree
Showing 6 changed files with 55 additions and 26 deletions.
5 changes: 3 additions & 2 deletions src/args.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@

def get_args():
parser = ArgumentParser(description='birch')
parser.add_argument('--mode', default='training', help='[training, inference, retrieval]')
parser.add_argument('--mode', default='retrieval', help='[training, inference, retrieval]')
parser.add_argument('--output_path', default='out.tmp', help='Name of log file')
parser.add_argument('--data_path', default='data')
parser.add_argument('--qrels_file', default='qrels.robust2004.txt',
Expand All @@ -11,8 +11,9 @@ def get_args():
help='[mb, robust04, core17, core18]')

# Interactive
parser.add_argument('--interactive', action='store_true', default=False, help='Batch evaluation if not set')
parser.add_argument('--query', default='hubble space telescope', help='Query string')
parser.add_argument('--result_path', default='data/query.csv', help='Path to output sentence results from query')
parser.add_argument('--interactive_path', default='data/datasets/query_sents.csv', help='Path to output sentence results from query')

# Retrieval
parser.add_argument('--experiment', default='base_mb_robust04',
Expand Down
25 changes: 13 additions & 12 deletions src/eval_bert.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,10 +25,6 @@ def eval_bm25(collection_file, topK=1000):
if rank <= topK:
top_doc_dict[qid].append(doc)
rank += 1
# for qid in top_doc_dict:
# print('qid: {}'.format(qid))
# print(len(top_doc_dict[qid]))
# assert(len(top_doc_dict[qid]) == topK)
return top_doc_dict, doc_score_dict, sent_dict, q_dict, doc_label_dict


Expand All @@ -37,25 +33,29 @@ def load_bert_scores(pred_file, query_dict, sent_dict):
with open(pred_file) as bF:
for line in bF:
q, _, d, _, score, _ = line.strip().split()
q = query_dict[q]
sent = sent_dict[d]
doc = sent.split('_')[0]
score = float(score)
if doc not in score_dict[q]:
score_dict[q][doc] = [score]
else:
score_dict[q][doc].append(score)
if q in query_dict.keys():
q = query_dict[q]
sent = sent_dict[d]
doc = sent.split('_')[0]
score = float(score)
if doc not in score_dict[q]:
score_dict[q][doc] = [score]
else:
score_dict[q][doc].append(score)
return score_dict


def calc_q_doc_bert(score_dict, run_file, topics, top_doc_dict, bm25_dict,
topKSent, alpha, beta, gamma):
run_file = open(os.path.join('runs', run_file), "w")

for q in topics:
doc_score_dict = {}
for d in top_doc_dict[q]:
# print(d)
scores = score_dict[q][d]
scores.sort(reverse=True)
# print(scores)
sum_score = 0
score_list = []
weight_list = [1, beta, gamma]
Expand All @@ -65,6 +65,7 @@ def calc_q_doc_bert(score_dict, run_file, topics, top_doc_dict, bm25_dict,
doc_score_dict[d] = alpha * bm25_dict[q][d] + (1.0 - alpha) * sum_score
doc_score_dict = sorted(doc_score_dict.items(), key=operator.itemgetter(1), reverse=True)
rank = 1
# print(doc_score_dict)
for doc, score in doc_score_dict:
run_file.write("{} Q0 {} {} {} BERT\n".format(q, doc, rank, score))
rank += 1
Expand Down
19 changes: 16 additions & 3 deletions src/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@
from model.test import test
from model.utils import print_scores
from args import get_args
from query import query_sents

RANDOM_SEED = 12345
random.seed(RANDOM_SEED)
Expand All @@ -24,7 +25,6 @@ def main():

experiment = args.experiment
anserini_path = args.anserini_path
predictions_path = os.path.join(args.data_path, 'predictions', 'predict.' + experiment)
datasets_path = os.path.join(args.data_path, 'datasets')

if not os.path.isdir('log'):
Expand Down Expand Up @@ -57,8 +57,21 @@ def main():
else:
test_topics.extend(folds[i])

top_doc_dict, doc_bm25_dict, sent_dict, q_dict, doc_label_dict = eval_bm25(os.path.join(datasets_path, args.collection + '.csv'))
if args.interactive:
query_sents(args)
test(args) # inference over each sentence

collection_path = os.path.join(datasets_path,
args.collection + '.csv') if not args.interactive else args.interactive_path
predictions_path = os.path.join(args.data_path, 'predictions',
'predict.' + experiment) if not args.interactive else os.path.join(
args.data_path, 'predictions', args.predict_path)

top_doc_dict, doc_bm25_dict, sent_dict, q_dict, doc_label_dict = eval_bm25(collection_path)
score_dict = load_bert_scores(predictions_path, q_dict, sent_dict)
topics = test_topics if not args.interactive else list(q_dict.keys())

print(topics)

if not os.path.isdir('runs'):
os.mkdir('runs')
Expand All @@ -83,7 +96,7 @@ def main():
elif mode == 'test':
calc_q_doc_bert(score_dict,
'run.' + experiment + '.cv.test.' + str(test_folder_set),
test_topics, top_doc_dict, doc_bm25_dict, topK, alpha,
topics, top_doc_dict, doc_bm25_dict, topK, alpha,
beta, gamma)
else:
calc_q_doc_bert(score_dict, 'run.' + experiment + '.cv.all', all_topics,
Expand Down
3 changes: 3 additions & 0 deletions src/model/eval.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,9 @@ def evaluate(trec_eval_path, predictions_file, qrels_file):
p = subprocess.Popen(pargs, stdout=subprocess.PIPE, stderr=subprocess.PIPE)
pout, perr = p.communicate()

print(pout)
print(perr)

if sys.version_info[0] < 3:
lines = pout.split(b'\n')
else:
Expand Down
15 changes: 10 additions & 5 deletions src/model/test.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,7 +38,8 @@ def test(args, split='test', model=None, tokenizer=None, training_or_lm=False):
device=args.device)
else:
# Load Robust04 data
test_dataset = load_trec_data(args.data_path, args.collection,
collection = args.collection if not args.interactive else 'query_sents'
test_dataset = load_trec_data(args.data_path, collection,
args.batch_size, tokenizer, split,
args.device)

Expand Down Expand Up @@ -72,8 +73,12 @@ def test(args, split='test', model=None, tokenizer=None, training_or_lm=False):
output_file.close()
predict_file.close()

map, mrr, p30 = evaluate(args.trec_eval_path,
predictions_file=predictions_path,
qrels_file=os.path.join(args.data_path, args.qrels_file))
if args.interactive:
return None

return [['map', 'mrr', 'p30'], [map, mrr, p30]]
else:
map, mrr, p30 = evaluate(args.trec_eval_path,
predictions_file=predictions_path,
qrels_file=os.path.join(args.data_path, args.qrels_file))

return [['map', 'mrr', 'p30'], [map, mrr, p30]]
14 changes: 10 additions & 4 deletions src/query.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,14 +2,20 @@
from searcher import *
from args import get_args

if __name__ == '__main__':
args, _ = get_args()

def query_sents(args):
collection = args.collection
anserini_path = args.anserini_path
index_path = args.index_path
query = args.query
output_fn = args.result_path
output_fn = args.interactive_path

docsearch = Searcher(anserini_path)
searcher = docsearch.build_searcher(k1=0.9, b=0.4, index_path=index_path, rm3=True)
searcher = docsearch.build_searcher(k1=0.9, b=0.4, index_path=index_path,
rm3=True)
docsearch.search_query(searcher, query, output_fn, collection, K=1000)


if __name__ == '__main__':
args, _ = get_args()
query_sents(args)

0 comments on commit fcb26d0

Please sign in to comment.