Skip to content

Commit

Permalink
fix: ddp
Browse files Browse the repository at this point in the history
  • Loading branch information
sigridjineth committed Aug 24, 2024
1 parent 43a642d commit 7626bbf
Show file tree
Hide file tree
Showing 2 changed files with 53 additions and 34 deletions.
46 changes: 29 additions & 17 deletions src/tevatron/reranker/driver/train.py
Original file line number Diff line number Diff line change
@@ -1,51 +1,58 @@
import logging
import os
import sys
import torch
from transformers import AutoTokenizer
from transformers import (
HfArgumentParser,
set_seed,
)
from torch.nn.parallel import DistributedDataParallel as DDP
import torch.distributed as dist
from tevatron.reranker.arguments import ModelArguments, DataArguments, TevatronTrainingArguments
from tevatron.reranker.modeling import RerankerModel
from tevatron.reranker.dataset import RerankerTrainDataset
from tevatron.reranker.collator import RerankerTrainCollator
from tevatron.reranker.trainer import RerankerTrainer # Make sure this is your updated RerankerTrainer
from tevatron.reranker.trainer import RerankerTrainer

logger = logging.getLogger(__name__)


def setup_ddp():
if not dist.is_initialized():
dist.init_process_group(backend="nccl")
local_rank = int(os.environ.get("LOCAL_RANK", 0))
torch.cuda.set_device(local_rank)
return local_rank


def main():
parser = HfArgumentParser((ModelArguments, DataArguments, TevatronTrainingArguments))

parser.add_argument('--bf16', action='store_true', help='Use bfloat16 precision')

if len(sys.argv) == 2 and sys.argv[1].endswith(".json"):
model_args, data_args, training_args = parser.parse_json_file(json_file=os.path.abspath(sys.argv[1]))
else:
model_args, data_args, training_args = parser.parse_args_into_dataclasses()

if (
os.path.exists(training_args.output_dir)
and os.listdir(training_args.output_dir)
and training_args.do_train
and not training_args.overwrite_output_dir
):
raise ValueError(
f"Output directory ({training_args.output_dir}) already exists and is not empty. Use --overwrite_output_dir to overcome."
)
local_rank = -1
if training_args.local_rank != -1:
local_rank = setup_ddp()

# Setup logging
logging.basicConfig(
format="%(asctime)s - %(levelname)s - %(name)s - %(message)s",
datefmt="%m/%d/%Y %H:%M:%S",
level=logging.INFO if training_args.local_rank in [-1, 0] else logging.WARN,
level=logging.INFO if local_rank in [-1, 0] else logging.WARN,
)
logger.warning(
"Process rank: %s, device: %s, n_gpu: %s, distributed training: %s, 16-bits training: %s",
training_args.local_rank,
local_rank,
training_args.device,
training_args.n_gpu,
bool(training_args.local_rank != -1),
training_args.fp16,
bool(local_rank != -1),
training_args.fp16 or training_args.bf16,
)
logger.info("Training/evaluation parameters %s", training_args)
logger.info("MODEL parameters %s", model_args)
Expand All @@ -67,11 +74,16 @@ def main():
cache_dir=model_args.cache_dir,
)

# Move model to GPU
if local_rank != -1:
model = model.to(local_rank)
model = DDP(model, device_ids=[local_rank], output_device=local_rank)

train_dataset = RerankerTrainDataset(data_args)
train_collator = RerankerTrainCollator(data_args, tokenizer)

# Add GradCache-specific arguments to training_args
training_args.gc_chunk_size = getattr(training_args, 'gc_chunk_size', 2)
training_args.grad_cache = getattr(training_args, 'grad_cache', False)

trainer = RerankerTrainer(
model=model,
Expand All @@ -81,11 +93,11 @@ def main():
)
train_dataset.trainer = trainer

trainer.train() # TODO: resume training
trainer.train()
trainer.save_model()
if trainer.is_world_process_zero():
tokenizer.save_pretrained(training_args.output_dir)


if __name__ == "__main__":
main()
main()
41 changes: 24 additions & 17 deletions src/tevatron/reranker/trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,6 @@
from transformers.trainer_utils import PredictionOutput

from grad_cache import GradCache

from grad_cache.functional import cached, cat_input_tensor
from torch.cuda.amp import autocast

Expand Down Expand Up @@ -39,22 +38,27 @@ def split_inputs(model_input, chunk_size):
class RerankerTrainer(Trainer):
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
logger.info("Initializing RerankerTrainer with GradCache")
logger.info("Initializing RerankerTrainer")
self.args: TrainingArguments

# Add these lines to include the necessary parameters
self.gc_chunk_size = getattr(self.args, 'gc_chunk_size', 4) # default to 4 if not provided

self.gc = GradCache(
models=[self.model],
chunk_sizes=self.gc_chunk_size,
loss_fn=contrastive_loss,
split_input_fn=split_inputs,
get_rep_fn=lambda x: x.scores,
fp16=self.args.fp16,
scaler=self.scaler if self.args.fp16 else None
)
logger.info(f"GradCache initialized with chunk size: {self.gc_chunk_size}")
self.gc_chunk_size = getattr(self.args, 'gc_chunk_size', 4)
self.use_grad_cache = getattr(self.args, 'grad_cache', False)

if self.use_grad_cache:
# If the model is wrapped in DDP, we need to use the .module attribute
model_for_gc = self.model.module if hasattr(self.model, 'module') else self.model

self.gc = GradCache(
models=[model_for_gc],
chunk_sizes=self.gc_chunk_size,
loss_fn=contrastive_loss,
split_input_fn=split_inputs,
get_rep_fn=lambda x: x.scores,
fp16=self.args.fp16,
bf16=self.args.bf16,
scaler=self.scaler if (self.args.fp16 or self.args.bf16) else None
)
logger.info(f"GradCache initialized with chunk size: {self.gc_chunk_size}")

def compute_loss(self, model, inputs, return_outputs=False):
logger.debug(f"Computing loss with inputs: {inputs.keys()}")
Expand All @@ -68,8 +72,11 @@ def training_step(self, model: nn.Module, inputs: Dict[str, Union[torch.Tensor,
logger.debug("Entering training step")
model.train()
inputs = self._prepare_inputs(inputs)
_distributed = self.args.local_rank > -1
loss = self.gc(inputs, no_sync_except_last=_distributed)
if self.use_grad_cache:
_distributed = self.args.local_rank != -1
loss = self.gc(inputs, no_sync_except_last=_distributed)
else:
loss = self.compute_loss(model, inputs)
logger.debug(f"Training step loss: {loss.item()}")
return loss

Expand Down

0 comments on commit 7626bbf

Please sign in to comment.