Skip to content

Commit

Permalink
add cmd mode
Browse files Browse the repository at this point in the history
  • Loading branch information
supercoderhawk committed Jan 9, 2020
1 parent 059a7c1 commit 3764bc6
Showing 1 changed file with 59 additions and 25 deletions.
84 changes: 59 additions & 25 deletions wsdm_digg/reranking/predict.py
Original file line number Diff line number Diff line change
@@ -1,28 +1,42 @@
# -*- coding: UTF-8 -*-
import os
import time
import argparse
from collections import OrderedDict
from munch import Munch
import torch
from pysenal import read_jsonline_lazy, read_json, append_jsonline
from wsdm_digg.reranking.dataloader import RerankDataLoader
from wsdm_digg.reranking.model_loader import load_model, get_score_func
from wsdm_digg.reranking.parse_args import parse_args
from wsdm_digg.utils import result_format
from wsdm_digg.constants import MODEL_DICT, DATA_DIR


class PlmRerankReranker(object):
def __init__(self, model_path, batch_size):
self.model_path = model_path
def __init__(self, model_info, batch_size, parser=None):
self.batch_size = batch_size
self.config = self.load_config()
self.model = self.load_model(load_model(self.config), model_path)
model_info = MODEL_DICT[self.config.model_name]
if 'path' in model_info:
tokenizer_path = model_info['path'] + 'vocab.txt'
self.parser = parser
if isinstance(model_info, str):
self.model_path = model_info
self.config = self.load_config()
self.model = self.load_model(load_model(self.config), model_info)

# for k in self.model.kernel.kernel_list:
# print(k.mean, k.stddev)

model_dict = MODEL_DICT[self.config.plm_model_name]
if 'path' in model_dict:
tokenizer_path = model_dict['path'] + 'vocab.txt'
else:
tokenizer_path = self.config.model_name
self.tokenizer = model_dict['tokenizer_class'].from_pretrained(tokenizer_path)
elif isinstance(model_info, dict):
self.model = model_info['model']
self.config = model_info['config']
self.tokenizer = model_info['tokenizer']
else:
tokenizer_path = self.config.model_name
self.tokenizer = model_info['tokenizer_class'].from_pretrained(tokenizer_path)
raise ValueError('error')

def load_model(self, model, model_path):
if torch.cuda.is_available():
Expand All @@ -41,8 +55,11 @@ def load_model(self, model, model_path):
return model

def load_config(self):
default_config = vars(parse_args(parser=self.parser))
config_path = os.path.splitext(self.model_path)[0] + '.json'
config_dict = read_json(config_path)
model_config = read_json(config_path)
# config_dict = model_config
config_dict = {**default_config, **model_config}
config_dict['batch_size'] = self.batch_size
config = Munch(config_dict)
return config
Expand All @@ -55,13 +72,14 @@ def rerank_file(self, search_filename, golden_filename, dest_filename, topk, is_
assert topk >= self.batch_size
assert topk % self.batch_size == 0

self.model.eval()
searched_desc_ids = self.get_searched_desc_id(dest_filename)
data_source = {'search_filename': search_filename,
'golden_filename': golden_filename,
'topk': topk,
'searched_id_list': searched_desc_ids}
loader = RerankDataLoader(data_source, self.tokenizer, self.config, 'eval')
score_func = get_score_func(self.model)
score_func = get_score_func(self.model, inference=True)

start = time.time()

Expand All @@ -77,7 +95,16 @@ def rerank_file(self, search_filename, golden_filename, dest_filename, topk, is_
desc_id = batch['raw'][0]['description_id']
paper_id_score_list = list(zip([i['doc_id'] for i in batch['raw']], scores))
if desc_id not in desc_id2id_score_list:
desc_id2id_score_list[desc_id] = {'description_id': desc_id, 'docs': paper_id_score_list}
if len(paper_id_score_list) == topk:
sorted_id_score_list = sorted(paper_id_score_list,
key=lambda i: (i[1], i[0]), reverse=True)
sorted_paper_ids = [idx for idx, _ in sorted_id_score_list]
result_item = {'description_id': desc_id, 'docs': sorted_paper_ids,
'docs_with_score': sorted_id_score_list}
append_jsonline(dest_filename, result_item)
else:
desc_id2id_score_list[desc_id] = {'description_id': desc_id,
'docs': paper_id_score_list}
else:
desc_id2id_score_list[desc_id]['docs'].extend(paper_id_score_list)
if len(desc_id2id_score_list[desc_id]['docs']) == topk:
Expand Down Expand Up @@ -106,6 +133,7 @@ def rerank_file(self, search_filename, golden_filename, dest_filename, topk, is_
result_format(dest_filename)

def rerank(self, query, doc_id_list):
self.model.eval()
doc_len = len(doc_id_list)
input_data = list(zip([query] * doc_len, doc_id_list))
loader = RerankDataLoader(input_data, self.tokenizer, self.config, 'inference')
Expand All @@ -124,17 +152,23 @@ def rerank(self, query, doc_id_list):
return sorted_paper_ids


def main():
parser = argparse.ArgumentParser()
parser.add_argument('-eval_search_filename', type=str, required=True)
parser.add_argument('-golden_filename', type=str, required=True)
parser.add_argument('-dest_filename', type=str, required=True)
parser.add_argument('-model_path', type=str, required=True)
parser.add_argument('-topk', type=int, default=20)
parser.add_argument('-eval_batch_size', type=int, default=10)
args = parser.parse_args()

ranker = PlmRerankReranker(args.model_path, args.eval_batch_size, parser)
ranker.rerank_file(args.eval_search_filename,
args.golden_filename,
args.dest_filename,
args.topk,
is_submit=True)


if __name__ == '__main__':
topk = 100
model_path = DATA_DIR + 'rerank/cite_textrank_top10_rerank_search_result' \
'/cite_textrank_top10_rerank_search_result_epoch_5_step_70000.model'

# search_filename = DATA_DIR + 'result/cite_textrank_top10.jsonl'
# golden_filename = DATA_DIR + 'test.jsonl'
# dest_filename = DATA_DIR + 'rerank_result/cite_textrank_top10_rerank_top{}.jsonl'.format(topk)

search_filename = DATA_DIR + 'submit_result/cite_textrank_top10.jsonl'
golden_filename = DATA_DIR + 'validation.jsonl'
dest_filename = DATA_DIR + 'submit_result/cite_textrank_top10_rerank_top{}.jsonl'.format(topk)
ranker = PlmRerankReranker(model_path, 5)
ranker.rerank_file(search_filename, golden_filename, dest_filename, topk, is_submit=True)
main()

0 comments on commit 3764bc6

Please sign in to comment.