-
Notifications
You must be signed in to change notification settings - Fork 1.4k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
accelerate
integration
#58
accelerate
integration
#58
Conversation
#### Run PPO step | ||
t = time.time() | ||
stats = ppo_trainer.step(query_tensors, response_tensors, rewards) | ||
ppo_trainer.log_stats(stats, timing, batch, rewards, t0, t, logs) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
To improve, we probably want a better way to log the stats
trl/trainer/accelerate_ppo.py
Outdated
if isinstance(v, torch.Tensor) and k != 'objective/kl': | ||
# tensor_list = [torch.zeros_like(v) for _ in range(self.accelerator.num_processes)] | ||
dist.all_reduce(v, dist.ReduceOp.SUM) | ||
v /= self.accelerator.num_processes |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
For me in a DP setup, each GPU will need to have its own replica of objective/kl
since this is used to update the kl_ctl
object above. That is why I prefered to not include it in the all_reduce
operation but I just wanted to confirm
- add docstring on most functions - correct logging
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks for this @younesbelkada. My main comments are about DP. I think if we don't wrap the step
inputs (queries/responses) in a dataloader we don't achieve proper DP. But maybe I am wrong?
trl/trainer/accelerate_ppo.py
Outdated
model (torch.model): Hugging Face transformer GPT2 model with value head | ||
ref_model (torch.model): Hugging Face transformer GPT2 refrence model used for KL penalty | ||
tokenizer (tokenizer): Hugging Face tokenizer | ||
ppo_params (dict or None): PPO parameters for training. Can include following keys: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
We should replace **config
(=ppo_params
) with explicit kwargs or setup TrainingArguments like in transformers.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Can be a follow up PR btw
Co-authored-by: Leandro von Werra <lvwerra@users.noreply.github.com>
- random init seems to converge much faster
Co-authored-by: lewtun <lewis.c.tunstall@gmail.com>
wandb run (multi-GPU) after the latest commit: https://wandb.ai/distill-bloom/trl/runs/1mps4h09?workspace=user-younesbelkada |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think we are pretty close - a few open questions and minor changs :)
stats (dict[str, Any]): | ||
a dictionary of stats with the tensors gathered. | ||
""" | ||
import torch.distributed as dist |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
what do you think?
# In a distributed setup, only logging needs to be performed on the main process | ||
# check: https://pytorch.org/docs/stable/generated/torch.nn.parallel.DistributedDataParallel.html | ||
# or: https://discuss.pytorch.org/t/use-distributed-data-parallel-correctly/82500/11 | ||
self.is_distributed = self.accelerator.distributed_type == "MULTI_GPU" |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
If we can use accelerates
gather
method we can probably get rid of this?
|
||
#### Compute sentiment score | ||
t = time.time() | ||
texts = [q + r for q,r in zip(batch['query'], batch['response'])] |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
with the remove columns method inside the trainer the query
shouldn't be there anymore? since we don't pass the data through the model internally, we don't need to remove the columns?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The query are kept here, https://github.com/younesbelkada/trl/blob/d2c363fe4018c74df829ed6c067fad50ecaaf479/trl/trainer/ppo_trainer.py#L152 but maybe we can change that, wdyt?
Wandb log of the final run: https://wandb.ai/distill-bloom/trl/runs/dcd2gqn1?workspace=user-younesbelkada |
The documentation is not available anymore as the PR was closed or merged. |
What does this PR do?
This PR integrates
trl
withaccelerate
to make it compatible with the tools provided by the library to be able to train models usingPPOTrainer
. This would enable users to train their models in mixed precision, using Data Parallelism etc in a very simple manner.Users should design their own training script and run them using
accelerate launch xxx.py
based on the example scripts provided inexamples/scripits
.This PR also integrates Data Parallelism paradigm, enabling users to benefit from multi-GPU training if they want to speedup training.
TODOs
accelerate
examples)DeepSpeed tests (check where it works)