Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[DRAFT] Add support gradient cache for Reranker #148

Open
wants to merge 17 commits into
base: main
Choose a base branch
from
1 change: 1 addition & 0 deletions examples/example_rankllama.md
Original file line number Diff line number Diff line change
Expand Up @@ -18,4 +18,5 @@ deepspeed --include localhost:4,5,6,7 --master_port 60000 --module tevatron.rera
--num_train_epochs 1 \
--logging_steps 10 \
--overwrite_output_dir
--gra
```
9 changes: 8 additions & 1 deletion src/tevatron/reranker/arguments.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
from dataclasses import dataclass, field
from typing import Optional

from transformers import TrainingArguments

@dataclass
class ModelArguments:
Expand Down Expand Up @@ -116,3 +116,10 @@ class DataArguments:
"enable the use of Tensor Cores on NVIDIA hardware with compute capability >= 7.5 (Volta)."
},
)

@dataclass
class TevatronTrainingArguments(TrainingArguments):
warmup_ratio: float = field(default=0.1)

grad_cache: bool = field(default=False, metadata={"help": "Use gradient cache"})
gc_chunk_size: Optional[int] = field(default=2, metadata={"help": "Chunk size for gradient cache"})
65 changes: 39 additions & 26 deletions src/tevatron/reranker/driver/train.py
Original file line number Diff line number Diff line change
@@ -1,57 +1,60 @@
import logging
import os
import sys

import torch
from transformers import AutoTokenizer
from transformers import (
HfArgumentParser,
set_seed,
)
from transformers import TrainingArguments

from tevatron.reranker.arguments import ModelArguments, DataArguments

from torch.nn.parallel import DistributedDataParallel as DDP
import torch.distributed as dist
from tevatron.reranker.arguments import ModelArguments, DataArguments, TevatronTrainingArguments
from tevatron.reranker.modeling import RerankerModel
from tevatron.reranker.dataset import RerankerTrainDataset
from tevatron.reranker.trainer import RerankerTrainer
from tevatron.reranker.collator import RerankerTrainCollator
from tevatron.reranker.trainer import RerankerTrainer

logger = logging.getLogger(__name__)


def setup_ddp():
if 'RANK' in os.environ and 'WORLD_SIZE' in os.environ:
# We're running in a distributed environment
import torch.distributed as dist
rank = int(os.environ['RANK'])
world_size = int(os.environ['WORLD_SIZE'])
dist.init_process_group(backend="nccl")
return rank
else:
# We're not running in a distributed environment
return -1


def main():
parser = HfArgumentParser((ModelArguments, DataArguments, TrainingArguments))
parser = HfArgumentParser((ModelArguments, DataArguments, TevatronTrainingArguments))

if len(sys.argv) == 2 and sys.argv[1].endswith(".json"):
model_args, data_args, training_args = parser.parse_json_file(json_file=os.path.abspath(sys.argv[1]))
else:
model_args, data_args, training_args = parser.parse_args_into_dataclasses()
model_args: ModelArguments
data_args: DataArguments
training_args: TrainingArguments

if (
os.path.exists(training_args.output_dir)
and os.listdir(training_args.output_dir)
and training_args.do_train
and not training_args.overwrite_output_dir
):
raise ValueError(
f"Output directory ({training_args.output_dir}) already exists and is not empty. Use --overwrite_output_dir to overcome."
)

local_rank = setup_ddp()
training_args.local_rank = local_rank

# Setup logging
logging.basicConfig(
format="%(asctime)s - %(levelname)s - %(name)s - %(message)s",
datefmt="%m/%d/%Y %H:%M:%S",
level=logging.INFO if training_args.local_rank in [-1, 0] else logging.WARN,
level=logging.INFO if local_rank in [-1, 0] else logging.WARN,
)
logger.warning(
"Process rank: %s, device: %s, n_gpu: %s, distributed training: %s, 16-bits training: %s",
training_args.local_rank,
local_rank,
training_args.device,
training_args.n_gpu,
bool(training_args.local_rank != -1),
training_args.fp16,
bool(local_rank != -1),
training_args.fp16 or training_args.bf16,
)
logger.info("Training/evaluation parameters %s", training_args)
logger.info("MODEL parameters %s", model_args)
Expand All @@ -60,20 +63,30 @@ def main():

tokenizer = AutoTokenizer.from_pretrained(
model_args.tokenizer_name if model_args.tokenizer_name else model_args.model_name_or_path,
cache_dir=model_args.cache_dir
cache_dir=model_args.cache_dir,
trust_remote_code=True
)
if tokenizer.pad_token_id is None:
tokenizer.pad_token_id = tokenizer.unk_token_id
tokenizer.padding_side = 'right'

model = RerankerModel.build(
model_args,
training_args,
cache_dir=model_args.cache_dir,
)

# Move model to GPU
if local_rank != -1:
model = model.to(local_rank)
model = DDP(model, device_ids=[local_rank], output_device=local_rank)

train_dataset = RerankerTrainDataset(data_args)
train_collator = RerankerTrainCollator(data_args, tokenizer)

training_args.gc_chunk_size = getattr(training_args, 'gc_chunk_size', 2)
training_args.grad_cache = getattr(training_args, 'grad_cache', False)

trainer = RerankerTrainer(
model=model,
args=training_args,
Expand All @@ -82,7 +95,7 @@ def main():
)
train_dataset.trainer = trainer

trainer.train() # TODO: resume training
trainer.train()
trainer.save_model()
if trainer.is_world_process_zero():
tokenizer.save_pretrained(training_args.output_dir)
Expand Down
83 changes: 33 additions & 50 deletions src/tevatron/reranker/modeling.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
import os
import logging
from dataclasses import dataclass
from typing import Dict, Optional
from typing import Optional

import torch
from torch import nn, Tensor
Expand All @@ -9,11 +9,8 @@
from transformers import TrainingArguments
from peft import LoraConfig, PeftModel, TaskType, get_peft_model


from tevatron.reranker.arguments import ModelArguments

import logging

logger = logging.getLogger(__name__)


Expand All @@ -22,44 +19,24 @@ class RerankerOutput(ModelOutput):
loss: Optional[Tensor] = None
scores: Optional[Tensor] = None


class RerankerModel(nn.Module):
TRANSFORMER_CLS = AutoModelForSequenceClassification

def __init__(self, hf_model: PreTrainedModel, train_batch_size: int = None):
def __init__(self, hf_model: PreTrainedModel):
super().__init__()
logger.info("Initializing RerankerModel")
self.config = hf_model.config
self.hf_model = hf_model
self.train_batch_size = train_batch_size
self.cross_entropy = nn.CrossEntropyLoss(reduction='mean')
if train_batch_size:
self.register_buffer(
'target_label',
torch.zeros(self.train_batch_size, dtype=torch.long, device=self.hf_model.device)
)
for name, param in self.hf_model.named_parameters():
# for some reason, ds zero 3 left some weights empty
if 'modules_to_save' in name and param.numel() == 0:
logger.warning(f'parameter {name}, shape {param.shape} is empty')
param.data = nn.Linear(self.hf_model.config.hidden_size, 1).weight.data
logger.warning('{} data: {}'.format(name, param.data.cpu().numpy()))
logger.info(f"RerankerModel initialized with config: {self.config}")

def forward(self, pair: Dict[str, Tensor] = None):
ranker_logits = self.hf_model(**pair, return_dict=True).logits
if self.train_batch_size:
grouped_logits = ranker_logits.view(self.train_batch_size, -1)
loss = self.cross_entropy(grouped_logits, self.target_label)
return RerankerOutput(
loss = loss,
scores = ranker_logits
)
def forward(self, input_ids: Tensor = None, attention_mask: Tensor = None, **kwargs):
logger.debug(f"Forward pass with input shape: {input_ids.shape if input_ids is not None else 'None'}")
outputs = self.hf_model(input_ids=input_ids, attention_mask=attention_mask, **kwargs)

return RerankerOutput(
loss = None,
scores = ranker_logits
scores=outputs.logits
)

def gradient_checkpointing_enable(self, **kwargs):
self.hf_model.base_model.model.gradient_checkpointing_enable(**kwargs)

@classmethod
def build(
Expand All @@ -68,19 +45,27 @@ def build(
train_args: TrainingArguments,
**hf_kwargs,
):
logger.info(f"Building RerankerModel with args: {model_args}")
base_model = cls.TRANSFORMER_CLS.from_pretrained(
model_args.model_name_or_path,
**hf_kwargs,
)
if base_model.config.pad_token_id is None:
base_model.config.pad_token_id = 0
logger.info("Set pad_token_id to 0")

if model_args.lora or model_args.lora_name_or_path:
logger.info("Applying LoRA")
if train_args.gradient_checkpointing:
base_model.enable_input_require_grads()
if model_args.lora_name_or_path:
logger.info(f"Loading LoRA from {model_args.lora_name_or_path}")
lora_config = LoraConfig.from_pretrained(model_args.lora_name_or_path, **hf_kwargs)
lora_model = PeftModel.from_pretrained(base_model, model_args.lora_name_or_path, torch_dtype=torch.bfloat16, attn_implementation="flash_attention_2")
lora_model = PeftModel.from_pretrained(base_model, model_args.lora_name_or_path,
torch_dtype=torch.bfloat16,
attn_implementation="flash_attention_2")
else:
logger.info("Initializing new LoRA")
lora_config = LoraConfig(
base_model_name_or_path=model_args.model_name_or_path,
task_type=TaskType.SEQ_CLS,
Expand All @@ -91,37 +76,35 @@ def build(
inference_mode=False,
)
lora_model = get_peft_model(base_model, lora_config)
model = cls(
hf_model=lora_model,
train_batch_size=train_args.per_device_train_batch_size,
)
model = cls(hf_model=lora_model)
else:
model = cls(
hf_model=base_model,
train_batch_size=train_args.per_device_train_batch_size,
)
logger.info("Building model without LoRA")
model = cls(hf_model=base_model)
return model

@classmethod
def load(cls,
model_name_or_path: str,
lora_name_or_path: str = None,
**hf_kwargs):
base_model = cls.TRANSFORMER_CLS.from_pretrained(model_name_or_path, num_labels=1, **hf_kwargs, torch_dtype=torch.bfloat16, attn_implementation="flash_attention_2")
logger.info(f"Loading RerankerModel from {model_name_or_path}")
base_model = cls.TRANSFORMER_CLS.from_pretrained(model_name_or_path, num_labels=1, **hf_kwargs,
torch_dtype=torch.bfloat16,
attn_implementation="flash_attention_2")
if base_model.config.pad_token_id is None:
base_model.config.pad_token_id = 0
logger.info("Set pad_token_id to 0")
if lora_name_or_path:
logger.info(f"Loading LoRA from {lora_name_or_path}")
lora_config = LoraConfig.from_pretrained(lora_name_or_path, **hf_kwargs)
lora_model = PeftModel.from_pretrained(base_model, lora_name_or_path, config=lora_config)
lora_model = lora_model.merge_and_unload()
model = cls(
hf_model=lora_model,
)
model = cls(hf_model=lora_model)
else:
model = cls(
hf_model=base_model,
)
logger.info("Loading model without LoRA")
model = cls(hf_model=base_model)
return model

def save(self, output_dir: str):
self.hf_model.save_pretrained(output_dir)
logger.info(f"Saving model to {output_dir}")
self.hf_model.save_pretrained(output_dir)
Loading