Skip to content

Commit

Permalink
add sample count mode
Browse files Browse the repository at this point in the history
  • Loading branch information
supercoderhawk committed Jan 9, 2020
1 parent 79d0d11 commit 8394f64
Showing 1 changed file with 24 additions and 12 deletions.
36 changes: 24 additions & 12 deletions wsdm_digg/data_process/rerank_data_builder.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,12 +3,14 @@
import random
import argparse
from multiprocessing import Pool
from pysenal import read_jsonline_lazy, get_chunk, append_jsonlines, index, read_lines
from pysenal import read_jsonline_lazy, get_chunk, append_jsonlines, index, read_lines, get_logger
from wsdm_digg.elasticsearch.data import get_paper
from wsdm_digg.constants import DATA_DIR, RESULT_DIR


class RerankDataBuilder(object):
logger = get_logger('rank data builder')

def __init__(self):
self.args = self.parse_args()
self.search_filename = self.args.search_filename
Expand All @@ -26,6 +28,9 @@ def parse_args(self):
parser.add_argument('-select_strategy', type=str,
choices=['random', 'search_result_offset', 'search_result_false_top'],
required=True)
parser.add_argument('-query_field', type=str, default='cites_text',
choices=['cites_text', 'description_text'])
parser.add_argument('-sample_count', type=int, default=1)
parser.add_argument('-offset', type=int, default=50)
args = parser.parse_args()
return args
Expand All @@ -42,23 +47,29 @@ def build_data(self):
desc_id = item['description_id']
desc_id2item[desc_id] = item

chunk_size = 500
chunk_size = 50
for item_chunk in get_chunk(read_jsonline_lazy(self.search_filename), chunk_size):
new_item_chunk = []
for item in item_chunk:
true_paper_id = desc_id2item[item['description_id']]['paper_id']
cites_text = desc_id2item[item['description_id']]['cites_text']
train_pair = self.select_train_pair(item['docs'],
true_paper_id,
self.args.select_strategy)
true_item = desc_id2item[item['description_id']]
true_paper_id = true_item['paper_id']
cites_text = true_item['cites_text']
docs = item['docs']
item.pop('docs')
item.pop('keywords')
new_item_chunk.append({**train_pair, **item, 'cites_text': cites_text})
for idx in range(self.args.sample_count):
train_pair = self.select_train_pair(docs,
true_paper_id,
self.args.select_strategy,
idx)
new_item = {**train_pair, **item, 'cites_text': cites_text,
'description_text': true_item['description_text']}
new_item_chunk.append(new_item)
built_items = pool.map(self.build_single_query, new_item_chunk)
append_jsonlines(self.dest_filename, built_items)

def select_train_pair(self, doc_list, true_doc_id, select_strategy):
offset = self.args.offset
def select_train_pair(self, doc_list, true_doc_id, select_strategy, intra_offset):
offset = self.args.offset + intra_offset
if select_strategy == 'search_result_offset':
true_idx = index(doc_list, true_doc_id, -1)
if true_idx == -1 or true_idx + offset >= len(doc_list):
Expand All @@ -69,7 +80,8 @@ def select_train_pair(self, doc_list, true_doc_id, select_strategy):
# 2. too small topk in benchmark, and last instance of this result list has similar context of true paper, will confused model
false_paper_id = self.random_choose_false_id(true_doc_id)
else:
false_paper_id = doc_list[-1]
false_idx = -self.args.sample_count + intra_offset
false_paper_id = doc_list[false_idx]
else:
false_paper_id = doc_list[true_idx + offset]
elif select_strategy == 'random':
Expand All @@ -94,7 +106,7 @@ def random_choose_false_id(self, true_doc_id):
return false_paper_id

def build_single_query(self, item):
query = item['cites_text']
query = item[self.args.query_field]
true_paper = get_paper(item['true_paper_id'])
false_paper = get_paper(item['false_paper_id'])
true_text = true_paper['title'] + true_paper['abstract']
Expand Down

0 comments on commit 8394f64

Please sign in to comment.