Skip to content

Commit

Permalink
fix: prediction step
Browse files Browse the repository at this point in the history
  • Loading branch information
sigridjineth committed Aug 24, 2024
1 parent a79b647 commit 5f999c1
Show file tree
Hide file tree
Showing 2 changed files with 57 additions and 41 deletions.
43 changes: 9 additions & 34 deletions src/tevatron/reranker/modeling.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
import logging
from dataclasses import dataclass
from typing import Optional, Dict, Any
from typing import Optional

import torch
from torch import nn, Tensor
Expand All @@ -23,37 +23,22 @@ class RerankerOutput(ModelOutput):
class RerankerModel(nn.Module):
TRANSFORMER_CLS = AutoModelForSequenceClassification

def __init__(self, hf_model: PreTrainedModel, train_batch_size: int = None):
def __init__(self, hf_model: PreTrainedModel):
super().__init__()
logger.info(f"Initializing RerankerModel with train_batch_size: {train_batch_size}")
logger.info("Initializing RerankerModel")
self.config = hf_model.config
self.hf_model = hf_model
self.train_batch_size = train_batch_size
self.cross_entropy = nn.CrossEntropyLoss(reduction='mean')
if train_batch_size:
self.register_buffer(
'target_label',
torch.zeros(self.train_batch_size, dtype=torch.long, device=self.hf_model.device)
)
logger.info(f"RerankerModel initialized with config: {self.config}")

def forward(self, input_ids: Tensor = None, attention_mask: Tensor = None, labels: Tensor = None, **kwargs):
def forward(self, input_ids: Tensor = None, attention_mask: Tensor = None, **kwargs):
logger.debug(f"Forward pass with input shape: {input_ids.shape if input_ids is not None else 'None'}")
outputs = self.hf_model(input_ids=input_ids, attention_mask=attention_mask, **kwargs)

if labels is not None:
loss = self.cross_entropy(outputs.logits.view(self.train_batch_size, -1), labels)
logger.debug(f"Computed loss: {loss.item()}")
else:
loss = None
logger.debug("No labels provided, skipping loss computation")

return RerankerOutput(
loss=loss,
scores=outputs.logits
)

def gradient_checkpointing_enable(self, gradient_checkpointing_kwargs: Optional[Dict[str, Any]] = None):
def gradient_checkpointing_enable(self, gradient_checkpointing_kwargs = None):
return False

@classmethod
Expand Down Expand Up @@ -94,16 +79,10 @@ def build(
inference_mode=False,
)
lora_model = get_peft_model(base_model, lora_config)
model = cls(
hf_model=lora_model,
train_batch_size=train_args.per_device_train_batch_size,
)
model = cls(hf_model=lora_model)
else:
logger.info("Building model without LoRA")
model = cls(
hf_model=base_model,
train_batch_size=train_args.per_device_train_batch_size,
)
model = cls(hf_model=base_model)
return model

@classmethod
Expand All @@ -123,14 +102,10 @@ def load(cls,
lora_config = LoraConfig.from_pretrained(lora_name_or_path, **hf_kwargs)
lora_model = PeftModel.from_pretrained(base_model, lora_name_or_path, config=lora_config)
lora_model = lora_model.merge_and_unload()
model = cls(
hf_model=lora_model,
)
model = cls(hf_model=lora_model)
else:
logger.info("Loading model without LoRA")
model = cls(
hf_model=base_model,
)
model = cls(hf_model=base_model)
return model

def save(self, output_dir: str):
Expand Down
55 changes: 48 additions & 7 deletions src/tevatron/reranker/trainer.py
Original file line number Diff line number Diff line change
@@ -1,19 +1,37 @@
from tevatron.reranker.modeling import RerankerOutput
from tevatron.retriever.trainer import TevatronTrainer
import logging
from typing import Dict, Union, Any

import torch
from torch import nn
from transformers import Trainer
from transformers.trainer_utils import PredictionOutput

from tevatron.arguments import TevatronTrainingArguments
from grad_cache import GradCache

logger = logging.getLogger(__name__)

def split_inputs(model_input, chunk_size):
logger.debug(f"Splitting inputs with chunk size: {chunk_size}")
keys = list(model_input.keys())
chunked_tensors = [model_input[k].split(chunk_size, dim=0) for k in keys]
return [dict(zip(keys, tt)) for tt in zip(*chunked_tensors)]

def get_rep(x: RerankerOutput):
return x.scores
def get_rep(model_output):
logger.debug(f"Getting representation from model output: {type(model_output)}")
return model_output.scores

class RerankerTrainer(TevatronTrainer):
class RerankerTrainer(Trainer):
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
loss_fn = lambda x, y: self.compute_loss(self.model, {'input_ids': x, 'labels': y})
logger.info("Initializing RerankerTrainer")
self.args: TevatronTrainingArguments

def loss_fn(scores, labels):
grouped_scores = scores.view(self.args.train_group_size, -1)
labels = torch.zeros(self.args.train_group_size, dtype=torch.long, device=scores.device)
return nn.CrossEntropyLoss()(grouped_scores, labels)

self.gc = GradCache(
models=[self.model],
chunk_sizes=[self.args.gc_chunk_size],
Expand All @@ -23,14 +41,37 @@ def __init__(self, *args, **kwargs):
fp16=self.args.fp16,
scaler=self.scaler if self.args.fp16 else None
)
logger.info(f"GradCache initialized with chunk size: {self.args.gc_chunk_size}")

def compute_loss(self, model, inputs, return_outputs=False):
logger.debug(f"Computing loss with inputs: {inputs.keys()}")
outputs = model(**inputs)
loss = outputs.loss
logger.debug(f"Computed loss: {loss.item()}")
return (loss, outputs) if return_outputs else loss

def training_step(self, model, inputs):
def training_step(self, model: nn.Module, inputs: Dict[str, Union[torch.Tensor, Any]]) -> torch.Tensor:
logger.debug("Entering training step")
model.train()
inputs = self._prepare_inputs(inputs)
_distributed = self.args.local_rank > -1
self.gc.models = [model]
loss = self.gc(inputs, no_sync_except_last=_distributed)
logger.debug(f"Training step loss: {loss.item()}")
return loss

def prediction_step(
self,
model: nn.Module,
inputs: Dict[str, Union[torch.Tensor, Any]],
prediction_loss_only: bool,
ignore_keys: bool = None,
) -> PredictionOutput:
logger.debug("Entering prediction step")
inputs = self._prepare_inputs(inputs)
with torch.no_grad():
outputs = model(**inputs)
loss = outputs.loss
logits = outputs.scores
logger.debug(f"Prediction step loss: {loss.item() if loss is not None else 'N/A'}")
return PredictionOutput(predictions=logits, label_ids=inputs.get("labels"), metrics=None)

0 comments on commit 5f999c1

Please sign in to comment.