Skip to content

Commit

Permalink
No need to unsqueeze
Browse files Browse the repository at this point in the history
  • Loading branch information
Pau Riba committed Nov 6, 2024
1 parent ef794c7 commit 1f33881
Showing 1 changed file with 2 additions and 2 deletions.
4 changes: 2 additions & 2 deletions torchrl/objectives/ppo.py
Original file line number Diff line number Diff line change
Expand Up @@ -463,8 +463,8 @@ def reset(self) -> None:
def get_entropy_bonus(self, dist: d.Distribution) -> torch.Tensor:
try:
entropy = dist.entropy()
if is_tensor_collection(entropy) and hasattr(dist, "entropy_key"):
entropy = entropy.get(dist.entropy_key).unsqueeze(-1)
if is_tensor_collection(entropy):
entropy = entropy.get(dist.entropy_key)
except NotImplementedError:
x = dist.rsample((self.samples_mc_entropy,))
log_prob = dist.log_prob(x)
Expand Down

0 comments on commit 1f33881

Please sign in to comment.