Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[DRAFT] Add support gradient cache for Reranker #148

Open
wants to merge 17 commits into
base: main
Choose a base branch
from
Prev Previous commit
Next Next commit
fix: trainer
  • Loading branch information
sigridjineth committed Aug 24, 2024
commit 43a642d3bee6543fd8da81e20541dc0dfeb3df8b
2 changes: 1 addition & 1 deletion src/tevatron/reranker/arguments.py
Original file line number Diff line number Diff line change
Expand Up @@ -122,4 +122,4 @@ class TevatronTrainingArguments(TrainingArguments):
warmup_ratio: float = field(default=0.1)

grad_cache: bool = field(default=False, metadata={"help": "Use gradient cache"})
gc_chunk_size: Optional[int] = field(default=None, metadata={"help": "Chunk size for gradient cache"})
gc_chunk_size: Optional[int] = field(default=2, metadata={"help": "Chunk size for gradient cache"})
9 changes: 6 additions & 3 deletions src/tevatron/reranker/driver/train.py
Original file line number Diff line number Diff line change
@@ -1,21 +1,20 @@
import logging
import os
import sys

from transformers import AutoTokenizer
from transformers import (
HfArgumentParser,
set_seed,
)

from tevatron.reranker.arguments import ModelArguments, DataArguments, TevatronTrainingArguments
from tevatron.reranker.modeling import RerankerModel
from tevatron.reranker.dataset import RerankerTrainDataset
from tevatron.reranker.collator import RerankerTrainCollator
from tevatron.reranker.trainer import RerankerTrainer
from tevatron.reranker.trainer import RerankerTrainer # Make sure this is your updated RerankerTrainer

logger = logging.getLogger(__name__)


def main():
parser = HfArgumentParser((ModelArguments, DataArguments, TevatronTrainingArguments))

Expand Down Expand Up @@ -71,6 +70,9 @@ def main():
train_dataset = RerankerTrainDataset(data_args)
train_collator = RerankerTrainCollator(data_args, tokenizer)

# Add GradCache-specific arguments to training_args
training_args.gc_chunk_size = getattr(training_args, 'gc_chunk_size', 2)

trainer = RerankerTrainer(
model=model,
args=training_args,
Expand All @@ -84,5 +86,6 @@ def main():
if trainer.is_world_process_zero():
tokenizer.save_pretrained(training_args.output_dir)


if __name__ == "__main__":
main()
62 changes: 37 additions & 25 deletions src/tevatron/reranker/trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,51 +3,64 @@

import torch
from torch import nn
from transformers import Trainer
from transformers import Trainer, TrainingArguments
from transformers.trainer_utils import PredictionOutput

from grad_cache import GradCache

from tevatron.reranker.arguments import TevatronTrainingArguments
from grad_cache.functional import cached, cat_input_tensor
from torch.cuda.amp import autocast

logger = logging.getLogger(__name__)


@cached
@autocast()
def get_model_rep(model, inputs):
outputs = model(**inputs)
return outputs.scores


@cat_input_tensor
@autocast()
def contrastive_loss(scores):
batch_size = scores.size(0) // 2
labels = torch.arange(batch_size, device=scores.device)
return nn.CrossEntropyLoss()(scores, labels)


def split_inputs(model_input, chunk_size):
logger.debug(f"Splitting inputs with chunk size: {chunk_size}")
keys = list(model_input.keys())
chunked_tensors = [model_input[k].split(chunk_size, dim=0) for k in keys]
return [dict(zip(keys, tt)) for tt in zip(*chunked_tensors)]

def get_rep(model_output):
logger.debug(f"Getting representation from model output: {type(model_output)}")
return model_output.scores

class RerankerTrainer(Trainer):
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
logger.info("Initializing RerankerTrainer")
self.args: TevatronTrainingArguments
logger.info("Initializing RerankerTrainer with GradCache")
self.args: TrainingArguments

def loss_fn(scores, labels):
grouped_scores = scores.view(self.args.train_group_size, -1)
labels = torch.zeros(self.args.train_group_size, dtype=torch.long, device=scores.device)
return nn.CrossEntropyLoss()(grouped_scores, labels)
# Add these lines to include the necessary parameters
self.gc_chunk_size = getattr(self.args, 'gc_chunk_size', 4) # default to 4 if not provided

self.gc = GradCache(
models=[self.model],
chunk_sizes=[self.args.gc_chunk_size],
loss_fn=loss_fn,
chunk_sizes=self.gc_chunk_size,
loss_fn=contrastive_loss,
split_input_fn=split_inputs,
get_rep_fn=get_rep,
get_rep_fn=lambda x: x.scores,
fp16=self.args.fp16,
scaler=self.scaler if self.args.fp16 else None
)
logger.info(f"GradCache initialized with chunk size: {self.args.gc_chunk_size}")
logger.info(f"GradCache initialized with chunk size: {self.gc_chunk_size}")

def compute_loss(self, model, inputs, return_outputs=False):
logger.debug(f"Computing loss with inputs: {inputs.keys()}")
outputs = model(**inputs)
loss = outputs.loss
scores = outputs.scores
loss = contrastive_loss(scores)
logger.debug(f"Computed loss: {loss.item()}")
return (loss, outputs) if return_outputs else loss

Expand All @@ -56,23 +69,22 @@ def training_step(self, model: nn.Module, inputs: Dict[str, Union[torch.Tensor,
model.train()
inputs = self._prepare_inputs(inputs)
_distributed = self.args.local_rank > -1
self.gc.models = [model]
loss = self.gc(inputs, no_sync_except_last=_distributed)
logger.debug(f"Training step loss: {loss.item()}")
return loss

def prediction_step(
self,
model: nn.Module,
inputs: Dict[str, Union[torch.Tensor, Any]],
prediction_loss_only: bool,
ignore_keys: bool = None,
self,
model: nn.Module,
inputs: Dict[str, Union[torch.Tensor, Any]],
prediction_loss_only: bool,
ignore_keys: bool = None,
) -> PredictionOutput:
logger.debug("Entering prediction step")
inputs = self._prepare_inputs(inputs)
with torch.no_grad():
outputs = model(**inputs)
loss = outputs.loss
logits = outputs.scores
scores = outputs.scores
loss = contrastive_loss(scores)
logger.debug(f"Prediction step loss: {loss.item() if loss is not None else 'N/A'}")
return PredictionOutput(predictions=logits, label_ids=inputs.get("labels"), metrics=None)
return PredictionOutput(predictions=scores, label_ids=None, metrics=None)