Skip to content

Commit

Permalink
[BugFix] No grad on collector reset (pytorch#1927)
Browse files Browse the repository at this point in the history
  • Loading branch information
matteobettini authored Feb 20, 2024
1 parent 78b31a9 commit eacad37
Showing 1 changed file with 3 additions and 1 deletion.
4 changes: 3 additions & 1 deletion torchrl/collectors/collectors.py
Original file line number Diff line number Diff line change
Expand Up @@ -720,7 +720,8 @@ def __init__(
self.return_same_td = return_same_td

# Shuttle is a deviceless tensordict that just carried data from env to policy and policy to env
self._shuttle = self.env.reset()
with torch.no_grad():
self._shuttle = self.env.reset()
if self.policy_device != self.env_device or self.env_device is None:
self._shuttle_has_no_device = True
self._shuttle.clear_device_()
Expand Down Expand Up @@ -1145,6 +1146,7 @@ def _update_device_wise(tensor0, tensor1):
return tensor1
return tensor1.to(tensor0.device, non_blocking=True)

@torch.no_grad()
def reset(self, index=None, **kwargs) -> None:
"""Resets the environments to a new initial state."""
# metadata
Expand Down

0 comments on commit eacad37

Please sign in to comment.