Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[DRAFT] Add support gradient cache for Reranker #148

Open
wants to merge 17 commits into
base: main
Choose a base branch
from
Prev Previous commit
Next Next commit
fix: forward method
  • Loading branch information
sigridjineth committed Aug 24, 2024
commit a79b647c7effb80f5477d0fdcacd4138b4045fb7
56 changes: 31 additions & 25 deletions src/tevatron/reranker/modeling.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
import os
import logging
from dataclasses import dataclass
from typing import Dict, Optional
from typing import Optional, Dict, Any

import torch
from torch import nn, Tensor
Expand All @@ -11,8 +11,6 @@

from tevatron.reranker.arguments import ModelArguments

import logging

logger = logging.getLogger(__name__)


Expand All @@ -27,6 +25,7 @@ class RerankerModel(nn.Module):

def __init__(self, hf_model: PreTrainedModel, train_batch_size: int = None):
super().__init__()
logger.info(f"Initializing RerankerModel with train_batch_size: {train_batch_size}")
self.config = hf_model.config
self.hf_model = hf_model
self.train_batch_size = train_batch_size
Expand All @@ -36,31 +35,26 @@ def __init__(self, hf_model: PreTrainedModel, train_batch_size: int = None):
'target_label',
torch.zeros(self.train_batch_size, dtype=torch.long, device=self.hf_model.device)
)
for name, param in self.hf_model.named_parameters():
# for some reason, ds zero 3 left some weights empty
if 'modules_to_save' in name and param.numel() == 0:
logger.warning(f'parameter {name}, shape {param.shape} is empty')
param.data = nn.Linear(self.hf_model.config.hidden_size, 1).weight.data
logger.warning('{} data: {}'.format(name, param.data.cpu().numpy()))

def forward(self, pair: Dict[str, Tensor] = None):
ranker_logits = self.hf_model(**pair, return_dict=True).logits
if self.train_batch_size:
grouped_logits = ranker_logits.view(self.train_batch_size, -1)
loss = self.cross_entropy(grouped_logits, self.target_label)
return RerankerOutput(
loss=loss,
scores=ranker_logits
)
logger.info(f"RerankerModel initialized with config: {self.config}")

def forward(self, input_ids: Tensor = None, attention_mask: Tensor = None, labels: Tensor = None, **kwargs):
logger.debug(f"Forward pass with input shape: {input_ids.shape if input_ids is not None else 'None'}")
outputs = self.hf_model(input_ids=input_ids, attention_mask=attention_mask, **kwargs)

if labels is not None:
loss = self.cross_entropy(outputs.logits.view(self.train_batch_size, -1), labels)
logger.debug(f"Computed loss: {loss.item()}")
else:
loss = None
logger.debug("No labels provided, skipping loss computation")

return RerankerOutput(
loss=None,
scores=ranker_logits
loss=loss,
scores=outputs.logits
)

def gradient_checkpointing_enable(self, **kwargs):
def gradient_checkpointing_enable(self, gradient_checkpointing_kwargs: Optional[Dict[str, Any]] = None):
return False
# self.hf_model.base_model.model.gradient_checkpointing_enable(**kwargs)

