Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add colbert version of IBM reranker #918

Merged
merged 20 commits into from
Jan 3, 2022
Merged
Changes from 4 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
266 changes: 137 additions & 129 deletions scripts/rank_ibm.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,37 +18,38 @@
import os
import json
import sys
import time
sys.path.append('..')
sys.path.append('../pyserini')
from multiprocessing.pool import ThreadPool
import subprocess
from pyserini.pyclass import autoclass, JString

from enum import Enum
from typing import List
from typing import List, Set

import spacy

import struct
import math


JSimpleSearcher = autoclass('io.anserini.search.SimpleSearcher')
JIndexReader = autoclass('io.anserini.index.IndexReaderUtils')
JTerm = autoclass('org.apache.lucene.index.Term')
JDocumentFieldContext = autoclass('io.anserini.ltr.DocumentFieldContext')
JQueryFieldContext = autoclass('io.anserini.ltr.QueryFieldContext')
selfTrans = 0.05
minProb=5e-4
lambdaValue = 0.1
alpha=0.5

selfTrans = 0.35
minProb=0.0025
lambdaValue = 0.3
minCollectProb=1e-9

def normalize(scores):
low = min(scores)
high = max(scores)
width = high - low

return [(s-low)/width for s in scores]

if width!=0:
return [(s-low)/width for s in scores]
else:
return scores


def get_lines_by_topic(path, topic, tag):
Expand Down Expand Up @@ -84,48 +85,22 @@ def read_qrels(path: str):
return qrels


def get_doc_to_id_from_qrun_by_topic(path: str, topic: str):
res = {}
with open(path, 'r') as f:
for line in f:
tokens = line.strip().split()
t = tokens[0]
if topic != t:
continue
doc_id = tokens[2]
score = float(tokens[-2])
res[doc_id] = score

return res


def get_docs_from_qrun_by_topic(path: str, topic: str):
x, y = [], []
def get_docs_from_qrun_by_topic(path: str):
result_dic={}
with open(path, 'r') as f:
for line in f:
tokens = line.strip().split()
t = tokens[0]
if topic != t:
continue
doc_id = tokens[2]
score = float(tokens[-2])
x.append(doc_id)
y.append(score)

return x, y

if t in result_dic.keys():
result_dic[t][0].append(doc_id)
result_dic[t][1].append(score)
else:
result_dic[t]=[[doc_id],[score]]

def get_X_Y_from_qrels_by_topic(path: str, topic: str, R: List[int]):
# always include topic 0
R.append(0)
qrels = [qrel for qrel in read_qrels(path) if qrel['topic'] == topic and qrel['relevance'] in R]
x, y = [], []
for pack in qrels:
x.append(pack['doc_id'])
label = 0 if pack['relevance'] == 0 else 1
y.append(label)
return result_dic

return x, y

def get_topics_from_qrun(path: str) -> Set[str]:
res = set()
Expand All @@ -141,11 +116,11 @@ def sort_str_topics_list(topics: List[str]) -> List[str]:

def evaluate(qrels_path: str, run_path: str, options: str = ''):
curdir = os.getcwd()
if curdir.endswith('clprf'):
anserini_root = '../../../anserini'
if curdir.endswith('scripts'):
anserini_root = '../../anserini'
else:
anserini_root = '../anserini'
prefix = f"{anserini_root} {qrels_path}"
prefix = f"{anserini_root}/tools/eval/trec_eval.9.0.4/trec_eval -c -M1000 -m all_trec {qrels_path}"
cmd1 = f"{prefix} {run_path} {options} | grep 'ndcg_cut_20 '"
cmd2 = f"{prefix} {run_path} {options} | grep 'map '"
ndcg_score = str(subprocess.check_output(cmd1, shell=True)).split('\\t')[-1].split('\\n')[0]
Expand All @@ -166,43 +141,59 @@ def sort_dual_list(pred, docs):



