-
Notifications
You must be signed in to change notification settings - Fork 327
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[BugFix] Fix LSTM - VecEnv compatibility #1427
Conversation
# def transform_input_spec( | ||
# self, input_spec: CompositeSpec | ||
# ) -> CompositeSpec: | ||
# if not isinstance(input_spec, CompositeSpec): | ||
# raise ValueError( | ||
# f"input_spec was expected to be of type CompositeSpec. Got {type(input_spec)} instead." | ||
# ) | ||
# state_spec = input_spec["_state_spec"] | ||
# for key, spec in self.primers.items(): | ||
# if spec.shape[: len(state_spec.shape)] != state_spec.shape: | ||
# raise RuntimeError( | ||
# f"The leading shape of the primer specs ({self.__class__}) should match the one of the parent env. " | ||
# f"Got state_spec.shape={state_spec.shape} but the '{key}' entry's shape is {spec.shape}." | ||
# ) | ||
# try: | ||
# device = state_spec.device | ||
# except RuntimeError: | ||
# device = self.device | ||
# state_spec[key] = spec.to(device) | ||
# return input_spec | ||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I'd like to keep this commented here
A bit unsure about what to do with it: per se the idea is that values set by the primer can be inputs but I don't thing it'll systematically be the case.
An observation can be an input too if it is not cleared.
However one may want to consider these as inputs to the env...
|
Name | Max | Mean | Ops | Ops on Repo HEAD
|
Change |
---|---|---|---|---|---|
test_single | 0.1394s | 0.1391s | 7.1895 Ops/s | 7.1787 Ops/s | |
test_sync | 0.1498s | 78.2219ms | 12.7841 Ops/s | 12.6920 Ops/s | |
test_async | 0.1994s | 72.2234ms | 13.8459 Ops/s | 14.0156 Ops/s | |
test_simple | 0.6918s | 0.6221s | 1.6074 Ops/s | 1.5921 Ops/s | |
test_transformed | 1.7016s | 1.6416s | 0.6091 Ops/s | 0.6079 Ops/s | |
test_serial | 1.8080s | 1.7495s | 0.5716 Ops/s | 0.5688 Ops/s | |
test_parallel | 1.5762s | 1.4999s | 0.6667 Ops/s | 0.6677 Ops/s | |
test_step_mdp_speed[True-True-True-True-True] | 0.4088ms | 44.8511μs | 22.2960 KOps/s | 22.5315 KOps/s | |
test_step_mdp_speed[True-True-True-True-False] | 51.4000μs | 25.7877μs | 38.7781 KOps/s | 39.2294 KOps/s | |
test_step_mdp_speed[True-True-True-False-True] | 50.6010μs | 31.5594μs | 31.6863 KOps/s | 31.8699 KOps/s | |
test_step_mdp_speed[True-True-True-False-False] | 79.4000μs | 17.6342μs | 56.7081 KOps/s | 57.2098 KOps/s | |
test_step_mdp_speed[True-True-False-True-True] | 70.0000μs | 46.8076μs | 21.3640 KOps/s | 21.6022 KOps/s | |
test_step_mdp_speed[True-True-False-True-False] | 51.4000μs | 27.5048μs | 36.3573 KOps/s | 36.7106 KOps/s | |
test_step_mdp_speed[True-True-False-False-True] | 0.1352ms | 34.3069μs | 29.1486 KOps/s | 29.8510 KOps/s | |
test_step_mdp_speed[True-True-False-False-False] | 0.2167ms | 19.7975μs | 50.5116 KOps/s | 51.3908 KOps/s | |
test_step_mdp_speed[True-False-True-True-True] | 83.0000μs | 48.7447μs | 20.5150 KOps/s | 20.7623 KOps/s | |
test_step_mdp_speed[True-False-True-True-False] | 47.9000μs | 29.3582μs | 34.0620 KOps/s | 34.4631 KOps/s | |
test_step_mdp_speed[True-False-True-False-True] | 58.0010μs | 33.5485μs | 29.8076 KOps/s | 29.7724 KOps/s | |
test_step_mdp_speed[True-False-True-False-False] | 47.0000μs | 20.0087μs | 49.9782 KOps/s | 51.2048 KOps/s | |
test_step_mdp_speed[True-False-False-True-True] | 0.1158ms | 49.8874μs | 20.0452 KOps/s | 19.8906 KOps/s | |
test_step_mdp_speed[True-False-False-True-False] | 96.4000μs | 31.1968μs | 32.0545 KOps/s | 32.8892 KOps/s | |
test_step_mdp_speed[True-False-False-False-True] | 54.6000μs | 35.0626μs | 28.5204 KOps/s | 28.7501 KOps/s | |
test_step_mdp_speed[True-False-False-False-False] | 55.6010μs | 21.5796μs | 46.3401 KOps/s | 47.5952 KOps/s | |
test_step_mdp_speed[False-True-True-True-True] | 74.4010μs | 48.4405μs | 20.6439 KOps/s | 20.8166 KOps/s | |
test_step_mdp_speed[False-True-True-True-False] | 0.1039ms | 29.5325μs | 33.8610 KOps/s | 34.4712 KOps/s | |
test_step_mdp_speed[False-True-True-False-True] | 58.3000μs | 37.3767μs | 26.7547 KOps/s | 27.0398 KOps/s | |
test_step_mdp_speed[False-True-True-False-False] | 3.4922ms | 22.1768μs | 45.0922 KOps/s | 44.4717 KOps/s | |
test_step_mdp_speed[False-True-False-True-True] | 0.1220ms | 49.6498μs | 20.1411 KOps/s | 20.1557 KOps/s | |
test_step_mdp_speed[False-True-False-True-False] | 54.4000μs | 31.0229μs | 32.2342 KOps/s | 32.4732 KOps/s | |
test_step_mdp_speed[False-True-False-False-True] | 70.0000μs | 39.0839μs | 25.5860 KOps/s | 25.9828 KOps/s | |
test_step_mdp_speed[False-True-False-False-False] | 59.7010μs | 23.5734μs | 42.4208 KOps/s | 42.8602 KOps/s | |
test_step_mdp_speed[False-False-True-True-True] | 82.9010μs | 51.7663μs | 19.3176 KOps/s | 19.4252 KOps/s | |
test_step_mdp_speed[False-False-True-True-False] | 61.5000μs | 32.9723μs | 30.3285 KOps/s | 30.4036 KOps/s | |
test_step_mdp_speed[False-False-True-False-True] | 88.5000μs | 39.4409μs | 25.3544 KOps/s | 25.7644 KOps/s | |
test_step_mdp_speed[False-False-True-False-False] | 82.8010μs | 23.3943μs | 42.7454 KOps/s | 43.4843 KOps/s | |
test_step_mdp_speed[False-False-False-True-True] | 0.1370ms | 53.1640μs | 18.8097 KOps/s | 19.0104 KOps/s | |
test_step_mdp_speed[False-False-False-True-False] | 0.6747ms | 34.6048μs | 28.8978 KOps/s | 29.4549 KOps/s | |
test_step_mdp_speed[False-False-False-False-True] | 84.8010μs | 40.4878μs | 24.6988 KOps/s | 25.2218 KOps/s | |
test_step_mdp_speed[False-False-False-False-False] | 48.7010μs | 24.9365μs | 40.1018 KOps/s | 40.6406 KOps/s | |
test_values[generalized_advantage_estimate-True-True] | 13.8457ms | 13.2438ms | 75.5072 Ops/s | 73.5982 Ops/s | |
test_values[vec_generalized_advantage_estimate-True-True] | 49.6258ms | 41.8195ms | 23.9123 Ops/s | 23.0938 Ops/s | |
test_values[td0_return_estimate-False-False] | 0.6586ms | 0.2180ms | 4.5875 KOps/s | 4.2265 KOps/s | |
test_values[td1_return_estimate-False-False] | 13.2415ms | 12.8526ms | 77.8050 Ops/s | 75.8780 Ops/s | |
test_values[vec_td1_return_estimate-False-False] | 51.2700ms | 41.6907ms | 23.9861 Ops/s | 23.4260 Ops/s | |
test_values[td_lambda_return_estimate-True-False] | 31.3629ms | 30.9133ms | 32.3486 Ops/s | 31.0802 Ops/s | |
test_values[vec_td_lambda_return_estimate-True-False] | 46.4901ms | 41.5350ms | 24.0761 Ops/s | 23.4192 Ops/s | |
test_gae_speed[generalized_advantage_estimate-False-1-512] | 11.6593ms | 11.4756ms | 87.1411 Ops/s | 85.6316 Ops/s | |
test_gae_speed[vec_generalized_advantage_estimate-True-1-512] | 6.1906ms | 3.3174ms | 301.4439 Ops/s | 300.9391 Ops/s | |
test_gae_speed[vec_generalized_advantage_estimate-False-1-512] | 0.8924ms | 0.4736ms | 2.1116 KOps/s | 2.1041 KOps/s | |
test_gae_speed[vec_generalized_advantage_estimate-True-32-512] | 57.5788ms | 52.5711ms | 19.0219 Ops/s | 17.8918 Ops/s | |
test_gae_speed[vec_generalized_advantage_estimate-False-32-512] | 11.2277ms | 2.8364ms | 352.5567 Ops/s | 354.2147 Ops/s | |
test_dqn_speed | 8.8088ms | 1.8957ms | 527.5002 Ops/s | 532.0392 Ops/s | |
test_ddpg_speed | 9.6478ms | 2.8011ms | 357.0057 Ops/s | 359.8453 Ops/s | |
test_sac_speed | 15.7404ms | 8.1774ms | 122.2879 Ops/s | 118.5418 Ops/s | |
test_redq_speed | 21.7682ms | 16.2396ms | 61.5780 Ops/s | 60.7531 Ops/s | |
test_redq_deprec_speed | 20.3900ms | 13.0498ms | 76.6296 Ops/s | 76.8492 Ops/s | |
test_td3_speed | 11.8095ms | 10.2777ms | 97.2981 Ops/s | 96.8364 Ops/s | |
test_cql_speed | 33.2727ms | 26.9613ms | 37.0902 Ops/s | 37.2744 Ops/s | |
test_a2c_speed | 12.8022ms | 5.5746ms | 179.3861 Ops/s | 183.8191 Ops/s | |
test_ppo_speed | 10.6342ms | 5.7884ms | 172.7586 Ops/s | 174.8142 Ops/s | |
test_reinforce_speed | 11.6292ms | 4.1339ms | 241.9034 Ops/s | 237.6728 Ops/s | |
test_iql_speed | 27.1844ms | 21.2543ms | 47.0493 Ops/s | 45.7227 Ops/s | |
test_sample_rb[TensorDictReplayBuffer-ListStorage-RandomSampler-4000] | 3.0549ms | 2.6615ms | 375.7331 Ops/s | 374.5212 Ops/s | |
test_sample_rb[TensorDictReplayBuffer-LazyMemmapStorage-RandomSampler-10000] | 5.0921ms | 2.8355ms | 352.6704 Ops/s | 354.8226 Ops/s | |
test_sample_rb[TensorDictReplayBuffer-LazyTensorStorage-RandomSampler-10000] | 5.6716ms | 2.8285ms | 353.5481 Ops/s | 354.6927 Ops/s | |
test_sample_rb[TensorDictReplayBuffer-ListStorage-SamplerWithoutReplacement-4000] | 3.5814ms | 2.6721ms | 374.2352 Ops/s | 375.2004 Ops/s | |
test_sample_rb[TensorDictReplayBuffer-LazyMemmapStorage-SamplerWithoutReplacement-10000] | 5.2328ms | 2.8098ms | 355.8966 Ops/s | 357.7303 Ops/s | |
test_sample_rb[TensorDictReplayBuffer-LazyTensorStorage-SamplerWithoutReplacement-10000] | 4.9375ms | 2.8555ms | 350.1964 Ops/s | 359.2631 Ops/s | |
test_sample_rb[TensorDictPrioritizedReplayBuffer-ListStorage-None-4000] | 3.7453ms | 2.6792ms | 373.2471 Ops/s | 378.8513 Ops/s | |
test_sample_rb[TensorDictPrioritizedReplayBuffer-LazyMemmapStorage-None-10000] | 4.6086ms | 2.8017ms | 356.9249 Ops/s | 352.4715 Ops/s | |
test_sample_rb[TensorDictPrioritizedReplayBuffer-LazyTensorStorage-None-10000] | 5.5610ms | 2.8847ms | 346.6526 Ops/s | 356.2753 Ops/s | |
test_iterate_rb[TensorDictReplayBuffer-ListStorage-RandomSampler-4000] | 3.6620ms | 2.7474ms | 363.9768 Ops/s | 380.0462 Ops/s | |
test_iterate_rb[TensorDictReplayBuffer-LazyMemmapStorage-RandomSampler-10000] | 4.7573ms | 2.8909ms | 345.9135 Ops/s | 356.4181 Ops/s | |
test_iterate_rb[TensorDictReplayBuffer-LazyTensorStorage-RandomSampler-10000] | 4.8827ms | 2.8473ms | 351.2047 Ops/s | 352.9711 Ops/s | |
test_iterate_rb[TensorDictReplayBuffer-ListStorage-SamplerWithoutReplacement-4000] | 3.3580ms | 2.6919ms | 371.4816 Ops/s | 375.1595 Ops/s | |
test_iterate_rb[TensorDictReplayBuffer-LazyMemmapStorage-SamplerWithoutReplacement-10000] | 6.0098ms | 2.8361ms | 352.5910 Ops/s | 358.8161 Ops/s | |
test_iterate_rb[TensorDictReplayBuffer-LazyTensorStorage-SamplerWithoutReplacement-10000] | 4.9902ms | 2.8577ms | 349.9263 Ops/s | 355.3564 Ops/s | |
test_iterate_rb[TensorDictPrioritizedReplayBuffer-ListStorage-None-4000] | 3.6377ms | 2.7591ms | 362.4378 Ops/s | 384.5618 Ops/s | |
test_iterate_rb[TensorDictPrioritizedReplayBuffer-LazyMemmapStorage-None-10000] | 5.0708ms | 2.9070ms | 343.9954 Ops/s | 354.9698 Ops/s | |
test_iterate_rb[TensorDictPrioritizedReplayBuffer-LazyTensorStorage-None-10000] | 5.0887ms | 2.8852ms | 346.5958 Ops/s | 355.3806 Ops/s | |
test_populate_rb[TensorDictReplayBuffer-ListStorage-RandomSampler-400] | 0.2593s | 29.2304ms | 34.2110 Ops/s | 32.0119 Ops/s | |
test_populate_rb[TensorDictReplayBuffer-LazyMemmapStorage-RandomSampler-400] | 0.1444s | 29.5313ms | 33.8624 Ops/s | 33.9138 Ops/s | |
test_populate_rb[TensorDictReplayBuffer-LazyTensorStorage-RandomSampler-400] | 0.1445s | 26.8537ms | 37.2389 Ops/s | 36.7757 Ops/s | |
test_populate_rb[TensorDictReplayBuffer-ListStorage-SamplerWithoutReplacement-400] | 0.1454s | 29.2808ms | 34.1520 Ops/s | 33.8305 Ops/s | |
test_populate_rb[TensorDictReplayBuffer-LazyMemmapStorage-SamplerWithoutReplacement-400] | 0.1474s | 29.2362ms | 34.2042 Ops/s | 36.8514 Ops/s | |
test_populate_rb[TensorDictReplayBuffer-LazyTensorStorage-SamplerWithoutReplacement-400] | 0.1438s | 26.6803ms | 37.4809 Ops/s | 33.7118 Ops/s | |
test_populate_rb[TensorDictPrioritizedReplayBuffer-ListStorage-None-400] | 0.1430s | 29.0018ms | 34.4806 Ops/s | 36.8805 Ops/s | |
test_populate_rb[TensorDictPrioritizedReplayBuffer-LazyMemmapStorage-None-400] | 0.1490s | 26.7430ms | 37.3930 Ops/s | 33.8815 Ops/s | |
test_populate_rb[TensorDictPrioritizedReplayBuffer-LazyTensorStorage-None-400] | 0.1473s | 26.7151ms | 37.4320 Ops/s | 36.7965 Ops/s |
Putting this on hold until #1448 has landed |
# Conflicts: # torchrl/collectors/collectors.py # torchrl/envs/common.py # torchrl/envs/transforms/transforms.py # torchrl/envs/utils.py # torchrl/envs/vec_env.py
# Conflicts: # torchrl/envs/common.py # torchrl/envs/vec_env.py
# Conflicts: # torchrl/data/tensor_specs.py # torchrl/envs/common.py
@albertbou92 tests are now passing but i'm still not 100% confident about TensorDictPrimer. First let me explain what TensorDictPrimer is about:
My main concern is basically that this is all a bit hacky: we're telling the env that it must expect tensors as input when, in fact, they have nothing to do with the env. Perhaps we could add a spec to the policy instead and tell collectors to fetch the specs from the policy as well as the env, and combine the two. That way we'd keep things separated. One long-term solution I could propose is to have
There are drawbacks:
env = make_env()
policy = SafeActor(network, input_spec=CompositeSpec(obs=...), output_spec=CompositeSpec(action=...)) @smorad @matteobettini wdyt? |
Safe modules are a nice optional feature, but forcing practitioners to use them is not very user-friendly. Regarding the LSTM hidden states, the TensorDictPrimer transform places them in the TensorDict and you already obtain them after calling reset. During the reset, the environment creates these hidden states, and in every subsequent step, it passes them forward without changing them but still returns them. So, in my view, it's not unreasonable to consider them as outputs. In fact, same way you can think of output as what the environment produces, you can think of input as what the environment requires to take a step. However, not all the tensors we define in the input_spec fit this definition. For example, when using the RewardSum transform, the episode_reward specs are part of the input_specs, even though they are not something the environment uses to perform steps. |
Wouldn't forcing every modules to be Safe in the library be quite bc-breaking? It seems like keeping the primers in output spec is the best solution as of now. at the end of the day they are still data that needs to be given to step_mdp and as policy input |
Description
Fixes #1406 by ensuring that env.step copies any preallocated "next" value to the next td.
cc @albertbou92