Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[DRAFT] Add support gradient cache for Reranker #148

Open
wants to merge 17 commits into
base: main
Choose a base branch
from
Prev Previous commit
Next Next commit
fix: prediction step
  • Loading branch information
sigridjineth committed Aug 24, 2024
commit 78ff5800d86b38c0a0a9282ced59f7d81726c006
43 changes: 9 additions & 34 deletions src/tevatron/reranker/modeling.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
import logging
from dataclasses import dataclass
from typing import Optional, Dict, Any
from typing import Optional

import torch
from torch import nn, Tensor
Expand All @@ -23,37 +23,22 @@ class RerankerOutput(ModelOutput):
class RerankerModel(nn.Module):
TRANSFORMER_CLS = AutoModelForSequenceClassification

def __init__(self, hf_model: PreTrainedModel, train_batch_size: int = None):
def __init__(self, hf_model: PreTrainedModel):
super().__init__()
logger.info(f"Initializing RerankerModel with train_batch_size: {train_batch_size}")
logger.info("Initializing RerankerModel")
self.config = hf_model.config
self.hf_model = hf_model
self.train_batch_size = train_batch_size
self.cross_entropy = nn.CrossEntropyLoss(reduction='mean')
if train_batch_size:
self.register_buffer(
'target_label',
torch.zeros(self.train_batch_size, dtype=torch.long, device=self.hf_model.device)
)
logger.info(f"RerankerModel initialized with config: {self.config}")

def forward(self, input_ids: Tensor = None, attention_mask: Tensor = None, labels: Tensor = None, **kwargs):
def forward(self, input_ids: Tensor = None, attention_mask: Tensor = None, **kwargs):
logger.debug(f"Forward pass with input shape: {input_ids.shape if input_ids is not None else 'None'}")
outputs = self.hf_model(input_ids=input_ids, attention_mask=attention_mask, **kwargs)

if labels is not None:
loss = self.cross_entropy(outputs.logits.view(self.train_batch_size, -1), labels)
logger.debug(f"Computed loss: {loss.item()}")
else:
loss = None
logger.debug("No labels provided, skipping loss computation")

return RerankerOutput(
loss=loss,
scores=outputs.logits
)

def gradient_checkpointing_enable(self, gradient_checkpointing_kwargs: Optional[Dict[str, Any]] = None):
def gradient_checkpointing_enable(self, gradient_checkpointing_kwargs = None):
return False

@classmethod
Expand Down Expand Up @@ -94,16 +79,10 @@ def build(
inference_mode=False,
)
lora_model = get_peft_model(base_model, lora_config)
model = cls(
hf_model=lora_model,
train_batch_size=train_args.per_device_train_batch_size,
)
model = cls(hf_model=lora_model)
else:
logger.info("Building model without LoRA")
model = cls(
hf_model=base_model,
train_batch_size=train_args.per_device_train_batch_size,
)
model = cls(hf_model=base_model)
return model

@classmethod
Expand All @@ -123,14 +102,10 @@ def load(cls,
lora_config = LoraConfig.from_pretrained(lora_name_or_path, **hf_kwargs)
lora_model = PeftModel.from_pretrained(base_model, lora_name_or_path, config=lora_config)
lora_model = lora_model.merge_and_unload()
model = cls(
hf_model=lora_model,
)
model = cls(hf_model=lora_model)
else:
logger.info("Loading model without LoRA")
model = cls(
hf_model=base_model,
)
model = cls(hf_model=base_model)
return model

def save(self, output_dir: str):
Expand Down
56 changes: 49 additions & 7 deletions src/tevatron/reranker/trainer.py
Original file line number Diff line number Diff line change
@@ -1,19 +1,38 @@
from tevatron.reranker.modeling import RerankerOutput
from tevatron.retriever.trainer import TevatronTrainer
import logging
from typing import Dict, Union, Any

import torch
from torch import nn
from transformers import Trainer
from transformers.trainer_utils import PredictionOutput

from grad_cache import GradCache

from tevatron.reranker.arguments import TevatronTrainingArguments

logger = logging.getLogger(__name__)

def split_inputs(model_input, chunk_size):
logger.debug(f"Splitting inputs with chunk size: {chunk_size}")
keys = list(model_input.keys())
chunked_tensors = [model_input[k].split(chunk_size, dim=0) for k in keys]
return [dict(zip(keys, tt)) for tt in zip(*chunked_tensors)]

def get_rep(x: RerankerOutput):
return x.scores
def get_rep(model_output):
logger.debug(f"Getting representation from model output: {type(model_output)}")
return model_output.scores

class RerankerTrainer(TevatronTrainer):
class RerankerTrainer(Trainer):
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
loss_fn = lambda x, y: self.compute_loss(self.model, {'input_ids': x, 'labels': y})
logger.info("Initializing RerankerTrainer")
self.args: TevatronTrainingArguments

def loss_fn(scores, labels):
grouped_scores = scores.view(self.args.train_group_size, -1)
labels = torch.zeros(self.args.train_group_size, dtype=torch.long, device=scores.device)
return nn.CrossEntropyLoss()(grouped_scores, labels)

self.gc = GradCache(
models=[self.model],
chunk_sizes=[self.args.gc_chunk_size],
Expand All @@ -23,14 +42,37 @@ def __init__(self, *args, **kwargs):
fp16=self.args.fp16,
scaler=self.scaler if self.args.fp16 else None
)
logger.info(f"GradCache initialized with chunk size: {self.args.gc_chunk_size}")

def compute_loss(self, model, inputs, return_outputs=False):
logger.debug(f"Computing loss with inputs: {inputs.keys()}")
outputs = model(**inputs)
loss = outputs.loss
logger.debug(f"Computed loss: {loss.item()}")
return (loss, outputs) if return_outputs else loss

def training_step(self, model, inputs):
def training_step(self, model: nn.Module, inputs: Dict[str, Union[torch.Tensor, Any]]) -> torch.Tensor:
logger.debug("Entering training step")
model.train()
inputs = self._prepare_inputs(inputs)
_distributed = self.args.local_rank > -1
self.gc.models = [model]
loss = self.gc(inputs, no_sync_except_last=_distributed)
logger.debug(f"Training step loss: {loss.item()}")
return loss

def prediction_step(
self,
model: nn.Module,
inputs: Dict[str, Union[torch.Tensor, Any]],
prediction_loss_only: bool,
ignore_keys: bool = None,
) -> PredictionOutput:
logger.debug("Entering prediction step")
inputs = self._prepare_inputs(inputs)
with torch.no_grad():
outputs = model(**inputs)
loss = outputs.loss
logits = outputs.scores
logger.debug(f"Prediction step loss: {loss.item() if loss is not None else 'N/A'}")
return PredictionOutput(predictions=logits, label_ids=inputs.get("labels"), metrics=None)