Skip to content

Commit

Permalink
fix: forward method
Browse files Browse the repository at this point in the history
  • Loading branch information
sigridjineth committed Aug 24, 2024
1 parent ca4b04b commit 3ec6cfd
Show file tree
Hide file tree
Showing 2 changed files with 40 additions and 83 deletions.
55 changes: 29 additions & 26 deletions src/tevatron/reranker/modeling.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
import os
import logging
from dataclasses import dataclass
from typing import Dict, Optional
from typing import Optional

import torch
from torch import nn, Tensor
Expand All @@ -11,8 +11,6 @@

from tevatron.reranker.arguments import ModelArguments

import logging

logger = logging.getLogger(__name__)


Expand All @@ -27,6 +25,7 @@ class RerankerModel(nn.Module):

def __init__(self, hf_model: PreTrainedModel, train_batch_size: int = None):
super().__init__()
logger.info(f"Initializing RerankerModel with train_batch_size: {train_batch_size}")
self.config = hf_model.config
self.hf_model = hf_model
self.train_batch_size = train_batch_size
Expand All @@ -36,54 +35,52 @@ def __init__(self, hf_model: PreTrainedModel, train_batch_size: int = None):
'target_label',
torch.zeros(self.train_batch_size, dtype=torch.long, device=self.hf_model.device)
)
for name, param in self.hf_model.named_parameters():
# for some reason, ds zero 3 left some weights empty
if 'modules_to_save' in name and param.numel() == 0:
logger.warning(f'parameter {name}, shape {param.shape} is empty')
param.data = nn.Linear(self.hf_model.config.hidden_size, 1).weight.data
logger.warning('{} data: {}'.format(name, param.data.cpu().numpy()))
logger.info(f"RerankerModel initialized with config: {self.config}")

def forward(self, pair: Dict[str, Tensor] = None):
ranker_logits = self.hf_model(**pair, return_dict=True).logits
if self.train_batch_size:
grouped_logits = ranker_logits.view(self.train_batch_size, -1)
loss = self.cross_entropy(grouped_logits, self.target_label)
return RerankerOutput(
loss=loss,
scores=ranker_logits
)
def forward(self, input_ids: Tensor = None, attention_mask: Tensor = None, labels: Tensor = None, **kwargs):
logger.debug(f"Forward pass with input shape: {input_ids.shape if input_ids is not None else 'None'}")
outputs = self.hf_model(input_ids=input_ids, attention_mask=attention_mask, **kwargs)

if labels is not None:
loss = self.cross_entropy(outputs.logits.view(self.train_batch_size, -1), labels)
logger.debug(f"Computed loss: {loss.item()}")
else:
loss = None
logger.debug("No labels provided, skipping loss computation")

return RerankerOutput(
loss=None,
scores=ranker_logits
loss=loss,
scores=outputs.logits
)

def gradient_checkpointing_enable(self, **kwargs):
return False
# self.hf_model.base_model.model.gradient_checkpointing_enable(**kwargs)

@classmethod
def build(
cls,
model_args: ModelArguments,
train_args: TrainingArguments,
**hf_kwargs,
):
logger.info(f"Building RerankerModel with args: {model_args}")
base_model = cls.TRANSFORMER_CLS.from_pretrained(
model_args.model_name_or_path,
**hf_kwargs,
)
if base_model.config.pad_token_id is None:
base_model.config.pad_token_id = 0
logger.info("Set pad_token_id to 0")

