Skip to content

Commit

Permalink
add plm model loader method
Browse files Browse the repository at this point in the history
  • Loading branch information
supercoderhawk committed Jan 10, 2020
1 parent 41bb368 commit df0376b
Showing 1 changed file with 17 additions and 7 deletions.
24 changes: 17 additions & 7 deletions wsdm_digg/reranking/model_loader.py
Original file line number Diff line number Diff line change
@@ -1,12 +1,13 @@
# -*- coding: UTF-8 -*-
import torch
from wsdm_digg.reranking.plm_rerank import PlmRerank
from wsdm_digg.reranking.plm_knrm import PlmKnrm
from wsdm_digg.reranking.plm_conv_knrm import PlmConvKnrm

_MODEL_NAME_SET = {'plm', 'knrm', 'conv-knrm'}


def load_model(args):
def load_rerank_model(args):
model_name = args.rerank_model_name
if model_name not in _MODEL_NAME_SET:
raise ValueError('model name {} is not implemented'.format(model_name))
Expand All @@ -19,7 +20,7 @@ def load_model(args):
return model


def get_score_func(model, prefix=None):
def get_score_func(model, prefix=None, inference=False):
def calculate(batch):
if prefix:
token_field = '{}_token'.format(prefix)
Expand All @@ -33,11 +34,20 @@ def calculate(batch):
mask_field = 'mask'
query_lens_field = 'query_lens'
doc_lens_field = 'doc_lens'
scores = model(token_ids=batch[token_field],
segment_ids=batch[segment_field],
token_mask=batch[mask_field],
query_lens=batch[query_lens_field],
doc_lens=batch[doc_lens_field])
if inference:
with torch.no_grad():
scores = model(token_ids=batch[token_field],
segment_ids=batch[segment_field],
token_mask=batch[mask_field],
query_lens=batch[query_lens_field],
doc_lens=batch[doc_lens_field])
scores = scores.squeeze(1)
else:
scores = model(token_ids=batch[token_field],
segment_ids=batch[segment_field],
token_mask=batch[mask_field],
query_lens=batch[query_lens_field],
doc_lens=batch[doc_lens_field])
return scores

return calculate

0 comments on commit df0376b

Please sign in to comment.