Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[BugFix] Fix LSTM - VecEnv compatibility #1427

Merged
merged 31 commits into from
Sep 2, 2023
Merged

[BugFix] Fix LSTM - VecEnv compatibility #1427

merged 31 commits into from
Sep 2, 2023

Conversation

vmoens
Copy link
Contributor

@vmoens vmoens commented Jul 28, 2023

Description

Fixes #1406 by ensuring that env.step copies any preallocated "next" value to the next td.

cc @albertbou92

@facebook-github-bot facebook-github-bot added the CLA Signed This label is managed by the Facebook bot. Authors need to sign the CLA before a PR can be reviewed. label Jul 28, 2023
@vmoens vmoens marked this pull request as ready for review July 28, 2023 15:01
Comment on lines 2965 to 2985
# def transform_input_spec(
# self, input_spec: CompositeSpec
# ) -> CompositeSpec:
# if not isinstance(input_spec, CompositeSpec):
# raise ValueError(
# f"input_spec was expected to be of type CompositeSpec. Got {type(input_spec)} instead."
# )
# state_spec = input_spec["_state_spec"]
# for key, spec in self.primers.items():
# if spec.shape[: len(state_spec.shape)] != state_spec.shape:
# raise RuntimeError(
# f"The leading shape of the primer specs ({self.__class__}) should match the one of the parent env. "
# f"Got state_spec.shape={state_spec.shape} but the '{key}' entry's shape is {spec.shape}."
# )
# try:
# device = state_spec.device
# except RuntimeError:
# device = self.device
# state_spec[key] = spec.to(device)
# return input_spec

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'd like to keep this commented here
A bit unsure about what to do with it: per se the idea is that values set by the primer can be inputs but I don't thing it'll systematically be the case.
An observation can be an input too if it is not cleared.
However one may want to consider these as inputs to the env...

@github-actions
Copy link

github-actions bot commented Jul 28, 2023

