-
Notifications
You must be signed in to change notification settings - Fork 327
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[Feature] Gym compatibility: Terminal and truncated #1539
Conversation
@skandermoalla The current project is to keep |
Thanks a lot for starting this and for pinging me! IMO naming differences are fine as long as
So here "most" should be very carefully assessed. It would be terrible to have I think I sometimes confused the So maybe also making the docs more explicit about this definition would help a lot of people arriving from Gymnasium. I'm also thinking it would be useful to allow the use |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Adding an early review before disappearing until the ICLR deadline 🫠.
torchrl/data/utils.py
Outdated
recursive: bool = True, | ||
): | ||
found_entry = False | ||
for value in spec.values(): | ||
for key, value in spec.items(): | ||
if isinstance(value, CompositeSpec): | ||
_check_only_one_entry(value, error, recursive) | ||
else: | ||
if not found_entry: | ||
found_entry = True | ||
else: | ||
raise error | ||
if key == "done": | ||
if found_entry: | ||
raise error | ||
else: | ||
found_entry = True |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Recursivity is still not implemented here right?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@matteobettini I'm not even sure we need this function anymore. We wanted to avoid that people would have "done_agent1"
and "done_agent2"
and rather ("agent1", "done")
IIRC.
Now we will allow "done_spec"
to contain "done"
and "truncated"
at any level. Per se the done_spec
could even be richer. So I don't think we need this check anymore: it's good practice to have only one done per level but I don't think that will break anywhere if you have more.
The only exception I see is computing the "_reset"
signal which usually reflects done | truncated
. In other words, _reset
reflects somethings like
def _make_reset(data, done_spec):
out = {}
for key, item in done_spec.items():
if isinstance(done_spec, CompositeSpec):
out[key] = _make_reset(data.get(key), item)
else:
out["_reset"] = out.get("_reset", False) | data.get(key)
return data.update(out)
This means tht if you put done_agent1 and done_agent2 at the same level you'll get a single reset for both. Not ideal but not really our business if the MARL API is well documented IMO.
In collectors we use "done"
to compute the trajectory ids but this is an extra feature that could apply only in settings where done is at the root (or in other restricted settings). The trajectory id doesn't influence the collector behaviour, which only cares about the number of frames (number of elements of the root tensordict).
For the rest, loss modules have parametric dones so I think we could rename "done"
to "merylstreep"
and they could still work. They will need refactoring to account for "truncated"
anyway.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
So the situation is a little more complex, let me try to explain it here:
Since we decided that our naming convention is to have one "_reset" key per tensordict (due to the fact that if we had more we would have to change the name), then this key has to reflect the shape of the "done" in that tensordict.
If we allowed more then one leaf key in each td of the done_spec then the problem is what shape should the _reset reflect. We cannot do like vincent said cause the shapes might be arbitrary.
Therefore the convention established is that at least a key with the name "done" (and exactly done) should exist in each td of the done_spec and the reset will mimic its shape.
If now we want to change the convention and say that the constraint now is that all the values in the same tensordict of the done_spec should have the same shape, that is fine by me, but then we have to check and enforce that.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
if we do like in this PR and assume to always have truncated in the done_spec, then we have to carefully adjust all the components of the codebase that create the _reset key.
Like here
rl/torchrl/collectors/collectors.py
Line 797 in fe91c4f
for done_key in self.env.done_keys: |
and here
Line 1590 in fe91c4f
for done_key in self.done_keys: |
For example.
These lines will currently create a _reset
for each done_spec.keys(True,True)
.
But if the done_spec
now contains more than one leaf per td we have to change all this.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
If now we want to change the convention and say that the constraint now is that all the values in the same tensordict of the done_spec should have the same shape, that is fine by me, but then we have to check and enforce that.
Per se, if someone writes an env that does not require _reset
(which we use internally but is not exposed to users) it's fine. What I mean is that if a data structure has "done_agent1"
of shape X
and "done_agent2"
of shape Y
, computing a "_reset"
key will fail but that only means that the env cannot be used within a collector.
Since we want to provide independent, modular components, I think it's a design choice that is ok not to enforce that all envs written by any user would fit in a collector. We need to make sure that the collector tells you specifically what is wrong with your env though.
What about letting envs have the done_spec
they want, and let the collectors / parallel envs spot the potential problems during construction?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yeah but also the rollout function uses the reset logic. So how would we treat that?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Imagine someone wants to use TorchRL to write a multiagent env that does some job with the "done_agentX"
structure.
I think it's cool if they do that and put that in their app. I don't think they should get an exception or a warning that tells them that they're going to be in a dead end when using a collector. Maybe they don't want to use the collectors.
Let's not prevent legitimate use of one class to prevent bug in another if the usage of the other is optional, wdyt?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yes but vectorized envs will need to be fed a _reset
flag upon reset in the env.rollout
function. so you are saying we will provide that with our best effort but if we cannot we just give a warning and say that in collector this will be an error?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yeah but also the rollout function uses the reset logic. So how would we treat that?
Oh yeah good point.
This is a bit tricky: per se, we set a _reset
in rollout and hope for the best but we don't enforce that it's being read anywhere.
The thing is that we don't really have any agency about what reset(tensordict)
does under the hood. For ParallelEnv and similar we know what to do but in other cases I'm not 100% sure of what could be done generically.
But yeah, since we use _reset
within the env class it makes sense that it's computed in a consistent way within the library.
So what we would opt for is:
- during construction, walk through the specs and check that at each level, all done specs have a matching shape
- during execution, compute
_reset
as the combination of all_reset
keys at each level
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
during execution, compute _reset as the combination of all _reset keys at each level
-> as the combination of all done_spec_keys at each level
but yes that could work!
torchrl/data/utils.py
Outdated
@@ -228,14 +228,15 @@ def _check_only_one_entry( | |||
recursive: bool = True, | |||
): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The method would benefit from a quick docstring or better name? I'm not sure to understand what it does without more context.
torchrl/envs/gym_like.py
Outdated
This method returns: | ||
- a done state to be set in the environment | ||
- a boolean value indicating whether the frame_skip loop should be broken |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This doesn't seem to match its current output (tuple of 3 elements). Missing the truncated output?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
yes I still need to go through the docstrings but thanks for reminding me!
torchrl/envs/gym_like.py
Outdated
In TorchRL, done means that a trajectory is terminated (we do not support the | ||
terminated signal). Truncated means the trajectory has been interrupted. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
If done
is exactly terminated
then what does "we do not support the terminated signal" mean?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I agree it's confusing. Maybe
In torchrl, a `done` signal means that a trajectory is terminated (what is referred to as `termination` in gymnasium).
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yep! That's clear!
torchrl/envs/gym_like.py
Outdated
if termination is not None: | ||
return termination, truncation, done | ||
return done, truncation, done |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This means eventual bootstrapping will be done correctly for new versions of gym and incorrectly for old versions (as when truncated they set done to true and specify their truncation bit in the info variable). A note should be added about this somewhere.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Let me rephrase what you're saying here:
if older version of gym (<0.26) a truncation signal could be retrieved from the info.
Currently, we don't use it so assuming that done=truncated will not work properly with value functions, is that right?
If so, I think that we should advise users of adopting a dedicated info reader to write the truncated signal, which will solve the problem.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
if older version of gym (<0.26) a truncation signal could be retrieved from the info.
Yes in the TimeLimit.truncated
key if it happened. See https://gymnasium.farama.org/tutorials/gymnasium_basics/handling_time_limits/#solution.
assuming that done=truncated will not work properly with value functions
Yes, and to be exact done here with the old gym versions would be gym_done=terminated | truncated
. (It was an end_of_traj
)
I think that we should advise users of adopting a dedicated info reader to write the truncated signal, which will solve the problem.
Yes, that would be the correct way to record the truncated
key and propagate it to the rest of TorchRL's components. However, the gym_done
key would still need to be corrected with torchrl_done = gym_done and ~truncated
I don't how this could be done by the users.
torchrl/envs/gym_like.py
Outdated
@@ -187,26 +195,12 @@ def _step(self, tensordict: TensorDictBase) -> TensorDictBase: | |||
|
|||
reward = 0 | |||
for _ in range(self.wrapper_frame_skip): | |||
obs, _reward, done, *info = self._output_transform( | |||
obs, _reward, termination, truncation, done, info = self._output_transform( |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Realizing once more that variable order can so easily be messed up here. Shout out to Tensordict for solving this!
torchrl/envs/gym_like.py
Outdated
@@ -133,18 +133,24 @@ def read_action(self, action): | |||
""" | |||
return self.action_spec.to_numpy(action, safe=False) | |||
|
|||
def read_done(self, done): | |||
def read_done(self, done, termination, truncation): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Maybe the variables in this file referring to the gym done can be renamed to gym_done
to avoid confusion with the torchrl done?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
good point, it's getting a bit confusing!
|
Name | Max | Mean | Ops | Ops on Repo HEAD
|
Change |
---|---|---|---|---|---|
test_single | 0.1520s | 0.1501s | 6.6616 Ops/s | 8.6325 Ops/s | |
test_sync | 91.5973ms | 80.9036ms | 12.3604 Ops/s | 15.8629 Ops/s | |
test_async | 0.2039s | 79.7407ms | 12.5406 Ops/s | 15.7839 Ops/s | |
test_simple | 1.3358s | 1.2287s | 0.8139 Ops/s | 1.0459 Ops/s | |
test_transformed | 1.5936s | 1.5333s | 0.6522 Ops/s | 0.8244 Ops/s | |
test_serial | 3.4552s | 3.3403s | 0.2994 Ops/s | 0.3646 Ops/s | |
test_parallel | 2.7917s | 2.6656s | 0.3752 Ops/s | 0.4225 Ops/s | |
test_step_mdp_speed[True-True-True-True-True] | 2.3443ms | 47.5616μs | 21.0254 KOps/s | 19.5929 KOps/s | |
test_step_mdp_speed[True-True-True-True-False] | 1.4567ms | 28.1532μs | 35.5200 KOps/s | 36.0895 KOps/s | |
test_step_mdp_speed[True-True-True-False-True] | 1.6109ms | 35.7529μs | 27.9698 KOps/s | 27.6125 KOps/s | |
test_step_mdp_speed[True-True-True-False-False] | 1.4395ms | 20.3467μs | 49.1480 KOps/s | 48.0980 KOps/s | |
test_step_mdp_speed[True-True-False-True-True] | 1.8869ms | 50.3061μs | 19.8783 KOps/s | 20.1280 KOps/s | |
test_step_mdp_speed[True-True-False-True-False] | 1.2443ms | 27.8566μs | 35.8981 KOps/s | 30.6501 KOps/s | |
test_step_mdp_speed[True-True-False-False-True] | 3.7292ms | 40.3144μs | 24.8050 KOps/s | 25.0849 KOps/s | |
test_step_mdp_speed[True-True-False-False-False] | 2.7486ms | 21.1617μs | 47.2552 KOps/s | 42.2469 KOps/s | |
test_step_mdp_speed[True-False-True-True-True] | 1.4861ms | 53.0680μs | 18.8438 KOps/s | 18.2045 KOps/s | |
test_step_mdp_speed[True-False-True-True-False] | 1.8459ms | 32.4297μs | 30.8359 KOps/s | 30.6470 KOps/s | |
test_step_mdp_speed[True-False-True-False-True] | 3.5215ms | 36.6657μs | 27.2734 KOps/s | 26.1134 KOps/s | |
test_step_mdp_speed[True-False-True-False-False] | 1.5480ms | 21.8291μs | 45.8104 KOps/s | 43.9778 KOps/s | |
test_step_mdp_speed[True-False-False-True-True] | 4.0434ms | 58.2856μs | 17.1569 KOps/s | 17.8605 KOps/s | |
test_step_mdp_speed[True-False-False-True-False] | 2.5648ms | 35.1649μs | 28.4375 KOps/s | 29.5241 KOps/s | |
test_step_mdp_speed[True-False-False-False-True] | 1.5462ms | 38.7024μs | 25.8382 KOps/s | 25.1199 KOps/s | |
test_step_mdp_speed[True-False-False-False-False] | 4.2800ms | 24.4261μs | 40.9397 KOps/s | 42.4358 KOps/s | |
test_step_mdp_speed[False-True-True-True-True] | 3.5002ms | 54.5293μs | 18.3388 KOps/s | 18.6064 KOps/s | |
test_step_mdp_speed[False-True-True-True-False] | 1.5240ms | 33.3729μs | 29.9645 KOps/s | 28.2972 KOps/s | |
test_step_mdp_speed[False-True-True-False-True] | 1.4671ms | 43.1279μs | 23.1868 KOps/s | 23.8162 KOps/s | |
test_step_mdp_speed[False-True-True-False-False] | 1.6163ms | 24.4371μs | 40.9213 KOps/s | 38.5422 KOps/s | |
test_step_mdp_speed[False-True-False-True-True] | 1.5880ms | 56.8706μs | 17.5838 KOps/s | 18.1635 KOps/s | |
test_step_mdp_speed[False-True-False-True-False] | 1.9065ms | 35.3787μs | 28.2656 KOps/s | 28.9811 KOps/s | |
test_step_mdp_speed[False-True-False-False-True] | 1.7214ms | 45.0859μs | 22.1799 KOps/s | 24.0683 KOps/s | |
test_step_mdp_speed[False-True-False-False-False] | 9.1869ms | 26.8996μs | 37.1753 KOps/s | 36.6084 KOps/s | |
test_step_mdp_speed[False-False-True-True-True] | 2.7321ms | 57.2490μs | 17.4676 KOps/s | 16.3324 KOps/s | |
test_step_mdp_speed[False-False-True-True-False] | 2.2998ms | 36.2487μs | 27.5872 KOps/s | 23.8731 KOps/s | |
test_step_mdp_speed[False-False-True-False-True] | 2.4786ms | 42.8941μs | 23.3133 KOps/s | 21.2729 KOps/s | |
test_step_mdp_speed[False-False-True-False-False] | 3.5744ms | 26.5392μs | 37.6801 KOps/s | 37.0256 KOps/s | |
test_step_mdp_speed[False-False-False-True-True] | 1.8297ms | 60.5148μs | 16.5249 KOps/s | 16.1699 KOps/s | |
test_step_mdp_speed[False-False-False-True-False] | 4.9242ms | 38.9195μs | 25.6941 KOps/s | 25.7345 KOps/s | |
test_step_mdp_speed[False-False-False-False-True] | 1.5061ms | 41.5351μs | 24.0760 KOps/s | 21.6581 KOps/s | |
test_step_mdp_speed[False-False-False-False-False] | 5.4670ms | 27.7662μs | 36.0150 KOps/s | 33.8220 KOps/s | |
test_values[generalized_advantage_estimate-True-True] | 21.8422ms | 16.7033ms | 59.8686 Ops/s | 54.9726 Ops/s | |
test_values[vec_generalized_advantage_estimate-True-True] | 0.1274s | 78.4241ms | 12.7512 Ops/s | 12.4612 Ops/s | |
test_values[td0_return_estimate-False-False] | 2.2441ms | 0.6966ms | 1.4355 KOps/s | 1.4467 KOps/s | |
test_values[td1_return_estimate-False-False] | 22.1129ms | 17.7120ms | 56.4589 Ops/s | 56.5616 Ops/s | |
test_values[vec_td1_return_estimate-False-False] | 0.1050s | 79.5756ms | 12.5667 Ops/s | 12.9641 Ops/s | |
test_values[td_lambda_return_estimate-True-False] | 56.0795ms | 51.9957ms | 19.2323 Ops/s | 19.1128 Ops/s | |
test_values[vec_td_lambda_return_estimate-True-False] | 0.1062s | 78.3074ms | 12.7702 Ops/s | 12.2649 Ops/s | |
test_gae_speed[generalized_advantage_estimate-False-1-512] | 22.3703ms | 14.6167ms | 68.4149 Ops/s | 72.5063 Ops/s | |
test_gae_speed[vec_generalized_advantage_estimate-True-1-512] | 10.8682ms | 5.2302ms | 191.1978 Ops/s | 177.2607 Ops/s | |
test_gae_speed[vec_generalized_advantage_estimate-False-1-512] | 4.3420ms | 0.9028ms | 1.1077 KOps/s | 1.0838 KOps/s | |
test_gae_speed[vec_generalized_advantage_estimate-True-32-512] | 78.1810ms | 74.2825ms | 13.4621 Ops/s | 13.4408 Ops/s | |
test_gae_speed[vec_generalized_advantage_estimate-False-32-512] | 9.3381ms | 6.6182ms | 151.0993 Ops/s | 143.3184 Ops/s | |
test_dqn_speed | 8.6350ms | 3.9416ms | 253.7030 Ops/s | 249.1651 Ops/s | |
test_ddpg_speed | 13.7405ms | 7.6336ms | 130.9999 Ops/s | 131.6222 Ops/s | |
test_sac_speed | 29.3333ms | 20.5020ms | 48.7757 Ops/s | 48.8679 Ops/s | |
test_redq_speed | 39.5370ms | 31.5141ms | 31.7318 Ops/s | 31.1134 Ops/s | |
test_redq_deprec_speed | 37.9842ms | 30.0608ms | 33.2659 Ops/s | 30.8167 Ops/s | |
test_td3_speed | 27.1305ms | 20.3410ms | 49.1618 Ops/s | 49.0771 Ops/s | |
test_cql_speed | 0.1223s | 91.9787ms | 10.8721 Ops/s | 13.3079 Ops/s | |
test_a2c_speed | 22.9850ms | 14.5774ms | 68.5994 Ops/s | 70.0593 Ops/s | |
test_ppo_speed | 22.1010ms | 14.5217ms | 68.8626 Ops/s | 63.3403 Ops/s | |
test_reinforce_speed | 28.0911ms | 12.1667ms | 82.1918 Ops/s | 85.3810 Ops/s | |
test_iql_speed | 81.0886ms | 61.4065ms | 16.2849 Ops/s | 16.1663 Ops/s | |
test_sample_rb[TensorDictReplayBuffer-ListStorage-RandomSampler-4000] | 5.8806ms | 3.2130ms | 311.2367 Ops/s | 271.5736 Ops/s | |
test_sample_rb[TensorDictReplayBuffer-LazyMemmapStorage-RandomSampler-10000] | 6.1424ms | 3.1324ms | 319.2484 Ops/s | 276.5881 Ops/s | |
test_sample_rb[TensorDictReplayBuffer-LazyTensorStorage-RandomSampler-10000] | 8.9960ms | 3.3200ms | 301.2038 Ops/s | 283.1985 Ops/s | |
test_sample_rb[TensorDictReplayBuffer-ListStorage-SamplerWithoutReplacement-4000] | 4.7876ms | 2.9547ms | 338.4425 Ops/s | 298.9432 Ops/s | |
test_sample_rb[TensorDictReplayBuffer-LazyMemmapStorage-SamplerWithoutReplacement-10000] | 8.1250ms | 3.2951ms | 303.4831 Ops/s | 289.2854 Ops/s | |
test_sample_rb[TensorDictReplayBuffer-LazyTensorStorage-SamplerWithoutReplacement-10000] | 5.2123ms | 3.1451ms | 317.9543 Ops/s | 238.8304 Ops/s | |
test_sample_rb[TensorDictPrioritizedReplayBuffer-ListStorage-None-4000] | 5.3671ms | 3.0135ms | 331.8377 Ops/s | 310.6851 Ops/s | |
test_sample_rb[TensorDictPrioritizedReplayBuffer-LazyMemmapStorage-None-10000] | 6.0443ms | 3.3323ms | 300.0925 Ops/s | 241.0044 Ops/s | |
test_sample_rb[TensorDictPrioritizedReplayBuffer-LazyTensorStorage-None-10000] | 9.4841ms | 3.4549ms | 289.4463 Ops/s | 278.6298 Ops/s | |
test_iterate_rb[TensorDictReplayBuffer-ListStorage-RandomSampler-4000] | 5.9577ms | 3.4111ms | 293.1584 Ops/s | 227.1365 Ops/s | |
test_iterate_rb[TensorDictReplayBuffer-LazyMemmapStorage-RandomSampler-10000] | 11.2787ms | 3.3547ms | 298.0914 Ops/s | 271.2144 Ops/s | |
test_iterate_rb[TensorDictReplayBuffer-LazyTensorStorage-RandomSampler-10000] | 6.3453ms | 3.2779ms | 305.0698 Ops/s | 271.9679 Ops/s | |
test_iterate_rb[TensorDictReplayBuffer-ListStorage-SamplerWithoutReplacement-4000] | 5.3576ms | 3.0207ms | 331.0490 Ops/s | 319.0112 Ops/s | |
test_iterate_rb[TensorDictReplayBuffer-LazyMemmapStorage-SamplerWithoutReplacement-10000] | 0.1883s | 3.6347ms | 275.1226 Ops/s | 288.2381 Ops/s | |
test_iterate_rb[TensorDictReplayBuffer-LazyTensorStorage-SamplerWithoutReplacement-10000] | 6.9478ms | 3.3685ms | 296.8674 Ops/s | 298.1164 Ops/s | |
test_iterate_rb[TensorDictPrioritizedReplayBuffer-ListStorage-None-4000] | 6.6910ms | 3.0574ms | 327.0739 Ops/s | 319.0371 Ops/s | |
test_iterate_rb[TensorDictPrioritizedReplayBuffer-LazyMemmapStorage-None-10000] | 5.9688ms | 3.3096ms | 302.1529 Ops/s | 293.7645 Ops/s | |
test_iterate_rb[TensorDictPrioritizedReplayBuffer-LazyTensorStorage-None-10000] | 6.6344ms | 3.3876ms | 295.1961 Ops/s | 277.8849 Ops/s | |
test_populate_rb[TensorDictReplayBuffer-ListStorage-RandomSampler-400] | 0.2737s | 38.5404ms | 25.9468 Ops/s | 24.6046 Ops/s | |
test_populate_rb[TensorDictReplayBuffer-LazyMemmapStorage-RandomSampler-400] | 0.1903s | 32.5464ms | 30.7254 Ops/s | 28.8165 Ops/s | |
test_populate_rb[TensorDictReplayBuffer-LazyTensorStorage-RandomSampler-400] | 0.1692s | 34.9829ms | 28.5854 Ops/s | 26.6485 Ops/s | |
test_populate_rb[TensorDictReplayBuffer-ListStorage-SamplerWithoutReplacement-400] | 0.1617s | 32.2699ms | 30.9886 Ops/s | 29.3244 Ops/s | |
test_populate_rb[TensorDictReplayBuffer-LazyMemmapStorage-SamplerWithoutReplacement-400] | 0.1759s | 35.8962ms | 27.8581 Ops/s | 26.2619 Ops/s | |
test_populate_rb[TensorDictReplayBuffer-LazyTensorStorage-SamplerWithoutReplacement-400] | 0.1895s | 34.7836ms | 28.7492 Ops/s | 26.6320 Ops/s | |
test_populate_rb[TensorDictPrioritizedReplayBuffer-ListStorage-None-400] | 0.1903s | 36.6567ms | 27.2801 Ops/s | 27.8482 Ops/s | |
test_populate_rb[TensorDictPrioritizedReplayBuffer-LazyMemmapStorage-None-400] | 0.1797s | 35.3796ms | 28.2649 Ops/s | 26.5305 Ops/s | |
test_populate_rb[TensorDictPrioritizedReplayBuffer-LazyTensorStorage-None-400] | 0.1691s | 30.9726ms | 32.2866 Ops/s | 28.5540 Ops/s |
torchrl/envs/libs/gym.py
Outdated
def _post_init(self): | ||
# writes the functions that are gym-version specific to the instance | ||
# once and for all. This is aimed at avoiding the need of decorating code | ||
# with set_gym_backend + allowing for parallel execution (which would | ||
# be troublesome when both an old version of gym and recent gymnasium | ||
# are present within the same virtual env). | ||
# These calls seemingly do nothing but they actually get rid of the @implement_for decorator. | ||
# We execute them within the set_gym_backend context manager to make sure we get | ||
# the right implementation. | ||
with set_gym_backend(self.get_library_name(self._env)): | ||
self._reset_output_transform = self._reset_output_transform | ||
self._output_transform = self._output_transform |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM thanks a lot fot this.
I left a few comment, i am interested in the new function, ping me when we adopt it in rollout and collector
I think we're reaching a good state for this PR. |
Final list of changes part of this PR:
|
I'm leaving one todo in this PR which is the compatibility with RLHF text data parsing where we simply compute a done, which can wither mean terminated or truncated. This will be part of a follow-up PR |
I ll try to make a full review before tonight if thats ok |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Some comments, also some general points
- pettingzoo has truncated and terminated in each agent group so we can adapt that
- vmas, smac and similar simulators should output terminated instead of done so that we keep consistency
@@ -1700,6 +1748,7 @@ def _step( | |||
self.count += one_hot_action.to(torch.int) | |||
td["observation"] += expand_right(self.count, td["observation"].shape) | |||
done["done"] = self.count > self.max_steps | |||
done["terminated"] = self.count > self.max_steps |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Instead of adding teminated and keeping done, why don't we just use terminated instead of done?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
because we want done to be able to quickly know if we should reset. It avoids always going through all the possible keys and carries the summary of the end-of-traj in one tensor.
So now all envs have at least one done and one terminated, possibly a truncated.
It's what we discussed with @skandermoalla IIRC.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
but the done would be added automatically by EnvBase no?
i am talking about this mocking class, it could just output terminated and done would be built
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Not sure if my input is needed here, but here's how I understand this:
So now all envs have at least one done and one terminated, possibly a truncated.
Yes, this is what is expected.
but the done would be added automatically by EnvBase no?
i am talking about this mocking class, it could just output terminated and done would be built
Yes, the done key can be inferred from the terminated (possibly truncated keys). I haven't looked at the implementation yet to know where this is done, but if the EnvBase takes care of this, then this mock can avoid writing the done key indeed to test the automatic done inference feature.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think it's best if we show both whenever we can. We can copy the done if you prefer but not showing that it's populated isn't great from a UX perspective IMO
The auto patch is supposedly just there to ensure that everything runs fine even if someone just writes a terminated/done key
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
If this is our rule then we should apply it to all envs.
let’s just do it uniformily
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
just did, what did I miss?
torchrl/envs/batched_envs.py
Outdated
if self._single_task: | ||
# select + clone creates 2 tds, but we can create one only | ||
out = TensorDict( | ||
{}, batch_size=self.shared_tensordict_parent.shape, device=self.device | ||
) | ||
for key in self._selected_reset_keys: | ||
if _unravel_key_to_tuple(key)[-1] != "_reset": | ||
key = unravel_keys(key) | ||
if key not in self.reset_keys: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
the reset key should not be removed here
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
why? Can you explain? It used to be removed (see diff) and i'm pretty sure tests fail if you don't.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
there are many reasons we discussed
- they should alwayes be removed by
env.reset()
for uniformity - this removes them only when
self._single_task
and not otherwise
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
ok let me try
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
right so the problem is this:
if there is no reset in the input we assume it's all true.
But the buffer is not necessarily all true (it's more likely that it is all false). So if there's a _reset in the buffer, we don't want to pass it in the result of reset because its value isn't synced.
Now you could argue that we must then update the _reset in the buffer with the one that we got from the input or that we inferred but then the risk is that these _resets do not have the same shape.
Transforms such as reward sum look at the tensordict that resuts from _reset before it is pruned from its "_reset" and decide what to do from there.
So here are the cases we need to account for:
- the input tensordict does not have a "_reset" (or there is no input tensordict) => we infer that "_reset" is full of
True
- the input tensordict has a reset but its shape mismatch from the parallel env. This can happen if the "_reset" in the buffer has the shape of the "done" which, like you mentioned on some occasions, can have more dims than 1, but the user just passed a "_reset" that matched the number of envs (1-dim). In that case we could update the buffer by extending on the right I guess, but then the problem is if the env on the other end (in the subprocess) reads that "_reset". Then there's a chance it will be corrupt.
- All is good and the "_reset" match.
I think case 2 is hard to handle. The test pass if we exclude the "_reset". And we update the tensordict that results from ParallelEnv._reset in TransformedEnv._reset so the transforms can see the partial "_reset" if there is.
All in all I don't see why not exluding the "_reset" in ParallelEnv._reset
: the only method that will look for a "_reset" are TransformedEnv and EnvBase in their _reset / reset methods, respectively. I don't see anywhere where that will break.
If the point is that it should also be excluded inf the case where there is more than 1 task I totally agree!
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
``> If the point is that it should also be excluded inf the case where there is more than 1 task I totally agree!
definitely let's uniform this.
my other point is just that i thought we decided that the convention was that all the _resets are removed only in env.reset()
and in this particular instance they are removed bofore. this has 2 problems:
env.reset()
cannot perform the done check appropriately (it will see a missing _reset and assume it is all true but that is not the case)(this is a problem if the reset was partial and some dones are true after the reset as they should but they cause the check to fail)- and that this is not uniform as the batched env _reset() will be the only one to not provide _resets (and this will make env.reset() not perform its checks properly)
this was not a problem before as in env.reset()
i was chaching the _resets in the input (before any other call) to check the dones
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
env.reset() cannot perform the done check appropriately (it will see a missing _reset and assume it is all true but that is not the case)(this is a problem if the reset was partial and some dones are true after the reset as they should but they cause the check to fail)
no because env.reset gets its "_reset" from the input tensordict, not the output
and that this is not uniform as the batched env _reset() will be the only one to not provide _resets
Same I don't think it is true: per se all _reset must return a new tensordict and it's written nowhere in the contract of that function that you must copy any part of the content of the input in the output. In other words: we never read "_reset" from the output of _reset. Check the reset
method for more context if need reassurance.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
ok got it, my bad.
let's just test that BatchedEnv._reset() does not return _resets in all branches
I don't get this comment, what do you mean? Is there some extra work needed? I think we cover multiple truncated and terminated within the data structure |
yep got that but if we keep writing done explicitly instead of terminated in the wrappers then the users will be deceived on the meaning of what comes from the simulator. I think terminated should be the one provided and done created through a clone also because done may be overwritten but terminated not |
right now in pettingzoo there is an or between the 2 to create a done, we should just adapt that |
ah got it, so your point is that as part of this PR all simulators that were writing done should write terminated instead?
Per se both are always provided EDIT: another datapoint: Brax explicitely calls its eot "done", not sure if we want to rename that. To me if an env calls something done we document that we also create a terminated but I don't think that state["terminated"] = state.get("done").view(*self.reward_spec.shape) within torchrl is less "deceiving" than the other option EDIT: if you look at the latest commit, what I did is just placing done in terminated and done in each env. I think it's clearer than placing just one and let the magic happen |
so placing both directly in the env? we can do this but let s make sure to have it uniformly everywhere |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM, let's just addrress the remaining open comments and we are good to go for me
Thanks for this btw, it is an impressive work!
Co-authored-by: Skander Moalla <37197319+skandermoalla@users.noreply.github.com> Co-authored-by: Matteo Bettini <55539777+matteobettini@users.noreply.github.com>
Final list of changes part of this PR:
Add full_done_spec, full_observation_spec, full_action_spec, full_state_spec, full_reward_spec properties.
Make all envs have "done" and "terminated".
Add a _complete_done private method within envs to complete output data if done_keys are missing. Make this compatible with transforms to avoid duplicated calls to the method.
Full conversion to done_keys wherever done_key was used.
versioned reading helpers in GymWrapper to account for discrepancies between versions
terminated_or_truncated, a function that reads the data, writes the "_reset" keys and return a boolean indicating if anything is done.
Fix @implement_for, set_gym_backend and related + add tests
Fix envs tutorial and add info wrt new "done" attributes
Compatibility with batched environments
Fix transforms: VecGymEnv, Reward2GoTransform, RenameTransform, InitTracker, TimeMaxPool, SelectTransform, ExcludeTransform, StepCounter
Fix ModelBased envs, RoboHive, D4RL, RLHF
Adapt collectors to new logic
Better handling of workers refusing to close in collectors
Fix Gym tests across versions
cc @skandermoalla