-
Notifications
You must be signed in to change notification settings - Fork 327
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[BugFix] Remove reset on last step of a rollout #1936
Conversation
🔗 Helpful Links🧪 See artifacts and rendered test results at hud.pytorch.org/pr/pytorch/rl/1936
Note: Links to docs will display an error until the docs builds have been completed. ❌ 3 New FailuresAs of commit 5222496 with merge base 23bf315 (): NEW FAILURES - The following jobs have failed:
This comment was automatically generated by Dr. CI and updates every 15 minutes. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks for this! Ideally we'd like a non regression test.
Also now that there is a reset missing, if the user takes the tensordict passed as output and feeds it back to rollout, rollout will need to call reset on it. I don't think this was part of rollout contract before: either a tensordict is passed or reset is called but not both (as this PR will dictate)
I'll work on the test.
So this is a design choice, let me explain my view. Before the PRBefore the PR, a user could not take the last tensordict from a rollout and safely pass it back in. This is because if the env was reset on the last step, the data from that reset would not be part of the rollout results. Thus, if the user passed the output of a rollout back into a new rollout, It would be providing old data from a previous trajectory to a freshly reset env which is in a different state. After this PRIn the current state of this PR, the user still cannot just take the last tensordict from a rollout and safely pass it back in. Doing so would have the same effect as prior to this PR, with the difference that now at least the env is not in a new reset state. In my opinion this should remain like this (and like it was). If the users chooses The difference that this PR allows is that now users that want to run rollouts in a row are able to by adding the reset logic outside the rollout function (as detailed in the snippet in #1929 (comment)). Before they could not do this. |
If we patch things let's patch them correctly no? if auto_reset:
if tensordict is not None:
raise RuntimeError(
"tensordict cannot be provided when auto_reset is True"
)
tensordict = self.reset()
elif tensordict is None:
raise RuntimeError("tensordict must be provided when auto_reset is False")
else:
aggregate_reset = _aggregate_end_of_traj(tensordict)
if aggregate_reset.any():
self.reset(tensordict) I don't understand the reluctance against this, does this break anything? |
Oh no this works too! I'm not against it. It will just be a new feature. Aka rollout will check that the passed tensordict is done and it will reset the env if so. We can do it, a few considerations:
|
I think we're at the edge between new feature and bug fix. If you can't pass the tensordict you got from the last step instead of calling reset safely, this would be a bugfix. |
Yeah if this was supposed to be a feature before, I guess this is a bug fix |
If the env does not allow a step after done and you pass a tensordict that is done, it will result in an error. IMO we should capture that and make sure it does not happen for ease of use. I can't think of anyone using rollout reasonably right now who would be annoyed by this change and consider it bc-breaking, but I could be overlooking things. To me |
This consideration makes an assumption about the env. Forcing a reset on these users would narrow the flexibility of rollout with respect to prior versions. If we want to have another param to deactivate it other than auto_reset that is fine too, I just though that auto_reset fits nicely as it was not allowed before and it is true by default. EDIT: thinking about it, you might be right. Since rollout auto-resets anyway along the trajectory, it could makes sense that this is done on the input td as well. It will add computational complexity for checking the done tho, so maybe having it optional still makes sense? |
Not sure I'm following here Previously we had a broken version of rollout where the reset data of the last step was potentially lost. Anyone wanting to do anything with this was doomed to fail. I don't think that in this scenario working with env that do not reset when done was an option. So this scenario can be excluded entirely and considered as a separate issue. If we want to land this, to me having it work with tensordict that are in a done state is a pre-requirement. The change I'm proposing isn't bc-breaking in any way since it build upon a bug-fix: anyone who was doing anything with rollout before can only have fewer bugs now. The runtime won't be affected since the check I'm suggesting is already performed by |
There is something that we still did not discuss. If you really want to chain calls of rollout, then the function should call Or are we assuming that users call |
It's safe to assume that users call |
Ok if you wanna have a look now it should do what we want |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM, just missing a docstring and a test
(maybe also writing a note in rollout
doctring about all of this? Like "how to use rollout as a data collector" or similar?)
Further question: if |
We should be gucci! Also, I still haven't come around the problem in #1929, so currently rollout is still not working for me as a data collector. Nevertheless, this can be merged. |
Co-authored-by: vmoens <vincentmoens@gmail.com>
Discussion in #1929