-
Notifications
You must be signed in to change notification settings - Fork 1.4k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[core] refactor step
method
#76
[core] refactor step
method
#76
Conversation
- add safety checker + manual device assignment
The documentation is not available anymore as the PR was closed or merged. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Generally looks good to me, just one comment! 🚀
trl/trainer/ppo_trainer.py
Outdated
@@ -252,6 +255,17 @@ def _step_safety_checker( | |||
f"Batch size ({batch_size}) does not match number of examples - but got {len(tensor_list)} for: {name}" | |||
) | |||
|
|||
# set scores on the correct device | |||
if name == "scores": | |||
scores = [score.to(self.accelerator.device) for score in scores] |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think we can do the same for queries and values as well. although they should be you don't know what the user might do before passing them to step
and then the behaviour is consistent for all inputs.
Do you know what happens when the tensor is already on device? Will it copy it again or do nothing?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Makes sense, in PT it should do nothing if they are on the same device so this should be cost-free
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Should be now handled in b32bfbf
Co-authored-by: Leandro von Werra <lvwerra@users.noreply.github.com>
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM!
This PR adds a new safety checker inside
step
method to make sure the rewards are set on the correct device. Regarding the queries and the responses users should retrieve the dataloader from the trainer and use that instead as the device assignment is performed directly at the dataloader level. Since the reward is not part of the dataloader the device assignment needs to be performed manually.This PR also adds inside the safety checker a new check. Before this PR if a user pass a reward tensor with a dimension different from 0 (e.g.
torch.tensor([1.0])
) it would break the training loop. Therefore now we force the reward tensor to be with the desired shape. We now also throw a value error if the dimension of the reward is > 1.cc @lewtun @lvwerra @edbeeching