Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[BugFix] Allow for composite action distributions in PPO/A2C losses #2391

Merged
merged 19 commits into from
Sep 4, 2024
Prev Previous commit
Next Next commit
a2c tests
  • Loading branch information
albertbou92 committed Aug 14, 2024
commit 07ec262ee298f5e2d1eea643e36009a2c33c5ff8
16 changes: 13 additions & 3 deletions test/test_cost.py
Original file line number Diff line number Diff line change
Expand Up @@ -8860,14 +8860,24 @@ def test_a2c(
functional=functional,
)

def set_requires_grad(tensor, requires_grad):
tensor.requires_grad = requires_grad
return tensor

# Check error is raised when actions require grads
td["action"].requires_grad = True
if composite_action_dist:
td["action"].apply_(lambda x: set_requires_grad(x, True))
else:
td["action"].requires_grad = True
with pytest.raises(
RuntimeError,
match="tensordict stored action require grad.",
match="tensordict stored action requires grad.",
):
_ = loss_fn._log_probs(td)
td["action"].requires_grad = False
if composite_action_dist:
td["action"].apply_(lambda x: set_requires_grad(x, False))
else:
td["action"].requires_grad = False

td = td.exclude(loss_fn.tensor_keys.value_target)
if advantage is not None:
Expand Down
16 changes: 11 additions & 5 deletions torchrl/objectives/a2c.py
Original file line number Diff line number Diff line change
Expand Up @@ -384,7 +384,7 @@ def get_entropy_bonus(self, dist: d.Distribution) -> torch.Tensor:
except NotImplementedError:
x = dist.rsample((self.samples_mc_entropy,))
log_prob = dist.log_prob(x)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Previously, was this a bug or did we sum the log-probs automatically?

Copy link
Contributor Author

@albertbou92 albertbou92 Aug 26, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is simply because the log_prob() method for a composite dist will return a TD instead of a Tensor, so we compute the entropy in 2 steps.

This is the old version:

    def get_entropy_bonus(self, dist: d.Distribution) -> torch.Tensor:
        try:
            entropy = dist.entropy()
        except NotImplementedError:
            x = dist.rsample((self.samples_mc_entropy,))
            entropy = -dist.log_prob(x).mean(0)
        return entropy.unsqueeze(-1)

This is the new version. It simply retrieves the log tensor before computing the entropy.

    def get_entropy_bonus(self, dist: d.Distribution) -> torch.Tensor:
        try:
            entropy = dist.entropy()
        except NotImplementedError:
            x = dist.rsample((self.samples_mc_entropy,))
            log_prob = dist.log_prob(x)
            if is_tensor_collection(log_prob):
                log_prob = log_prob.get(self.tensor_keys.sample_log_prob)
            entropy = -log_prob.mean(0)
        return entropy.unsqueeze(-1)

if isinstance(x, log_prob):
if isinstance(log_prob, TensorDict):
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

A lazy stack is not a TensorDict but a TensorDict base.
Also ideally we would want this to work with tensorclasses.
The way to go should be to use is_tensor_collection from tensordict lib.

log_prob = log_prob.get(self.tensor_keys.sample_log_prob)
entropy = -log_prob.mean(0)
return entropy.unsqueeze(-1)
Expand All @@ -394,20 +394,26 @@ def _log_probs(
) -> Tuple[torch.Tensor, d.Distribution]:
# current log_prob of actions
action = tensordict.get(self.tensor_keys.action)
if action.requires_grad:
raise RuntimeError(
f"tensordict stored {self.tensor_keys.action} require grad."
)
tensordict_clone = tensordict.select(
*self.actor_network.in_keys, strict=False
).clone()
with self.actor_network_params.to_module(
self.actor_network
) if self.functional else contextlib.nullcontext():
dist = self.actor_network.get_dist(tensordict_clone)

def check_requires_grad(tensor):
if tensor.requires_grad:
albertbou92 marked this conversation as resolved.
Show resolved Hide resolved
raise RuntimeError(
f"tensordict stored {self.tensor_keys.action} requires grad."
)
return tensor

if isinstance(action, torch.Tensor):
check_requires_grad(action)
log_prob = dist.log_prob(action)
else:
action.apply(check_requires_grad)
tensordict = dist.log_prob(tensordict)
log_prob = tensordict.get(self.tensor_keys.sample_log_prob)
log_prob = log_prob.unsqueeze(-1)
Expand Down
14 changes: 8 additions & 6 deletions torchrl/objectives/ppo.py
Original file line number Diff line number Diff line change
Expand Up @@ -460,24 +460,26 @@ def _log_weight(
) -> Tuple[torch.Tensor, d.Distribution]:
# current log_prob of actions
action = tensordict.get(self.tensor_keys.action)
if action.requires_grad:
raise RuntimeError(
f"tensordict stored {self.tensor_keys.action} requires grad."
)

with self.actor_network_params.to_module(
self.actor_network
) if self.functional else contextlib.nullcontext():
dist = self.actor_network.get_dist(tensordict)
# dist = TransformedDistribution(dist, ExpTransform())

def check_requires_grad(tensor, key=self.tensor_keys.action):
albertbou92 marked this conversation as resolved.
Show resolved Hide resolved
if tensor.requires_grad:
raise RuntimeError(f"tensordict stored {key} requires grad.")
return tensor

prev_log_prob = tensordict.get(self.tensor_keys.sample_log_prob)
if prev_log_prob.requires_grad:
raise RuntimeError("tensordict prev_log_prob requires grad.")
check_requires_grad(prev_log_prob, self.tensor_keys.sample_log_prob)

if isinstance(action, torch.Tensor):
check_requires_grad(action, self.tensor_keys.action)
log_prob = dist.log_prob(action)
else:
action.apply(check_requires_grad)
tensordict = dist.log_prob(tensordict)
log_prob = tensordict.get(self.tensor_keys.sample_log_prob)

Expand Down