Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[DRAFT] Add support gradient cache for Reranker #148

Open
wants to merge 17 commits into
base: main
Choose a base branch
from
Prev Previous commit
Next Next commit
hotfix: gradient_checkpointing_enable
  • Loading branch information
sigridjineth committed Aug 24, 2024
commit ca4b04b7fb40e4d30df16d12d6ee9c9261091209
23 changes: 14 additions & 9 deletions src/tevatron/reranker/modeling.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,6 @@
from transformers import TrainingArguments
from peft import LoraConfig, PeftModel, TaskType, get_peft_model


from tevatron.reranker.arguments import ModelArguments

import logging
Expand All @@ -22,6 +21,7 @@ class RerankerOutput(ModelOutput):
loss: Optional[Tensor] = None
scores: Optional[Tensor] = None


class RerankerModel(nn.Module):
TRANSFORMER_CLS = AutoModelForSequenceClassification

Expand Down Expand Up @@ -49,17 +49,18 @@ def forward(self, pair: Dict[str, Tensor] = None):
grouped_logits = ranker_logits.view(self.train_batch_size, -1)
loss = self.cross_entropy(grouped_logits, self.target_label)
return RerankerOutput(
loss = loss,
scores = ranker_logits
loss=loss,
scores=ranker_logits
)

return RerankerOutput(
loss = None,
scores = ranker_logits
loss=None,
scores=ranker_logits
)

def gradient_checkpointing_enable(self, **kwargs):
self.hf_model.base_model.model.gradient_checkpointing_enable(**kwargs)
return False
# self.hf_model.base_model.model.gradient_checkpointing_enable(**kwargs)

@classmethod
def build(
Expand All @@ -79,7 +80,9 @@ def build(
base_model.enable_input_require_grads()
if model_args.lora_name_or_path:
lora_config = LoraConfig.from_pretrained(model_args.lora_name_or_path, **hf_kwargs)
lora_model = PeftModel.from_pretrained(base_model, model_args.lora_name_or_path, torch_dtype=torch.bfloat16, attn_implementation="flash_attention_2")
lora_model = PeftModel.from_pretrained(base_model, model_args.lora_name_or_path,
torch_dtype=torch.bfloat16,
attn_implementation="flash_attention_2")
else:
lora_config = LoraConfig(
base_model_name_or_path=model_args.model_name_or_path,
Expand Down Expand Up @@ -107,7 +110,9 @@ def load(cls,
model_name_or_path: str,
lora_name_or_path: str = None,
**hf_kwargs):
base_model = cls.TRANSFORMER_CLS.from_pretrained(model_name_or_path, num_labels=1, **hf_kwargs, torch_dtype=torch.bfloat16, attn_implementation="flash_attention_2")
base_model = cls.TRANSFORMER_CLS.from_pretrained(model_name_or_path, num_labels=1, **hf_kwargs,
torch_dtype=torch.bfloat16,
attn_implementation="flash_attention_2")
if base_model.config.pad_token_id is None:
base_model.config.pad_token_id = 0
if lora_name_or_path:
Expand Down