Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Feature] Gym compatibility: Terminal and truncated #1539

Merged
merged 181 commits into from
Sep 29, 2023
Merged

Conversation

vmoens
Copy link
Contributor

@vmoens vmoens commented Sep 15, 2023

Final list of changes part of this PR:

Add full_done_spec, full_observation_spec, full_action_spec, full_state_spec, full_reward_spec properties.
Make all envs have "done" and "terminated".
Add a _complete_done private method within envs to complete output data if done_keys are missing. Make this compatible with transforms to avoid duplicated calls to the method.
Full conversion to done_keys wherever done_key was used.
versioned reading helpers in GymWrapper to account for discrepancies between versions
terminated_or_truncated, a function that reads the data, writes the "_reset" keys and return a boolean indicating if anything is done.
Fix @implement_for, set_gym_backend and related + add tests
Fix envs tutorial and add info wrt new "done" attributes
Compatibility with batched environments
Fix transforms: VecGymEnv, Reward2GoTransform, RenameTransform, InitTracker, TimeMaxPool, SelectTransform, ExcludeTransform, StepCounter
Fix ModelBased envs, RoboHive, D4RL, RLHF
Adapt collectors to new logic
Better handling of workers refusing to close in collectors
Fix Gym tests across versions

cc @skandermoalla

@facebook-github-bot facebook-github-bot added the CLA Signed This label is managed by the Facebook bot. Authors need to sign the CLA before a PR can be reviewed. label Sep 15, 2023
@vmoens vmoens marked this pull request as draft September 15, 2023 17:15
@vmoens vmoens changed the title [Feature] Gym compatibility: Terminal and truncated [WIP] Gym compatibility: Terminal and truncated Sep 15, 2023
@vmoens vmoens added the Environments Adds or modifies an environment wrapper label Sep 15, 2023
@vmoens
Copy link
Contributor Author

vmoens commented Sep 17, 2023

@skandermoalla The current project is to keep done as it was and add truncated.
Most of our code base interprets done as termination, ie: in gym, you have done = termination | truncation but in torchrl you'd have end_of_traj = done | truncation.
This naming is a bit different than gym but the alternative would be to rename done in termination everywhere in the code.
With @matteobettini we thought that proceeding as we do now is less disruptive than renaming every single occurence of done. Happy to read your thoughts about this!

@skandermoalla
Copy link
Contributor

Thanks a lot for starting this and for pinging me!

IMO naming differences are fine as long as done and truncation consistently mean the Gymnasium termination and truncation everywhere else in TorchRL and that they're correctly used.

Most of our code base interprets done as termination

So here "most" should be very carefully assessed. It would be terrible to have done mean terminated at some point and end_of_traj at some other point. I can see exceptions coming from the fact that TorchRL supports multiple environment libraries and some of them only return end_of_traj, but I think there is no way to handle this underspecified signal and from the TorchRL perspective should be treated as terminated.

I think I sometimes confused the done in TorchRL for an end_of_traj signal. I don't recall exactly where, but that's to emphasize that it would be good to check this, especially in value estimators and transforms. (An end_of_traj signal would not allow to bootstrap correctly without further considering the truncated bit).

So maybe also making the docs more explicit about this definition would help a lot of people arriving from Gymnasium.

I'm also thinking it would be useful to allow the use StepCounter(max_steps=N) to mark a native time limit (with the done key) to adjust the finite horizon or an env, vs. to mark an external time limit (truncation). I think this is already possible by just changing the truncation key to done right? Will this be safe?
I don't know how Gymnasium differentiates between internal and external time limits to correctly set the truncation key.

Copy link
Contributor

@skandermoalla skandermoalla left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Adding an early review before disappearing until the ICLR deadline 🫠.

Comment on lines 228 to 239
recursive: bool = True,
):
found_entry = False
for value in spec.values():
for key, value in spec.items():
if isinstance(value, CompositeSpec):
_check_only_one_entry(value, error, recursive)
else:
if not found_entry:
found_entry = True
else:
raise error
if key == "done":
if found_entry:
raise error
else:
found_entry = True
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Recursivity is still not implemented here right?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@matteobettini I'm not even sure we need this function anymore. We wanted to avoid that people would have "done_agent1" and "done_agent2" and rather ("agent1", "done") IIRC.
Now we will allow "done_spec" to contain "done" and "truncated" at any level. Per se the done_spec could even be richer. So I don't think we need this check anymore: it's good practice to have only one done per level but I don't think that will break anywhere if you have more.

The only exception I see is computing the "_reset" signal which usually reflects done | truncated. In other words, _reset reflects somethings like

def _make_reset(data, done_spec):
    out = {}
    for key, item in done_spec.items():
        if isinstance(done_spec, CompositeSpec):
            out[key] = _make_reset(data.get(key), item)
        else:
            out["_reset"] = out.get("_reset", False) | data.get(key)
     return data.update(out)

This means tht if you put done_agent1 and done_agent2 at the same level you'll get a single reset for both. Not ideal but not really our business if the MARL API is well documented IMO.

In collectors we use "done" to compute the trajectory ids but this is an extra feature that could apply only in settings where done is at the root (or in other restricted settings). The trajectory id doesn't influence the collector behaviour, which only cares about the number of frames (number of elements of the root tensordict).

