Skip to content

Commit

Permalink
[BugFix] skip_done_states in SAC
Browse files Browse the repository at this point in the history
ghstack-source-id: f534c53d30af035edb2e3b5291d4db71313086fd
Pull Request resolved: #2613
  • Loading branch information
vmoens committed Nov 27, 2024
1 parent d537dcb commit 8078906
Show file tree
Hide file tree
Showing 2 changed files with 99 additions and 54 deletions.
2 changes: 2 additions & 0 deletions test/test_cost.py
Original file line number Diff line number Diff line change
Expand Up @@ -4493,6 +4493,7 @@ def test_sac_terminating(
actor_network=actor,
qvalue_network=qvalue,
value_network=value,
skip_done_states=True,
)
loss.set_keys(
action=action_key,
Expand Down Expand Up @@ -5204,6 +5205,7 @@ def test_discrete_sac_terminating(
qvalue_network=qvalue,
num_actions=actor.spec[action_key].space.n,
action_space="one-hot",
skip_done_states=True,
)
loss.set_keys(
action=action_key,
Expand Down
151 changes: 97 additions & 54 deletions torchrl/objectives/sac.py
Original file line number Diff line number Diff line change
Expand Up @@ -126,6 +126,10 @@ class SACLoss(LossModule):
``"none"`` | ``"mean"`` | ``"sum"``. ``"none"``: no reduction will be applied,
``"mean"``: the sum of the output will be divided by the number of
elements in the output, ``"sum"``: the output will be summed. Default: ``"mean"``.
skip_done_states (bool, optional): whether the actor network should only be run on valid, non-terminating
next states. If ``True``, it is assumed that the done state can be broadcast to the shape of the
data and that masking the data results in a valid data structure. Among other things, this may not
be true in MARL settings or when using RNNs. Defaults to ``False``.
Examples:
>>> import torch
Expand Down Expand Up @@ -320,6 +324,7 @@ def __init__(
priority_key: str = None,
separate_losses: bool = False,
reduction: str = None,
skip_done_states: bool = False,
) -> None:
self._in_keys = None
self._out_keys = None
Expand Down Expand Up @@ -418,6 +423,7 @@ def __init__(
raise TypeError(_GAMMA_LMBDA_DEPREC_ERROR)
self._make_vmap()
self.reduction = reduction
self.skip_done_states = skip_done_states

def _make_vmap(self):
self._vmap_qnetworkN0 = _vmap_func(
Expand Down Expand Up @@ -712,36 +718,44 @@ def _compute_target_v2(self, tensordict) -> Tensor:
ExplorationType.RANDOM
), self.actor_network_params.to_module(self.actor_network):
next_tensordict = tensordict.get("next").copy()
# Check done state and avoid passing these to the actor
done = next_tensordict.get(self.tensor_keys.done)
if done is not None and done.any():
next_tensordict_select = next_tensordict[~done.squeeze(-1)]
else:
next_tensordict_select = next_tensordict
next_dist = self.actor_network.get_dist(next_tensordict_select)
next_action = next_dist.rsample()
next_sample_log_prob = compute_log_prob(
next_dist, next_action, self.tensor_keys.log_prob
)
if next_tensordict_select is not next_tensordict:
mask = ~done.squeeze(-1)
if mask.ndim < next_action.ndim:
mask = expand_right(
mask, (*mask.shape, *next_action.shape[mask.ndim :])
)
next_action = next_action.new_zeros(mask.shape).masked_scatter_(
mask, next_action
if self.skip_done_states:
# Check done state and avoid passing these to the actor
done = next_tensordict.get(self.tensor_keys.done)
if done is not None and done.any():
next_tensordict_select = next_tensordict[~done.squeeze(-1)]
else:
next_tensordict_select = next_tensordict
next_dist = self.actor_network.get_dist(next_tensordict_select)
next_action = next_dist.rsample()
next_sample_log_prob = compute_log_prob(
next_dist, next_action, self.tensor_keys.log_prob
)
mask = ~done.squeeze(-1)
if mask.ndim < next_sample_log_prob.ndim:
mask = expand_right(
mask,
(*mask.shape, *next_sample_log_prob.shape[mask.ndim :]),
if next_tensordict_select is not next_tensordict:
mask = ~done.squeeze(-1)
if mask.ndim < next_action.ndim:
mask = expand_right(
mask, (*mask.shape, *next_action.shape[mask.ndim :])
)
next_action = next_action.new_zeros(mask.shape).masked_scatter_(
mask, next_action
)
next_sample_log_prob = next_sample_log_prob.new_zeros(
mask.shape
).masked_scatter_(mask, next_sample_log_prob)
next_tensordict.set(self.tensor_keys.action, next_action)
mask = ~done.squeeze(-1)
if mask.ndim < next_sample_log_prob.ndim:
mask = expand_right(
mask,
(*mask.shape, *next_sample_log_prob.shape[mask.ndim :]),
)
next_sample_log_prob = next_sample_log_prob.new_zeros(
mask.shape
).masked_scatter_(mask, next_sample_log_prob)
next_tensordict.set(self.tensor_keys.action, next_action)
else:
next_dist = self.actor_network.get_dist(next_tensordict)
next_action = next_dist.rsample()
next_tensordict.set(self.tensor_keys.action, next_action)
next_sample_log_prob = compute_log_prob(
next_dist, next_action, self.tensor_keys.log_prob
)

# get q-values
next_tensordict_expand = self._vmap_qnetworkN0(
Expand Down Expand Up @@ -877,6 +891,10 @@ class DiscreteSACLoss(LossModule):
``"none"`` | ``"mean"`` | ``"sum"``. ``"none"``: no reduction will be applied,
``"mean"``: the sum of the output will be divided by the number of
elements in the output, ``"sum"``: the output will be summed. Default: ``"mean"``.
skip_done_states (bool, optional): whether the actor network should only be run on valid, non-terminating
next states. If ``True``, it is assumed that the done state can be broadcast to the shape of the
data and that masking the data results in a valid data structure. Among other things, this may not
be true in MARL settings or when using RNNs. Defaults to ``False``.
Examples:
>>> import torch
Expand Down Expand Up @@ -1051,6 +1069,7 @@ def __init__(
priority_key: str = None,
separate_losses: bool = False,
reduction: str = None,
skip_done_states: bool = False,
):
if reduction is None:
reduction = "mean"
Expand Down Expand Up @@ -1133,6 +1152,7 @@ def __init__(
)
self._make_vmap()
self.reduction = reduction
self.skip_done_states = skip_done_states

def _make_vmap(self):
self._vmap_qnetworkN0 = _vmap_func(
Expand Down Expand Up @@ -1218,35 +1238,58 @@ def _compute_target(self, tensordict) -> Tensor:
with torch.no_grad():
next_tensordict = tensordict.get("next").clone(False)

done = next_tensordict.get(self.tensor_keys.done)
if done is not None and done.any():
next_tensordict_select = next_tensordict[~done.squeeze(-1)]
else:
next_tensordict_select = next_tensordict
if self.skip_done_states:
done = next_tensordict.get(self.tensor_keys.done)
if done is not None and done.any():
next_tensordict_select = next_tensordict[~done.squeeze(-1)]
else:
next_tensordict_select = next_tensordict

# get probs and log probs for actions computed from "next"
with self.actor_network_params.to_module(self.actor_network):
next_dist = self.actor_network.get_dist(next_tensordict_select)
next_log_prob = next_dist.logits
next_prob = next_log_prob.exp()
# get probs and log probs for actions computed from "next"
with self.actor_network_params.to_module(self.actor_network):
next_dist = self.actor_network.get_dist(next_tensordict_select)
next_log_prob = next_dist.logits
next_prob = next_log_prob.exp()

# get q-values for all actions
next_tensordict_expand = self._vmap_qnetworkN0(
next_tensordict_select, self.target_qvalue_network_params
)
next_action_value = next_tensordict_expand.get(
self.tensor_keys.action_value
)
# get q-values for all actions
next_tensordict_expand = self._vmap_qnetworkN0(
next_tensordict_select, self.target_qvalue_network_params
)
next_action_value = next_tensordict_expand.get(
self.tensor_keys.action_value
)

# like in continuous SAC, we take the minimum of the value ensemble and subtract the entropy term
next_state_value = next_action_value.min(0)[0] - self._alpha * next_log_prob
# unlike in continuous SAC, we can compute the exact expectation over all discrete actions
next_state_value = (next_prob * next_state_value).sum(-1).unsqueeze(-1)
if next_tensordict_select is not next_tensordict:
mask = ~done
next_state_value = next_state_value.new_zeros(
mask.shape
).masked_scatter_(mask, next_state_value)
# like in continuous SAC, we take the minimum of the value ensemble and subtract the entropy term
next_state_value = (
next_action_value.min(0)[0] - self._alpha * next_log_prob
)
# unlike in continuous SAC, we can compute the exact expectation over all discrete actions
next_state_value = (next_prob * next_state_value).sum(-1).unsqueeze(-1)
if next_tensordict_select is not next_tensordict:
mask = ~done
next_state_value = next_state_value.new_zeros(
mask.shape
).masked_scatter_(mask, next_state_value)
else:
# get probs and log probs for actions computed from "next"
with self.actor_network_params.to_module(self.actor_network):
next_dist = self.actor_network.get_dist(next_tensordict)
next_prob = next_dist.probs
next_log_prob = torch.log(torch.where(next_prob == 0, 1e-8, next_prob))

# get q-values for all actions
next_tensordict_expand = self._vmap_qnetworkN0(
next_tensordict, self.target_qvalue_network_params
)
next_action_value = next_tensordict_expand.get(
self.tensor_keys.action_value
)
# like in continuous SAC, we take the minimum of the value ensemble and subtract the entropy term
next_state_value = (
next_action_value.min(0)[0] - self._alpha * next_log_prob
)
# unlike in continuous SAC, we can compute the exact expectation over all discrete actions
next_state_value = (next_prob * next_state_value).sum(-1).unsqueeze(-1)

tensordict.set(
("next", self.value_estimator.tensor_keys.value), next_state_value
Expand Down

0 comments on commit 8078906

Please sign in to comment.