Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

cleanup code, primarily DPR #157

Merged
merged 2 commits into from
Feb 6, 2021
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 3 additions & 3 deletions pygaggle/data/msmarco.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,8 +16,8 @@

__all__ = ['MsMarcoExample', 'MsMarcoDataset']

# MsMarcoExample represents a query along with its ranked and re-ranked
# candidates.

# MsMarcoExample represents a query along with its ranked and re-ranked candidates.
class MsMarcoExample(BaseModel):
qid: str
text: str
Expand All @@ -41,7 +41,7 @@ def load_qrels(cls, path: str) -> DefaultDict[str, Set[str]]:
return qrels

# Load a run from the provided path. The run file contains mappings from
# a query id and a doc title to a rank. load_run returns a dictionary
# a query id and a doc title to a rank. load_run returns a dictionary
# mapping query ids to lists of doc titles sorted by ascending rank.
@classmethod
def load_run(cls, path: str):
Expand Down
2 changes: 1 addition & 1 deletion pygaggle/data/retrieval.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,4 +8,4 @@
class RetrievalExample:
query: Query
texts: List[Text]
groundTruthAnswers: List[List[str]]
ground_truth_answers: List[List[str]]
1 change: 0 additions & 1 deletion pygaggle/model/decode.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,6 @@ def greedy_decode(model: PreTrainedModel,
decode_ids = torch.cat([decode_ids,
next_token_logits.max(1)[1].unsqueeze(-1)],
dim=-1)
past = outputs[1]
if return_last_logits:
return decode_ids, next_token_logits
return decode_ids
9 changes: 5 additions & 4 deletions pygaggle/model/evaluate.py
Original file line number Diff line number Diff line change
Expand Up @@ -244,6 +244,7 @@ def evaluate_by_segments(self,
metric.accumulate(doc_scores, example)
return metrics


class ReaderEvaluator:
"""Class for evaluating a reader.
Takes in a list of examples (query, texts, ground truth answers),
Expand All @@ -268,15 +269,15 @@ def evaluate(
for example in tqdm(examples):
answers = self.reader.predict(example.query, example.texts)

bestAnswer = answers[0].text
groundTruthAnswers = example.groundTruthAnswers
em_hit = max([ReaderEvaluator.exact_match_score(bestAnswer, ga) for ga in groundTruthAnswers])
best_answer = answers[0].text
ground_truth_answers = example.ground_truth_answers
em_hit = max([ReaderEvaluator.exact_match_score(best_answer, ga) for ga in ground_truth_answers])
ems.append(em_hit)

if dpr_predictions is not None:
dpr_predictions.append({
'question': example.query.text,
'prediction': bestAnswer,
'prediction': best_answer,
})

return ems
Expand Down
1 change: 1 addition & 0 deletions pygaggle/model/writer.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,7 @@ def write(self, scores: List[float], example: RelevanceExample):
for ct, (doc, score) in enumerate(doc_scores):
self.write_line(f"{example.query.id}\t{doc.metadata['docid']}\t{ct+1}")


class TrecWriter(Writer):
def write(self, scores: List[float], example: RelevanceExample):
doc_scores = sorted(list(zip(example.documents, scores)),
Expand Down
26 changes: 14 additions & 12 deletions pygaggle/reader/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,13 +7,15 @@
class Answer:
"""
Class representing an answer.
A answer contains the answer text itself and potentially other metadata.
An answer contains the answer text itself and potentially other metadata.
Parameters
----------
text : str
The answer text.
metadata : Mapping[str, Any]
Additional metadata and other annotations.
language: str
The language of the answer text.
score : Optional[float]
The score of the answer.
ctx_score : Optional[float]
Expand Down Expand Up @@ -55,16 +57,16 @@ def predict(
texts: List[Text],
) -> List[Answer]:
"""
Find answers from a list of texts with respect to a query.
Parameters
----------
query : Query
The query.
texts : List[Text]
The list of texts.
Returns
-------
List[Answer]
Predicted list of answers.
Find answers from a list of texts with respect to a query.
Parameters
----------
query : Query
The query.
texts : List[Text]
The list of texts.
Returns
-------
List[Answer]
Predicted list of answers.
"""
pass
9 changes: 6 additions & 3 deletions pygaggle/rerank/transformer.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,7 +42,8 @@ def get_model(pretrained_model_name_or_path: str = 'castorini/monot5-base-msmarc
*args, device: str = None, **kwargs) -> T5ForConditionalGeneration:
device = device or ('cuda' if torch.cuda.is_available() else 'cpu')
device = torch.device(device)
return T5ForConditionalGeneration.from_pretrained(pretrained_model_name_or_path, *args, **kwargs).to(device).eval()
return T5ForConditionalGeneration.from_pretrained(pretrained_model_name_or_path,
*args, **kwargs).to(device).eval()

@staticmethod
def get_tokenizer(pretrained_model_name_or_path: str = 't5-base',
Expand Down Expand Up @@ -86,7 +87,8 @@ def get_model(pretrained_model_name_or_path: str = 'castorini/duot5-base-msmarco
*args, device: str = None, **kwargs) -> T5ForConditionalGeneration:
device = device or ('cuda' if torch.cuda.is_available() else 'cpu')
device = torch.device(device)
return T5ForConditionalGeneration.from_pretrained(pretrained_model_name_or_path, *args, **kwargs).to(device).eval()
return T5ForConditionalGeneration.from_pretrained(pretrained_model_name_or_path,
*args, **kwargs).to(device).eval()

@staticmethod
def get_tokenizer(pretrained_model_name_or_path: str = 't5-base',
Expand Down Expand Up @@ -182,7 +184,8 @@ def get_model(pretrained_model_name_or_path: str = 'castorini/monobert-large-msm
*args, device: str = None, **kwargs) -> AutoModelForSequenceClassification:
device = device or ('cuda' if torch.cuda.is_available() else 'cpu')
device = torch.device(device)
return AutoModelForSequenceClassification.from_pretrained(pretrained_model_name_or_path, *args, **kwargs).to(device).eval()
return AutoModelForSequenceClassification.from_pretrained(pretrained_model_name_or_path,
*args, **kwargs).to(device).eval()

@staticmethod
def get_tokenizer(pretrained_model_name_or_path: str = 'bert-large-uncased',
Expand Down
6 changes: 2 additions & 4 deletions pygaggle/run/evaluate_document_ranker.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,9 +4,7 @@

from pydantic import BaseModel, validator
from transformers import (AutoModel,
AutoTokenizer,
AutoModelForSequenceClassification,
T5ForConditionalGeneration)
AutoTokenizer)
import torch

from .args import ArgumentParserBuilder, opt
Expand All @@ -20,7 +18,6 @@
from pygaggle.rerank.random import RandomReranker
from pygaggle.rerank.similarity import CosineSimilarityMatrixProvider
from pygaggle.model import (SimpleBatchTokenizer,
T5BatchTokenizer,
RerankerEvaluator,
metric_names,
MsMarcoWriter)
Expand Down Expand Up @@ -169,5 +166,6 @@ def main():
options.aggregate_method):
logging.info(f'{metric.name:<{width}}{metric.value:.5}')


if __name__ == '__main__':
main()
13 changes: 8 additions & 5 deletions pygaggle/run/evaluate_passage_reader.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,8 +6,6 @@
import numpy as np

from pydantic import BaseModel
from transformers import (DPRReader,
DPRReaderTokenizer)

from .args import ArgumentParserBuilder, opt
from pygaggle.reader.base import Reader
Expand All @@ -31,6 +29,7 @@ class PassageReadingEvaluationOptions(BaseModel):
num_spans_per_passage: int
device: str


def construct_dpr(options: PassageReadingEvaluationOptions) -> Reader:
model = DensePassageRetrieverReader.get_model(options.model_name, options.device)
tokenizer = DensePassageRetrieverReader.get_tokenizer(options.tokenizer_name)
Expand All @@ -41,13 +40,15 @@ def construct_dpr(options: PassageReadingEvaluationOptions) -> Reader:
options.max_answer_length,
options.num_spans_per_passage)


def display(ems):
if len(ems) == 0:
em = -1.
else:
em = np.mean(np.array(ems)) * 100.
logging.info(f'Exact Match Accuracy: {em}')


def main():
apb = ArgumentParserBuilder()
apb.add_opts(
Expand Down Expand Up @@ -89,7 +90,7 @@ def main():
opt('--output-file',
type=Path,
default=None,
help='File to output predictions for each example; if no output file specified, this output will be discarded'),
help='File to output predictions for each example; if not specified, this output will be discarded'),
opt('--device',
type=str,
default='cuda:0',
Expand Down Expand Up @@ -133,8 +134,10 @@ def main():
examples.append(
RetrievalExample(
query=Query(text=item["question"]),
texts=list(map(lambda context: Text(text=context["text"].split('\n', 1)[1], title=context["text"].split('\n', 1)[0][1:-1]), item["contexts"]))[:options.use_top_k_passages],
groundTruthAnswers=item["answers"],
texts=list(map(lambda context: Text(text=context["text"].split('\n', 1)[1],
title=context["text"].split('\n', 1)[0][1:-1]),
item["contexts"]))[:options.use_top_k_passages],
ground_truth_answers=item["answers"],
)
)

Expand Down
2 changes: 2 additions & 0 deletions pygaggle/settings.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,9 +14,11 @@ class Settings(BaseSettings):
class MsMarcoSettings(Settings):
pass


class TRECCovidSettings(Settings):
pass


class Cord19Settings(Settings):
# T5 model settings
t5_model_dir: str = 'gs://neuralresearcher_data/covid/data/model_exp304'
Expand Down
28 changes: 15 additions & 13 deletions scripts/train_d2q.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,25 +9,27 @@
from torch.utils.data import Dataset
import argparse


class TrainerDataset(Dataset):
def __init__(self, path):
df = pd.read_csv(path, sep = "\t")
df = pd.read_csv(path, sep="\t")
df = df.dropna()
self.dataset = df
self.tokenizer = T5Tokenizer.from_pretrained('t5-base')

def __len__(self):
return len(self.dataset)

def __getitem__(self, idx):
source = self.dataset.iloc[idx, 0]
target = self.dataset.iloc[idx, 1]
input_ids = self.tokenizer.encode(args.tag + ': ' + source, return_tensors='pt',
padding='max_length',truncation='longest_first', max_length=512)[0]
padding='max_length', truncation='longest_first', max_length=512)[0]
label = self.tokenizer.encode(target, return_tensors='pt', padding='max_length',
truncation='longest_first', max_length=64)[0]
return {'input_ids':input_ids, 'labels':label}

return {'input_ids': input_ids, 'labels': label}


parser = argparse.ArgumentParser(description='Train docTquery on more datasets')
parser.add_argument('--pretrained_model_path', default='t5-base', help='pretrained model path')
parser.add_argument('--tag', defaut='msmarco', help='tag for training data', type=str)
Expand All @@ -38,24 +40,24 @@ def __getitem__(self, idx):
parser.add_argument('--weight_decay', default=5e-5, type=float)
parser.add_argument('--lr', default=3e-4, type=float)
parser.add_argument('--gra_acc_steps', default=8, type=int)
args = parser.parse_args()
args = parser.parse_args()

model = T5ForConditionalGeneration.from_pretrained(args.pretrained_model_path)
train_dataset = TrainerDataset(args.train_data_path)

training_args = TrainingArguments(
output_dir=args.output_path,
num_train_epochs=args.epoch,
per_device_train_batch_size=args.batch_size,
output_dir=args.output_path,
num_train_epochs=args.epoch,
per_device_train_batch_size=args.batch_size,
weight_decay=args.weight_decay,
learning_rate=args.lr,
gradient_accumulation_steps=args.gra_acc_steps,
logging_dir='./logs',
logging_dir='./logs',
)

trainer = Trainer(
model=model,
args=training_args,
model=model,
args=training_args,
train_dataset=train_dataset
)

Expand Down