From 14c81296d59caefa14b93c4687ef633b7f4bada6 Mon Sep 17 00:00:00 2001 From: sigridjineth Date: Sat, 24 Aug 2024 22:16:09 +0900 Subject: [PATCH] fix: ddp --- src/tevatron/reranker/driver/train.py | 44 ++++++++++++++++----------- src/tevatron/reranker/modeling.py | 4 +++ src/tevatron/reranker/trainer.py | 38 +++++++++++++---------- 3 files changed, 53 insertions(+), 33 deletions(-) diff --git a/src/tevatron/reranker/driver/train.py b/src/tevatron/reranker/driver/train.py index 8115d911..dcae1d9a 100644 --- a/src/tevatron/reranker/driver/train.py +++ b/src/tevatron/reranker/driver/train.py @@ -1,20 +1,31 @@ import logging import os import sys +import torch from transformers import AutoTokenizer from transformers import ( HfArgumentParser, set_seed, ) +from torch.nn.parallel import DistributedDataParallel as DDP +import torch.distributed as dist from tevatron.reranker.arguments import ModelArguments, DataArguments, TevatronTrainingArguments from tevatron.reranker.modeling import RerankerModel from tevatron.reranker.dataset import RerankerTrainDataset from tevatron.reranker.collator import RerankerTrainCollator -from tevatron.reranker.trainer import RerankerTrainer # Make sure this is your updated RerankerTrainer +from tevatron.reranker.trainer import RerankerTrainer logger = logging.getLogger(__name__) +def setup_ddp(): + if not dist.is_initialized(): + dist.init_process_group(backend="nccl") + local_rank = int(os.environ.get("LOCAL_RANK", 0)) + torch.cuda.set_device(local_rank) + return local_rank + + def main(): parser = HfArgumentParser((ModelArguments, DataArguments, TevatronTrainingArguments)) @@ -23,29 +34,23 @@ def main(): else: model_args, data_args, training_args = parser.parse_args_into_dataclasses() - if ( - os.path.exists(training_args.output_dir) - and os.listdir(training_args.output_dir) - and training_args.do_train - and not training_args.overwrite_output_dir - ): - raise ValueError( - f"Output directory ({training_args.output_dir}) already exists and is not empty. Use --overwrite_output_dir to overcome." - ) + local_rank = -1 + if training_args.local_rank != -1: + local_rank = setup_ddp() # Setup logging logging.basicConfig( format="%(asctime)s - %(levelname)s - %(name)s - %(message)s", datefmt="%m/%d/%Y %H:%M:%S", - level=logging.INFO if training_args.local_rank in [-1, 0] else logging.WARN, + level=logging.INFO if local_rank in [-1, 0] else logging.WARN, ) logger.warning( "Process rank: %s, device: %s, n_gpu: %s, distributed training: %s, 16-bits training: %s", - training_args.local_rank, + local_rank, training_args.device, training_args.n_gpu, - bool(training_args.local_rank != -1), - training_args.fp16, + bool(local_rank != -1), + training_args.fp16 or training_args.bf16, ) logger.info("Training/evaluation parameters %s", training_args) logger.info("MODEL parameters %s", model_args) @@ -67,11 +72,16 @@ def main(): cache_dir=model_args.cache_dir, ) + # Move model to GPU + if local_rank != -1: + model = model.to(local_rank) + model = DDP(model, device_ids=[local_rank], output_device=local_rank) + train_dataset = RerankerTrainDataset(data_args) train_collator = RerankerTrainCollator(data_args, tokenizer) - # Add GradCache-specific arguments to training_args training_args.gc_chunk_size = getattr(training_args, 'gc_chunk_size', 2) + training_args.grad_cache = getattr(training_args, 'grad_cache', False) trainer = RerankerTrainer( model=model, @@ -81,11 +91,11 @@ def main(): ) train_dataset.trainer = trainer - trainer.train() # TODO: resume training + trainer.train() trainer.save_model() if trainer.is_world_process_zero(): tokenizer.save_pretrained(training_args.output_dir) if __name__ == "__main__": - main() \ No newline at end of file + main() diff --git a/src/tevatron/reranker/modeling.py b/src/tevatron/reranker/modeling.py index 887d6d24..d3502d2e 100644 --- a/src/tevatron/reranker/modeling.py +++ b/src/tevatron/reranker/modeling.py @@ -30,6 +30,10 @@ def __init__(self, hf_model: PreTrainedModel): self.hf_model = hf_model logger.info(f"RerankerModel initialized with config: {self.config}") + def gradient_checkpointing_enable(self, **kwargs): + return False + # self.hf_model.base_model.model.gradient_checkpointing_enable(**kwargs) + def forward(self, input_ids: Tensor = None, attention_mask: Tensor = None, **kwargs): logger.debug(f"Forward pass with input shape: {input_ids.shape if input_ids is not None else 'None'}") outputs = self.hf_model(input_ids=input_ids, attention_mask=attention_mask, **kwargs) diff --git a/src/tevatron/reranker/trainer.py b/src/tevatron/reranker/trainer.py index 79cc2bc1..8534b68a 100644 --- a/src/tevatron/reranker/trainer.py +++ b/src/tevatron/reranker/trainer.py @@ -7,7 +7,6 @@ from transformers.trainer_utils import PredictionOutput from grad_cache import GradCache - from grad_cache.functional import cached, cat_input_tensor from torch.cuda.amp import autocast @@ -39,22 +38,26 @@ def split_inputs(model_input, chunk_size): class RerankerTrainer(Trainer): def __init__(self, *args, **kwargs): super().__init__(*args, **kwargs) - logger.info("Initializing RerankerTrainer with GradCache") + logger.info("Initializing RerankerTrainer") self.args: TrainingArguments - # Add these lines to include the necessary parameters - self.gc_chunk_size = getattr(self.args, 'gc_chunk_size', 4) # default to 4 if not provided + self.gc_chunk_size = getattr(self.args, 'gc_chunk_size', 4) + self.use_grad_cache = getattr(self.args, 'grad_cache', False) + + if self.use_grad_cache: + # If the model is wrapped in DDP, we need to use the .module attribute + model_for_gc = self.model.module if hasattr(self.model, 'module') else self.model - self.gc = GradCache( - models=[self.model], - chunk_sizes=self.gc_chunk_size, - loss_fn=contrastive_loss, - split_input_fn=split_inputs, - get_rep_fn=lambda x: x.scores, - fp16=self.args.fp16, - scaler=self.scaler if self.args.fp16 else None - ) - logger.info(f"GradCache initialized with chunk size: {self.gc_chunk_size}") + self.gc = GradCache( + models=[model_for_gc], + chunk_sizes=self.gc_chunk_size, + loss_fn=contrastive_loss, + split_input_fn=split_inputs, + get_rep_fn=lambda x: x.scores, + fp16=self.args.fp16, + # scaler: GradScaler = None, + ) + logger.info(f"GradCache initialized with chunk size: {self.gc_chunk_size}") def compute_loss(self, model, inputs, return_outputs=False): logger.debug(f"Computing loss with inputs: {inputs.keys()}") @@ -68,8 +71,11 @@ def training_step(self, model: nn.Module, inputs: Dict[str, Union[torch.Tensor, logger.debug("Entering training step") model.train() inputs = self._prepare_inputs(inputs) - _distributed = self.args.local_rank > -1 - loss = self.gc(inputs, no_sync_except_last=_distributed) + if self.use_grad_cache: + _distributed = self.args.local_rank != -1 + loss = self.gc(inputs, no_sync_except_last=_distributed) + else: + loss = self.compute_loss(model, inputs) logger.debug(f"Training step loss: {loss.item()}") return loss