Skip to content

Commit

Permalink
added required grad for td
Browse files Browse the repository at this point in the history
  • Loading branch information
albertbou92 committed Aug 26, 2024
1 parent 37b733f commit 69922fa
Show file tree
Hide file tree
Showing 2 changed files with 12 additions and 18 deletions.
14 changes: 4 additions & 10 deletions torchrl/objectives/a2c.py
Original file line number Diff line number Diff line change
Expand Up @@ -413,19 +413,13 @@ def _log_probs(
self.actor_network
) if self.functional else contextlib.nullcontext():
dist = self.actor_network.get_dist(tensordict_clone)

def check_requires_grad(tensor):
if tensor.requires_grad:
raise RuntimeError(
f"tensordict stored {self.tensor_keys.action} requires grad."
)
return tensor

if action.requires_grad:
raise RuntimeError(
f"tensordict stored {self.tensor_keys.action} requires grad."
)
if isinstance(action, torch.Tensor):
check_requires_grad(action)
log_prob = dist.log_prob(action)
else:
action.apply(check_requires_grad)
tensordict = dist.log_prob(tensordict)
log_prob = tensordict.get(self.tensor_keys.sample_log_prob)
log_prob = log_prob.unsqueeze(-1)
Expand Down
16 changes: 8 additions & 8 deletions torchrl/objectives/ppo.py
Original file line number Diff line number Diff line change
Expand Up @@ -477,19 +477,19 @@ def _log_weight(
) if self.functional else contextlib.nullcontext():
dist = self.actor_network.get_dist(tensordict)

def check_requires_grad(tensor, key=self.tensor_keys.action):
if tensor.requires_grad:
raise RuntimeError(f"tensordict stored {key} requires grad.")
return tensor

prev_log_prob = tensordict.get(self.tensor_keys.sample_log_prob)
check_requires_grad(prev_log_prob, self.tensor_keys.sample_log_prob)
if prev_log_prob.requires_grad:
raise RuntimeError(
f"tensordict stored {self.tensor_keys.sample_log_prob} requires grad."
)

if action.requires_grad:
raise RuntimeError(
f"tensordict stored {self.tensor_keys.action} requires grad."
)
if isinstance(action, torch.Tensor):
check_requires_grad(action, self.tensor_keys.action)
log_prob = dist.log_prob(action)
else:
action.apply(check_requires_grad)
tensordict = dist.log_prob(tensordict)
log_prob = tensordict.get(self.tensor_keys.sample_log_prob)

Expand Down

0 comments on commit 69922fa

Please sign in to comment.