Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Feature] Support for GRU #1586

Merged
merged 37 commits into from
Oct 5, 2023
Merged

[Feature] Support for GRU #1586

merged 37 commits into from
Oct 5, 2023

Conversation

vmoens
Copy link
Contributor

@vmoens vmoens commented Oct 1, 2023

No description provided.

# Conflicts:
#	test/mocking_classes.py
@facebook-github-bot facebook-github-bot added the CLA Signed This label is managed by the Facebook bot. Authors need to sign the CLA before a PR can be reviewed. label Oct 1, 2023
@vmoens vmoens added the enhancement New feature or request label Oct 1, 2023
@vmoens vmoens marked this pull request as draft October 1, 2023 05:18
@github-actions
Copy link

github-actions bot commented Oct 1, 2023

$\color{#D29922}\textsf{\Large⚠\kern{0.2cm}\normalsize Warning}$ Result of CPU Benchmark Tests

Total Benchmarks: 89. Improved: $\large\color{#35bf28}4$. Worsened: $\large\color{#d91a1a}13$.

Expand to view detailed results
Name Max Mean Ops Ops on Repo HEAD Change
test_single 0.1031s 0.1022s 9.7833 Ops/s 9.7649 Ops/s $\color{#35bf28}+0.19\%$
test_sync 58.4572ms 54.4339ms 18.3709 Ops/s 17.6545 Ops/s $\color{#35bf28}+4.06\%$
test_async 90.3134ms 53.3863ms 18.7314 Ops/s 18.9899 Ops/s $\color{#d91a1a}-1.36\%$
test_simple 0.8979s 0.8279s 1.2079 Ops/s 1.2125 Ops/s $\color{#d91a1a}-0.38\%$
test_transformed 1.0981s 1.0369s 0.9645 Ops/s 0.9433 Ops/s $\color{#35bf28}+2.24\%$
test_serial 2.2668s 2.2042s 0.4537 Ops/s 0.4508 Ops/s $\color{#35bf28}+0.64\%$
test_parallel 1.9131s 1.8831s 0.5310 Ops/s 0.5298 Ops/s $\color{#35bf28}+0.24\%$
test_step_mdp_speed[True-True-True-True-True] 0.3062ms 44.4846μs 22.4797 KOps/s 22.2468 KOps/s $\color{#35bf28}+1.05\%$
test_step_mdp_speed[True-True-True-True-False] 0.3030ms 25.3514μs 39.4456 KOps/s 39.1394 KOps/s $\color{#35bf28}+0.78\%$
test_step_mdp_speed[True-True-True-False-True] 58.1010μs 31.2923μs 31.9567 KOps/s 31.4288 KOps/s $\color{#35bf28}+1.68\%$
test_step_mdp_speed[True-True-True-False-False] 48.1010μs 17.3129μs 57.7603 KOps/s 56.5134 KOps/s $\color{#35bf28}+2.21\%$
test_step_mdp_speed[True-True-False-True-True] 0.2306ms 46.0573μs 21.7121 KOps/s 21.3508 KOps/s $\color{#35bf28}+1.69\%$
test_step_mdp_speed[True-True-False-True-False] 0.1165ms 26.8969μs 37.1791 KOps/s 36.4559 KOps/s $\color{#35bf28}+1.98\%$
test_step_mdp_speed[True-True-False-False-True] 0.1203ms 33.6319μs 29.7337 KOps/s 29.5780 KOps/s $\color{#35bf28}+0.53\%$
test_step_mdp_speed[True-True-False-False-False] 45.7010μs 19.2955μs 51.8256 KOps/s 50.2864 KOps/s $\color{#35bf28}+3.06\%$
test_step_mdp_speed[True-False-True-True-True] 99.1010μs 48.3106μs 20.6994 KOps/s 20.4550 KOps/s $\color{#35bf28}+1.19\%$
test_step_mdp_speed[True-False-True-True-False] 62.1010μs 28.7694μs 34.7592 KOps/s 34.2076 KOps/s $\color{#35bf28}+1.61\%$
test_step_mdp_speed[True-False-True-False-True] 0.1110ms 33.5746μs 29.7844 KOps/s 29.2930 KOps/s $\color{#35bf28}+1.68\%$
test_step_mdp_speed[True-False-True-False-False] 44.7010μs 19.1762μs 52.1481 KOps/s 50.4715 KOps/s $\color{#35bf28}+3.32\%$
test_step_mdp_speed[True-False-False-True-True] 73.9010μs 49.3115μs 20.2793 KOps/s 19.8742 KOps/s $\color{#35bf28}+2.04\%$
test_step_mdp_speed[True-False-False-True-False] 52.7000μs 30.4732μs 32.8157 KOps/s 32.2213 KOps/s $\color{#35bf28}+1.84\%$
test_step_mdp_speed[True-False-False-False-True] 74.5000μs 34.8706μs 28.6774 KOps/s 27.7626 KOps/s $\color{#35bf28}+3.30\%$
test_step_mdp_speed[True-False-False-False-False] 48.9010μs 20.8564μs 47.9470 KOps/s 46.7362 KOps/s $\color{#35bf28}+2.59\%$
test_step_mdp_speed[False-True-True-True-True] 0.1412ms 47.6408μs 20.9904 KOps/s 20.2016 KOps/s $\color{#35bf28}+3.90\%$
test_step_mdp_speed[False-True-True-True-False] 98.9010μs 28.9665μs 34.5226 KOps/s 33.8655 KOps/s $\color{#35bf28}+1.94\%$
test_step_mdp_speed[False-True-True-False-True] 0.1109ms 37.1839μs 26.8933 KOps/s 26.6511 KOps/s $\color{#35bf28}+0.91\%$
test_step_mdp_speed[False-True-True-False-False] 56.4000μs 21.8449μs 45.7773 KOps/s 45.7434 KOps/s $\color{#35bf28}+0.07\%$
test_step_mdp_speed[False-True-False-True-True] 77.2000μs 49.1036μs 20.3651 KOps/s 19.9765 KOps/s $\color{#35bf28}+1.95\%$
test_step_mdp_speed[False-True-False-True-False] 93.9010μs 30.8607μs 32.4036 KOps/s 32.1362 KOps/s $\color{#35bf28}+0.83\%$
test_step_mdp_speed[False-True-False-False-True] 63.0000μs 38.9406μs 25.6802 KOps/s 25.6094 KOps/s $\color{#35bf28}+0.28\%$
test_step_mdp_speed[False-True-False-False-False] 55.0000μs 23.7155μs 42.1665 KOps/s 42.1101 KOps/s $\color{#35bf28}+0.13\%$
test_step_mdp_speed[False-False-True-True-True] 91.5010μs 51.0137μs 19.6026 KOps/s 19.2124 KOps/s $\color{#35bf28}+2.03\%$
test_step_mdp_speed[False-False-True-True-False] 59.7010μs 32.3679μs 30.8948 KOps/s 30.1567 KOps/s $\color{#35bf28}+2.45\%$
test_step_mdp_speed[False-False-True-False-True] 0.1013ms 38.7905μs 25.7795 KOps/s 25.4915 KOps/s $\color{#35bf28}+1.13\%$
test_step_mdp_speed[False-False-True-False-False] 48.8010μs 23.4913μs 42.5690 KOps/s 42.0048 KOps/s $\color{#35bf28}+1.34\%$
test_step_mdp_speed[False-False-False-True-True] 82.2010μs 52.1746μs 19.1664 KOps/s 18.7112 KOps/s $\color{#35bf28}+2.43\%$
test_step_mdp_speed[False-False-False-True-False] 57.9000μs 33.9210μs 29.4803 KOps/s 28.7635 KOps/s $\color{#35bf28}+2.49\%$
test_step_mdp_speed[False-False-False-False-True] 86.0010μs 40.1102μs 24.9313 KOps/s 24.7838 KOps/s $\color{#35bf28}+0.60\%$
test_step_mdp_speed[False-False-False-False-False] 45.6010μs 24.9107μs 40.1433 KOps/s 39.8593 KOps/s $\color{#35bf28}+0.71\%$
test_values[generalized_advantage_estimate-True-True] 16.3688ms 13.3072ms 75.1472 Ops/s 72.5766 Ops/s $\color{#35bf28}+3.54\%$
test_values[vec_generalized_advantage_estimate-True-True] 82.7470ms 45.4507ms 22.0019 Ops/s 23.0195 Ops/s $\color{#d91a1a}-4.42\%$
test_values[td0_return_estimate-False-False] 0.8018ms 0.5132ms 1.9486 KOps/s 2.1963 KOps/s $\textbf{\color{#d91a1a}-11.28\%}$
test_values[td1_return_estimate-False-False] 14.2167ms 12.7839ms 78.2236 Ops/s 75.3801 Ops/s $\color{#35bf28}+3.77\%$
test_values[vec_td1_return_estimate-False-False] 47.7961ms 43.3699ms 23.0575 Ops/s 23.3550 Ops/s $\color{#d91a1a}-1.27\%$
test_values[td_lambda_return_estimate-True-False] 31.6770ms 31.3178ms 31.9308 Ops/s 31.6961 Ops/s $\color{#35bf28}+0.74\%$
test_values[vec_td_lambda_return_estimate-True-False] 68.1950ms 44.8107ms 22.3161 Ops/s 21.9697 Ops/s $\color{#35bf28}+1.58\%$
test_gae_speed[generalized_advantage_estimate-False-1-512] 11.4179ms 11.2111ms 89.1970 Ops/s 87.2095 Ops/s $\color{#35bf28}+2.28\%$
test_gae_speed[vec_generalized_advantage_estimate-True-1-512] 6.4125ms 4.5396ms 220.2826 Ops/s 247.7928 Ops/s $\textbf{\color{#d91a1a}-11.10\%}$
test_gae_speed[vec_generalized_advantage_estimate-False-1-512] 0.8722ms 0.5594ms 1.7877 KOps/s 1.7568 KOps/s $\color{#35bf28}+1.76\%$
test_gae_speed[vec_generalized_advantage_estimate-True-32-512] 58.4616ms 56.6543ms 17.6509 Ops/s 18.1174 Ops/s $\color{#d91a1a}-2.57\%$
test_gae_speed[vec_generalized_advantage_estimate-False-32-512] 5.9892ms 3.4968ms 285.9762 Ops/s 293.3687 Ops/s $\color{#d91a1a}-2.52\%$
test_dqn_speed 28.0451ms 2.5107ms 398.2952 Ops/s 412.5639 Ops/s $\color{#d91a1a}-3.46\%$
test_ddpg_speed 9.6303ms 4.3924ms 227.6658 Ops/s 228.2371 Ops/s $\color{#d91a1a}-0.25\%$
test_sac_speed 18.1245ms 12.0443ms 83.0269 Ops/s 83.0898 Ops/s $\color{#d91a1a}-0.08\%$
test_redq_speed 27.9333ms 20.5553ms 48.6494 Ops/s 49.9200 Ops/s $\color{#d91a1a}-2.55\%$
test_redq_deprec_speed 24.1642ms 17.9829ms 55.6084 Ops/s 56.3842 Ops/s $\color{#d91a1a}-1.38\%$
test_td3_speed 14.3262ms 12.9422ms 77.2667 Ops/s 77.9363 Ops/s $\color{#d91a1a}-0.86\%$
test_cql_speed 40.9718ms 34.8096ms 28.7277 Ops/s 27.7868 Ops/s $\color{#35bf28}+3.39\%$
test_a2c_speed 17.0230ms 8.0949ms 123.5352 Ops/s 136.3088 Ops/s $\textbf{\color{#d91a1a}-9.37\%}$
test_ppo_speed 14.0482ms 8.3944ms 119.1276 Ops/s 132.5872 Ops/s $\textbf{\color{#d91a1a}-10.15\%}$
test_reinforce_speed 10.2739ms 6.5347ms 153.0295 Ops/s 181.5480 Ops/s $\textbf{\color{#d91a1a}-15.71\%}$
test_iql_speed 45.5387ms 34.5527ms 28.9413 Ops/s 34.6699 Ops/s $\textbf{\color{#d91a1a}-16.52\%}$
test_sample_rb[TensorDictReplayBuffer-ListStorage-RandomSampler-4000] 3.3870ms 2.7245ms 367.0418 Ops/s 367.2151 Ops/s $\color{#d91a1a}-0.05\%$
test_sample_rb[TensorDictReplayBuffer-LazyMemmapStorage-RandomSampler-10000] 5.7940ms 2.8688ms 348.5779 Ops/s 347.1899 Ops/s $\color{#35bf28}+0.40\%$
test_sample_rb[TensorDictReplayBuffer-LazyTensorStorage-RandomSampler-10000] 4.4114ms 2.8536ms 350.4289 Ops/s 349.3742 Ops/s $\color{#35bf28}+0.30\%$
test_sample_rb[TensorDictReplayBuffer-ListStorage-SamplerWithoutReplacement-4000] 0.2424s 3.3732ms 296.4570 Ops/s 367.9122 Ops/s $\textbf{\color{#d91a1a}-19.42\%}$
test_sample_rb[TensorDictReplayBuffer-LazyMemmapStorage-SamplerWithoutReplacement-10000] 4.9938ms 2.8569ms 350.0265 Ops/s 349.5225 Ops/s $\color{#35bf28}+0.14\%$
test_sample_rb[TensorDictReplayBuffer-LazyTensorStorage-SamplerWithoutReplacement-10000] 4.6902ms 2.8651ms 349.0233 Ops/s 349.1161 Ops/s $\color{#d91a1a}-0.03\%$
test_sample_rb[TensorDictPrioritizedReplayBuffer-ListStorage-None-4000] 3.2691ms 2.6985ms 370.5702 Ops/s 365.4642 Ops/s $\color{#35bf28}+1.40\%$
test_sample_rb[TensorDictPrioritizedReplayBuffer-LazyMemmapStorage-None-10000] 5.0894ms 2.8719ms 348.2070 Ops/s 350.1125 Ops/s $\color{#d91a1a}-0.54\%$
test_sample_rb[TensorDictPrioritizedReplayBuffer-LazyTensorStorage-None-10000] 5.0019ms 2.9116ms 343.4499 Ops/s 347.7931 Ops/s $\color{#d91a1a}-1.25\%$
test_iterate_rb[TensorDictReplayBuffer-ListStorage-RandomSampler-4000] 3.6712ms 2.7053ms 369.6489 Ops/s 368.0191 Ops/s $\color{#35bf28}+0.44\%$
test_iterate_rb[TensorDictReplayBuffer-LazyMemmapStorage-RandomSampler-10000] 5.6247ms 2.8690ms 348.5535 Ops/s 352.5837 Ops/s $\color{#d91a1a}-1.14\%$
test_iterate_rb[TensorDictReplayBuffer-LazyTensorStorage-RandomSampler-10000] 5.5059ms 2.9065ms 344.0524 Ops/s 348.1628 Ops/s $\color{#d91a1a}-1.18\%$
test_iterate_rb[TensorDictReplayBuffer-ListStorage-SamplerWithoutReplacement-4000] 3.6354ms 2.7240ms 367.1123 Ops/s 366.6386 Ops/s $\color{#35bf28}+0.13\%$
test_iterate_rb[TensorDictReplayBuffer-LazyMemmapStorage-SamplerWithoutReplacement-10000] 0.1490s 3.2955ms 303.4451 Ops/s 350.0042 Ops/s $\textbf{\color{#d91a1a}-13.30\%}$
test_iterate_rb[TensorDictReplayBuffer-LazyTensorStorage-SamplerWithoutReplacement-10000] 0.1438s 3.2810ms 304.7855 Ops/s 350.8307 Ops/s $\textbf{\color{#d91a1a}-13.12\%}$
test_iterate_rb[TensorDictPrioritizedReplayBuffer-ListStorage-None-4000] 3.8942ms 2.7101ms 368.9867 Ops/s 371.0105 Ops/s $\color{#d91a1a}-0.55\%$
test_iterate_rb[TensorDictPrioritizedReplayBuffer-LazyMemmapStorage-None-10000] 4.5261ms 2.8866ms 346.4302 Ops/s 346.8222 Ops/s $\color{#d91a1a}-0.11\%$
test_iterate_rb[TensorDictPrioritizedReplayBuffer-LazyTensorStorage-None-10000] 5.3493ms 2.8637ms 349.1964 Ops/s 349.8144 Ops/s $\color{#d91a1a}-0.18\%$
test_populate_rb[TensorDictReplayBuffer-ListStorage-RandomSampler-400] 0.2653s 29.7072ms 33.6619 Ops/s 31.6736 Ops/s $\textbf{\color{#35bf28}+6.28\%}$
test_populate_rb[TensorDictReplayBuffer-LazyMemmapStorage-RandomSampler-400] 0.1507s 29.8976ms 33.4475 Ops/s 33.0632 Ops/s $\color{#35bf28}+1.16\%$
test_populate_rb[TensorDictReplayBuffer-LazyTensorStorage-RandomSampler-400] 0.1543s 30.0831ms 33.2413 Ops/s 36.5092 Ops/s $\textbf{\color{#d91a1a}-8.95\%}$
test_populate_rb[TensorDictReplayBuffer-ListStorage-SamplerWithoutReplacement-400] 0.1476s 27.3554ms 36.5558 Ops/s 33.2745 Ops/s $\textbf{\color{#35bf28}+9.86\%}$
test_populate_rb[TensorDictReplayBuffer-LazyMemmapStorage-SamplerWithoutReplacement-400] 0.1483s 29.8583ms 33.4916 Ops/s 36.7406 Ops/s $\textbf{\color{#d91a1a}-8.84\%}$
test_populate_rb[TensorDictReplayBuffer-LazyTensorStorage-SamplerWithoutReplacement-400] 0.1452s 27.0742ms 36.9355 Ops/s 33.2963 Ops/s $\textbf{\color{#35bf28}+10.93\%}$
test_populate_rb[TensorDictPrioritizedReplayBuffer-ListStorage-None-400] 0.1538s 29.9478ms 33.3914 Ops/s 36.4342 Ops/s $\textbf{\color{#d91a1a}-8.35\%}$
test_populate_rb[TensorDictPrioritizedReplayBuffer-LazyMemmapStorage-None-400] 0.1524s 27.6096ms 36.2193 Ops/s 33.1728 Ops/s $\textbf{\color{#35bf28}+9.18\%}$
test_populate_rb[TensorDictPrioritizedReplayBuffer-LazyTensorStorage-None-400] 0.1503s 30.1574ms 33.1594 Ops/s 36.5810 Ops/s $\textbf{\color{#d91a1a}-9.35\%}$

@vmoens vmoens marked this pull request as ready for review October 4, 2023 14:31
@vmoens
Copy link
Contributor Author

vmoens commented Oct 4, 2023

@matteobettini this is ready for an initial review cycle

test/mocking_classes.py Outdated Show resolved Hide resolved
Copy link
Contributor

@matteobettini matteobettini left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM, some comments

In particular, for MARL, i am interested in understanding if it could work wit

  • an input tensor of shape [*B,T,*N,F]
  • in a tensordict with shape [*B,T]


If in temporal mode, it is expected that the last dimension of the tensordict
marks the number of steps. There is no constrain on the dimensionality of the
tensordict (except that it must be greater than one for temporal inputs).
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

so, just to clarify, if I have

  • an input tensor of shape [*B,T,*N,F]
  • in a tensordict with shape [*B,T]

my *N dimensions are treated as batch?
Is this usecase supported?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes, see the test_single_step_vs_multi test

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

in that test the *B is part of the tensordict.

Here I am asking about the *N
like the agent dimension in MARL for example, after the T, but before the F

torchrl/modules/tensordict_module/rnn.py Outdated Show resolved Hide resolved
torchrl/modules/tensordict_module/rnn.py Outdated Show resolved Hide resolved
torchrl/modules/tensordict_module/rnn.py Outdated Show resolved Hide resolved
torchrl/modules/tensordict_module/rnn.py Outdated Show resolved Hide resolved
torchrl/modules/tensordict_module/rnn.py Outdated Show resolved Hide resolved
torchrl/modules/tensordict_module/rnn.py Outdated Show resolved Hide resolved
else:
tensordict_shaped = tensordict.reshape(-1).unsqueeze(-1)

is_init = tensordict_shaped.get("is_init").squeeze(-1)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

here we are assuming is_init has the tensordict shape [*B,T]?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

not sure what makes you think that? If so is it a problem? Can you elaborate?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

  • What makes me think that is that we consider its last dimension as T in the lines after
  • What might become a problem is that 1) we are hardcoding the name "is_init". 2) it might have a different shape as it follows a done no?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think this issue and the one above boil down to what we want to accept as input to this module.
We need to have a way to tell what the time dimension is.

  • One option is to check the dim names but it's hard to get people to adopt that. I tried and got multiple issues of people saying that their td was lacking the "time" dim name and they did not know what to do.
  • Another option is to pass the time dim in the constructor but it's brittle. If you say time_dim=3 you won't be able to change the batch size, if you say time_dim=-3 you need to say with respect to what (obs, init, tensordict?)

A regular usage is assumed to be that the time is the last dim of the tensordict, but it's true that we could cover others. Otherwise we assume that the TD provided is always the root and always has time dim at the end. Handling batch-dims after the time dim can also be considered.

Another issue is that we assume that the hidden state can more or less be indexed by is_init but i'm not sure what this looks like in cases where you have multiple done states and multiple observations. The batch size of the obs could differ from the done state (eg in VMAS IIUC you have one done state per group of agents but one obs per agent).
In MARL settings you will also need to consider if the LSTM reads the aggregated observations of not, if you have one LSMT per agent or not etc. Overall it starts to look like a different class to me.

TL;DR
Happy to have a follow-up work on GRU and LSTM to make them compatible with more convoluted data structures, but do you think we can integrate GRU as it is now since it applies to most single agents use cases already?
If you have a precise list of queries of what this class should also cover i'd be glad to look at it.
What I would need is: what are the data structures that need to be considered, what behaviour we expect, and what we want to explicitly forbid because there is a solution. To give you an example, IMO the example of the done that has a shape that differs from the obs can be easily solved by forbidding that and telling users to append an rename + unsqueeze + expand transform before (which all work as regular nn.Modules).

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes let's go ahead with this as is and we can keep discussing on this. Also because the LSTM has the same considerations that can be applied so we can tackle this in a future PR.

The example

TnesorDict{
 "is_init": Tensor(shape=[*B,T])
  "group": TensorDict{
    "obs":  Tensor(shape=[*B,T,*N,F])
    "recurrent": Tensor(shape=[*B,T,*N,H])
    batch_size=[*B, T, *N]
    }
  batch_size=[*B,T]
}

The assumptions

  • i would assume that is_init, the input, and the recurrent state have all the same batch_size and the last dim respectively of 1, F, H. In the example above, is_init might have to be expanded into the group before passing it to the module
  • i would assume (like in value functions) that the root td is always the one passed to the module and its last dim the time dimension. that way i know where the time dimension is (without the need to mark it) and know that all the other dimensions in my input apart from that and the last are batches

What we expect

We then woud expect that the input and recurrent state they are reshaped with batch=[*B x *N] and time=[T], ready to be passed to the inner GRU/LSTM. ands then they are reshaped back.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

i would assume that is_init, the input, and the recurrent state have all the same batch_size and the last dim respectively of 1, F, H. In the example above, is_init might have to be expanded into the group before passing it to the module

my view in this case is (1) raise an exception because we can't really tell what to do anticipatively with a is_init that has a mismatching shape and (2) provide info on how to do the shape modif

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

If only first batch dims were working with rnns, sigh!

torchrl/modules/tensordict_module/rnn.py Outdated Show resolved Hide resolved
torchrl/modules/tensordict_module/rnn.py Outdated Show resolved Hide resolved
@vmoens vmoens merged commit 244f93a into main Oct 5, 2023
@vmoens vmoens deleted the gru branch October 5, 2023 10:43
vmoens added a commit to hyerra/rl that referenced this pull request Oct 10, 2023
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
CLA Signed This label is managed by the Facebook bot. Authors need to sign the CLA before a PR can be reviewed. enhancement New feature or request
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants