Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add colbert version of IBM reranker #918

Merged
merged 20 commits into from
Jan 3, 2022
Merged
Changes from 4 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
95 changes: 26 additions & 69 deletions scripts/rank_ibm.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,16 +17,11 @@
import argparse
import os
import json
import sys
import time
sys.path.append('..')
sys.path.append('../pyserini')
from multiprocessing.pool import ThreadPool
import subprocess
from pyserini.pyclass import autoclass, JString

from typing import List
from typing import List, Set
from typing import List, Set, Dict


import struct
Expand All @@ -42,7 +37,7 @@
lambdaValue = 0.3
minCollectProb=1e-9

def normalize(scores):
def normalize(scores: List[float]):
low = min(scores)
high = max(scores)
width = high - low
Expand All @@ -52,39 +47,6 @@ def normalize(scores):
return scores


def get_lines_by_topic(path, topic, tag):
res = []
with open(path, 'r') as f:
for line in f:
tokens = line.split()
if tokens[0] != topic:
continue
tokens[-1] = tag
new_line = ' '.join(tokens)
res.append(new_line)

return res


def read_qrels(path: str):
qrels = []

with open(path, 'r') as f:
for line in f:
line = line.strip()
tokens = line.split()
topic = tokens[0]
doc_id = tokens[-2]
relevance = int(tokens[-1])
qrels.append({
'topic': topic,
'doc_id': doc_id,
'relevance': relevance
})

return qrels


def get_docs_from_qrun_by_topic(path: str):
result_dic={}
with open(path, 'r') as f:
Expand Down Expand Up @@ -125,10 +87,9 @@ def evaluate(qrels_path: str, run_path: str, options: str = ''):
cmd2 = f"{prefix} {run_path} {options} | grep 'map '"
ndcg_score = str(subprocess.check_output(cmd1, shell=True)).split('\\t')[-1].split('\\n')[0]
map_score = str(subprocess.check_output(cmd2, shell=True)).split('\\t')[-1].split('\\n')[0]
print(str(map_score),str(ndcg_score))
return str(map_score),str(ndcg_score)

def sort_dual_list(pred, docs):
def sort_dual_list(pred: List[float], docs: List[str]):
zipped_lists = zip(pred, docs)
sorted_pairs = sorted(zipped_lists)

Expand All @@ -140,7 +101,6 @@ def sort_dual_list(pred, docs):
return pred, docs



def get_ibm_score(arguments):
query_text_lst = arguments['query_text_lst']
test_doc = arguments['test_doc']
Expand All @@ -150,18 +110,16 @@ def get_ibm_score(arguments):
targetLookup = arguments['targetLookup']
tran = arguments['tran']
collectProbs = arguments['collectProbs']
#print(time.time())

if searcher.documentRaw(test_doc) ==None:
print(test_doc)
print(f'{test_doc} is not found in searcher')
document_text= json.loads(searcher.documentRaw(test_doc))[fieldName]
#print(time.time())
doc_token_lst = document_text.split(" ")
totalQueryProb = 0
docSize = len(doc_token_lst)
querySize = len(query_text_lst)
for querytoken in query_text_lst:
targetMap = {}
#print(time.time())
totTranProb = 0
collectProb = collectProbs[querytoken]
if querytoken in targetLookup.keys():
Expand All @@ -180,13 +138,13 @@ def get_ibm_score(arguments):
totTranProb += (tranProb/docSize)

queryWordProb=math.log((1 - lambdaValue) * totTranProb + lambdaValue * collectProb)

totalQueryProb += queryWordProb
#print(time.time())
return totalQueryProb /querySize



def query_loader(query_path):
def query_loader(query_path: str):
queries = {}
with open(query_path) as f:
for line in f:
Expand All @@ -200,11 +158,12 @@ def query_loader(query_path):
return queries


def intBitsToFloat(b):
def intBitsToFloat(b: bytes):
s = struct.pack('>l', b)
return struct.unpack('>f', s)[0]

def rescale(sourceLookup,targetLookup,tranLookup,targetVoc,sourceVoc):
def rescale(sourceLookup: Dict[str,int],targetLookup: Dict[str,int],tranLookup: Dict[str,Dict[str,float]],\
targetVoc: Dict[int,str],sourceVoc: Dict[int,str]):
for targetID in tranLookup:
targetProbs = tranLookup[targetID]
if targetID > 0:
Expand All @@ -228,7 +187,7 @@ def rescale(sourceLookup,targetLookup,tranLookup,targetVoc,sourceVoc):



def load_tranProbsTable(dir_path):
def load_tranProbsTable(dir_path: str):
source_path = dir_path +"/source.vcb"
sourceLookup = {}
sourceVoc={}
Expand Down Expand Up @@ -260,7 +219,6 @@ def load_tranProbsTable(dir_path):
targetID = int.from_bytes(byte,"big")
assert(targetID in targetVoc.keys())
byte = file.read(4)
#tranProb = float.from_bytes(byte)
tranProb = intBitsToFloat(int.from_bytes(byte,"big"))
if (targetID in tranLookup.keys()) and (tranProb>minProb):
tranLookup[targetID][sourceID] = tranProb
Expand All @@ -271,9 +229,9 @@ def load_tranProbsTable(dir_path):
return rescale(sourceLookup,targetLookup,tranLookup,targetVoc,sourceVoc)