For the rest, loss modules have parametric dones so I think we could rename "done" to "merylstreep" and they could still work. They will need refactoring to account for "truncated" anyway.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

So the situation is a little more complex, let me try to explain it here:

Since we decided that our naming convention is to have one "_reset" key per tensordict (due to the fact that if we had more we would have to change the name), then this key has to reflect the shape of the "done" in that tensordict.

If we allowed more then one leaf key in each td of the done_spec then the problem is what shape should the _reset reflect. We cannot do like vincent said cause the shapes might be arbitrary.

Therefore the convention established is that at least a key with the name "done" (and exactly done) should exist in each td of the done_spec and the reset will mimic its shape.

If now we want to change the convention and say that the constraint now is that all the values in the same tensordict of the done_spec should have the same shape, that is fine by me, but then we have to check and enforce that.

Copy link
Contributor

@matteobettini matteobettini Sep 18, 2023

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

if we do like in this PR and assume to always have truncated in the done_spec, then we have to carefully adjust all the components of the codebase that create the _reset key.

Like here

for done_key in self.env.done_keys:

and here

for done_key in self.done_keys:

For example.

These lines will currently create a _reset for each done_spec.keys(True,True).
But if the done_spec now contains more than one leaf per td we have to change all this.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

If now we want to change the convention and say that the constraint now is that all the values in the same tensordict of the done_spec should have the same shape, that is fine by me, but then we have to check and enforce that.

Per se, if someone writes an env that does not require _reset (which we use internally but is not exposed to users) it's fine. What I mean is that if a data structure has "done_agent1" of shape X and "done_agent2" of shape Y, computing a "_reset" key will fail but that only means that the env cannot be used within a collector.
Since we want to provide independent, modular components, I think it's a design choice that is ok not to enforce that all envs written by any user would fit in a collector. We need to make sure that the collector tells you specifically what is wrong with your env though.

What about letting envs have the done_spec they want, and let the collectors / parallel envs spot the potential problems during construction?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yeah but also the rollout function uses the reset logic. So how would we treat that?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Imagine someone wants to use TorchRL to write a multiagent env that does some job with the "done_agentX" structure.
I think it's cool if they do that and put that in their app. I don't think they should get an exception or a warning that tells them that they're going to be in a dead end when using a collector. Maybe they don't want to use the collectors.
Let's not prevent legitimate use of one class to prevent bug in another if the usage of the other is optional, wdyt?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes but vectorized envs will need to be fed a _reset flag upon reset in the env.rollout function. so you are saying we will provide that with our best effort but if we cannot we just give a warning and say that in collector this will be an error?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yeah but also the rollout function uses the reset logic. So how would we treat that?

Oh yeah good point.
This is a bit tricky: per se, we set a _reset in rollout and hope for the best but we don't enforce that it's being read anywhere.
The thing is that we don't really have any agency about what reset(tensordict) does under the hood. For ParallelEnv and similar we know what to do but in other cases I'm not 100% sure of what could be done generically.

But yeah, since we use _reset within the env class it makes sense that it's computed in a consistent way within the library.
So what we would opt for is:

  • during construction, walk through the specs and check that at each level, all done specs have a matching shape
  • during execution, compute _reset as the combination of all _reset keys at each level

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

during execution, compute _reset as the combination of all _reset keys at each level

-> as the combination of all done_spec_keys at each level

but yes that could work!

@@ -228,14 +228,15 @@ def _check_only_one_entry(
recursive: bool = True,
):
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The method would benefit from a quick docstring or better name? I'm not sure to understand what it does without more context.

Comment on lines 142 to 144
This method returns:
- a done state to be set in the environment
- a boolean value indicating whether the frame_skip loop should be broken
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This doesn't seem to match its current output (tuple of 3 elements). Missing the truncated output?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

yes I still need to go through the docstrings but thanks for reminding me!

Comment on lines 139 to 140
In TorchRL, done means that a trajectory is terminated (we do not support the
terminated signal). Truncated means the trajectory has been interrupted.
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

If done is exactly terminated then what does "we do not support the terminated signal" mean?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I agree it's confusing. Maybe

In torchrl, a `done` signal means that a trajectory is terminated (what is referred to as `termination` in gymnasium).

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yep! That's clear!

Comment on lines 151 to 153
if termination is not None:
return termination, truncation, done
return done, truncation, done
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This means eventual bootstrapping will be done correctly for new versions of gym and incorrectly for old versions (as when truncated they set done to true and specify their truncation bit in the info variable). A note should be added about this somewhere.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Let me rephrase what you're saying here:
if older version of gym (<0.26) a truncation signal could be retrieved from the info.
Currently, we don't use it so assuming that done=truncated will not work properly with value functions, is that right?

If so, I think that we should advise users of adopting a dedicated info reader to write the truncated signal, which will solve the problem.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

if older version of gym (<0.26) a truncation signal could be retrieved from the info.

Yes in the TimeLimit.truncated key if it happened. See https://gymnasium.farama.org/tutorials/gymnasium_basics/handling_time_limits/#solution.

assuming that done=truncated will not work properly with value functions

Yes, and to be exact done here with the old gym versions would be gym_done=terminated | truncated. (It was an end_of_traj)

I think that we should advise users of adopting a dedicated info reader to write the truncated signal, which will solve the problem.