def get_ibm_score(query_text_lst,doc_token_lst, docSize,reader, fieldName,totalTermFreq,sourceLookup,targetLookup,tran):
def get_ibm_score(arguments):
query_text_lst = arguments['query_text_lst']
test_doc = arguments['test_doc']
searcher = arguments['searcher']
fieldName = arguments['fieldName']
sourceLookup = arguments['sourceLookup']
targetLookup = arguments['targetLookup']
tran = arguments['tran']
collectProbs = arguments['collectProbs']
#print(time.time())
if searcher.documentRaw(test_doc) ==None:
print(test_doc)
document_text= json.loads(searcher.documentRaw(test_doc))[fieldName]
#print(time.time())
doc_token_lst = document_text.split(" ")
totalQueryProb = 0
docSize = len(doc_token_lst)
querySize = len(query_text_lst)
for querytoken in query_text_lst:
targetMap = {}
#print(time.time())
totTranProb = 0
collectProb = max(reader.totalTermFreq(JTerm(fieldName, querytoken))/totalTermFreq, 1e-9)
collectProb = collectProbs[querytoken]
if querytoken in targetLookup.keys():
queryWordId = targetLookup[querytoken]
for doctoken in doc_token_lst:
tranProb = 0
docWordId = 0
if querytoken==doctoken:
tranProb = selfTrans
if doctoken in sourceLookup.keys():
docWordId = sourceLookup[doctoken]
if docWordId in tran.keys():
targetMap = tran[docWordId]
if queryWordId in targetMap.keys():
tranProb = max(targetMap[queryWordId],tranProb)
if (tranProb >= minProb):
totTranProb += (tranProb * ((1.0* doc_token_lst.count(doctoken)) / docSize))
queryWordProb = totTranProb*(1-lambdaValue)+lambdaValue*collectProb
#queryWordProb=math.log((1 - lambdaValue) * totTranProb + lambdaValue * collectProb) - math.log(lambdaValue * collectProb)
if totalQueryProb ==0:
totalQueryProb = queryWordProb
else:
totalQueryProb = totalQueryProb*queryWordProb

return totalQueryProb

def query_loader():
if queryWordId in tran.keys():
targetMap = tran[queryWordId]
for doctoken in doc_token_lst:
tranProb = 0
docWordId = 0
if querytoken==doctoken:
tranProb = selfTrans
if doctoken in sourceLookup.keys():
docWordId = sourceLookup[doctoken]
if docWordId in targetMap.keys():
tranProb = max(targetMap[docWordId],tranProb)
totTranProb += (tranProb/docSize)

queryWordProb=math.log((1 - lambdaValue) * totTranProb + lambdaValue * collectProb)
totalQueryProb += queryWordProb
#print(time.time())
return totalQueryProb /querySize



def query_loader(query_path):
queries = {}
with open(f'../ltr/queries.dev.small.json') as f:
with open(query_path) as f:
for line in f:
query = json.loads(line)
qid = query.pop('id')
query['analyzed'] = query['analyzed'].split(" ")
query['text'] = query['text_unlemm'].split(" ")
query['text'] = query['text'].split(" ")
query['text_unlemm'] = query['text_unlemm'].split(" ")
query['text_bert_tok'] = query['text_bert_tok'].split(" ")
queries[qid] = query
Expand All @@ -213,12 +204,28 @@ def intBitsToFloat(b):
s = struct.pack('>l', b)
return struct.unpack('>f', s)[0]

def _normalize(scores: List[float]):
low = min(scores)
high = max(scores)
width = high - low
def rescale(sourceLookup,targetLookup,tranLookup,targetVoc,sourceVoc):
for targetID in tranLookup:
targetProbs = tranLookup[targetID]
if targetID > 0:
adjustMult = (1 - selfTrans)
else:
adjustMult = 1
#adjust the prob with adjustMult and add selfTran prob to self-translation pair
for sourceID in targetProbs.keys():
tranProb = targetProbs[sourceID]
if sourceID >0:
sourceWord = sourceVoc[sourceID]
targetWord = targetVoc[targetID]
tranProb *= adjustMult
if (sourceWord== targetWord):
tranProb += selfTrans
targetProbs[sourceID]= tranProb
# in case if self-translation pair was not included in TransTable
if targetID not in targetProbs.keys():
targetProbs[targetID]= selfTrans
return sourceLookup,targetLookup,tranLookup