def rank(qrels: str, base: str,tran_path:str, query_path:str, lucene_index_path: str,output_path:str,score_path:str,fieldName:str, tag: str,alpha:int,num_threads:int):
#print(time.time())
# build output path
def rank(qrels: str, base: str,tran_path:str, query_path:str, lucene_index_path: str,output_path:str, \
score_path:str,fieldName:str, tag: str,alpha:int,num_threads:int):

pool = ThreadPool(num_threads)
searcher = JSimpleSearcher(JString(lucene_index_path))
reader = JIndexReader().getReader(JString(lucene_index_path))
Expand All @@ -288,9 +246,7 @@ def rank(qrels: str, base: str,tran_path:str, query_path:str, lucene_index_path:
query= query_loader(query_path)

i = 0
#print(time.time())
for topic in topics:
#print(time.time())
[test_docs, base_scores] = doc_dic[topic]
rank_scores = []
if i % 100==0:
Expand All @@ -300,12 +256,12 @@ def rank(qrels: str, base: str,tran_path:str, query_path:str, lucene_index_path:
collectProbs ={}
for querytoken in query_text_lst:
collectProbs[querytoken] = max(reader.totalTermFreq(JTerm(fieldName, querytoken))/totalTermFreq, minCollectProb)
arguments = [{"query_text_lst":query_text_lst,"test_doc":test_doc, "searcher":searcher,"fieldName":fieldName,"sourceLookup":sourceLookup,"targetLookup":targetLookup,"tran":tran,"collectProbs":collectProbs} for test_doc in test_docs]
#print(time.time())
#print(time.time())
rank_scores = pool.map(get_ibm_score, arguments)
arguments = [{"query_text_lst":query_text_lst,"test_doc":test_doc, "searcher":searcher,\
"fieldName":fieldName,"sourceLookup":sourceLookup,"targetLookup":targetLookup,\
"tran":tran,"collectProbs":collectProbs} for test_doc in test_docs]
rank_scores = pool.map(get_ibm_score, arguments)

ibm_scores = normalize([p for p in rank_scores])
#print(time.time())
base_scores = normalize([p for p in base_scores])

interpolated_scores = [a * alpha + b * (1-alpha) for a, b in zip(base_scores, ibm_scores)]
Expand All @@ -331,24 +287,25 @@ def rank(qrels: str, base: str,tran_path:str, query_path:str, lucene_index_path:
metavar="path_to_qrels", help='path to new_qrels file')
parser.add_argument('-base', type=str, default="../ibm/run.msmarco-passage.bm25tuned.trec",
metavar="path_to_base_run", help='path to base run')
parser.add_argument('-tran_path', type=str, default="../ibm/ibm_model/text_bert_tok",
parser.add_argument('-tran_path', type=str, default="../ibm/ibm_model/text_bert_tok_raw",
metavar="directory_path", help='directory path to source.vcb target.vcb and Transtable bin file')
parser.add_argument('-query_path', type=str, default="../ibm/queries.dev.small.json",
metavar="path_to_query", help='path to dev queries file')
parser.add_argument('-index', type=str, default="../ibm/index-msmarco-passage-ltr-20210519-e25e33f",
metavar="path_to_lucene_index", help='path to lucene index folder')
parser.add_argument('-output', type=str, default="../ibm/runs/result-text-bert-0.txt",
parser.add_argument('-output', type=str, default="../ibm/runs/result-text-bert-tuned0.1.txt",
metavar="path_to_reranked_run", help='the path to store reranked run file')
parser.add_argument('-score_path', type=str, default="../ibm/result-ibm-0.json",
parser.add_argument('-score_path', type=str, default="../ibm/result-ibm-0.1.json",
metavar="path_to_base_run", help='the path to map and ndcg scores')
parser.add_argument('-fieldName', type=str, default="text_bert_tok",
metavar="type of field", help='type of field used for training')
parser.add_argument('-alpha', type=float, default="0",
parser.add_argument('-alpha', type=float, default="0.1",
metavar="type of field", help='interpolation weight')
parser.add_argument('-num_threads', type=int, default="24",
parser.add_argument('-num_threads', type=int, default="12",
metavar="num_of_threads", help='number of threads to use')
args = parser.parse_args()

print('Using base run:', args.base)

rank(args.qrels, args.base, args.tran_path, args.query_path, args.index, args.output, args.score_path,args.fieldName, args.tag,args.alpha,args.num_threads)
rank(args.qrels, args.base, args.tran_path, args.query_path, args.index, args.output, \
args.score_path,args.fieldName, args.tag,args.alpha,args.num_threads)