diff --git a/scripts/rank_ibm.py b/scripts/reranker_ibm_colbert.py similarity index 88% rename from scripts/rank_ibm.py rename to scripts/reranker_ibm_colbert.py index 90988f711..e87ae06e0 100644 --- a/scripts/rank_ibm.py +++ b/scripts/reranker_ibm_colbert.py @@ -109,6 +109,7 @@ def get_ibm_score(arguments): target_lookup = arguments['target_lookup'] tran = arguments['tran'] collect_probs = arguments['collect_probs'] + max_sim = arguments['max_sim'] if searcher.documentRaw(test_doc) ==None: print(f'{test_doc} is not found in searcher') @@ -121,22 +122,24 @@ def get_ibm_score(arguments): target_map = {} total_tran_prob = 0 collect_prob = collect_probs[querytoken] + max_sim_score = 0 if querytoken in target_lookup.keys(): query_word_id = target_lookup[querytoken] if query_word_id in tran.keys(): target_map = tran[query_word_id] for doctoken in doc_token_lst: tran_prob = 0 - doc_word_id = 0 - if querytoken==doctoken: - tran_prob = SELF_TRAN + doc_word_id = 0 if doctoken in source_lookup.keys(): doc_word_id = source_lookup[doctoken] if doc_word_id in target_map.keys(): - tran_prob = max(target_map[doc_word_id],tran_prob) - total_tran_prob += (tran_prob/doc_size) - - query_word_prob=math.log((1 - LAMBDA_VALUE) * total_tran_prob + LAMBDA_VALUE * collect_prob) + tran_prob = max(target_map[doc_word_id],tran_prob) + max_sim_score = max(tran_prob, max_sim_score) + total_tran_prob += (tran_prob/doc_size) + if (max_sim): + query_word_prob=math.log((1 - LAMBDA_VALUE) * max_sim_score + LAMBDA_VALUE * collect_prob) + else: + query_word_prob=math.log((1 - LAMBDA_VALUE) * total_tran_prob + LAMBDA_VALUE * collect_prob) total_query_prob += query_word_prob return total_query_prob /query_size @@ -164,28 +167,28 @@ def intbits_to_float(b: bytes): def rescale(source_lookup: Dict[str,int],target_lookup: Dict[str,int],tran_lookup: Dict[str,Dict[str,float]],\ target_voc: Dict[int,str],source_voc: Dict[int,str]): for target_id in tran_lookup: - target_probs = tran_lookup[target_id] if target_id > 0: adjust_mult = (1 - SELF_TRAN) else: adjust_mult = 1 #adjust the prob with adjust_mult and add SELF_TRAN prob to self-translation pair - for source_id in target_probs.keys(): - tran_prob = target_probs[source_id] + for source_id in tran_lookup[target_id].keys(): + tran_prob = tran_lookup[target_id][source_id] if source_id >0: source_word = source_voc[source_id] target_word = target_voc[target_id] tran_prob *= adjust_mult if (source_word== target_word): tran_prob += SELF_TRAN - target_probs[source_id]= tran_prob + tran_lookup[target_id][source_id]= tran_prob # in case if self-translation pair was not included in TransTable - if target_id not in target_probs.keys(): - target_probs[target_id]= SELF_TRAN + if target_id not in tran_lookup[target_id].keys(): + target_word = target_voc[target_id] + source_id = source_lookup[target_word] + tran_lookup[target_id][source_id]= SELF_TRAN return source_lookup,target_lookup,tran_lookup - def load_tranprobs_table(dir_path: str): source_path = dir_path +"/source.vcb" source_lookup = {} @@ -229,7 +232,7 @@ def load_tranprobs_table(dir_path: str): def rank(qrels: str, base: str,tran_path:str, query_path:str, lucene_index_path: str,output_path:str, \ - score_path:str,field_name:str, tag: str,alpha:int,num_threads:int): + score_path:str,field_name:str, tag: str,alpha:int,num_threads:int, max_sim:bool): pool = ThreadPool(num_threads) searcher = JSimpleSearcher(JString(lucene_index_path)) @@ -257,7 +260,7 @@ def rank(qrels: str, base: str,tran_path:str, query_path:str, lucene_index_path: collect_probs[querytoken] = max(reader.totalTermFreq(JTerm(field_name, querytoken))/total_term_freq, MIN_COLLECT_PROB) arguments = [{"query_text_lst":query_text_lst,"test_doc":test_doc, "searcher":searcher,\ "field_name":field_name,"source_lookup":source_lookup,"target_lookup":target_lookup,\ - "tran":tran,"collect_probs":collect_probs} for test_doc in test_docs] + "tran":tran,"collect_probs":collect_probs, "max_sim":max_sim} for test_doc in test_docs] rank_scores = pool.map(get_ibm_score, arguments) ibm_scores = normalize([p for p in rank_scores]) @@ -270,7 +273,6 @@ def rank(qrels: str, base: str,tran_path:str, query_path:str, lucene_index_path: rank = index + 1 f.write(f'{topic} Q0 {doc_id} {rank} {score} {tag}\n') - f.close() map_score,ndcg_score = evaluate(qrels, output_path) with open(score_path, 'w') as outfile: @@ -286,25 +288,27 @@ def rank(qrels: str, base: str,tran_path:str, query_path:str, lucene_index_path: metavar="path_to_qrels", help='path to new_qrels file') parser.add_argument('-base', type=str, default="../ibm/run.msmarco-passage.bm25tuned.trec", metavar="path_to_base_run", help='path to base run') - parser.add_argument('-tran_path', type=str, default="../ibm/ibm_model/text_bert_tok", + parser.add_argument('-tran_path', type=str, default="../ibm/ibm_model/text_bert_tok_raw", metavar="directory_path", help='directory path to source.vcb target.vcb and Transtable bin file') parser.add_argument('-query_path', type=str, default="../ibm/queries.dev.small.json", metavar="path_to_query", help='path to dev queries file') parser.add_argument('-index', type=str, default="../ibm/index-msmarco-passage-ltr-20210519-e25e33f", metavar="path_to_lucene_index", help='path to lucene index folder') - parser.add_argument('-output', type=str, default="../ibm/runs/result-text-bert-tuned0.1.txt", + parser.add_argument('-output', type=str, default="../ibm/runs/result-colbert-test-alpha0.3.txt", metavar="path_to_reranked_run", help='the path to store reranked run file') - parser.add_argument('-score_path', type=str, default="../ibm/result-ibm-0.1.json", + parser.add_argument('-score_path', type=str, default="../ibm/runs/result-colbert-test-alpha0.3.json", metavar="path_to_base_run", help='the path to map and ndcg scores') parser.add_argument('-field_name', type=str, default="text_bert_tok", metavar="type of field", help='type of field used for training') - parser.add_argument('-alpha', type=float, default="0.1", + parser.add_argument('-alpha', type=float, default="0.3", metavar="type of field", help='interpolation weight') parser.add_argument('-num_threads', type=int, default="12", metavar="num_of_threads", help='number of threads to use') + parser.add_argument('-max_sim', type=bool, default=True, + metavar="bool for max sim operator", help='whether we use max sim operator or avg instead') args = parser.parse_args() print('Using base run:', args.base) rank(args.qrels, args.base, args.tran_path, args.query_path, args.index, args.output, \ - args.score_path,args.field_name, args.tag,args.alpha,args.num_threads) + args.score_path,args.field_name, args.tag,args.alpha,args.num_threads, args.max_sim)