From ec04c353c9fa4e5f1874f3cbcf3016dedd5506a3 Mon Sep 17 00:00:00 2001 From: Albert Bou Date: Fri, 11 Oct 2024 08:18:36 -0700 Subject: [PATCH] [Feature] SAC compatibility with composite distributions. (#2447) --- test/test_cost.py | 150 ++++++++++++++++++++++++++++++-------- torchrl/objectives/sac.py | 30 ++++++-- 2 files changed, 140 insertions(+), 40 deletions(-) diff --git a/test/test_cost.py b/test/test_cost.py index 1c00d4d965f..3530fff825d 100644 --- a/test/test_cost.py +++ b/test/test_cost.py @@ -48,7 +48,7 @@ from mocking_classes import ContinuousActionConvMockEnv # from torchrl.data.postprocs.utils import expand_as_right -from tensordict import assert_allclose_td, TensorDict +from tensordict import assert_allclose_td, TensorDict, TensorDictBase from tensordict.nn import NormalParamExtractor, TensorDictModule from tensordict.nn.utils import Buffer from tensordict.utils import unravel_key @@ -3450,21 +3450,40 @@ def _create_mock_actor( device="cpu", observation_key="observation", action_key="action", + composite_action_dist=False, ): # Actor action_spec = Bounded( -torch.ones(action_dim), torch.ones(action_dim), (action_dim,) ) + if composite_action_dist: + action_spec = Composite({action_key: {"action1": action_spec}}) net = nn.Sequential(nn.Linear(obs_dim, 2 * action_dim), NormalParamExtractor()) + if composite_action_dist: + distribution_class = functools.partial( + CompositeDistribution, + distribution_map={ + "action1": TanhNormal, + }, + aggregate_probabilities=True, + ) + module_out_keys = [ + ("params", "action1", "loc"), + ("params", "action1", "scale"), + ] + actor_in_keys = ["params"] + else: + distribution_class = TanhNormal + module_out_keys = actor_in_keys = ["loc", "scale"] module = TensorDictModule( - net, in_keys=[observation_key], out_keys=["loc", "scale"] + net, in_keys=[observation_key], out_keys=module_out_keys ) actor = ProbabilisticActor( module=module, - in_keys=["loc", "scale"], - spec=action_spec, - distribution_class=TanhNormal, + distribution_class=distribution_class, + in_keys=actor_in_keys, out_keys=[action_key], + spec=action_spec, ) return actor.to(device) @@ -3484,6 +3503,8 @@ def __init__(self): self.linear = nn.Linear(obs_dim + action_dim, 1) def forward(self, obs, act): + if isinstance(act, TensorDictBase): + act = act.get("action1") return self.linear(torch.cat([obs, act], -1)) module = ValueClass() @@ -3512,8 +3533,26 @@ def _create_mock_value( return value.to(device) def _create_mock_common_layer_setup( - self, n_obs=3, n_act=4, ncells=4, batch=2, n_hidden=2 + self, + n_obs=3, + n_act=4, + ncells=4, + batch=2, + n_hidden=2, + composite_action_dist=False, ): + class QValueClass(nn.Module): + def __init__(self): + super().__init__() + self.linear1 = nn.Linear(n_hidden + n_act, n_hidden) + self.relu = nn.ReLU() + self.linear2 = nn.Linear(n_hidden, 1) + + def forward(self, obs, act): + if isinstance(act, TensorDictBase): + act = act.get("action1") + return self.linear2(self.relu(self.linear1(torch.cat([obs, act], -1)))) + common = MLP( num_cells=ncells, in_features=n_obs, @@ -3526,17 +3565,13 @@ def _create_mock_common_layer_setup( depth=1, out_features=2 * n_act, ) - qvalue = MLP( - in_features=n_hidden + n_act, - num_cells=ncells, - depth=1, - out_features=1, - ) + qvalue = QValueClass() batch = [batch] + action = torch.randn(*batch, n_act) td = TensorDict( { "obs": torch.randn(*batch, n_obs), - "action": torch.randn(*batch, n_act), + "action": {"action1": action} if composite_action_dist else action, "done": torch.zeros(*batch, 1, dtype=torch.bool), "terminated": torch.zeros(*batch, 1, dtype=torch.bool), "next": { @@ -3549,14 +3584,30 @@ def _create_mock_common_layer_setup( batch, ) common = Mod(common, in_keys=["obs"], out_keys=["hidden"]) + if composite_action_dist: + distribution_class = functools.partial( + CompositeDistribution, + distribution_map={ + "action1": TanhNormal, + }, + aggregate_probabilities=True, + ) + module_out_keys = [ + ("params", "action1", "loc"), + ("params", "action1", "scale"), + ] + actor_in_keys = ["params"] + else: + distribution_class = TanhNormal + module_out_keys = actor_in_keys = ["loc", "scale"] actor = ProbSeq( common, Mod(actor_net, in_keys=["hidden"], out_keys=["param"]), - Mod(NormalParamExtractor(), in_keys=["param"], out_keys=["loc", "scale"]), + Mod(NormalParamExtractor(), in_keys=["param"], out_keys=module_out_keys), ProbMod( - in_keys=["loc", "scale"], + in_keys=actor_in_keys, out_keys=["action"], - distribution_class=TanhNormal, + distribution_class=distribution_class, ), ) qvalue_head = Mod( @@ -3582,6 +3633,7 @@ def _create_mock_data_sac( done_key="done", terminated_key="terminated", reward_key="reward", + composite_action_dist=False, ): # create a tensordict obs = torch.randn(batch, obs_dim, device=device) @@ -3603,14 +3655,21 @@ def _create_mock_data_sac( terminated_key: terminated, reward_key: reward, }, - action_key: action, + action_key: {"action1": action} if composite_action_dist else action, }, device=device, ) return td def _create_seq_mock_data_sac( - self, batch=8, T=4, obs_dim=3, action_dim=4, atoms=None, device="cpu" + self, + batch=8, + T=4, + obs_dim=3, + action_dim=4, + atoms=None, + device="cpu", + composite_action_dist=False, ): # create a tensordict total_obs = torch.randn(batch, T + 1, obs_dim, device=device) @@ -3626,6 +3685,7 @@ def _create_seq_mock_data_sac( done = torch.zeros(batch, T, 1, dtype=torch.bool, device=device) terminated = torch.zeros(batch, T, 1, dtype=torch.bool, device=device) mask = torch.ones(batch, T, dtype=torch.bool, device=device) + action = action.masked_fill_(~mask.unsqueeze(-1), 0.0) td = TensorDict( batch_size=(batch, T), source={ @@ -3637,7 +3697,7 @@ def _create_seq_mock_data_sac( "reward": reward.masked_fill_(~mask.unsqueeze(-1), 0.0), }, "collector": {"mask": mask}, - "action": action.masked_fill_(~mask.unsqueeze(-1), 0.0), + "action": {"action1": action} if composite_action_dist else action, }, names=[None, "time"], device=device, @@ -3650,6 +3710,7 @@ def _create_seq_mock_data_sac( @pytest.mark.parametrize("num_qvalue", [1, 2, 4, 8]) @pytest.mark.parametrize("device", get_default_devices()) @pytest.mark.parametrize("td_est", list(ValueEstimators) + [None]) + @pytest.mark.parametrize("composite_action_dist", [True, False]) def test_sac( self, delay_value, @@ -3659,14 +3720,19 @@ def test_sac( device, version, td_est, + composite_action_dist, ): if (delay_actor or delay_qvalue) and not delay_value: pytest.skip("incompatible config") torch.manual_seed(self.seed) - td = self._create_mock_data_sac(device=device) + td = self._create_mock_data_sac( + device=device, composite_action_dist=composite_action_dist + ) - actor = self._create_mock_actor(device=device) + actor = self._create_mock_actor( + device=device, composite_action_dist=composite_action_dist + ) qvalue = self._create_mock_qvalue(device=device) if version == 1: value = self._create_mock_value(device=device) @@ -3816,6 +3882,7 @@ def test_sac( @pytest.mark.parametrize("delay_qvalue", (True, False)) @pytest.mark.parametrize("num_qvalue", [2]) @pytest.mark.parametrize("device", get_default_devices()) + @pytest.mark.parametrize("composite_action_dist", [True, False]) def test_sac_state_dict( self, delay_value, @@ -3824,13 +3891,16 @@ def test_sac_state_dict( num_qvalue, device, version, + composite_action_dist, ): if (delay_actor or delay_qvalue) and not delay_value: pytest.skip("incompatible config") torch.manual_seed(self.seed) - actor = self._create_mock_actor(device=device) + actor = self._create_mock_actor( + device=device, composite_action_dist=composite_action_dist + ) qvalue = self._create_mock_qvalue(device=device) if version == 1: value = self._create_mock_value(device=device) @@ -3866,15 +3936,19 @@ def test_sac_state_dict( @pytest.mark.parametrize("device", get_default_devices()) @pytest.mark.parametrize("separate_losses", [False, True]) + @pytest.mark.parametrize("composite_action_dist", [True, False]) def test_sac_separate_losses( self, device, separate_losses, version, + composite_action_dist, n_act=4, ): torch.manual_seed(self.seed) - actor, qvalue, common, td = self._create_mock_common_layer_setup(n_act=n_act) + actor, qvalue, common, td = self._create_mock_common_layer_setup( + n_act=n_act, composite_action_dist=composite_action_dist + ) loss_fn = SACLoss( actor_network=actor, @@ -3960,6 +4034,7 @@ def test_sac_separate_losses( @pytest.mark.parametrize("delay_qvalue", (True, False)) @pytest.mark.parametrize("num_qvalue", [1, 2, 4, 8]) @pytest.mark.parametrize("device", get_default_devices()) + @pytest.mark.parametrize("composite_action_dist", [True, False]) def test_sac_batcher( self, n, @@ -3969,13 +4044,18 @@ def test_sac_batcher( num_qvalue, device, version, + composite_action_dist, ): if (delay_actor or delay_qvalue) and not delay_value: pytest.skip("incompatible config") torch.manual_seed(self.seed) - td = self._create_seq_mock_data_sac(device=device) + td = self._create_seq_mock_data_sac( + device=device, composite_action_dist=composite_action_dist + ) - actor = self._create_mock_actor(device=device) + actor = self._create_mock_actor( + device=device, composite_action_dist=composite_action_dist + ) qvalue = self._create_mock_qvalue(device=device) if version == 1: value = self._create_mock_value(device=device) @@ -4126,10 +4206,11 @@ def test_sac_batcher( @pytest.mark.parametrize( "td_est", [ValueEstimators.TD1, ValueEstimators.TD0, ValueEstimators.TDLambda] ) - def test_sac_tensordict_keys(self, td_est, version): - td = self._create_mock_data_sac() + @pytest.mark.parametrize("composite_action_dist", [True, False]) + def test_sac_tensordict_keys(self, td_est, version, composite_action_dist): + td = self._create_mock_data_sac(composite_action_dist=composite_action_dist) - actor = self._create_mock_actor() + actor = self._create_mock_actor(composite_action_dist=composite_action_dist) qvalue = self._create_mock_qvalue() if version == 1: value = self._create_mock_value() @@ -4149,7 +4230,7 @@ def test_sac_tensordict_keys(self, td_est, version): "value": "state_value", "state_action_value": "state_action_value", "action": "action", - "log_prob": "_log_prob", + "log_prob": "sample_log_prob", "reward": "reward", "done": "done", "terminated": "terminated", @@ -4311,15 +4392,20 @@ def test_state_dict(self, version): loss.load_state_dict(state) @pytest.mark.parametrize("reduction", [None, "none", "mean", "sum"]) - def test_sac_reduction(self, reduction, version): + @pytest.mark.parametrize("composite_action_dist", [True, False]) + def test_sac_reduction(self, reduction, version, composite_action_dist): torch.manual_seed(self.seed) device = ( torch.device("cpu") if torch.cuda.device_count() == 0 else torch.device("cuda") ) - td = self._create_mock_data_sac(device=device) - actor = self._create_mock_actor(device=device) + td = self._create_mock_data_sac( + device=device, composite_action_dist=composite_action_dist + ) + actor = self._create_mock_actor( + device=device, composite_action_dist=composite_action_dist + ) qvalue = self._create_mock_qvalue(device=device) if version == 1: value = self._create_mock_value(device=device) diff --git a/torchrl/objectives/sac.py b/torchrl/objectives/sac.py index bd21e33c30d..6350538db16 100644 --- a/torchrl/objectives/sac.py +++ b/torchrl/objectives/sac.py @@ -46,6 +46,19 @@ def new_func(self, *args, **kwargs): return new_func +def compute_log_prob(action_dist, action_or_tensordict, tensor_key): + """Compute the log probability of an action given a distribution.""" + if isinstance(action_or_tensordict, torch.Tensor): + log_p = action_dist.log_prob(action_or_tensordict) + else: + maybe_log_prob = action_dist.log_prob(action_or_tensordict) + if not isinstance(maybe_log_prob, torch.Tensor): + log_p = maybe_log_prob.get(tensor_key) + else: + log_p = maybe_log_prob + return log_p + + class SACLoss(LossModule): """TorchRL implementation of the SAC loss. @@ -251,7 +264,7 @@ class _AcceptedKeys: state_action_value (NestedKey): The input tensordict key where the state action value is expected. Defaults to ``"state_action_value"``. log_prob (NestedKey): The input tensordict key where the log probability is expected. - Defaults to ``"_log_prob"``. + Defaults to ``"sample_log_prob"``. priority (NestedKey): The input tensordict key where the target priority is written to. Defaults to ``"td_error"``. reward (NestedKey): The input tensordict key where the reward is expected. @@ -267,7 +280,7 @@ class _AcceptedKeys: action: NestedKey = "action" value: NestedKey = "state_value" state_action_value: NestedKey = "state_action_value" - log_prob: NestedKey = "_log_prob" + log_prob: NestedKey = "sample_log_prob" priority: NestedKey = "td_error" reward: NestedKey = "reward" done: NestedKey = "done" @@ -450,9 +463,7 @@ def target_entropy(self): else: action_container_shape = action_spec.shape target_entropy = -float( - action_spec[self.tensor_keys.action] - .shape[len(action_container_shape) :] - .numel() + action_spec.shape[len(action_container_shape) :].numel() ) delattr(self, "_target_entropy") self.register_buffer( @@ -622,7 +633,7 @@ def _actor_loss( ), self.actor_network_params.to_module(self.actor_network): dist = self.actor_network.get_dist(tensordict) a_reparm = dist.rsample() - log_prob = dist.log_prob(a_reparm) + log_prob = compute_log_prob(dist, a_reparm, self.tensor_keys.log_prob) td_q = tensordict.select(*self.qvalue_network.in_keys, strict=False) td_q.set(self.tensor_keys.action, a_reparm) @@ -713,7 +724,9 @@ def _compute_target_v2(self, tensordict) -> Tensor: next_dist = self.actor_network.get_dist(next_tensordict) next_action = next_dist.rsample() next_tensordict.set(self.tensor_keys.action, next_action) - next_sample_log_prob = next_dist.log_prob(next_action) + next_sample_log_prob = compute_log_prob( + next_dist, next_action, self.tensor_keys.log_prob + ) # get q-values next_tensordict_expand = self._vmap_qnetworkN0( @@ -780,7 +793,8 @@ def _value_loss( td_copy.get(self.tensor_keys.state_action_value).squeeze(-1).min(0)[0] ) - log_p = action_dist.log_prob(action) + log_p = compute_log_prob(action_dist, action, self.tensor_keys.log_prob) + if log_p.shape != min_qval.shape: raise RuntimeError( f"Losses shape mismatch: {min_qval.shape} and {log_p.shape}"