$\color{#D29922}\textsf{\Large⚠\kern{0.2cm}\normalsize Warning}$ Result of CPU Benchmark Tests

Total Benchmarks: 89. Improved: $\large\color{#35bf28}5$. Worsened: $\large\color{#d91a1a}3$.

Expand to view detailed results
Name Max Mean Ops Ops on Repo HEAD Change
test_single 0.1394s 0.1391s 7.1895 Ops/s 7.1787 Ops/s $\color{#35bf28}+0.15\%$
test_sync 0.1498s 78.2219ms 12.7841 Ops/s 12.6920 Ops/s $\color{#35bf28}+0.73\%$
test_async 0.1994s 72.2234ms 13.8459 Ops/s 14.0156 Ops/s $\color{#d91a1a}-1.21\%$
test_simple 0.6918s 0.6221s 1.6074 Ops/s 1.5921 Ops/s $\color{#35bf28}+0.96\%$
test_transformed 1.7016s 1.6416s 0.6091 Ops/s 0.6079 Ops/s $\color{#35bf28}+0.20\%$
test_serial 1.8080s 1.7495s 0.5716 Ops/s 0.5688 Ops/s $\color{#35bf28}+0.49\%$
test_parallel 1.5762s 1.4999s 0.6667 Ops/s 0.6677 Ops/s $\color{#d91a1a}-0.15\%$
test_step_mdp_speed[True-True-True-True-True] 0.4088ms 44.8511μs 22.2960 KOps/s 22.5315 KOps/s $\color{#d91a1a}-1.05\%$
test_step_mdp_speed[True-True-True-True-False] 51.4000μs 25.7877μs 38.7781 KOps/s 39.2294 KOps/s $\color{#d91a1a}-1.15\%$
test_step_mdp_speed[True-True-True-False-True] 50.6010μs 31.5594μs 31.6863 KOps/s 31.8699 KOps/s $\color{#d91a1a}-0.58\%$
test_step_mdp_speed[True-True-True-False-False] 79.4000μs 17.6342μs 56.7081 KOps/s 57.2098 KOps/s $\color{#d91a1a}-0.88\%$
test_step_mdp_speed[True-True-False-True-True] 70.0000μs 46.8076μs 21.3640 KOps/s 21.6022 KOps/s $\color{#d91a1a}-1.10\%$
test_step_mdp_speed[True-True-False-True-False] 51.4000μs 27.5048μs 36.3573 KOps/s 36.7106 KOps/s $\color{#d91a1a}-0.96\%$
test_step_mdp_speed[True-True-False-False-True] 0.1352ms 34.3069μs 29.1486 KOps/s 29.8510 KOps/s $\color{#d91a1a}-2.35\%$
test_step_mdp_speed[True-True-False-False-False] 0.2167ms 19.7975μs 50.5116 KOps/s 51.3908 KOps/s $\color{#d91a1a}-1.71\%$
test_step_mdp_speed[True-False-True-True-True] 83.0000μs 48.7447μs 20.5150 KOps/s 20.7623 KOps/s $\color{#d91a1a}-1.19\%$
test_step_mdp_speed[True-False-True-True-False] 47.9000μs 29.3582μs 34.0620 KOps/s 34.4631 KOps/s $\color{#d91a1a}-1.16\%$
test_step_mdp_speed[True-False-True-False-True] 58.0010μs 33.5485μs 29.8076 KOps/s 29.7724 KOps/s $\color{#35bf28}+0.12\%$
test_step_mdp_speed[True-False-True-False-False] 47.0000μs 20.0087μs 49.9782 KOps/s 51.2048 KOps/s $\color{#d91a1a}-2.40\%$
test_step_mdp_speed[True-False-False-True-True] 0.1158ms 49.8874μs 20.0452 KOps/s 19.8906 KOps/s $\color{#35bf28}+0.78\%$
test_step_mdp_speed[True-False-False-True-False] 96.4000μs 31.1968μs 32.0545 KOps/s 32.8892 KOps/s $\color{#d91a1a}-2.54\%$
test_step_mdp_speed[True-False-False-False-True] 54.6000μs 35.0626μs 28.5204 KOps/s 28.7501 KOps/s $\color{#d91a1a}-0.80\%$
test_step_mdp_speed[True-False-False-False-False] 55.6010μs 21.5796μs 46.3401 KOps/s 47.5952 KOps/s $\color{#d91a1a}-2.64\%$
test_step_mdp_speed[False-True-True-True-True] 74.4010μs 48.4405μs 20.6439 KOps/s 20.8166 KOps/s $\color{#d91a1a}-0.83\%$
test_step_mdp_speed[False-True-True-True-False] 0.1039ms 29.5325μs 33.8610 KOps/s 34.4712 KOps/s $\color{#d91a1a}-1.77\%$
test_step_mdp_speed[False-True-True-False-True] 58.3000μs 37.3767μs 26.7547 KOps/s 27.0398 KOps/s $\color{#d91a1a}-1.05\%$
test_step_mdp_speed[False-True-True-False-False] 3.4922ms 22.1768μs 45.0922 KOps/s 44.4717 KOps/s $\color{#35bf28}+1.40\%$
test_step_mdp_speed[False-True-False-True-True] 0.1220ms 49.6498μs 20.1411 KOps/s 20.1557 KOps/s $\color{#d91a1a}-0.07\%$
test_step_mdp_speed[False-True-False-True-False] 54.4000μs 31.0229μs 32.2342 KOps/s 32.4732 KOps/s $\color{#d91a1a}-0.74\%$
test_step_mdp_speed[False-True-False-False-True] 70.0000μs 39.0839μs 25.5860 KOps/s 25.9828 KOps/s $\color{#d91a1a}-1.53\%$
test_step_mdp_speed[False-True-False-False-False] 59.7010μs 23.5734μs 42.4208 KOps/s 42.8602 KOps/s $\color{#d91a1a}-1.03\%$
test_step_mdp_speed[False-False-True-True-True] 82.9010μs 51.7663μs 19.3176 KOps/s 19.4252 KOps/s $\color{#d91a1a}-0.55\%$
test_step_mdp_speed[False-False-True-True-False] 61.5000μs 32.9723μs 30.3285 KOps/s 30.4036 KOps/s $\color{#d91a1a}-0.25\%$
test_step_mdp_speed[False-False-True-False-True] 88.5000μs 39.4409μs 25.3544 KOps/s 25.7644 KOps/s $\color{#d91a1a}-1.59\%$
test_step_mdp_speed[False-False-True-False-False] 82.8010μs 23.3943μs 42.7454 KOps/s 43.4843 KOps/s $\color{#d91a1a}-1.70\%$
test_step_mdp_speed[False-False-False-True-True] 0.1370ms 53.1640μs 18.8097 KOps/s 19.0104 KOps/s $\color{#d91a1a}-1.06\%$
test_step_mdp_speed[False-False-False-True-False] 0.6747ms 34.6048μs 28.8978 KOps/s 29.4549 KOps/s $\color{#d91a1a}-1.89\%$
test_step_mdp_speed[False-False-False-False-True] 84.8010μs 40.4878μs 24.6988 KOps/s 25.2218 KOps/s $\color{#d91a1a}-2.07\%$
test_step_mdp_speed[False-False-False-False-False] 48.7010μs 24.9365μs 40.1018 KOps/s 40.6406 KOps/s $\color{#d91a1a}-1.33\%$
test_values[generalized_advantage_estimate-True-True] 13.8457ms 13.2438ms 75.5072 Ops/s 73.5982 Ops/s $\color{#35bf28}+2.59\%$
test_values[vec_generalized_advantage_estimate-True-True] 49.6258ms 41.8195ms 23.9123 Ops/s 23.0938 Ops/s $\color{#35bf28}+3.54\%$
test_values[td0_return_estimate-False-False] 0.6586ms 0.2180ms 4.5875 KOps/s 4.2265 KOps/s $\textbf{\color{#35bf28}+8.54\%}$
test_values[td1_return_estimate-False-False] 13.2415ms 12.8526ms 77.8050 Ops/s 75.8780 Ops/s $\color{#35bf28}+2.54\%$
test_values[vec_td1_return_estimate-False-False] 51.2700ms 41.6907ms 23.9861 Ops/s 23.4260 Ops/s $\color{#35bf28}+2.39\%$
test_values[td_lambda_return_estimate-True-False] 31.3629ms 30.9133ms 32.3486 Ops/s 31.0802 Ops/s $\color{#35bf28}+4.08\%$
test_values[vec_td_lambda_return_estimate-True-False] 46.4901ms 41.5350ms 24.0761 Ops/s 23.4192 Ops/s $\color{#35bf28}+2.80\%$
test_gae_speed[generalized_advantage_estimate-False-1-512] 11.6593ms 11.4756ms 87.1411 Ops/s 85.6316 Ops/s $\color{#35bf28}+1.76\%$
test_gae_speed[vec_generalized_advantage_estimate-True-1-512] 6.1906ms 3.3174ms 301.4439 Ops/s 300.9391 Ops/s $\color{#35bf28}+0.17\%$
test_gae_speed[vec_generalized_advantage_estimate-False-1-512] 0.8924ms 0.4736ms 2.1116 KOps/s 2.1041 KOps/s $\color{#35bf28}+0.36\%$
test_gae_speed[vec_generalized_advantage_estimate-True-32-512] 57.5788ms 52.5711ms 19.0219 Ops/s 17.8918 Ops/s $\textbf{\color{#35bf28}+6.32\%}$
test_gae_speed[vec_generalized_advantage_estimate-False-32-512] 11.2277ms 2.8364ms 352.5567 Ops/s 354.2147 Ops/s $\color{#d91a1a}-0.47\%$
test_dqn_speed 8.8088ms 1.8957ms 527.5002 Ops/s 532.0392 Ops/s $\color{#d91a1a}-0.85\%$
test_ddpg_speed 9.6478ms 2.8011ms 357.0057 Ops/s 359.8453 Ops/s $\color{#d91a1a}-0.79\%$
test_sac_speed 15.7404ms 8.1774ms 122.2879 Ops/s 118.5418 Ops/s $\color{#35bf28}+3.16\%$
test_redq_speed 21.7682ms 16.2396ms 61.5780 Ops/s 60.7531 Ops/s $\color{#35bf28}+1.36\%$
test_redq_deprec_speed 20.3900ms 13.0498ms 76.6296 Ops/s 76.8492 Ops/s $\color{#d91a1a}-0.29\%$
test_td3_speed 11.8095ms 10.2777ms 97.2981 Ops/s 96.8364 Ops/s $\color{#35bf28}+0.48\%$
test_cql_speed 33.2727ms 26.9613ms 37.0902 Ops/s 37.2744 Ops/s $\color{#d91a1a}-0.49\%$
test_a2c_speed 12.8022ms 5.5746ms 179.3861 Ops/s 183.8191 Ops/s $\color{#d91a1a}-2.41\%$
test_ppo_speed 10.6342ms 5.7884ms 172.7586 Ops/s 174.8142 Ops/s $\color{#d91a1a}-1.18\%$
test_reinforce_speed 11.6292ms 4.1339ms 241.9034 Ops/s 237.6728 Ops/s $\color{#35bf28}+1.78\%$
test_iql_speed 27.1844ms 21.2543ms 47.0493 Ops/s 45.7227 Ops/s $\color{#35bf28}+2.90\%$
test_sample_rb[TensorDictReplayBuffer-ListStorage-RandomSampler-4000] 3.0549ms 2.6615ms 375.7331 Ops/s 374.5212 Ops/s $\color{#35bf28}+0.32\%$
test_sample_rb[TensorDictReplayBuffer-LazyMemmapStorage-RandomSampler-10000] 5.0921ms 2.8355ms 352.6704 Ops/s 354.8226 Ops/s $\color{#d91a1a}-0.61\%$
test_sample_rb[TensorDictReplayBuffer-LazyTensorStorage-RandomSampler-10000] 5.6716ms 2.8285ms 353.5481 Ops/s 354.6927 Ops/s $\color{#d91a1a}-0.32\%$
test_sample_rb[TensorDictReplayBuffer-ListStorage-SamplerWithoutReplacement-4000] 3.5814ms 2.6721ms 374.2352 Ops/s 375.2004 Ops/s $\color{#d91a1a}-0.26\%$
test_sample_rb[TensorDictReplayBuffer-LazyMemmapStorage-SamplerWithoutReplacement-10000] 5.2328ms 2.8098ms 355.8966 Ops/s 357.7303 Ops/s $\color{#d91a1a}-0.51\%$
test_sample_rb[TensorDictReplayBuffer-LazyTensorStorage-SamplerWithoutReplacement-10000] 4.9375ms 2.8555ms 350.1964 Ops/s 359.2631 Ops/s $\color{#d91a1a}-2.52\%$
test_sample_rb[TensorDictPrioritizedReplayBuffer-ListStorage-None-4000] 3.7453ms 2.6792ms 373.2471 Ops/s 378.8513 Ops/s $\color{#d91a1a}-1.48\%$
test_sample_rb[TensorDictPrioritizedReplayBuffer-LazyMemmapStorage-None-10000] 4.6086ms 2.8017ms 356.9249 Ops/s 352.4715 Ops/s $\color{#35bf28}+1.26\%$
test_sample_rb[TensorDictPrioritizedReplayBuffer-LazyTensorStorage-None-10000] 5.5610ms 2.8847ms 346.6526 Ops/s 356.2753 Ops/s $\color{#d91a1a}-2.70\%$
test_iterate_rb[TensorDictReplayBuffer-ListStorage-RandomSampler-4000] 3.6620ms 2.7474ms 363.9768 Ops/s 380.0462 Ops/s $\color{#d91a1a}-4.23\%$
test_iterate_rb[TensorDictReplayBuffer-LazyMemmapStorage-RandomSampler-10000] 4.7573ms 2.8909ms 345.9135 Ops/s 356.4181 Ops/s $\color{#d91a1a}-2.95\%$
test_iterate_rb[TensorDictReplayBuffer-LazyTensorStorage-RandomSampler-10000] 4.8827ms 2.8473ms 351.2047 Ops/s 352.9711 Ops/s $\color{#d91a1a}-0.50\%$
test_iterate_rb[TensorDictReplayBuffer-ListStorage-SamplerWithoutReplacement-4000] 3.3580ms 2.6919ms 371.4816 Ops/s 375.1595 Ops/s $\color{#d91a1a}-0.98\%$
test_iterate_rb[TensorDictReplayBuffer-LazyMemmapStorage-SamplerWithoutReplacement-10000] 6.0098ms 2.8361ms 352.5910 Ops/s 358.8161 Ops/s $\color{#d91a1a}-1.73\%$
test_iterate_rb[TensorDictReplayBuffer-LazyTensorStorage-SamplerWithoutReplacement-10000] 4.9902ms 2.8577ms 349.9263 Ops/s 355.3564 Ops/s $\color{#d91a1a}-1.53\%$
test_iterate_rb[TensorDictPrioritizedReplayBuffer-ListStorage-None-4000] 3.6377ms 2.7591ms 362.4378 Ops/s 384.5618 Ops/s $\textbf{\color{#d91a1a}-5.75\%}$
test_iterate_rb[TensorDictPrioritizedReplayBuffer-LazyMemmapStorage-None-10000] 5.0708ms 2.9070ms 343.9954 Ops/s 354.9698 Ops/s $\color{#d91a1a}-3.09\%$
test_iterate_rb[TensorDictPrioritizedReplayBuffer-LazyTensorStorage-None-10000] 5.0887ms 2.8852ms 346.5958 Ops/s 355.3806 Ops/s $\color{#d91a1a}-2.47\%$
test_populate_rb[TensorDictReplayBuffer-ListStorage-RandomSampler-400] 0.2593s 29.2304ms 34.2110 Ops/s 32.0119 Ops/s $\textbf{\color{#35bf28}+6.87\%}$
test_populate_rb[TensorDictReplayBuffer-LazyMemmapStorage-RandomSampler-400] 0.1444s 29.5313ms 33.8624 Ops/s 33.9138 Ops/s $\color{#d91a1a}-0.15\%$
test_populate_rb[TensorDictReplayBuffer-LazyTensorStorage-RandomSampler-400] 0.1445s 26.8537ms 37.2389 Ops/s 36.7757 Ops/s $\color{#35bf28}+1.26\%$
test_populate_rb[TensorDictReplayBuffer-ListStorage-SamplerWithoutReplacement-400] 0.1454s 29.2808ms 34.1520 Ops/s 33.8305 Ops/s $\color{#35bf28}+0.95\%$
test_populate_rb[TensorDictReplayBuffer-LazyMemmapStorage-SamplerWithoutReplacement-400] 0.1474s 29.2362ms 34.2042 Ops/s 36.8514 Ops/s $\textbf{\color{#d91a1a}-7.18\%}$
test_populate_rb[TensorDictReplayBuffer-LazyTensorStorage-SamplerWithoutReplacement-400] 0.1438s 26.6803ms 37.4809 Ops/s 33.7118 Ops/s $\textbf{\color{#35bf28}+11.18\%}$
test_populate_rb[TensorDictPrioritizedReplayBuffer-ListStorage-None-400] 0.1430s 29.0018ms 34.4806 Ops/s 36.8805 Ops/s $\textbf{\color{#d91a1a}-6.51\%}$
test_populate_rb[TensorDictPrioritizedReplayBuffer-LazyMemmapStorage-None-400] 0.1490s 26.7430ms 37.3930 Ops/s 33.8815 Ops/s $\textbf{\color{#35bf28}+10.36\%}$
test_populate_rb[TensorDictPrioritizedReplayBuffer-LazyTensorStorage-None-400] 0.1473s 26.7151ms 37.4320 Ops/s 36.7965 Ops/s $\color{#35bf28}+1.73\%$

@vmoens vmoens added the bug Something isn't working label Aug 10, 2023
@vmoens
Copy link
Contributor Author

vmoens commented Aug 10, 2023

Putting this on hold until #1448 has landed
THe _step will be drastically simplified and as a result I expect that the behaviour of TensorDictPrimer will be clearer too.

vmoens added 20 commits August 11, 2023 05:52
# Conflicts:
#	torchrl/collectors/collectors.py
#	torchrl/envs/common.py
#	torchrl/envs/transforms/transforms.py
#	torchrl/envs/utils.py
#	torchrl/envs/vec_env.py
# Conflicts:
#	torchrl/envs/common.py
#	torchrl/envs/vec_env.py
# Conflicts:
#	torchrl/data/tensor_specs.py
#	torchrl/envs/common.py
@vmoens
Copy link
Contributor Author

vmoens commented Sep 1, 2023

@albertbou92 tests are now passing but i'm still not 100% confident about TensorDictPrimer.

First let me explain what TensorDictPrimer is about:

  • Collectors allocate buffers to communicate from process to process based on the environment specs. We basically do env.fake_tensordict() + some shape expansion and use that as a data container for communication.
  • If there is some data that needs to be carried but is not produced by the env but by the policy, we need to pass it too. Therefore, we designed TensorDictPrimer, which adds the missing tensors to the env specs such that env.fake_tensordict is complete.
  • In our case, we want to add the input and output hidden states to the specs.
  • The question is whether we want these to be added to the input or output spec. By definition, output is what the env is producing so that would not make much sense to put it there (in the context of LSTMs). Therefore, we should probably put it in input. However, when calling reset in batched environments, we filter out the data that is not part of output_spec. This means that when calling reset, we won't see the preallocated tensors from TensorDictPrimer. One possible fix would be to also keep the tensors from input_spec if they can be found.
  • The contract of TensorDictPrimer is nevertheless that the values will be available after reset, and as such this should be part of output_spec too.

My main concern is basically that this is all a bit hacky: we're telling the env that it must expect tensors as input when, in fact, they have nothing to do with the env. Perhaps we could add a spec to the policy instead and tell collectors to fetch the specs from the policy as well as the env, and combine the two. That way we'd keep things separated.

One long-term solution I could propose is to have SafeModules that also have input and output specs.
The advantages are:

  • Clear separation between data origin (policy vs env)
  • transforms can be moved from env to policy without problems
  • We will be able to call parent or container on a transform stored within the policy.

There are drawbacks:

  • Currently SafeModules are only accessory (they just do some more checks than regular modules). This would force users to use SafeModules everywhere, which may not be the best experience.
  • The parent and container interface won't be very clean in the transforms.
  • It could be error prone (transforming the env output specs in policy input specs may not be straightforward). Overall, I would expect that the user experience will be negatively impacted, which is a high price to pay. For instance, it could be something like
env = make_env()
policy = SafeActor(network, input_spec=CompositeSpec(obs=...), output_spec=CompositeSpec(action=...))

@smorad @matteobettini wdyt?

@albertbou92
Copy link
Contributor

Safe modules are a nice optional feature, but forcing practitioners to use them is not very user-friendly.

Regarding the LSTM hidden states, the TensorDictPrimer transform places them in the TensorDict and you already obtain them after calling reset. During the reset, the environment creates these hidden states, and in every subsequent step, it passes them forward without changing them but still returns them. So, in my view, it's not unreasonable to consider them as outputs.

In fact, same way you can think of output as what the environment produces, you can think of input as what the environment requires to take a step. However, not all the tensors we define in the input_spec fit this definition. For example, when using the RewardSum transform, the episode_reward specs are part of the input_specs, even though they are not something the environment uses to perform steps.

@matteobettini
Copy link
Contributor

Wouldn't forcing every modules to be Safe in the library be quite bc-breaking?
IMO the impact of that solution might be substantial.

It seems like keeping the primers in output spec is the best solution as of now. at the end of the day they are still data that needs to be given to step_mdp and as policy input

@vmoens vmoens merged commit 2982515 into main Sep 2, 2023
@vmoens vmoens deleted the fix_lstm_penv branch September 2, 2023 05:39
vmoens added a commit to hyerra/rl that referenced this pull request Oct 10, 2023
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
bug Something isn't working CLA Signed This label is managed by the Facebook bot. Authors need to sign the CLA before a PR can be reviewed.
Projects
None yet
Development

Successfully merging this pull request may close these issues.

[BUG] LSTMModule sets hidden state to 0.0 at every step when combined with ParallelEnv.
4 participants