return [(s-low)/width for s in scores]


def load_tranProbsTable(dir_path):
Expand Down Expand Up @@ -255,52 +262,51 @@ def load_tranProbsTable(dir_path):
byte = file.read(4)
#tranProb = float.from_bytes(byte)
tranProb = intBitsToFloat(int.from_bytes(byte,"big"))
if sourceID in tranLookup.keys():
tranLookup[sourceID][targetID] = tranProb
else:
tranLookup[sourceID] = {}
tranLookup[sourceID][targetID] = tranProb
if (targetID in tranLookup.keys()) and (tranProb>minProb):
tranLookup[targetID][sourceID] = tranProb
elif tranProb>minProb:
tranLookup[targetID] = {}
tranLookup[targetID][sourceID] = tranProb
byte = file.read(4)
return sourceLookup,targetLookup,tranLookup
return rescale(sourceLookup,targetLookup,tranLookup,targetVoc,sourceVoc)


def rank(new_qrels: str, base: str,dir_path:str, lucene_index_path: str,output_path:str,score_path:str,fieldName:str, tag: str):

def rank(qrels: str, base: str,tran_path:str, query_path:str, lucene_index_path: str,output_path:str,score_path:str,fieldName:str, tag: str,alpha:int,num_threads:int):
#print(time.time())
# build output path
pool = ThreadPool(num_threads)
searcher = JSimpleSearcher(JString(lucene_index_path))
reader = JIndexReader().getReader(JString(lucene_index_path))
sourceLookup,targetLookup,tran = load_tranProbsTable(dir_path)
nlp = spacy.load('en_core_web_sm')
totalTermFreq = reader.getSumTotalTermFreq(fieldName)

sourceLookup,targetLookup,tran = load_tranProbsTable(tran_path)
totalTermFreq = reader.getSumTotalTermFreq(fieldName)
doc_dic = get_docs_from_qrun_by_topic(base)


f = open(output_path, 'w')

skipped_topics = set()
topics = get_topics_from_qrun(base)
query= query_loader()
query= query_loader(query_path)

i = 0
#print(time.time())
for topic in topics:
test_docs, base_scores = get_docs_from_qrun_by_topic(base, topic)
#print(time.time())
[test_docs, base_scores] = doc_dic[topic]
rank_scores = []
#if (i%100==0):
print(f"Reranking {i} query")
j = 0
if i % 100==0:
print(f"Reranking {i} query")
i=i+1
for test_doc in test_docs:
document_text= json.loads(searcher.documentRaw(test_doc))[fieldName]
doc_tokens = nlp(document_text)
doc_token_lst = [i.text for i in doc_tokens]
docSize = len(doc_token_lst)
query_text_lst = query[topic][fieldName]
rank_score = get_ibm_score(query_text_lst,doc_token_lst, docSize,reader, fieldName,totalTermFreq,sourceLookup,targetLookup,tran)
rank_scores.append(rank_score)
#if (j%10==0):
#print(f"Reranking {j} topics score:{rank_score}")
j=j+1
ibm_scores = _normalize([p for p in rank_scores])
base_scores = _normalize([p for p in base_scores])
query_text_lst = query[topic][fieldName]
collectProbs ={}
for querytoken in query_text_lst:
collectProbs[querytoken] = max(reader.totalTermFreq(JTerm(fieldName, querytoken))/totalTermFreq, minCollectProb)
arguments = [{"query_text_lst":query_text_lst,"test_doc":test_doc, "searcher":searcher,"fieldName":fieldName,"sourceLookup":sourceLookup,"targetLookup":targetLookup,"tran":tran,"collectProbs":collectProbs} for test_doc in test_docs]
#print(time.time())
#print(time.time())
rank_scores = pool.map(get_ibm_score, arguments)
ibm_scores = normalize([p for p in rank_scores])
#print(time.time())
base_scores = normalize([p for p in base_scores])