Yes, that would be the correct way to record the truncated key and propagate it to the rest of TorchRL's components. However, the gym_done key would still need to be corrected with torchrl_done = gym_done and ~truncated I don't how this could be done by the users.

torchrl/envs/gym_like.py Outdated Show resolved Hide resolved
@@ -187,26 +195,12 @@ def _step(self, tensordict: TensorDictBase) -> TensorDictBase:

reward = 0
for _ in range(self.wrapper_frame_skip):
obs, _reward, done, *info = self._output_transform(
obs, _reward, termination, truncation, done, info = self._output_transform(
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Realizing once more that variable order can so easily be messed up here. Shout out to Tensordict for solving this!

@@ -133,18 +133,24 @@ def read_action(self, action):
"""
return self.action_spec.to_numpy(action, safe=False)

def read_done(self, done):
def read_done(self, done, termination, truncation):
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Maybe the variables in this file referring to the gym done can be renamed to gym_done to avoid confusion with the torchrl done?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

good point, it's getting a bit confusing!

@github-actions
Copy link

github-actions bot commented Sep 18, 2023

$\color{#D29922}\textsf{\Large&amp;#x26A0;\kern{0.2cm}\normalsize Warning}$ Result of CPU Benchmark Tests

Total Benchmarks: 89. Improved: $\large\color{#35bf28}34$. Worsened: $\large\color{#d91a1a}10$.

Expand to view detailed results
Name Max Mean Ops Ops on Repo HEAD Change
test_single 0.1520s 0.1501s 6.6616 Ops/s 8.6325 Ops/s $\textbf{\color{#d91a1a}-22.83\%}$
test_sync 91.5973ms 80.9036ms 12.3604 Ops/s 15.8629 Ops/s $\textbf{\color{#d91a1a}-22.08\%}$
test_async 0.2039s 79.7407ms 12.5406 Ops/s 15.7839 Ops/s $\textbf{\color{#d91a1a}-20.55\%}$
test_simple 1.3358s 1.2287s 0.8139 Ops/s 1.0459 Ops/s $\textbf{\color{#d91a1a}-22.19\%}$
test_transformed 1.5936s 1.5333s 0.6522 Ops/s 0.8244 Ops/s $\textbf{\color{#d91a1a}-20.89\%}$
test_serial 3.4552s 3.3403s 0.2994 Ops/s 0.3646 Ops/s $\textbf{\color{#d91a1a}-17.90\%}$
test_parallel 2.7917s 2.6656s 0.3752 Ops/s 0.4225 Ops/s $\textbf{\color{#d91a1a}-11.20\%}$
test_step_mdp_speed[True-True-True-True-True] 2.3443ms 47.5616μs 21.0254 KOps/s 19.5929 KOps/s $\textbf{\color{#35bf28}+7.31\%}$
test_step_mdp_speed[True-True-True-True-False] 1.4567ms 28.1532μs 35.5200 KOps/s 36.0895 KOps/s $\color{#d91a1a}-1.58\%$
test_step_mdp_speed[True-True-True-False-True] 1.6109ms 35.7529μs 27.9698 KOps/s 27.6125 KOps/s $\color{#35bf28}+1.29\%$
test_step_mdp_speed[True-True-True-False-False] 1.4395ms 20.3467μs 49.1480 KOps/s 48.0980 KOps/s $\color{#35bf28}+2.18\%$
test_step_mdp_speed[True-True-False-True-True] 1.8869ms 50.3061μs 19.8783 KOps/s 20.1280 KOps/s $\color{#d91a1a}-1.24\%$
test_step_mdp_speed[True-True-False-True-False] 1.2443ms 27.8566μs 35.8981 KOps/s 30.6501 KOps/s $\textbf{\color{#35bf28}+17.12\%}$
test_step_mdp_speed[True-True-False-False-True] 3.7292ms 40.3144μs 24.8050 KOps/s 25.0849 KOps/s $\color{#d91a1a}-1.12\%$
test_step_mdp_speed[True-True-False-False-False] 2.7486ms 21.1617μs 47.2552 KOps/s 42.2469 KOps/s $\textbf{\color{#35bf28}+11.85\%}$
test_step_mdp_speed[True-False-True-True-True] 1.4861ms 53.0680μs 18.8438 KOps/s 18.2045 KOps/s $\color{#35bf28}+3.51\%$
test_step_mdp_speed[True-False-True-True-False] 1.8459ms 32.4297μs 30.8359 KOps/s 30.6470 KOps/s $\color{#35bf28}+0.62\%$
test_step_mdp_speed[True-False-True-False-True] 3.5215ms 36.6657μs 27.2734 KOps/s 26.1134 KOps/s $\color{#35bf28}+4.44\%$
test_step_mdp_speed[True-False-True-False-False] 1.5480ms 21.8291μs 45.8104 KOps/s 43.9778 KOps/s $\color{#35bf28}+4.17\%$
test_step_mdp_speed[True-False-False-True-True] 4.0434ms 58.2856μs 17.1569 KOps/s 17.8605 KOps/s $\color{#d91a1a}-3.94\%$
test_step_mdp_speed[True-False-False-True-False] 2.5648ms 35.1649μs 28.4375 KOps/s 29.5241 KOps/s $\color{#d91a1a}-3.68\%$
test_step_mdp_speed[True-False-False-False-True] 1.5462ms 38.7024μs 25.8382 KOps/s 25.1199 KOps/s $\color{#35bf28}+2.86\%$
test_step_mdp_speed[True-False-False-False-False] 4.2800ms 24.4261μs 40.9397 KOps/s 42.4358 KOps/s $\color{#d91a1a}-3.53\%$
test_step_mdp_speed[False-True-True-True-True] 3.5002ms 54.5293μs 18.3388 KOps/s 18.6064 KOps/s $\color{#d91a1a}-1.44\%$
test_step_mdp_speed[False-True-True-True-False] 1.5240ms 33.3729μs 29.9645 KOps/s 28.2972 KOps/s $\textbf{\color{#35bf28}+5.89\%}$
test_step_mdp_speed[False-True-True-False-True] 1.4671ms 43.1279μs 23.1868 KOps/s 23.8162 KOps/s $\color{#d91a1a}-2.64\%$
test_step_mdp_speed[False-True-True-False-False] 1.6163ms 24.4371μs 40.9213 KOps/s 38.5422 KOps/s $\textbf{\color{#35bf28}+6.17\%}$
test_step_mdp_speed[False-True-False-True-True] 1.5880ms 56.8706μs 17.5838 KOps/s 18.1635 KOps/s $\color{#d91a1a}-3.19\%$
test_step_mdp_speed[False-True-False-True-False] 1.9065ms 35.3787μs 28.2656 KOps/s 28.9811 KOps/s $\color{#d91a1a}-2.47\%$
test_step_mdp_speed[False-True-False-False-True] 1.7214ms 45.0859μs 22.1799 KOps/s 24.0683 KOps/s $\textbf{\color{#d91a1a}-7.85\%}$
test_step_mdp_speed[False-True-False-False-False] 9.1869ms 26.8996μs 37.1753 KOps/s 36.6084 KOps/s $\color{#35bf28}+1.55\%$
test_step_mdp_speed[False-False-True-True-True] 2.7321ms 57.2490μs 17.4676 KOps/s 16.3324 KOps/s $\textbf{\color{#35bf28}+6.95\%}$
test_step_mdp_speed[False-False-True-True-False] 2.2998ms 36.2487μs 27.5872 KOps/s 23.8731 KOps/s $\textbf{\color{#35bf28}+15.56\%}$
test_step_mdp_speed[False-False-True-False-True] 2.4786ms 42.8941μs 23.3133 KOps/s 21.2729 KOps/s $\textbf{\color{#35bf28}+9.59\%}$
test_step_mdp_speed[False-False-True-False-False] 3.5744ms 26.5392μs 37.6801 KOps/s 37.0256 KOps/s $\color{#35bf28}+1.77\%$
test_step_mdp_speed[False-False-False-True-True] 1.8297ms 60.5148μs 16.5249 KOps/s 16.1699 KOps/s $\color{#35bf28}+2.20\%$
test_step_mdp_speed[False-False-False-True-False] 4.9242ms 38.9195μs 25.6941 KOps/s 25.7345 KOps/s $\color{#d91a1a}-0.16\%$
test_step_mdp_speed[False-False-False-False-True] 1.5061ms 41.5351μs 24.0760 KOps/s 21.6581 KOps/s $\textbf{\color{#35bf28}+11.16\%}$
test_step_mdp_speed[False-False-False-False-False] 5.4670ms 27.7662μs 36.0150 KOps/s 33.8220 KOps/s $\textbf{\color{#35bf28}+6.48\%}$
test_values[generalized_advantage_estimate-True-True] 21.8422ms 16.7033ms 59.8686 Ops/s 54.9726 Ops/s $\textbf{\color{#35bf28}+8.91\%}$
test_values[vec_generalized_advantage_estimate-True-True] 0.1274s 78.4241ms 12.7512 Ops/s 12.4612 Ops/s $\color{#35bf28}+2.33\%$
test_values[td0_return_estimate-False-False] 2.2441ms 0.6966ms 1.4355 KOps/s 1.4467 KOps/s $\color{#d91a1a}-0.78\%$
test_values[td1_return_estimate-False-False] 22.1129ms 17.7120ms 56.4589 Ops/s 56.5616 Ops/s $\color{#d91a1a}-0.18\%$
test_values[vec_td1_return_estimate-False-False] 0.1050s 79.5756ms 12.5667 Ops/s 12.9641 Ops/s $\color{#d91a1a}-3.07\%$
test_values[td_lambda_return_estimate-True-False] 56.0795ms 51.9957ms 19.2323 Ops/s 19.1128 Ops/s $\color{#35bf28}+0.63\%$
test_values[vec_td_lambda_return_estimate-True-False] 0.1062s 78.3074ms 12.7702 Ops/s 12.2649 Ops/s $\color{#35bf28}+4.12\%$
test_gae_speed[generalized_advantage_estimate-False-1-512] 22.3703ms 14.6167ms 68.4149 Ops/s 72.5063 Ops/s $\textbf{\color{#d91a1a}-5.64\%}$
test_gae_speed[vec_generalized_advantage_estimate-True-1-512] 10.8682ms 5.2302ms 191.1978 Ops/s 177.2607 Ops/s $\textbf{\color{#35bf28}+7.86\%}$
test_gae_speed[vec_generalized_advantage_estimate-False-1-512] 4.3420ms 0.9028ms 1.1077 KOps/s 1.0838 KOps/s $\color{#35bf28}+2.21\%$
test_gae_speed[vec_generalized_advantage_estimate-True-32-512] 78.1810ms 74.2825ms 13.4621 Ops/s 13.4408 Ops/s $\color{#35bf28}+0.16\%$
test_gae_speed[vec_generalized_advantage_estimate-False-32-512] 9.3381ms 6.6182ms 151.0993 Ops/s 143.3184 Ops/s $\textbf{\color{#35bf28}+5.43\%}$
test_dqn_speed 8.6350ms 3.9416ms 253.7030 Ops/s 249.1651 Ops/s $\color{#35bf28}+1.82\%$
test_ddpg_speed 13.7405ms 7.6336ms 130.9999 Ops/s 131.6222 Ops/s $\color{#d91a1a}-0.47\%$
test_sac_speed 29.3333ms 20.5020ms 48.7757 Ops/s 48.8679 Ops/s $\color{#d91a1a}-0.19\%$
test_redq_speed 39.5370ms 31.5141ms 31.7318 Ops/s 31.1134 Ops/s $\color{#35bf28}+1.99\%$
test_redq_deprec_speed 37.9842ms 30.0608ms 33.2659 Ops/s 30.8167 Ops/s $\textbf{\color{#35bf28}+7.95\%}$
test_td3_speed 27.1305ms 20.3410ms 49.1618 Ops/s 49.0771 Ops/s $\color{#35bf28}+0.17\%$
test_cql_speed 0.1223s 91.9787ms 10.8721 Ops/s 13.3079 Ops/s $\textbf{\color{#d91a1a}-18.30\%}$
test_a2c_speed 22.9850ms 14.5774ms 68.5994 Ops/s 70.0593 Ops/s $\color{#d91a1a}-2.08\%$
test_ppo_speed 22.1010ms 14.5217ms 68.8626 Ops/s 63.3403 Ops/s $\textbf{\color{#35bf28}+8.72\%}$
test_reinforce_speed 28.0911ms 12.1667ms 82.1918 Ops/s 85.3810 Ops/s $\color{#d91a1a}-3.74\%$
test_iql_speed 81.0886ms 61.4065ms 16.2849 Ops/s 16.1663 Ops/s $\color{#35bf28}+0.73\%$
test_sample_rb[TensorDictReplayBuffer-ListStorage-RandomSampler-4000] 5.8806ms 3.2130ms 311.2367 Ops/s 271.5736 Ops/s $\textbf{\color{#35bf28}+14.60\%}$
test_sample_rb[TensorDictReplayBuffer-LazyMemmapStorage-RandomSampler-10000] 6.1424ms 3.1324ms 319.2484 Ops/s 276.5881 Ops/s $\textbf{\color{#35bf28}+15.42\%}$
test_sample_rb[TensorDictReplayBuffer-LazyTensorStorage-RandomSampler-10000] 8.9960ms 3.3200ms 301.2038 Ops/s 283.1985 Ops/s $\textbf{\color{#35bf28}+6.36\%}$
test_sample_rb[TensorDictReplayBuffer-ListStorage-SamplerWithoutReplacement-4000] 4.7876ms 2.9547ms 338.4425 Ops/s 298.9432 Ops/s $\textbf{\color{#35bf28}+13.21\%}$
test_sample_rb[TensorDictReplayBuffer-LazyMemmapStorage-SamplerWithoutReplacement-10000] 8.1250ms 3.2951ms 303.4831 Ops/s 289.2854 Ops/s $\color{#35bf28}+4.91\%$
test_sample_rb[TensorDictReplayBuffer-LazyTensorStorage-SamplerWithoutReplacement-10000] 5.2123ms 3.1451ms 317.9543 Ops/s 238.8304 Ops/s $\textbf{\color{#35bf28}+33.13\%}$
test_sample_rb[TensorDictPrioritizedReplayBuffer-ListStorage-None-4000] 5.3671ms 3.0135ms 331.8377 Ops/s 310.6851 Ops/s $\textbf{\color{#35bf28}+6.81\%}$
test_sample_rb[TensorDictPrioritizedReplayBuffer-LazyMemmapStorage-None-10000] 6.0443ms 3.3323ms 300.0925 Ops/s 241.0044 Ops/s $\textbf{\color{#35bf28}+24.52\%}$
test_sample_rb[TensorDictPrioritizedReplayBuffer-LazyTensorStorage-None-10000] 9.4841ms 3.4549ms 289.4463 Ops/s 278.6298 Ops/s $\color{#35bf28}+3.88\%$
test_iterate_rb[TensorDictReplayBuffer-ListStorage-RandomSampler-4000] 5.9577ms 3.4111ms 293.1584 Ops/s 227.1365 Ops/s $\textbf{\color{#35bf28}+29.07\%}$
test_iterate_rb[TensorDictReplayBuffer-LazyMemmapStorage-RandomSampler-10000] 11.2787ms 3.3547ms 298.0914 Ops/s 271.2144 Ops/s $\textbf{\color{#35bf28}+9.91\%}$
test_iterate_rb[TensorDictReplayBuffer-LazyTensorStorage-RandomSampler-10000] 6.3453ms 3.2779ms 305.0698 Ops/s 271.9679 Ops/s $\textbf{\color{#35bf28}+12.17\%}$
test_iterate_rb[TensorDictReplayBuffer-ListStorage-SamplerWithoutReplacement-4000] 5.3576ms 3.0207ms 331.0490 Ops/s 319.0112 Ops/s $\color{#35bf28}+3.77\%$
test_iterate_rb[TensorDictReplayBuffer-LazyMemmapStorage-SamplerWithoutReplacement-10000] 0.1883s 3.6347ms 275.1226 Ops/s 288.2381 Ops/s $\color{#d91a1a}-4.55\%$
test_iterate_rb[TensorDictReplayBuffer-LazyTensorStorage-SamplerWithoutReplacement-10000] 6.9478ms 3.3685ms 296.8674 Ops/s 298.1164 Ops/s $\color{#d91a1a}-0.42\%$
test_iterate_rb[TensorDictPrioritizedReplayBuffer-ListStorage-None-4000] 6.6910ms 3.0574ms 327.0739 Ops/s 319.0371 Ops/s $\color{#35bf28}+2.52\%$
test_iterate_rb[TensorDictPrioritizedReplayBuffer-LazyMemmapStorage-None-10000] 5.9688ms 3.3096ms 302.1529 Ops/s 293.7645 Ops/s $\color{#35bf28}+2.86\%$
test_iterate_rb[TensorDictPrioritizedReplayBuffer-LazyTensorStorage-None-10000] 6.6344ms 3.3876ms 295.1961 Ops/s 277.8849 Ops/s $\textbf{\color{#35bf28}+6.23\%}$
test_populate_rb[TensorDictReplayBuffer-ListStorage-RandomSampler-400] 0.2737s 38.5404ms 25.9468 Ops/s 24.6046 Ops/s $\textbf{\color{#35bf28}+5.46\%}$
test_populate_rb[TensorDictReplayBuffer-LazyMemmapStorage-RandomSampler-400] 0.1903s 32.5464ms 30.7254 Ops/s 28.8165 Ops/s $\textbf{\color{#35bf28}+6.62\%}$
test_populate_rb[TensorDictReplayBuffer-LazyTensorStorage-RandomSampler-400] 0.1692s 34.9829ms 28.5854 Ops/s 26.6485 Ops/s $\textbf{\color{#35bf28}+7.27\%}$
test_populate_rb[TensorDictReplayBuffer-ListStorage-SamplerWithoutReplacement-400] 0.1617s 32.2699ms 30.9886 Ops/s 29.3244 Ops/s $\textbf{\color{#35bf28}+5.68\%}$
test_populate_rb[TensorDictReplayBuffer-LazyMemmapStorage-SamplerWithoutReplacement-400] 0.1759s 35.8962ms 27.8581 Ops/s 26.2619 Ops/s $\textbf{\color{#35bf28}+6.08\%}$
test_populate_rb[TensorDictReplayBuffer-LazyTensorStorage-SamplerWithoutReplacement-400] 0.1895s 34.7836ms 28.7492 Ops/s 26.6320 Ops/s $\textbf{\color{#35bf28}+7.95\%}$
test_populate_rb[TensorDictPrioritizedReplayBuffer-ListStorage-None-400] 0.1903s 36.6567ms 27.2801 Ops/s 27.8482 Ops/s $\color{#d91a1a}-2.04\%$
test_populate_rb[TensorDictPrioritizedReplayBuffer-LazyMemmapStorage-None-400] 0.1797s 35.3796ms 28.2649 Ops/s 26.5305 Ops/s $\textbf{\color{#35bf28}+6.54\%}$
test_populate_rb[TensorDictPrioritizedReplayBuffer-LazyTensorStorage-None-400] 0.1691s 30.9726ms 32.2866 Ops/s 28.5540 Ops/s $\textbf{\color{#35bf28}+13.07\%}$

Comment on lines 470 to 481
def _post_init(self):
# writes the functions that are gym-version specific to the instance
# once and for all. This is aimed at avoiding the need of decorating code
# with set_gym_backend + allowing for parallel execution (which would
# be troublesome when both an old version of gym and recent gymnasium
# are present within the same virtual env).
# These calls seemingly do nothing but they actually get rid of the @implement_for decorator.
# We execute them within the set_gym_backend context manager to make sure we get
# the right implementation.
with set_gym_backend(self.get_library_name(self._env)):
self._reset_output_transform = self._reset_output_transform
self._output_transform = self._output_transform
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Copy link
Contributor

@matteobettini matteobettini left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM thanks a lot fot this.

I left a few comment, i am interested in the new function, ping me when we adopt it in rollout and collector

@vmoens
Copy link
Contributor Author

vmoens commented Sep 27, 2023

I think we're reaching a good state for this PR.
I will do some cleanup tomorrow and maybe add a couple of tests. I will also test that on a separate branch with the SOTA implementations we have to check that everything's running smoothly!

@vmoens
Copy link
Contributor Author

vmoens commented Sep 28, 2023

Final list of changes part of this PR:

  • Add full_done_spec, full_observation_spec, full_action_spec, full_state_spec, full_reward_spec properties.
  • Make all envs have "done" and "terminated".
  • Add a _complete_done private method within envs to complete output data if done_keys are missing. Make this compatible with transforms to avoid duplicated calls to the method.
  • Full conversion to done_keys wherever done_key was used.
  • versioned reading helpers in GymWrapper to account for discrepancies between versions
  • terminated_or_truncated, a function that reads the data, writes the "_reset" keys and return a boolean indicating if anything is done.
  • Fix @implement_for, set_gym_backend and related + add tests
  • Fix envs tutorial and add info wrt new "done" attributes
  • Compatibility with batched environments
  • Fix transforms: VecGymEnv, Reward2GoTransform, RenameTransform, InitTracker, TimeMaxPool, SelectTransform, ExcludeTransform, StepCounter
  • Fix ModelBased envs, RoboHive, D4RL, RLHF
  • Adapt collectors to new logic
  • Better handling of workers refusing to close in collectors
  • Fix Gym tests across versions

@vmoens
Copy link
Contributor Author

vmoens commented Sep 28, 2023

I'm leaving one todo in this PR which is the compatibility with RLHF text data parsing where we simply compute a done, which can wither mean terminated or truncated. This will be part of a follow-up PR

@matteobettini
Copy link
Contributor

I ll try to make a full review before tonight if thats ok

Copy link
Contributor

@matteobettini matteobettini left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Some comments, also some general points

  • pettingzoo has truncated and terminated in each agent group so we can adapt that
  • vmas, smac and similar simulators should output terminated instead of done so that we keep consistency

docs/source/reference/envs.rst Outdated Show resolved Hide resolved
@@ -1700,6 +1748,7 @@ def _step(
self.count += one_hot_action.to(torch.int)
td["observation"] += expand_right(self.count, td["observation"].shape)
done["done"] = self.count > self.max_steps
done["terminated"] = self.count > self.max_steps
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Instead of adding teminated and keeping done, why don't we just use terminated instead of done?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

because we want done to be able to quickly know if we should reset. It avoids always going through all the possible keys and carries the summary of the end-of-traj in one tensor.
So now all envs have at least one done and one terminated, possibly a truncated.
It's what we discussed with @skandermoalla IIRC.

Copy link
Contributor

@matteobettini matteobettini Sep 28, 2023

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

but the done would be added automatically by EnvBase no?
i am talking about this mocking class, it could just output terminated and done would be built

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Not sure if my input is needed here, but here's how I understand this:

So now all envs have at least one done and one terminated, possibly a truncated.

Yes, this is what is expected.

but the done would be added automatically by EnvBase no?
i am talking about this mocking class, it could just output terminated and done would be built

Yes, the done key can be inferred from the terminated (possibly truncated keys). I haven't looked at the implementation yet to know where this is done, but if the EnvBase takes care of this, then this mock can avoid writing the done key indeed to test the automatic done inference feature.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think it's best if we show both whenever we can. We can copy the done if you prefer but not showing that it's populated isn't great from a UX perspective IMO
The auto patch is supposedly just there to ensure that everything runs fine even if someone just writes a terminated/done key

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

If this is our rule then we should apply it to all envs.

let’s just do it uniformily

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

just did, what did I miss?

torchrl/envs/utils.py Show resolved Hide resolved
torchrl/envs/utils.py Outdated Show resolved Hide resolved
test/test_collector.py Outdated Show resolved Hide resolved
if self._single_task:
# select + clone creates 2 tds, but we can create one only
out = TensorDict(
{}, batch_size=self.shared_tensordict_parent.shape, device=self.device
)
for key in self._selected_reset_keys:
if _unravel_key_to_tuple(key)[-1] != "_reset":
key = unravel_keys(key)
if key not in self.reset_keys:
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

the reset key should not be removed here

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

why? Can you explain? It used to be removed (see diff) and i'm pretty sure tests fail if you don't.

Copy link
Contributor

@matteobettini matteobettini Sep 28, 2023

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

there are many reasons we discussed

  • they should alwayes be removed by env.reset() for uniformity
  • this removes them only when self._single_task and not otherwise

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

ok let me try

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

right so the problem is this:
if there is no reset in the input we assume it's all true.
But the buffer is not necessarily all true (it's more likely that it is all false). So if there's a _reset in the buffer, we don't want to pass it in the result of reset because its value isn't synced.

Now you could argue that we must then update the _reset in the buffer with the one that we got from the input or that we inferred but then the risk is that these _resets do not have the same shape.

Transforms such as reward sum look at the tensordict that resuts from _reset before it is pruned from its "_reset" and decide what to do from there.

So here are the cases we need to account for:

  1. the input tensordict does not have a "_reset" (or there is no input tensordict) => we infer that "_reset" is full of True
  2. the input tensordict has a reset but its shape mismatch from the parallel env. This can happen if the "_reset" in the buffer has the shape of the "done" which, like you mentioned on some occasions, can have more dims than 1, but the user just passed a "_reset" that matched the number of envs (1-dim). In that case we could update the buffer by extending on the right I guess, but then the problem is if the env on the other end (in the subprocess) reads that "_reset". Then there's a chance it will be corrupt.
  3. All is good and the "_reset" match.

I think case 2 is hard to handle. The test pass if we exclude the "_reset". And we update the tensordict that results from ParallelEnv._reset in TransformedEnv._reset so the transforms can see the partial "_reset" if there is.
All in all I don't see why not exluding the "_reset" in ParallelEnv._reset: the only method that will look for a "_reset" are TransformedEnv and EnvBase in their _reset / reset methods, respectively. I don't see anywhere where that will break.
If the point is that it should also be excluded inf the case where there is more than 1 task I totally agree!

Copy link
Contributor

@matteobettini matteobettini Sep 28, 2023

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

``> If the point is that it should also be excluded inf the case where there is more than 1 task I totally agree!

definitely let's uniform this.

my other point is just that i thought we decided that the convention was that all the _resets are removed only in env.reset() and in this particular instance they are removed bofore. this has 2 problems:

  • env.reset() cannot perform the done check appropriately (it will see a missing _reset and assume it is all true but that is not the case)(this is a problem if the reset was partial and some dones are true after the reset as they should but they cause the check to fail)
  • and that this is not uniform as the batched env _reset() will be the only one to not provide _resets (and this will make env.reset() not perform its checks properly)

this was not a problem before as in env.reset() i was chaching the _resets in the input (before any other call) to check the dones

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

env.reset() cannot perform the done check appropriately (it will see a missing _reset and assume it is all true but that is not the case)(this is a problem if the reset was partial and some dones are true after the reset as they should but they cause the check to fail)

no because env.reset gets its "_reset" from the input tensordict, not the output

and that this is not uniform as the batched env _reset() will be the only one to not provide _resets

Same I don't think it is true: per se all _reset must return a new tensordict and it's written nowhere in the contract of that function that you must copy any part of the content of the input in the output. In other words: we never read "_reset" from the output of _reset. Check the reset method for more context if need reassurance.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

ok got it, my bad.

let's just test that BatchedEnv._reset() does not return _resets in all branches

torchrl/envs/batched_envs.py Outdated Show resolved Hide resolved
torchrl/envs/batched_envs.py Outdated Show resolved Hide resolved
torchrl/envs/common.py Outdated Show resolved Hide resolved
torchrl/envs/common.py Show resolved Hide resolved
@vmoens
Copy link
Contributor Author

vmoens commented Sep 28, 2023

pettingzoo has truncated and terminated in each agent group so we can adapt that

I don't get this comment, what do you mean? Is there some extra work needed? I think we cover multiple truncated and terminated within the data structure

@vmoens
Copy link
Contributor Author

vmoens commented Sep 28, 2023

vmas, smac and similar simulators should output terminated instead of done so that we keep consistency

As discussed, they now output both without you (the user) needing to do anything:

271339955-17596423-5793-4b95-ae79-9f27715d22c6

@matteobettini
Copy link
Contributor

matteobettini commented Sep 28, 2023

As discussed, they now output both without you (the user) needing to do anything:

yep got that but if we keep writing done explicitly instead of terminated in the wrappers then the users will be deceived on the meaning of what comes from the simulator.

I think terminated should be the one provided and done created through a clone

also because done may be overwritten but terminated not

@matteobettini
Copy link
Contributor

I don't get this comment, what do you mean? Is there some extra work needed? I think we cover multiple truncated and terminated within the data structure

right now in pettingzoo there is an or between the 2 to create a done, we should just adapt that

@vmoens
Copy link
Contributor Author

vmoens commented Sep 28, 2023

As discussed, they now output both without you (the user) needing to do anything:

yep got that but if we keep writing done explicitly instead of terminated in the wrappers then the users will be deceived on the meaning of what comes from the simulator.

ah got it, so your point is that as part of this PR all simulators that were writing done should write terminated instead?

I think terminated should be the one provided and done created through a clone

Per se both are always provided
I think it's ok but it's not user-facing, mostly stuff that lives in private methods. We can check that the doc is accurate.

EDIT: another datapoint: Brax explicitely calls its eot "done", not sure if we want to rename that. To me if an env calls something done we document that we also create a terminated but I don't think that

        state["terminated"] = state.get("done").view(*self.reward_spec.shape)

within torchrl is less "deceiving" than the other option

EDIT: if you look at the latest commit, what I did is just placing done in terminated and done in each env. I think it's clearer than placing just one and let the magic happen

@matteobettini
Copy link
Contributor

EDIT: if you look at the latest commit, what I did is just placing done in terminated and done in each env. I think it's clearer than placing just one and let the magic happen

so placing both directly in the env? we can do this but let s make sure to have it uniformly everywhere

Copy link
Contributor

@matteobettini matteobettini left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM, let's just addrress the remaining open comments and we are good to go for me

Thanks for this btw, it is an impressive work!

torchrl/collectors/collectors.py Outdated Show resolved Hide resolved
torchrl/envs/batched_envs.py Outdated Show resolved Hide resolved
torchrl/envs/batched_envs.py Outdated Show resolved Hide resolved
@vmoens vmoens merged commit 802f0e4 into main Sep 29, 2023
@vmoens vmoens deleted the terminal_truncated branch September 29, 2023 15:55
vmoens added a commit to hyerra/rl that referenced this pull request Oct 10, 2023
Co-authored-by: Skander Moalla <37197319+skandermoalla@users.noreply.github.com>
Co-authored-by: Matteo Bettini <55539777+matteobettini@users.noreply.github.com>
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
CLA Signed This label is managed by the Facebook bot. Authors need to sign the CLA before a PR can be reviewed. Environments Adds or modifies an environment wrapper major major refactoring in the code base
Projects
None yet
Development

Successfully merging this pull request may close these issues.

[BUG] SelectTransform and ExcludeTransform don't Select and Exclude
4 participants