@classmethod
def build(
Expand All @@ -69,21 +63,27 @@ def build(
train_args: TrainingArguments,
**hf_kwargs,
):
logger.info(f"Building RerankerModel with args: {model_args}")
base_model = cls.TRANSFORMER_CLS.from_pretrained(
model_args.model_name_or_path,
**hf_kwargs,
)
if base_model.config.pad_token_id is None:
base_model.config.pad_token_id = 0
logger.info("Set pad_token_id to 0")

if model_args.lora or model_args.lora_name_or_path:
logger.info("Applying LoRA")
if train_args.gradient_checkpointing:
base_model.enable_input_require_grads()
if model_args.lora_name_or_path:
logger.info(f"Loading LoRA from {model_args.lora_name_or_path}")
lora_config = LoraConfig.from_pretrained(model_args.lora_name_or_path, **hf_kwargs)
lora_model = PeftModel.from_pretrained(base_model, model_args.lora_name_or_path,
torch_dtype=torch.bfloat16,
attn_implementation="flash_attention_2")
else:
logger.info("Initializing new LoRA")
lora_config = LoraConfig(
base_model_name_or_path=model_args.model_name_or_path,
task_type=TaskType.SEQ_CLS,
Expand All @@ -99,6 +99,7 @@ def build(
train_batch_size=train_args.per_device_train_batch_size,
)
else:
logger.info("Building model without LoRA")
model = cls(
hf_model=base_model,
train_batch_size=train_args.per_device_train_batch_size,
Expand All @@ -110,23 +111,28 @@ def load(cls,
model_name_or_path: str,
lora_name_or_path: str = None,
**hf_kwargs):
logger.info(f"Loading RerankerModel from {model_name_or_path}")
base_model = cls.TRANSFORMER_CLS.from_pretrained(model_name_or_path, num_labels=1, **hf_kwargs,
torch_dtype=torch.bfloat16,
attn_implementation="flash_attention_2")
if base_model.config.pad_token_id is None:
base_model.config.pad_token_id = 0
logger.info("Set pad_token_id to 0")
if lora_name_or_path:
logger.info(f"Loading LoRA from {lora_name_or_path}")
lora_config = LoraConfig.from_pretrained(lora_name_or_path, **hf_kwargs)
lora_model = PeftModel.from_pretrained(base_model, lora_name_or_path, config=lora_config)
lora_model = lora_model.merge_and_unload()
model = cls(
hf_model=lora_model,
)
else:
logger.info("Loading model without LoRA")
model = cls(
hf_model=base_model,
)
return model

def save(self, output_dir: str):
self.hf_model.save_pretrained(output_dir)
logger.info(f"Saving model to {output_dir}")
self.hf_model.save_pretrained(output_dir)
68 changes: 11 additions & 57 deletions src/tevatron/reranker/trainer.py
Original file line number Diff line number Diff line change
@@ -1,81 +1,35 @@
import os
from typing import Optional
from tevatron.reranker.modeling import RerankerOutput
from tevatron.retriever.trainer import TevatronTrainer
from grad_cache import GradCache

import torch
from torch import Tensor
from torch.nn import functional as F

from transformers.trainer import Trainer
from transformers.deepspeed import is_deepspeed_zero3_enabled
from peft import get_peft_model_state_dict

import logging

logger = logging.getLogger(__name__)

try:
from grad_cache import GradCache

_grad_cache_available = True
except ModuleNotFoundError:
_grad_cache_available = False


def split_inputs(model_input: dict, chunk_size: int):
def split_inputs(model_input, chunk_size):
keys = list(model_input.keys())
chunked_tensors = [model_input[k].split(chunk_size, dim=0) for k in keys]
return [dict(zip(keys, tt)) for tt in zip(*chunked_tensors)]

def get_rep(x: RerankerOutput):
return x.scores

def get_rep(x):
return x.logits


class RerankerTrainer(Trainer):
class RerankerTrainer(TevatronTrainer):
def __init__(self, *args, **kwargs):
super(RerankerTrainer, self).__init__(*args, **kwargs)

if not _grad_cache_available:
raise ValueError(
'Grad Cache package not available. You can obtain it from https://github.com/luyug/GradCache.')

super().__init__(*args, **kwargs)
loss_fn = lambda x, y: self.compute_loss(self.model, {'input_ids': x, 'labels': y})
self.gc = GradCache(
models=[self.model],
chunk_sizes=[self.args.gc_chunk_size],
loss_fn=self.compute_loss,
loss_fn=loss_fn,
split_input_fn=split_inputs,
get_rep_fn=get_rep,
fp16=self.args.fp16,
scaler=self.scaler if self.args.fp16 else None
)

def _save(self, output_dir: Optional[str] = None, state_dict=None):
output_dir = output_dir if output_dir is not None else self.args.output_dir
os.makedirs(output_dir, exist_ok=True)
logger.info("Saving model checkpoint to %s", output_dir)
self.model.save(output_dir)

if is_deepspeed_zero3_enabled():
if state_dict is None:
state_dict = self.model.state_dict()
prefix = 'hf_model.'
assert all(
k.startswith(prefix) or k == "target_label"
for k in state_dict.keys()
), list(state_dict.keys())
state_dict = {k[len(prefix):]: v for k, v in state_dict.items()}
lora_state_dict = get_peft_model_state_dict(self.model.hf_model, state_dict)
if self.args.process_index <= 0:
torch.save(lora_state_dict, os.path.join(output_dir, "adapter_model.bin"))
print(f"Save adapter model at {output_dir}")

def compute_loss(self, model, inputs, return_outputs=False):
outputs = model(inputs)
outputs = model(**inputs)
loss = outputs.loss
return (loss, outputs) if return_outputs else loss

def training_step(self, model, inputs):
model.train()
_distributed = self.args.local_rank > -1
self.gc.models = [model]
loss = self.gc(inputs, no_sync_except_last=_distributed)
Expand Down