interpolated_scores = [a * alpha + b * (1-alpha) for a, b in zip(base_scores, ibm_scores)]

Expand All @@ -309,38 +315,40 @@ def rank(new_qrels: str, base: str,dir_path:str, lucene_index_path: str,output_p
rank = index + 1
f.write(f'{topic} Q0 {doc_id} {rank} {score} {tag}\n')

for topic in sort_str_topics_list(list(skipped_topics)):
lines = get_lines_by_topic(base, topic, tag)
print(f'Copying over skipped topic {topic} with {len(lines)} lines')
for line in lines:
f.write(f'{line}\n')

f.close()
map_score,ndcg_score = evaluate(new_qrels, output_path)
map_score,ndcg_score = evaluate(qrels, output_path)
with open(score_path, 'w') as outfile:
json.dump({'map':map_score,'ndcg':ndcg_score}, outfile)


if __name__ == '__main__':
parser = argparse.ArgumentParser(
description='use tfidf vectorizer on cord-19 dataset with ccrf technique')
description='use ibm model 1 feature to rerank the base run file')
parser.add_argument('-tag', type=str, default="ibm",
metavar="tag_name", help='tag name for resulting Qrun')
parser.add_argument('-new_qrels', type=str, default="../tools/topics-and-qrels/qrels.msmarco-passage.dev-subset.txt",
metavar="path_to_new_qrels", help='path to new_qrels file')
parser.add_argument('-base', type=str, default="../ltr/run.msmarco-passage.bm25tuned.trec",
parser.add_argument('-qrels', type=str, default="../tools/topics-and-qrels/qrels.msmarco-passage.dev-subset.txt",
metavar="path_to_qrels", help='path to new_qrels file')
parser.add_argument('-base', type=str, default="../ibm/run.msmarco-passage.bm25tuned.trec",
metavar="path_to_base_run", help='path to base run')
parser.add_argument('-dir_path', type=str, default="../ltr/",
metavar="directory path", help='directory path')
parser.add_argument('-index', type=str, default="../ltr/index-msmarco-passage-ltr-20210519-e25e33f",
parser.add_argument('-tran_path', type=str, default="../ibm/ibm_model/text_bert_tok",
metavar="directory_path", help='directory path to source.vcb target.vcb and Transtable bin file')
parser.add_argument('-query_path', type=str, default="../ibm/queries.dev.small.json",
metavar="path_to_query", help='path to dev queries file')
parser.add_argument('-index', type=str, default="../ibm/index-msmarco-passage-ltr-20210519-e25e33f",
metavar="path_to_lucene_index", help='path to lucene index folder')
parser.add_argument('-output', type=str, default="../ltr/result.txt",
metavar="path_to_reranked_run", help='the path to reranked run file')
parser.add_argument('-score_path', type=str, default="../ltr/result.json",
parser.add_argument('-output', type=str, default="../ibm/runs/result-text-bert-0.txt",
metavar="path_to_reranked_run", help='the path to store reranked run file')
parser.add_argument('-score_path', type=str, default="../ibm/result-ibm-0.json",
metavar="path_to_base_run", help='the path to map and ndcg scores')
parser.add_argument('-fieldName', type=str, default="text_unlemm",
parser.add_argument('-fieldName', type=str, default="text_bert_tok",
metavar="type of field", help='type of field used for training')
parser.add_argument('-alpha', type=float, default="0",
metavar="type of field", help='interpolation weight')
parser.add_argument('-num_threads', type=int, default="24",
metavar="num_of_threads", help='number of threads to use')
args = parser.parse_args()

print('Using base run:', args.base)
rank(args.new_qrels, args.base, args.dir_path, args.index, args.output, args.score_path,args.fieldName, args.tag)

rank(args.qrels, args.base, args.tran_path, args.query_path, args.index, args.output, args.score_path,args.fieldName, args.tag,args.alpha,args.num_threads)