Skip to content

Commit

Permalink
[BugFix] Fix LSTM in GAE with vmap (pytorch#2376)
Browse files Browse the repository at this point in the history
  • Loading branch information
vmoens authored Aug 7, 2024
1 parent 607db8b commit 342450e
Showing 1 changed file with 1 addition and 1 deletion.
2 changes: 1 addition & 1 deletion torchrl/modules/tensordict_module/rnn.py
Original file line number Diff line number Diff line change
Expand Up @@ -687,7 +687,7 @@ def forward(self, tensordict: TensorDictBase):
# packed sequences do not help to get the accurate last hidden values
# if splits is not None:
# value = torch.nn.utils.rnn.pack_padded_sequence(value, splits, batch_first=True)
if is_init.any() and hidden0 is not None:
if hidden0 is not None:
is_init_expand = expand_as_right(is_init, hidden0)
hidden0 = torch.where(is_init_expand, 0, hidden0)
hidden1 = torch.where(is_init_expand, 0, hidden1)
Expand Down

0 comments on commit 342450e

Please sign in to comment.