diff --git a/torchrl/modules/tensordict_module/rnn.py b/torchrl/modules/tensordict_module/rnn.py index 657bf6649d7..6fefda2dd5d 100644 --- a/torchrl/modules/tensordict_module/rnn.py +++ b/torchrl/modules/tensordict_module/rnn.py @@ -687,7 +687,7 @@ def forward(self, tensordict: TensorDictBase): # packed sequences do not help to get the accurate last hidden values # if splits is not None: # value = torch.nn.utils.rnn.pack_padded_sequence(value, splits, batch_first=True) - if is_init.any() and hidden0 is not None: + if hidden0 is not None: is_init_expand = expand_as_right(is_init, hidden0) hidden0 = torch.where(is_init_expand, 0, hidden0) hidden1 = torch.where(is_init_expand, 0, hidden1)