-
Notifications
You must be signed in to change notification settings - Fork 328
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[Feature] Support for GRU #1586
Conversation
# Conflicts: # torchrl/collectors/collectors.py # torchrl/envs/common.py # torchrl/envs/transforms/transforms.py # torchrl/envs/utils.py # torchrl/envs/vec_env.py
# Conflicts: # torchrl/envs/common.py # torchrl/envs/vec_env.py
# Conflicts: # torchrl/data/tensor_specs.py # torchrl/envs/common.py
# Conflicts: # test/mocking_classes.py
|
Name | Max | Mean | Ops | Ops on Repo HEAD
|
Change |
---|---|---|---|---|---|
test_single | 0.1031s | 0.1022s | 9.7833 Ops/s | 9.7649 Ops/s | |
test_sync | 58.4572ms | 54.4339ms | 18.3709 Ops/s | 17.6545 Ops/s | |
test_async | 90.3134ms | 53.3863ms | 18.7314 Ops/s | 18.9899 Ops/s | |
test_simple | 0.8979s | 0.8279s | 1.2079 Ops/s | 1.2125 Ops/s | |
test_transformed | 1.0981s | 1.0369s | 0.9645 Ops/s | 0.9433 Ops/s | |
test_serial | 2.2668s | 2.2042s | 0.4537 Ops/s | 0.4508 Ops/s | |
test_parallel | 1.9131s | 1.8831s | 0.5310 Ops/s | 0.5298 Ops/s | |
test_step_mdp_speed[True-True-True-True-True] | 0.3062ms | 44.4846μs | 22.4797 KOps/s | 22.2468 KOps/s | |
test_step_mdp_speed[True-True-True-True-False] | 0.3030ms | 25.3514μs | 39.4456 KOps/s | 39.1394 KOps/s | |
test_step_mdp_speed[True-True-True-False-True] | 58.1010μs | 31.2923μs | 31.9567 KOps/s | 31.4288 KOps/s | |
test_step_mdp_speed[True-True-True-False-False] | 48.1010μs | 17.3129μs | 57.7603 KOps/s | 56.5134 KOps/s | |
test_step_mdp_speed[True-True-False-True-True] | 0.2306ms | 46.0573μs | 21.7121 KOps/s | 21.3508 KOps/s | |
test_step_mdp_speed[True-True-False-True-False] | 0.1165ms | 26.8969μs | 37.1791 KOps/s | 36.4559 KOps/s | |
test_step_mdp_speed[True-True-False-False-True] | 0.1203ms | 33.6319μs | 29.7337 KOps/s | 29.5780 KOps/s | |
test_step_mdp_speed[True-True-False-False-False] | 45.7010μs | 19.2955μs | 51.8256 KOps/s | 50.2864 KOps/s | |
test_step_mdp_speed[True-False-True-True-True] | 99.1010μs | 48.3106μs | 20.6994 KOps/s | 20.4550 KOps/s | |
test_step_mdp_speed[True-False-True-True-False] | 62.1010μs | 28.7694μs | 34.7592 KOps/s | 34.2076 KOps/s | |
test_step_mdp_speed[True-False-True-False-True] | 0.1110ms | 33.5746μs | 29.7844 KOps/s | 29.2930 KOps/s | |
test_step_mdp_speed[True-False-True-False-False] | 44.7010μs | 19.1762μs | 52.1481 KOps/s | 50.4715 KOps/s | |
test_step_mdp_speed[True-False-False-True-True] | 73.9010μs | 49.3115μs | 20.2793 KOps/s | 19.8742 KOps/s | |
test_step_mdp_speed[True-False-False-True-False] | 52.7000μs | 30.4732μs | 32.8157 KOps/s | 32.2213 KOps/s | |
test_step_mdp_speed[True-False-False-False-True] | 74.5000μs | 34.8706μs | 28.6774 KOps/s | 27.7626 KOps/s | |
test_step_mdp_speed[True-False-False-False-False] | 48.9010μs | 20.8564μs | 47.9470 KOps/s | 46.7362 KOps/s | |
test_step_mdp_speed[False-True-True-True-True] | 0.1412ms | 47.6408μs | 20.9904 KOps/s | 20.2016 KOps/s | |
test_step_mdp_speed[False-True-True-True-False] | 98.9010μs | 28.9665μs | 34.5226 KOps/s | 33.8655 KOps/s | |
test_step_mdp_speed[False-True-True-False-True] | 0.1109ms | 37.1839μs | 26.8933 KOps/s | 26.6511 KOps/s | |
test_step_mdp_speed[False-True-True-False-False] | 56.4000μs | 21.8449μs | 45.7773 KOps/s | 45.7434 KOps/s | |
test_step_mdp_speed[False-True-False-True-True] | 77.2000μs | 49.1036μs | 20.3651 KOps/s | 19.9765 KOps/s | |
test_step_mdp_speed[False-True-False-True-False] | 93.9010μs | 30.8607μs | 32.4036 KOps/s | 32.1362 KOps/s | |
test_step_mdp_speed[False-True-False-False-True] | 63.0000μs | 38.9406μs | 25.6802 KOps/s | 25.6094 KOps/s | |
test_step_mdp_speed[False-True-False-False-False] | 55.0000μs | 23.7155μs | 42.1665 KOps/s | 42.1101 KOps/s | |
test_step_mdp_speed[False-False-True-True-True] | 91.5010μs | 51.0137μs | 19.6026 KOps/s | 19.2124 KOps/s | |
test_step_mdp_speed[False-False-True-True-False] | 59.7010μs | 32.3679μs | 30.8948 KOps/s | 30.1567 KOps/s | |
test_step_mdp_speed[False-False-True-False-True] | 0.1013ms | 38.7905μs | 25.7795 KOps/s | 25.4915 KOps/s | |
test_step_mdp_speed[False-False-True-False-False] | 48.8010μs | 23.4913μs | 42.5690 KOps/s | 42.0048 KOps/s | |
test_step_mdp_speed[False-False-False-True-True] | 82.2010μs | 52.1746μs | 19.1664 KOps/s | 18.7112 KOps/s | |
test_step_mdp_speed[False-False-False-True-False] | 57.9000μs | 33.9210μs | 29.4803 KOps/s | 28.7635 KOps/s | |
test_step_mdp_speed[False-False-False-False-True] | 86.0010μs | 40.1102μs | 24.9313 KOps/s | 24.7838 KOps/s | |
test_step_mdp_speed[False-False-False-False-False] | 45.6010μs | 24.9107μs | 40.1433 KOps/s | 39.8593 KOps/s | |
test_values[generalized_advantage_estimate-True-True] | 16.3688ms | 13.3072ms | 75.1472 Ops/s | 72.5766 Ops/s | |
test_values[vec_generalized_advantage_estimate-True-True] | 82.7470ms | 45.4507ms | 22.0019 Ops/s | 23.0195 Ops/s | |
test_values[td0_return_estimate-False-False] | 0.8018ms | 0.5132ms | 1.9486 KOps/s | 2.1963 KOps/s | |
test_values[td1_return_estimate-False-False] | 14.2167ms | 12.7839ms | 78.2236 Ops/s | 75.3801 Ops/s | |
test_values[vec_td1_return_estimate-False-False] | 47.7961ms | 43.3699ms | 23.0575 Ops/s | 23.3550 Ops/s | |
test_values[td_lambda_return_estimate-True-False] | 31.6770ms | 31.3178ms | 31.9308 Ops/s | 31.6961 Ops/s | |
test_values[vec_td_lambda_return_estimate-True-False] | 68.1950ms | 44.8107ms | 22.3161 Ops/s | 21.9697 Ops/s | |
test_gae_speed[generalized_advantage_estimate-False-1-512] | 11.4179ms | 11.2111ms | 89.1970 Ops/s | 87.2095 Ops/s | |
test_gae_speed[vec_generalized_advantage_estimate-True-1-512] | 6.4125ms | 4.5396ms | 220.2826 Ops/s | 247.7928 Ops/s | |
test_gae_speed[vec_generalized_advantage_estimate-False-1-512] | 0.8722ms | 0.5594ms | 1.7877 KOps/s | 1.7568 KOps/s | |
test_gae_speed[vec_generalized_advantage_estimate-True-32-512] | 58.4616ms | 56.6543ms | 17.6509 Ops/s | 18.1174 Ops/s | |
test_gae_speed[vec_generalized_advantage_estimate-False-32-512] | 5.9892ms | 3.4968ms | 285.9762 Ops/s | 293.3687 Ops/s | |
test_dqn_speed | 28.0451ms | 2.5107ms | 398.2952 Ops/s | 412.5639 Ops/s | |
test_ddpg_speed | 9.6303ms | 4.3924ms | 227.6658 Ops/s | 228.2371 Ops/s | |
test_sac_speed | 18.1245ms | 12.0443ms | 83.0269 Ops/s | 83.0898 Ops/s | |
test_redq_speed | 27.9333ms | 20.5553ms | 48.6494 Ops/s | 49.9200 Ops/s | |
test_redq_deprec_speed | 24.1642ms | 17.9829ms | 55.6084 Ops/s | 56.3842 Ops/s | |
test_td3_speed | 14.3262ms | 12.9422ms | 77.2667 Ops/s | 77.9363 Ops/s | |
test_cql_speed | 40.9718ms | 34.8096ms | 28.7277 Ops/s | 27.7868 Ops/s | |
test_a2c_speed | 17.0230ms | 8.0949ms | 123.5352 Ops/s | 136.3088 Ops/s | |
test_ppo_speed | 14.0482ms | 8.3944ms | 119.1276 Ops/s | 132.5872 Ops/s | |
test_reinforce_speed | 10.2739ms | 6.5347ms | 153.0295 Ops/s | 181.5480 Ops/s | |
test_iql_speed | 45.5387ms | 34.5527ms | 28.9413 Ops/s | 34.6699 Ops/s | |
test_sample_rb[TensorDictReplayBuffer-ListStorage-RandomSampler-4000] | 3.3870ms | 2.7245ms | 367.0418 Ops/s | 367.2151 Ops/s | |
test_sample_rb[TensorDictReplayBuffer-LazyMemmapStorage-RandomSampler-10000] | 5.7940ms | 2.8688ms | 348.5779 Ops/s | 347.1899 Ops/s | |
test_sample_rb[TensorDictReplayBuffer-LazyTensorStorage-RandomSampler-10000] | 4.4114ms | 2.8536ms | 350.4289 Ops/s | 349.3742 Ops/s | |
test_sample_rb[TensorDictReplayBuffer-ListStorage-SamplerWithoutReplacement-4000] | 0.2424s | 3.3732ms | 296.4570 Ops/s | 367.9122 Ops/s | |
test_sample_rb[TensorDictReplayBuffer-LazyMemmapStorage-SamplerWithoutReplacement-10000] | 4.9938ms | 2.8569ms | 350.0265 Ops/s | 349.5225 Ops/s | |
test_sample_rb[TensorDictReplayBuffer-LazyTensorStorage-SamplerWithoutReplacement-10000] | 4.6902ms | 2.8651ms | 349.0233 Ops/s | 349.1161 Ops/s | |
test_sample_rb[TensorDictPrioritizedReplayBuffer-ListStorage-None-4000] | 3.2691ms | 2.6985ms | 370.5702 Ops/s | 365.4642 Ops/s | |
test_sample_rb[TensorDictPrioritizedReplayBuffer-LazyMemmapStorage-None-10000] | 5.0894ms | 2.8719ms | 348.2070 Ops/s | 350.1125 Ops/s | |
test_sample_rb[TensorDictPrioritizedReplayBuffer-LazyTensorStorage-None-10000] | 5.0019ms | 2.9116ms | 343.4499 Ops/s | 347.7931 Ops/s | |
test_iterate_rb[TensorDictReplayBuffer-ListStorage-RandomSampler-4000] | 3.6712ms | 2.7053ms | 369.6489 Ops/s | 368.0191 Ops/s | |
test_iterate_rb[TensorDictReplayBuffer-LazyMemmapStorage-RandomSampler-10000] | 5.6247ms | 2.8690ms | 348.5535 Ops/s | 352.5837 Ops/s | |
test_iterate_rb[TensorDictReplayBuffer-LazyTensorStorage-RandomSampler-10000] | 5.5059ms | 2.9065ms | 344.0524 Ops/s | 348.1628 Ops/s | |
test_iterate_rb[TensorDictReplayBuffer-ListStorage-SamplerWithoutReplacement-4000] | 3.6354ms | 2.7240ms | 367.1123 Ops/s | 366.6386 Ops/s | |
test_iterate_rb[TensorDictReplayBuffer-LazyMemmapStorage-SamplerWithoutReplacement-10000] | 0.1490s | 3.2955ms | 303.4451 Ops/s | 350.0042 Ops/s | |
test_iterate_rb[TensorDictReplayBuffer-LazyTensorStorage-SamplerWithoutReplacement-10000] | 0.1438s | 3.2810ms | 304.7855 Ops/s | 350.8307 Ops/s | |
test_iterate_rb[TensorDictPrioritizedReplayBuffer-ListStorage-None-4000] | 3.8942ms | 2.7101ms | 368.9867 Ops/s | 371.0105 Ops/s | |
test_iterate_rb[TensorDictPrioritizedReplayBuffer-LazyMemmapStorage-None-10000] | 4.5261ms | 2.8866ms | 346.4302 Ops/s | 346.8222 Ops/s | |
test_iterate_rb[TensorDictPrioritizedReplayBuffer-LazyTensorStorage-None-10000] | 5.3493ms | 2.8637ms | 349.1964 Ops/s | 349.8144 Ops/s | |
test_populate_rb[TensorDictReplayBuffer-ListStorage-RandomSampler-400] | 0.2653s | 29.7072ms | 33.6619 Ops/s | 31.6736 Ops/s | |
test_populate_rb[TensorDictReplayBuffer-LazyMemmapStorage-RandomSampler-400] | 0.1507s | 29.8976ms | 33.4475 Ops/s | 33.0632 Ops/s | |
test_populate_rb[TensorDictReplayBuffer-LazyTensorStorage-RandomSampler-400] | 0.1543s | 30.0831ms | 33.2413 Ops/s | 36.5092 Ops/s | |
test_populate_rb[TensorDictReplayBuffer-ListStorage-SamplerWithoutReplacement-400] | 0.1476s | 27.3554ms | 36.5558 Ops/s | 33.2745 Ops/s | |
test_populate_rb[TensorDictReplayBuffer-LazyMemmapStorage-SamplerWithoutReplacement-400] | 0.1483s | 29.8583ms | 33.4916 Ops/s | 36.7406 Ops/s | |
test_populate_rb[TensorDictReplayBuffer-LazyTensorStorage-SamplerWithoutReplacement-400] | 0.1452s | 27.0742ms | 36.9355 Ops/s | 33.2963 Ops/s | |
test_populate_rb[TensorDictPrioritizedReplayBuffer-ListStorage-None-400] | 0.1538s | 29.9478ms | 33.3914 Ops/s | 36.4342 Ops/s | |
test_populate_rb[TensorDictPrioritizedReplayBuffer-LazyMemmapStorage-None-400] | 0.1524s | 27.6096ms | 36.2193 Ops/s | 33.1728 Ops/s | |
test_populate_rb[TensorDictPrioritizedReplayBuffer-LazyTensorStorage-None-400] | 0.1503s | 30.1574ms | 33.1594 Ops/s | 36.5810 Ops/s |
@matteobettini this is ready for an initial review cycle |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM, some comments
In particular, for MARL, i am interested in understanding if it could work wit
- an input tensor of shape [*B,T,*N,F]
- in a tensordict with shape [*B,T]
|
||
If in temporal mode, it is expected that the last dimension of the tensordict | ||
marks the number of steps. There is no constrain on the dimensionality of the | ||
tensordict (except that it must be greater than one for temporal inputs). |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
so, just to clarify, if I have
- an input tensor of shape [*B,T,*N,F]
- in a tensordict with shape [*B,T]
my *N dimensions are treated as batch?
Is this usecase supported?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yes, see the test_single_step_vs_multi
test
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
in that test the *B is part of the tensordict.
Here I am asking about the *N
like the agent dimension in MARL for example, after the T, but before the F
else: | ||
tensordict_shaped = tensordict.reshape(-1).unsqueeze(-1) | ||
|
||
is_init = tensordict_shaped.get("is_init").squeeze(-1) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
here we are assuming is_init
has the tensordict shape [*B,T]?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
not sure what makes you think that? If so is it a problem? Can you elaborate?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
- What makes me think that is that we consider its last dimension as T in the lines after
- What might become a problem is that 1) we are hardcoding the name "is_init". 2) it might have a different shape as it follows a done no?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think this issue and the one above boil down to what we want to accept as input to this module.
We need to have a way to tell what the time dimension is.
- One option is to check the dim names but it's hard to get people to adopt that. I tried and got multiple issues of people saying that their td was lacking the "time" dim name and they did not know what to do.
- Another option is to pass the time dim in the constructor but it's brittle. If you say
time_dim=3
you won't be able to change the batch size, if you saytime_dim=-3
you need to say with respect to what (obs, init, tensordict?)
A regular usage is assumed to be that the time is the last dim of the tensordict, but it's true that we could cover others. Otherwise we assume that the TD provided is always the root and always has time dim at the end. Handling batch-dims after the time dim can also be considered.
Another issue is that we assume that the hidden state can more or less be indexed by is_init but i'm not sure what this looks like in cases where you have multiple done states and multiple observations. The batch size of the obs could differ from the done state (eg in VMAS IIUC you have one done state per group of agents but one obs per agent).
In MARL settings you will also need to consider if the LSTM reads the aggregated observations of not, if you have one LSMT per agent or not etc. Overall it starts to look like a different class to me.
TL;DR
Happy to have a follow-up work on GRU and LSTM to make them compatible with more convoluted data structures, but do you think we can integrate GRU as it is now since it applies to most single agents use cases already?
If you have a precise list of queries of what this class should also cover i'd be glad to look at it.
What I would need is: what are the data structures that need to be considered, what behaviour we expect, and what we want to explicitly forbid because there is a solution. To give you an example, IMO the example of the done that has a shape that differs from the obs can be easily solved by forbidding that and telling users to append an rename + unsqueeze + expand transform before (which all work as regular nn.Modules).
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yes let's go ahead with this as is and we can keep discussing on this. Also because the LSTM has the same considerations that can be applied so we can tackle this in a future PR.
The example
TnesorDict{
"is_init": Tensor(shape=[*B,T])
"group": TensorDict{
"obs": Tensor(shape=[*B,T,*N,F])
"recurrent": Tensor(shape=[*B,T,*N,H])
batch_size=[*B, T, *N]
}
batch_size=[*B,T]
}
The assumptions
- i would assume that is_init, the input, and the recurrent state have all the same batch_size and the last dim respectively of 1, F, H. In the example above, is_init might have to be expanded into the group before passing it to the module
- i would assume (like in value functions) that the root td is always the one passed to the module and its last dim the time dimension. that way i know where the time dimension is (without the need to mark it) and know that all the other dimensions in my input apart from that and the last are batches
What we expect
We then woud expect that the input and recurrent state they are reshaped with batch=[*B x *N] and time=[T], ready to be passed to the inner GRU/LSTM. ands then they are reshaped back.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
i would assume that is_init, the input, and the recurrent state have all the same batch_size and the last dim respectively of 1, F, H. In the example above, is_init might have to be expanded into the group before passing it to the module
my view in this case is (1) raise an exception because we can't really tell what to do anticipatively with a is_init that has a mismatching shape and (2) provide info on how to do the shape modif
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
If only first batch dims were working with rnns, sigh!
No description provided.