-
Notifications
You must be signed in to change notification settings - Fork 327
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[Feature] step_and_maybe_reset in env #1611
Conversation
# Conflicts: # benchmarks/ecosystem/gym_env_throughput.py
torchrl/envs/libs/pettingzoo.py
Outdated
@@ -464,12 +480,18 @@ def _reset( | |||
self, tensordict: Optional[TensorDictBase] = None, **kwargs | |||
) -> TensorDictBase: | |||
|
|||
_reset = tensordict.get("_reset", None) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This crashes when tensordict is None
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yeah sorry about that, I can't test petting zoo locally so I'm always moving in the dark...
@matteobettini @albertbou92 @BY571 this should be (almost) mergeable. |
Why would vmas sporadically fail? |
Not sure you can have a look. |
Was this happening before this PR? |
I was referring to VMAS CI beyond this PR |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Some final questions/comments
# goes through the tensordict and brings the _reset information to | ||
# a boolean tensor of the shape of the tensordict. | ||
batch_size = data.batch_size | ||
n = len(batch_size) | ||
|
||
if done_keys is not None and reset_keys is None: | ||
reset_keys = {_replace_last(key, "done") for key in done_keys} |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I am not following this.
We are taking the done keys (with all the terminated and truncated entries) and replacing a "done" ending for all and making it a set which we call reset keys.
This is counter intuitive as reset_keys have a _reset ending and not a done ending.
This change seems to come from the fact that you aim to use this function in 2 contexts:
- normally on the root td with the reset_keys as input
- on the "next" td in collectors with the done_keys as input
I think we should try to write this better, here are some suggestions:
- always call the function on the root td and pass the keys with preappended "next" if you want to use that
- do this key filtering and conversion outside of the function and let the function just operate on a set of keys to be cconsidered as reset keys
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I don't see the problem, can you elaborate why we should try to write this better
?
Is it a naming problem? We can rename the function _aggregate_stop
or smth similar.
always call the function on the root td and pass the keys with preappended "next" if you want to use that
That introduces some unwanted overhead when we can directly access "next" and read the done_keys. Recall thattd.get(("next", "key"))
is considerably slwoer thannext_td.get("key")
as we do here.
do this key filtering and conversion outside of the function and let the function just operate on a set of keys to be cconsidered as reset keys
What's your suggestion for _update_traj_ids
in collectors.py
for instance? We don't have a "_reset" in the "next" tensordict, but I think this function does its job of aggregating the done signals to read what the trajectory ids are.
What I understand is that the confusion comes from the "reset" in the function name, but what this function really does is just aggregating end-of-trajectory signals (either reset or done) to the root.
Given this, I don't see why it should be changed. It's a private function, properly tested and I think it serves its purpose.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
connecting to the comment below, we could have
_aggeregate_keys(keys=)
which can be called on anything
alternatively we could have both _aggeregate_dones
and _aggregate_resets
where one calls the other
torchrl/collectors/collectors.py
Outdated
traj_sop = _aggregate_resets( | ||
tensordict.get("next"), done_keys=self.env.done_keys |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I assume we have to do this becuse there are no reset keys anymore visible by the collector.
Since this is quite counterintitive, what about a _aggregate_dones()
?
I also have other suggestions in the other comment relating to this function.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
yep we can rename the function. I would rather say that it's more direct to aggregate the done rather than the "_reset" which come from the dones
action_keys=self.action_keys, | ||
done_keys=self.done_keys, | ||
) | ||
any_done = _terminated_or_truncated( |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Help me understand this a little bit better.
In an axample case (e.g., pettingzoo) where i have
{
"done": [False],
"agents":{"done:[True, False]}
}
Is the any_done
triggered?
If so, this is a problem for envs like PettingZoo where _reset()
will be called with {"_reset": [False]}
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The best thing it trying it out :)
from tensordict import TensorDict
from torchrl.envs.utils import _terminated_or_truncated
data = TensorDict({"done": [False], ("agent", "done"): [True, False]}, [])
print(_terminated_or_truncated(data))
which returns True
So what you're saying is that it should be False since there's a False at the root?
I can correct that
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think it should to follow the dominance rule we imposed right?
Or at least in this context definitely becuase we do not want to call reset.
I don't know in what other contexts this function is used though, but if its primary use is to decide when to call reset then yes
@matteobettini I addressed all your comments |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM
Contribution
This PR proposes the
step_and_maybe_reset
in EnvBase.This method executes a step followed by a reset, if necessary.
We also make
reset
more robust by ensuring that partial resets are handled uniformly. This is necessary sincestep_and_maybe_reset
must take care of this functionality, and from our perspective handling partial resets is the responsibility ofreset
(the user should not have to worry about data not updated properly).This has repercussions on the logic behind
TransformedEnv._reset
andBatchedEnv._reset
.I'm now considering having batched envs calling
reset
and not_reset
to make sure that the data is well presented, since now the update of the input tensordict with thetensordict_reset
occurs after_reset
(hence, the output of_reset
in SerialEnv is incomplete).This could introduce some overhead but that's of limited impact since now
step_and_maybe_reset
is there to handle things faster.