if model_args.lora or model_args.lora_name_or_path:
logger.info("Applying LoRA")
if train_args.gradient_checkpointing:
base_model.enable_input_require_grads()
if model_args.lora_name_or_path:
logger.info(f"Loading LoRA from {model_args.lora_name_or_path}")
lora_config = LoraConfig.from_pretrained(model_args.lora_name_or_path, **hf_kwargs)
lora_model = PeftModel.from_pretrained(base_model, model_args.lora_name_or_path,
torch_dtype=torch.bfloat16,
attn_implementation="flash_attention_2")
else:
logger.info("Initializing new LoRA")
lora_config = LoraConfig(
base_model_name_or_path=model_args.model_name_or_path,
task_type=TaskType.SEQ_CLS,
Expand All @@ -99,6 +96,7 @@ def build(
train_batch_size=train_args.per_device_train_batch_size,
)
else:
logger.info("Building model without LoRA")
model = cls(
hf_model=base_model,
train_batch_size=train_args.per_device_train_batch_size,
Expand All @@ -110,23 +108,28 @@ def load(cls,
model_name_or_path: str,
lora_name_or_path: str = None,
**hf_kwargs):
logger.info(f"Loading RerankerModel from {model_name_or_path}")
base_model = cls.TRANSFORMER_CLS.from_pretrained(model_name_or_path, num_labels=1, **hf_kwargs,
torch_dtype=torch.bfloat16,
attn_implementation="flash_attention_2")
if base_model.config.pad_token_id is None:
base_model.config.pad_token_id = 0
logger.info("Set pad_token_id to 0")
if lora_name_or_path:
logger.info(f"Loading LoRA from {lora_name_or_path}")
lora_config = LoraConfig.from_pretrained(lora_name_or_path, **hf_kwargs)
lora_model = PeftModel.from_pretrained(base_model, lora_name_or_path, config=lora_config)
lora_model = lora_model.merge_and_unload()
model = cls(
hf_model=lora_model,
)
else:
logger.info("Loading model without LoRA")
model = cls(
hf_model=base_model,
)
return model

def save(self, output_dir: str):
self.hf_model.save_pretrained(output_dir)
logger.info(f"Saving model to {output_dir}")
self.hf_model.save_pretrained(output_dir)
68 changes: 11 additions & 57 deletions src/tevatron/reranker/trainer.py
Original file line number Diff line number Diff line change
@@ -1,81 +1,35 @@
import os
from typing import Optional
from tevatron.reranker.modeling import RerankerOutput
from tevatron.retriever.trainer import TevatronTrainer
from grad_cache import GradCache

import torch
from torch import Tensor
from torch.nn import functional as F

from transformers.trainer import Trainer
from transformers.deepspeed import is_deepspeed_zero3_enabled
from peft import get_peft_model_state_dict

import logging

logger = logging.getLogger(__name__)

try:
from grad_cache import GradCache

_grad_cache_available = True
except ModuleNotFoundError:
_grad_cache_available = False


def split_inputs(model_input: dict, chunk_size: int):
def split_inputs(model_input, chunk_size):
keys = list(model_input.keys())
chunked_tensors = [model_input[k].split(chunk_size, dim=0) for k in keys]
return [dict(zip(keys, tt)) for tt in zip(*chunked_tensors)]

def get_rep(x: RerankerOutput):
return x.scores

def get_rep(x):
return x.logits


class RerankerTrainer(Trainer):
class RerankerTrainer(TevatronTrainer):
def __init__(self, *args, **kwargs):
super(RerankerTrainer, self).__init__(*args, **kwargs)

if not _grad_cache_available:
raise ValueError(
'Grad Cache package not available. You can obtain it from https://github.com/luyug/GradCache.')

super().__init__(*args, **kwargs)
loss_fn = lambda x, y: self.compute_loss(self.model, {'input_ids': x, 'labels': y})
self.gc = GradCache(
models=[self.model],
chunk_sizes=[self.args.gc_chunk_size],
loss_fn=self.compute_loss,
loss_fn=loss_fn,
split_input_fn=split_inputs,
get_rep_fn=get_rep,
fp16=self.args.fp16,
scaler=self.scaler if self.args.fp16 else None
)

def _save(self, output_dir: Optional[str] = None, state_dict=None):
output_dir = output_dir if output_dir is not None else self.args.output_dir
os.makedirs(output_dir, exist_ok=True)
logger.info("Saving model checkpoint to %s", output_dir)
self.model.save(output_dir)

if is_deepspeed_zero3_enabled():
if state_dict is None:
state_dict = self.model.state_dict()
prefix = 'hf_model.'
assert all(
k.startswith(prefix) or k == "target_label"
for k in state_dict.keys()
), list(state_dict.keys())
state_dict = {k[len(prefix):]: v for k, v in state_dict.items()}
lora_state_dict = get_peft_model_state_dict(self.model.hf_model, state_dict)
if self.args.process_index <= 0:
torch.save(lora_state_dict, os.path.join(output_dir, "adapter_model.bin"))
print(f"Save adapter model at {output_dir}")

def compute_loss(self, model, inputs, return_outputs=False):
outputs = model(inputs)
outputs = model(**inputs)
loss = outputs.loss
return (loss, outputs) if return_outputs else loss

def training_step(self, model, inputs):
model.train()
_distributed = self.args.local_rank > -1
self.gc.models = [model]
loss = self.gc(inputs, no_sync_except_last=_distributed)
Expand Down

0 comments on commit 3ec6cfd

Please sign in to comment.