Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

cleanup code, primarily DPR #157

Merged
merged 2 commits into from
Feb 6, 2021
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Prev Previous commit
flake8 compliant
  • Loading branch information
ronakice committed Feb 6, 2021
commit 87ec9f6d991d208ec58b2a0951ceab364f32b74b
6 changes: 3 additions & 3 deletions pygaggle/data/msmarco.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,8 +16,8 @@

__all__ = ['MsMarcoExample', 'MsMarcoDataset']

# MsMarcoExample represents a query along with its ranked and re-ranked
# candidates.

# MsMarcoExample represents a query along with its ranked and re-ranked candidates.
class MsMarcoExample(BaseModel):
qid: str
text: str
Expand All @@ -41,7 +41,7 @@ def load_qrels(cls, path: str) -> DefaultDict[str, Set[str]]:
return qrels

# Load a run from the provided path. The run file contains mappings from
# a query id and a doc title to a rank. load_run returns a dictionary
# a query id and a doc title to a rank. load_run returns a dictionary
# mapping query ids to lists of doc titles sorted by ascending rank.
@classmethod
def load_run(cls, path: str):
Expand Down
9 changes: 6 additions & 3 deletions pygaggle/rerank/transformer.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,7 +42,8 @@ def get_model(pretrained_model_name_or_path: str = 'castorini/monot5-base-msmarc
*args, device: str = None, **kwargs) -> T5ForConditionalGeneration:
device = device or ('cuda' if torch.cuda.is_available() else 'cpu')
device = torch.device(device)
return T5ForConditionalGeneration.from_pretrained(pretrained_model_name_or_path, *args, **kwargs).to(device).eval()
return T5ForConditionalGeneration.from_pretrained(pretrained_model_name_or_path,
*args, **kwargs).to(device).eval()

@staticmethod
def get_tokenizer(pretrained_model_name_or_path: str = 't5-base',
Expand Down Expand Up @@ -86,7 +87,8 @@ def get_model(pretrained_model_name_or_path: str = 'castorini/duot5-base-msmarco
*args, device: str = None, **kwargs) -> T5ForConditionalGeneration:
device = device or ('cuda' if torch.cuda.is_available() else 'cpu')
device = torch.device(device)
return T5ForConditionalGeneration.from_pretrained(pretrained_model_name_or_path, *args, **kwargs).to(device).eval()
return T5ForConditionalGeneration.from_pretrained(pretrained_model_name_or_path,
*args, **kwargs).to(device).eval()

@staticmethod
def get_tokenizer(pretrained_model_name_or_path: str = 't5-base',
Expand Down Expand Up @@ -182,7 +184,8 @@ def get_model(pretrained_model_name_or_path: str = 'castorini/monobert-large-msm
*args, device: str = None, **kwargs) -> AutoModelForSequenceClassification:
device = device or ('cuda' if torch.cuda.is_available() else 'cpu')
device = torch.device(device)
return AutoModelForSequenceClassification.from_pretrained(pretrained_model_name_or_path, *args, **kwargs).to(device).eval()
return AutoModelForSequenceClassification.from_pretrained(pretrained_model_name_or_path,
*args, **kwargs).to(device).eval()

@staticmethod
def get_tokenizer(pretrained_model_name_or_path: str = 'bert-large-uncased',
Expand Down
6 changes: 2 additions & 4 deletions pygaggle/run/evaluate_document_ranker.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,9 +4,7 @@

from pydantic import BaseModel, validator
from transformers import (AutoModel,
AutoTokenizer,
AutoModelForSequenceClassification,
T5ForConditionalGeneration)
AutoTokenizer)
import torch

from .args import ArgumentParserBuilder, opt
Expand All @@ -20,7 +18,6 @@
from pygaggle.rerank.random import RandomReranker
from pygaggle.rerank.similarity import CosineSimilarityMatrixProvider
from pygaggle.model import (SimpleBatchTokenizer,
T5BatchTokenizer,
RerankerEvaluator,
metric_names,
MsMarcoWriter)
Expand Down Expand Up @@ -169,5 +166,6 @@ def main():
options.aggregate_method):
logging.info(f'{metric.name:<{width}}{metric.value:.5}')


if __name__ == '__main__':
main()
2 changes: 2 additions & 0 deletions pygaggle/run/evaluate_passage_reader.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,13 +40,15 @@ def construct_dpr(options: PassageReadingEvaluationOptions) -> Reader:
options.max_answer_length,
options.num_spans_per_passage)


def display(ems):
if len(ems) == 0:
em = -1.
else:
em = np.mean(np.array(ems)) * 100.
logging.info(f'Exact Match Accuracy: {em}')


def main():
apb = ArgumentParserBuilder()
apb.add_opts(
Expand Down
2 changes: 2 additions & 0 deletions pygaggle/settings.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,9 +14,11 @@ class Settings(BaseSettings):
class MsMarcoSettings(Settings):
pass


class TRECCovidSettings(Settings):
pass


class Cord19Settings(Settings):
# T5 model settings
t5_model_dir: str = 'gs://neuralresearcher_data/covid/data/model_exp304'
Expand Down