Skip to content

Commit

Permalink
format code
Browse files Browse the repository at this point in the history
  • Loading branch information
supercoderhawk committed Jan 17, 2020
1 parent 0f36470 commit 480d51f
Showing 1 changed file with 20 additions and 24 deletions.
44 changes: 20 additions & 24 deletions wsdm_digg/reranking/predict.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,10 +21,6 @@ def __init__(self, model_info, batch_size, parser=None):
self.model_path = model_info
self.config = self.load_config()
self.model = self.load_model(load_rerank_model(self.config), model_info)

# for k in self.model.kernel.kernel_list:
# print(k.mean, k.stddev)

model_dict = MODEL_DICT[self.config.plm_model_name]
if 'path' in model_dict:
tokenizer_path = model_dict['path'] + 'vocab.txt'
Expand Down Expand Up @@ -67,7 +63,9 @@ def load_config(self):
def get_searched_desc_id(self, filename):
desc_ids = {item['description_id'] for item in read_jsonline_lazy(filename, default=[])}
return desc_ids
def rerank_pairwise_file(self, search_filename, golden_filename, dest_filename, topk, is_submit=False):

def rerank_pairwise_file(self, search_filename, golden_filename,
dest_filename, topk, is_submit=False):
self.model.eval()
searched_desc_ids = self.get_searched_desc_id(dest_filename)
data_source = {'search_filename': search_filename,
Expand All @@ -90,47 +88,45 @@ def rerank_pairwise_file(self, search_filename, golden_filename, dest_filename,
scores = [scores]

# desc_id = batch['raw'][0]['description_id']
for i,s in zip(batch['raw'],scores):
for i, s in zip(batch['raw'], scores):
desc_id = i['description_id']
fid = i['first_doc_id']
if desc_id not in desc_id2id_score_list:
desc_id2id_score_list[desc_id] = {fid: [s]}
else:
if fid not in desc_id2id_score_list[desc_id]:
desc_id2id_score_list[desc_id][fid] = [s]
else:
else:
desc_id2id_score_list[desc_id][fid].append(s)
if len(desc_id2id_score_list[desc_id][fid]) == topk-1:
if len(desc_id2id_score_list[desc_id][fid]) == topk - 1:
if desc_id not in desc_id2final:
desc_id2final[desc_id] = [(fid,sum(desc_id2id_score_list[desc_id][fid]))]
desc_id2final[desc_id] = [(fid, sum(desc_id2id_score_list[desc_id][fid]))]
else:
desc_id2final[desc_id].append((fid,sum(desc_id2id_score_list[desc_id][fid])))
desc_id2id_score_list[desc_id].pop(fid)
if len(desc_id2final[desc_id]) == topk:
sorted_id_score_list = sorted(desc_id2final[desc_id], key=lambda i:i[1],reverse=True)
desc_id2final[desc_id].append((fid, sum(desc_id2id_score_list[desc_id][fid])))
desc_id2id_score_list[desc_id].pop(fid)
if len(desc_id2final[desc_id]) == topk:
sorted_id_score_list = sorted(desc_id2final[desc_id],
key=lambda i: i[1], reverse=True)
sorted_paper_ids = [idx for idx, _ in sorted_id_score_list]
result_item = {'description_id': desc_id, 'docs': sorted_paper_ids,
'docs_with_score': sorted_id_score_list}
append_jsonline(dest_filename, result_item)
desc_id2final.pop(desc_id)
'docs_with_score': sorted_id_score_list}
append_jsonline(dest_filename, result_item)
desc_id2final.pop(desc_id)

for desc_id,fid2score in desc_id2id_score_list.items():
for desc_id, fid2score in desc_id2id_score_list.items():
for fid in fid2score:
if desc_id not in desc_id2final:
desc_id2final[desc_id] = [(fid,sum(fid2score[fid]))]
desc_id2final[desc_id] = [(fid, sum(fid2score[fid]))]
else:
desc_id2final[desc_id].append((fid,sum(fid2score[fid])))
desc_id2final[desc_id].append((fid, sum(fid2score[fid])))

for desc_id in desc_id2final:
sorted_id_score_list = sorted(desc_id2final[desc_id], key=lambda i:i[1],reverse=True)
sorted_id_score_list = sorted(desc_id2final[desc_id], key=lambda i: i[1], reverse=True)
sorted_paper_ids = [idx for idx, _ in sorted_id_score_list]
result_item = {'description_id': desc_id, 'docs': sorted_paper_ids,
'docs_with_score': sorted_id_score_list}
'docs_with_score': sorted_id_score_list}
append_jsonline(dest_filename, result_item)
# pass


# paper_id_score_list = list(zip([i['doc_id'] for i in batch['raw']], scores))
def rerank_file(self, search_filename, golden_filename, dest_filename, topk, is_submit=False):
assert topk >= self.batch_size
assert topk % self.batch_size == 0
Expand Down

0 comments on commit 480d51f

Please sign in to comment.