Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[DRAFT] Add support gradient cache for Reranker #148

Open
wants to merge 17 commits into
base: main
Choose a base branch
from
Next Next commit
gc reranker
  • Loading branch information
sigridjineth committed Aug 24, 2024
commit 859c9e613f19b6bfcb661c722d6932197aedc944
15 changes: 9 additions & 6 deletions src/tevatron/reranker/driver/train.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,17 +7,18 @@
HfArgumentParser,
set_seed,
)
from transformers import TrainingArguments

from tevatron.reranker.arguments import ModelArguments, DataArguments

from tevatron.reranker.arguments import ModelArguments, DataArguments, \
TevatronTrainingArguments as TrainingArguments
from tevatron.reranker.modeling import RerankerModel
from tevatron.reranker.dataset import RerankerTrainDataset
from tevatron.reranker.trainer import RerankerTrainer
from tevatron.reranker.collator import RerankerTrainCollator
from tevatron.reranker.trainer import RerankerTrainer
from tevatron.reranker.gc_trainer import GradCacheTrainer as GCTrainer

logger = logging.getLogger(__name__)


def main():
parser = HfArgumentParser((ModelArguments, DataArguments, TrainingArguments))

Expand Down Expand Up @@ -65,6 +66,7 @@ def main():
if tokenizer.pad_token_id is None:
tokenizer.pad_token_id = tokenizer.unk_token_id
tokenizer.padding_side = 'right'

model = RerankerModel.build(
model_args,
training_args,
Expand All @@ -74,7 +76,8 @@ def main():
train_dataset = RerankerTrainDataset(data_args)
train_collator = RerankerTrainCollator(data_args, tokenizer)

trainer = RerankerTrainer(
trainer_cls = GCTrainer if training_args.grad_cache else RerankerTrainer
trainer = trainer_cls(
model=model,
args=training_args,
train_dataset=train_dataset,
Expand All @@ -89,4 +92,4 @@ def main():


if __name__ == "__main__":
main()
main()
46 changes: 44 additions & 2 deletions src/tevatron/reranker/trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,19 +2,53 @@
from typing import Optional

import torch
from torch import Tensor
from torch.nn import functional as F

from transformers.trainer import Trainer
from transformers.deepspeed import is_deepspeed_zero3_enabled
from peft import get_peft_model_state_dict

import logging

logger = logging.getLogger(__name__)

try:
from grad_cache import GradCache

_grad_cache_available = True
except ModuleNotFoundError:
_grad_cache_available = False


def split_inputs(model_input: dict, chunk_size: int):
keys = list(model_input.keys())
chunked_tensors = [model_input[k].split(chunk_size, dim=0) for k in keys]
return [dict(zip(keys, tt)) for tt in zip(*chunked_tensors)]


def get_rep(x):
return x.logits


class RerankerTrainer(Trainer):
def __init__(self, *args, **kwargs):
super(RerankerTrainer, self).__init__(*args, **kwargs)

if not _grad_cache_available:
raise ValueError(
'Grad Cache package not available. You can obtain it from https://github.com/luyug/GradCache.')

self.gc = GradCache(
models=[self.model],
chunk_sizes=[self.args.gc_chunk_size],
loss_fn=self.compute_loss,
split_input_fn=split_inputs,
get_rep_fn=get_rep,
fp16=self.args.fp16,
scaler=self.scaler if self.args.fp16 else None
)

def _save(self, output_dir: Optional[str] = None, state_dict=None):
output_dir = output_dir if output_dir is not None else self.args.output_dir
os.makedirs(output_dir, exist_ok=True)
Expand All @@ -35,6 +69,14 @@ def _save(self, output_dir: Optional[str] = None, state_dict=None):
torch.save(lora_state_dict, os.path.join(output_dir, "adapter_model.bin"))
print(f"Save adapter model at {output_dir}")

def compute_loss(self, model, inputs, return_outputs=False):
outputs = model(inputs)
loss = outputs.loss
return (loss, outputs) if return_outputs else loss

def compute_loss(self, model, inputs):
return model(inputs).loss
def training_step(self, model, inputs):
model.train()
_distributed = self.args.local_rank > -1
self.gc.models = [model]
loss = self.gc(inputs, no_sync_except_last=_distributed)
return loss