Skip to content

Commit

Permalink
fix: ddp
Browse files Browse the repository at this point in the history
  • Loading branch information
sigridjineth committed Aug 24, 2024
1 parent 43a642d commit 0380e4e
Show file tree
Hide file tree
Showing 3 changed files with 53 additions and 36 deletions.
48 changes: 31 additions & 17 deletions src/tevatron/reranker/driver/train.py
Original file line number Diff line number Diff line change
@@ -1,20 +1,36 @@
import logging
import os
import sys
import torch
from transformers import AutoTokenizer
from transformers import (
HfArgumentParser,
set_seed,
)
from torch.nn.parallel import DistributedDataParallel as DDP
import torch.distributed as dist
from tevatron.reranker.arguments import ModelArguments, DataArguments, TevatronTrainingArguments
from tevatron.reranker.modeling import RerankerModel
from tevatron.reranker.dataset import RerankerTrainDataset
from tevatron.reranker.collator import RerankerTrainCollator
from tevatron.reranker.trainer import RerankerTrainer # Make sure this is your updated RerankerTrainer
from tevatron.reranker.trainer import RerankerTrainer

logger = logging.getLogger(__name__)


def setup_ddp():
if 'RANK' in os.environ and 'WORLD_SIZE' in os.environ:
# We're running in a distributed environment
import torch.distributed as dist
rank = int(os.environ['RANK'])
world_size = int(os.environ['WORLD_SIZE'])
dist.init_process_group(backend="nccl")
return rank
else:
# We're not running in a distributed environment
return -1


def main():
parser = HfArgumentParser((ModelArguments, DataArguments, TevatronTrainingArguments))

Expand All @@ -23,29 +39,22 @@ def main():
else:
model_args, data_args, training_args = parser.parse_args_into_dataclasses()

if (
os.path.exists(training_args.output_dir)
and os.listdir(training_args.output_dir)
and training_args.do_train
and not training_args.overwrite_output_dir
):
raise ValueError(
f"Output directory ({training_args.output_dir}) already exists and is not empty. Use --overwrite_output_dir to overcome."
)
local_rank = setup_ddp()
training_args.local_rank = local_rank

# Setup logging
logging.basicConfig(
format="%(asctime)s - %(levelname)s - %(name)s - %(message)s",
datefmt="%m/%d/%Y %H:%M:%S",
level=logging.INFO if training_args.local_rank in [-1, 0] else logging.WARN,
level=logging.INFO if local_rank in [-1, 0] else logging.WARN,
)
logger.warning(
"Process rank: %s, device: %s, n_gpu: %s, distributed training: %s, 16-bits training: %s",
training_args.local_rank,
local_rank,
training_args.device,
training_args.n_gpu,
bool(training_args.local_rank != -1),
training_args.fp16,
bool(local_rank != -1),
training_args.fp16 or training_args.bf16,
)
logger.info("Training/evaluation parameters %s", training_args)
logger.info("MODEL parameters %s", model_args)
Expand All @@ -67,11 +76,16 @@ def main():
cache_dir=model_args.cache_dir,
)

# Move model to GPU
if local_rank != -1:
model = model.to(local_rank)
model = DDP(model, device_ids=[local_rank], output_device=local_rank)

train_dataset = RerankerTrainDataset(data_args)
train_collator = RerankerTrainCollator(data_args, tokenizer)

# Add GradCache-specific arguments to training_args
training_args.gc_chunk_size = getattr(training_args, 'gc_chunk_size', 2)
training_args.grad_cache = getattr(training_args, 'grad_cache', False)

trainer = RerankerTrainer(
model=model,
Expand All @@ -81,11 +95,11 @@ def main():
)
train_dataset.trainer = trainer

trainer.train() # TODO: resume training
trainer.train()
trainer.save_model()
if trainer.is_world_process_zero():
tokenizer.save_pretrained(training_args.output_dir)


if __name__ == "__main__":
main()
main()
3 changes: 0 additions & 3 deletions src/tevatron/reranker/modeling.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,9 +38,6 @@ def forward(self, input_ids: Tensor = None, attention_mask: Tensor = None, **kwa
scores=outputs.logits
)

def gradient_checkpointing_enable(self, gradient_checkpointing_kwargs = None):
return False

@classmethod
def build(
cls,
Expand Down
38 changes: 22 additions & 16 deletions src/tevatron/reranker/trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,6 @@
from transformers.trainer_utils import PredictionOutput

from grad_cache import GradCache

from grad_cache.functional import cached, cat_input_tensor
from torch.cuda.amp import autocast

Expand Down Expand Up @@ -39,22 +38,26 @@ def split_inputs(model_input, chunk_size):
class RerankerTrainer(Trainer):
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
logger.info("Initializing RerankerTrainer with GradCache")
logger.info("Initializing RerankerTrainer")
self.args: TrainingArguments

# Add these lines to include the necessary parameters
self.gc_chunk_size = getattr(self.args, 'gc_chunk_size', 4) # default to 4 if not provided
self.gc_chunk_size = getattr(self.args, 'gc_chunk_size', 4)
self.use_grad_cache = getattr(self.args, 'grad_cache', False)

if self.use_grad_cache:
# If the model is wrapped in DDP, we need to use the .module attribute
model_for_gc = self.model.module if hasattr(self.model, 'module') else self.model

self.gc = GradCache(
models=[self.model],
chunk_sizes=self.gc_chunk_size,
loss_fn=contrastive_loss,
split_input_fn=split_inputs,
get_rep_fn=lambda x: x.scores,
fp16=self.args.fp16,
scaler=self.scaler if self.args.fp16 else None
)
logger.info(f"GradCache initialized with chunk size: {self.gc_chunk_size}")
self.gc = GradCache(
models=[model_for_gc],
chunk_sizes=self.gc_chunk_size,
loss_fn=contrastive_loss,
split_input_fn=split_inputs,
get_rep_fn=lambda x: x.scores,
fp16=self.args.fp16,
# scaler: GradScaler = None,
)
logger.info(f"GradCache initialized with chunk size: {self.gc_chunk_size}")

def compute_loss(self, model, inputs, return_outputs=False):
logger.debug(f"Computing loss with inputs: {inputs.keys()}")
Expand All @@ -68,8 +71,11 @@ def training_step(self, model: nn.Module, inputs: Dict[str, Union[torch.Tensor,
logger.debug("Entering training step")
model.train()
inputs = self._prepare_inputs(inputs)
_distributed = self.args.local_rank > -1
loss = self.gc(inputs, no_sync_except_last=_distributed)
if self.use_grad_cache:
_distributed = self.args.local_rank != -1
loss = self.gc(inputs, no_sync_except_last=_distributed)
else:
loss = self.compute_loss(model, inputs)
logger.debug(f"Training step loss: {loss.item()}")
return loss

Expand Down

0 comments on commit 0380e4e

Please sign in to comment.