-
Notifications
You must be signed in to change notification settings - Fork 1.4k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
add "_prepare_fsdp" for DPOTrainer #2539
base: main
Are you sure you want to change the base?
Conversation
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks a lot for the fix @faaany - overall it looks great!
Would you mind confirming that the following demo command works with your PR (once activation checkpointing is removed):
accelerate launch --config_file=examples/accelerate_configs/fsdp_qlora.yaml --num_processes=NUM_GPUS trl/scripts/dpo.py trl/scripts/dpo.py \
--dataset_name trl-lib/ultrafeedback_binarized \
--model_name_or_path Qwen/Qwen2-0.5B-Instruct \
--learning_rate 5.0e-7 \
--num_train_epochs 1 \
--per_device_train_batch_size 2 \
--gradient_accumulation_steps 8 \
--gradient_checkpointing \
--logging_steps 25 \
--eval_strategy steps \
--eval_steps 50 \
--output_dir Qwen2-0.5B-DPO \
--no_remove_unused_columns
If it runs without error, can you please rename fsdp_qlora.yaml
to fsdp.yaml
so it runs for both modes?
A question for @qgallouedec: should this helper function live in a utils
module somewhere so we don't have to copy it around to all other trainers?
"device_id": self.accelerator.device, | ||
} | ||
model = FSDP(model, **kwargs) | ||
if fsdp_plugin.activation_checkpointing: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
If I'm not mistaken, we don't need activation checkpointing since we never do a backward pass with the reference model. I think this block can thus be removed
I tried running the demo command without qlora, and got the following error: @faaany, I am wondering if you were able to replicate or fix this. I am attaching the trainer code for reference. |
Co-authored-by: lewtun <lewis.c.tunstall@gmail.com>
What does this PR do?
While training with DPOTrainer using FSDP and accelerate, I got the same error as mentioned in #1147. Similar to "_prepare_deepspeed", I fixed the issue by adding a new method called "_prepare_fsdp".