From c8676f4a87df65bff0ccc42ea09942ef73ce4d9a Mon Sep 17 00:00:00 2001 From: Vincent Moens Date: Mon, 25 Nov 2024 13:34:30 +0000 Subject: [PATCH] [BugFix] Account for terminating data in SAC losses ghstack-source-id: dc1870292786c262b4ab6a221b3afb551e0efb9b Pull Request resolved: https://github.com/pytorch/rl/pull/2606 --- test/test_cost.py | 119 ++++++++++++++++++++++++++++++++++++++ torchrl/objectives/sac.py | 51 +++++++++++++--- 2 files changed, 162 insertions(+), 8 deletions(-) diff --git a/test/test_cost.py b/test/test_cost.py index 598b9ba004d..c48b4a28b99 100644 --- a/test/test_cost.py +++ b/test/test_cost.py @@ -4459,6 +4459,69 @@ def test_sac_notensordict( assert loss_actor == loss_val_td["loss_actor"] assert loss_alpha == loss_val_td["loss_alpha"] + @pytest.mark.parametrize("action_key", ["action", "action2"]) + @pytest.mark.parametrize("observation_key", ["observation", "observation2"]) + @pytest.mark.parametrize("reward_key", ["reward", "reward2"]) + @pytest.mark.parametrize("done_key", ["done", "done2"]) + @pytest.mark.parametrize("terminated_key", ["terminated", "terminated2"]) + def test_sac_terminating( + self, action_key, observation_key, reward_key, done_key, terminated_key, version + ): + torch.manual_seed(self.seed) + td = self._create_mock_data_sac( + action_key=action_key, + observation_key=observation_key, + reward_key=reward_key, + done_key=done_key, + terminated_key=terminated_key, + ) + + actor = self._create_mock_actor( + observation_key=observation_key, action_key=action_key + ) + qvalue = self._create_mock_qvalue( + observation_key=observation_key, + action_key=action_key, + out_keys=["state_action_value"], + ) + if version == 1: + value = self._create_mock_value(observation_key=observation_key) + else: + value = None + + loss = SACLoss( + actor_network=actor, + qvalue_network=qvalue, + value_network=value, + ) + loss.set_keys( + action=action_key, + reward=reward_key, + done=done_key, + terminated=terminated_key, + ) + + torch.manual_seed(self.seed) + + SoftUpdate(loss, eps=0.5) + + done = td.get(("next", done_key)) + while not (done.any() and not done.all()): + done.bernoulli_(0.1) + obs_nan = td.get(("next", terminated_key)) + obs_nan[done.squeeze(-1)] = float("nan") + + kwargs = { + action_key: td.get(action_key), + observation_key: td.get(observation_key), + f"next_{reward_key}": td.get(("next", reward_key)), + f"next_{done_key}": done, + f"next_{terminated_key}": obs_nan, + f"next_{observation_key}": td.get(("next", observation_key)), + } + td = TensorDict(kwargs, td.batch_size).unflatten_keys("_") + assert loss(td).isfinite().all() + def test_state_dict(self, version): if version == 1: pytest.skip("Test not implemented for version 1.") @@ -5112,6 +5175,62 @@ def test_discrete_sac_notensordict( assert loss_actor == loss_val_td["loss_actor"] assert loss_alpha == loss_val_td["loss_alpha"] + @pytest.mark.parametrize("action_key", ["action", "action2"]) + @pytest.mark.parametrize("observation_key", ["observation", "observation2"]) + @pytest.mark.parametrize("reward_key", ["reward", "reward2"]) + @pytest.mark.parametrize("done_key", ["done", "done2"]) + @pytest.mark.parametrize("terminated_key", ["terminated", "terminated2"]) + def test_discrete_sac_terminating( + self, action_key, observation_key, reward_key, done_key, terminated_key + ): + torch.manual_seed(self.seed) + td = self._create_mock_data_sac( + action_key=action_key, + observation_key=observation_key, + reward_key=reward_key, + done_key=done_key, + terminated_key=terminated_key, + ) + + actor = self._create_mock_actor( + observation_key=observation_key, action_key=action_key + ) + qvalue = self._create_mock_qvalue( + observation_key=observation_key, + ) + + loss = DiscreteSACLoss( + actor_network=actor, + qvalue_network=qvalue, + num_actions=actor.spec[action_key].space.n, + action_space="one-hot", + ) + loss.set_keys( + action=action_key, + reward=reward_key, + done=done_key, + terminated=terminated_key, + ) + + SoftUpdate(loss, eps=0.5) + + torch.manual_seed(0) + done = td.get(("next", done_key)) + while not (done.any() and not done.all()): + done = done.bernoulli_(0.1) + obs_none = td.get(("next", observation_key)) + obs_none[done.squeeze(-1)] = float("nan") + kwargs = { + action_key: td.get(action_key), + observation_key: td.get(observation_key), + f"next_{reward_key}": td.get(("next", reward_key)), + f"next_{done_key}": done, + f"next_{terminated_key}": td.get(("next", terminated_key)), + f"next_{observation_key}": obs_none, + } + td = TensorDict(kwargs, td.batch_size).unflatten_keys("_") + assert loss(td).isfinite().all() + @pytest.mark.parametrize("reduction", [None, "none", "mean", "sum"]) def test_discrete_sac_reduction(self, reduction): torch.manual_seed(self.seed) diff --git a/torchrl/objectives/sac.py b/torchrl/objectives/sac.py index 3fb34678d02..52efb3d312b 100644 --- a/torchrl/objectives/sac.py +++ b/torchrl/objectives/sac.py @@ -16,7 +16,7 @@ from tensordict import TensorDict, TensorDictBase, TensorDictParams from tensordict.nn import dispatch, TensorDictModule -from tensordict.utils import NestedKey +from tensordict.utils import expand_right, NestedKey from torch import Tensor from torchrl.data.tensor_specs import Composite, TensorSpec from torchrl.data.utils import _find_action_space @@ -711,13 +711,37 @@ def _compute_target_v2(self, tensordict) -> Tensor: with set_exploration_type( ExplorationType.RANDOM ), self.actor_network_params.to_module(self.actor_network): - next_tensordict = tensordict.get("next").clone(False) - next_dist = self.actor_network.get_dist(next_tensordict) + next_tensordict = tensordict.get("next").copy() + # Check done state and avoid passing these to the actor + done = next_tensordict.get(self.tensor_keys.done) + if done is not None and done.any(): + next_tensordict_select = next_tensordict[~done.squeeze(-1)] + else: + next_tensordict_select = next_tensordict + next_dist = self.actor_network.get_dist(next_tensordict_select) next_action = next_dist.rsample() - next_tensordict.set(self.tensor_keys.action, next_action) next_sample_log_prob = compute_log_prob( next_dist, next_action, self.tensor_keys.log_prob ) + if next_tensordict_select is not next_tensordict: + mask = ~done.squeeze(-1) + if mask.ndim < next_action.ndim: + mask = expand_right( + mask, (*mask.shape, *next_action.shape[mask.ndim :]) + ) + next_action = next_action.new_zeros(mask.shape).masked_scatter_( + mask, next_action + ) + mask = ~done.squeeze(-1) + if mask.ndim < next_sample_log_prob.ndim: + mask = expand_right( + mask, + (*mask.shape, *next_sample_log_prob.shape[mask.ndim :]), + ) + next_sample_log_prob = next_sample_log_prob.new_zeros( + mask.shape + ).masked_scatter_(mask, next_sample_log_prob) + next_tensordict.set(self.tensor_keys.action, next_action) # get q-values next_tensordict_expand = self._vmap_qnetworkN0( @@ -1194,15 +1218,21 @@ def _compute_target(self, tensordict) -> Tensor: with torch.no_grad(): next_tensordict = tensordict.get("next").clone(False) + done = next_tensordict.get(self.tensor_keys.done) + if done is not None and done.any(): + next_tensordict_select = next_tensordict[~done.squeeze(-1)] + else: + next_tensordict_select = next_tensordict + # get probs and log probs for actions computed from "next" with self.actor_network_params.to_module(self.actor_network): - next_dist = self.actor_network.get_dist(next_tensordict) - next_prob = next_dist.probs - next_log_prob = torch.log(torch.where(next_prob == 0, 1e-8, next_prob)) + next_dist = self.actor_network.get_dist(next_tensordict_select) + next_log_prob = next_dist.logits + next_prob = next_log_prob.exp() # get q-values for all actions next_tensordict_expand = self._vmap_qnetworkN0( - next_tensordict, self.target_qvalue_network_params + next_tensordict_select, self.target_qvalue_network_params ) next_action_value = next_tensordict_expand.get( self.tensor_keys.action_value @@ -1212,6 +1242,11 @@ def _compute_target(self, tensordict) -> Tensor: next_state_value = next_action_value.min(0)[0] - self._alpha * next_log_prob # unlike in continuous SAC, we can compute the exact expectation over all discrete actions next_state_value = (next_prob * next_state_value).sum(-1).unsqueeze(-1) + if next_tensordict_select is not next_tensordict: + mask = ~done.squeeze(-1) + next_state_value = next_state_value.new_zeros( + mask.shape + ).masked_scatter_(mask, next_state_value) tensordict.set( ("next", self.value_estimator.tensor_keys.value), next_state_value