Skip to content

Commit

Permalink
fix the bug for dense retrieval pretraining
Browse files Browse the repository at this point in the history
  • Loading branch information
EC2 Default User committed Nov 22, 2023
1 parent 21940f6 commit 9fbe62e
Show file tree
Hide file tree
Showing 3 changed files with 75 additions and 42 deletions.
34 changes: 1 addition & 33 deletions t5_pretrainer/dataset/data_collator.py
Original file line number Diff line number Diff line change
Expand Up @@ -149,7 +149,7 @@ def __call__(self, batch):
"teacher_neg_scores": torch.FloatTensor(s_neg)
}

class TripleMarginMSECollator:
class MarginMSEforPretrainCollator:
def __init__(self, tokenizer_type, max_length):
self.max_length = max_length
self.tokenizer = AutoTokenizer.from_pretrained(tokenizer_type)
Expand Down Expand Up @@ -221,35 +221,3 @@ def __call__(self, batch):
"teacher_pos_score": torch.FloatTensor(s_pos),
"teacher_neg_score": torch.FloatTensor(s_neg),
}

def __init__(self, tokenizer_type, max_length):
self.max_length = max_length
self.tokenizer = AutoTokenizer.from_pretrained(tokenizer_type)

def __call__(self, batch):
queries, nway_passages, nway_scores = [], [], []
for elem in batch:
queries.append(elem["query"])
nway_passages.append(elem["passages"])
nway_scores.append(elem["scores"])

queries = self.tokenizer(queries,
add_special_tokens=True,
padding="longest", # pad to max sequence length in batch
truncation="longest_first", # truncates to self.max_length
max_length=self.max_length,
return_attention_mask=True,
return_tensors="pt")
nway_passages = self.tokenizer(flatten_list(nway_passages),
add_special_tokens=True,
padding="longest", # pad to max sequence length in batch
truncation="longest_first", # truncates to self.max_length
max_length=self.max_length,
return_attention_mask=True,
return_tensors="pt") #[bzxnway, seq_len]

return {
"enc_query": queries,
"enc_nway_docs": nway_passages,
"labels": torch.FloatTensor(nway_scores)
}
65 changes: 65 additions & 0 deletions t5_pretrainer/dataset/dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -615,3 +615,68 @@ def __getitem__(self, idx):
return q_pos, q_neg, pos_doc_encoding, neg_doc_encoding, s_pos, s_neg, \
q_pos_decoder_input_ids, q_neg_decoder_input_ids

class MarginMSEforPretrainDataset(Dataset):
def __init__(self, dataset_path, document_dir, query_dir, qrels_path,
docid_to_smtid_path):
self.document_dataset = CollectionDatasetPreLoad(document_dir, id_style="content_id")
self.query_dataset = CollectionDatasetPreLoad(query_dir, id_style="content_id")

self.examples = []
with open(dataset_path) as fin:
for line in fin:
self.examples.append(ujson.loads(line))

if docid_to_smtid_path is not None:
with open(docid_to_smtid_path) as fin:
self.docid_to_smtid = ujson.load(fin)
tmp_docids = list(self.docid_to_smtid.keys())
assert self.docid_to_smtid[tmp_docids[0]][0] == -1, self.docid_to_smtid[tmp_docids[0]]
else:
self.docid_to_smtid = None


def __len__(self):
return len(self.examples)

def __getitem__(self, idx):
example = self.examples[idx]
query = example["qid"]
positive = example["docids"][0]
s_pos = example["scores"][0]

neg_idx = random.sample(range(1, len(example["docids"])), k=1)[0]
negative = example["docids"][neg_idx]
s_neg = example["scores"][neg_idx]
#assert negative != positive, (positive, negative)

q = self.query_dataset[str(query)][1]
d_pos = self.document_dataset[positive][1]
d_neg = self.document_dataset[negative][1]

if self.docid_to_smtid is not None:
pos_prev_smtids = self.docid_to_smtid[str(positive)][1:]
neg_prev_smtids = self.docid_to_smtid[str(negative)][1:]

d_pos_decoder_input_ids = self.docid_to_smtid[str(positive)]
d_neg_decoder_input_ids = self.docid_to_smtid[str(negative)]
q_pos_decoder_input_ids = copy.deepcopy(d_pos_decoder_input_ids)
q_neg_decoder_input_ids = copy.deepcopy(d_neg_decoder_input_ids)
else:
pos_prev_smtids = None
neg_prev_smtids = None
d_pos_decoder_input_ids = [-1]
d_neg_decoder_input_ids = [-1]
q_pos_decoder_input_ids = [-1]
q_neg_decoder_input_ids = [-1]

q_pos = "query: " + q.strip()
q_neg = "query: " + q.strip()
d_pos = "document: " + d_pos.strip()
d_neg = "document: " + d_neg.strip()

if pos_prev_smtids is None and neg_prev_smtids is None:
return q_pos, q_neg, d_pos, d_neg, s_pos, s_neg, \
d_pos_decoder_input_ids, d_neg_decoder_input_ids, q_pos_decoder_input_ids, q_neg_decoder_input_ids
else:
return q_pos, q_neg, d_pos, d_neg, pos_prev_smtids, neg_prev_smtids, s_pos, s_neg, \
d_pos_decoder_input_ids, d_neg_decoder_input_ids, q_pos_decoder_input_ids, q_neg_decoder_input_ids
18 changes: 9 additions & 9 deletions t5_pretrainer/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,12 +20,13 @@
MarginMSEforT5SeqAQDataset,
Seq2SeqForT5SeqAQDataset,
LngKnpMarginMSEforT5SeqAQDataset,
MarginMSEforPretrainDataset,
)
from .dataset.data_collator import (
MarginMSEforT5SeqAQCollator,
Seq2SeqForT5SeqAQCollator,
LngKnpMarginMSEforT5SeqAQCollator,
TripleMarginMSECollator
MarginMSEforPretrainCollator
)
from .arguments import ModelArguments, Arguments
from .losses.regulariaztion import RegWeightScheduler
Expand All @@ -50,14 +51,13 @@ def main():
if args.local_rank <= 0:
print(f"apply t5_docid_gen_encoder for data, model_name_or_path: {model_args.model_name_or_path}")

if args.loss_type == "t5seq_pretrain_margin_mse":
if args.triple_margin_mse_path is not None:
assert args.teacher_score_path is None
train_dataset = TripleMarginMSEDataset(examples_path=args.triple_margin_mse_path,
document_dir=args.collection_path,
query_dir=args.queries_path,
docid_to_smtid_path=args.docid_to_smtid_path)
train_collator = TripleMarginMSECollator(model_args.model_name_or_path, max_length=args.max_length)
if args.loss_type == "t5seq_pretrain_margin_mse":
train_dataset = MarginMSEforPretrainDataset(dataset_path=args.teacher_score_path,
document_dir=args.collection_path,
query_dir=args.queries_path,
qrels_path=args.qrels_path,
docid_to_smtid_path=args.docid_to_smtid_path)
train_collator = MarginMSEforPretrainCollator(model_args.model_name_or_path, max_length=args.max_length)
elif args.loss_type == "t5seq_aq_encoder_margin_mse":
train_dataset = MarginMSEforT5SeqAQDataset(dataset_path=args.teacher_score_path,
document_dir=args.collection_path,
Expand Down

0 comments on commit 9fbe62e

Please sign in to comment.