From ee897281fd72f4d70280ceb0051cc794b4730cf9 Mon Sep 17 00:00:00 2001 From: Vincent Moens Date: Thu, 7 Dec 2023 04:47:27 +0000 Subject: [PATCH] [BugFix, Feature] Fix DDQN implementation (#1737) --- test/test_cost.py | 10 ++-- torchrl/envs/libs/robohive.py | 2 +- torchrl/modules/distributions/discrete.py | 4 +- torchrl/objectives/dqn.py | 64 ++++++++++++++++++++--- torchrl/objectives/multiagent/qmixer.py | 12 ++++- 5 files changed, 78 insertions(+), 14 deletions(-) diff --git a/test/test_cost.py b/test/test_cost.py index 12c2a6b6d6f..8bae683c5d5 100644 --- a/test/test_cost.py +++ b/test/test_cost.py @@ -426,11 +426,13 @@ def _create_seq_mock_data_dqn( ) return td - @pytest.mark.parametrize("delay_value", (False, True)) + @pytest.mark.parametrize( + "delay_value,double_dqn", ([False, False], [True, False], [True, True]) + ) @pytest.mark.parametrize("device", get_default_devices()) @pytest.mark.parametrize("action_spec_type", ("one_hot", "categorical")) @pytest.mark.parametrize("td_est", list(ValueEstimators) + [None]) - def test_dqn(self, delay_value, device, action_spec_type, td_est): + def test_dqn(self, delay_value, double_dqn, device, action_spec_type, td_est): torch.manual_seed(self.seed) actor = self._create_mock_actor( action_spec_type=action_spec_type, device=device @@ -438,7 +440,9 @@ def test_dqn(self, delay_value, device, action_spec_type, td_est): td = self._create_mock_data_dqn( action_spec_type=action_spec_type, device=device ) - loss_fn = DQNLoss(actor, loss_function="l2", delay_value=delay_value) + loss_fn = DQNLoss( + actor, loss_function="l2", delay_value=delay_value, double_dqn=double_dqn + ) if td_est in (ValueEstimators.GAE, ValueEstimators.VTrace): with pytest.raises(NotImplementedError): loss_fn.make_value_estimator(td_est) diff --git a/torchrl/envs/libs/robohive.py b/torchrl/envs/libs/robohive.py index a3ee1dfa893..7ce0938facb 100644 --- a/torchrl/envs/libs/robohive.py +++ b/torchrl/envs/libs/robohive.py @@ -68,7 +68,7 @@ class RoboHiveEnv(GymEnv, metaclass=_RoboHiveBuild): Args: env_name (str): the environment name to build. - read_info (bool, optional): whether the the info should be parsed. + read_info (bool, optional): whether the info should be parsed. Defaults to ``True``. device (torch.device, optional): the device on which the input/output are expected. Defaults to torch default device. diff --git a/torchrl/modules/distributions/discrete.py b/torchrl/modules/distributions/discrete.py index bb98b1412a8..d73457b2261 100644 --- a/torchrl/modules/distributions/discrete.py +++ b/torchrl/modules/distributions/discrete.py @@ -163,7 +163,7 @@ class MaskedCategorical(D.Categorical): must be taken into account. Exclusive with ``mask``. neg_inf (float, optional): The log-probability value allocated to invalid (out-of-mask) indices. Defaults to -inf. - padding_value: The padding value in the then mask tensor when + padding_value: The padding value in the mask tensor. When sparse_mask == True, the padding_value will be ignored. >>> torch.manual_seed(0) @@ -314,7 +314,7 @@ class MaskedOneHotCategorical(MaskedCategorical): must be taken into account. Exclusive with ``mask``. neg_inf (float, optional): The log-probability value allocated to invalid (out-of-mask) indices. Defaults to -inf. - padding_value: The padding value in the then mask tensor when + padding_value: The padding value in then mask tensor when sparse_mask == True, the padding_value will be ignored. grad_method (ReparamGradientStrategy, optional): strategy to gather reparameterized samples. diff --git a/torchrl/objectives/dqn.py b/torchrl/objectives/dqn.py index 07ffd7f463c..59c3f32697f 100644 --- a/torchrl/objectives/dqn.py +++ b/torchrl/objectives/dqn.py @@ -44,7 +44,9 @@ class DQNLoss(LossModule): Defaults to "l2". delay_value (bool, optional): whether to duplicate the value network into a new target value network to - create a double DQN. Default is ``False``. + create a DQN with a target network. Default is ``False``. + double_dqn (bool, optional): whether or not to use Double DQN, as described in + https://arxiv.org/abs/1509.06461. Defaults to ``False``. action_space (str or TensorSpec, optional): Action space. Must be one of ``"one-hot"``, ``"mult_one_hot"``, ``"binary"`` or ``"categorical"``, or an instance of the corresponding specs (:class:`torchrl.data.OneHotDiscreteTensorSpec`, @@ -164,13 +166,27 @@ def __init__( value_network: Union[QValueActor, nn.Module], *, loss_function: Optional[str] = "l2", - delay_value: bool = False, + delay_value: bool = None, + double_dqn: bool = False, gamma: float = None, action_space: Union[str, TensorSpec] = None, priority_key: str = None, ) -> None: + if delay_value is None: + warnings.warn( + f"You did not provide a delay_value argument for {type(self)}. " + "Currently (v0.3) the default for delay_value is `False` but as of " + "v0.4 it will be `True`. Make sure to adapt your code depending " + "on your preferred configuration. " + "To remove this warning, indicate the value of delay_value in your " + "script." + ) + delay_value = False super().__init__() self._in_keys = None + if double_dqn and not delay_value: + raise ValueError("double_dqn=True requires delay_value=True.") + self.double_dqn = double_dqn self._set_deprecated_ctor_keys(priority=priority_key) self.delay_value = delay_value value_network = ensure_tensordict_compatible( @@ -296,7 +312,7 @@ def forward(self, tensordict: TensorDictBase) -> TensorDict: pred_val = td_copy.get(self.tensor_keys.action_value) if self.action_space == "categorical": - if action.shape != pred_val.shape: + if action.ndim != pred_val.ndim: # unsqueeze the action if it lacks on trailing singleton dim action = action.unsqueeze(-1) pred_val_index = torch.gather(pred_val, -1, index=action).squeeze(-1) @@ -304,8 +320,32 @@ def forward(self, tensordict: TensorDictBase) -> TensorDict: action = action.to(torch.float) pred_val_index = (pred_val * action).sum(-1) + if self.double_dqn: + step_td = step_mdp(td_copy, keep_other=False) + step_td_copy = step_td.clone(False) + # Use online network to compute the action + with self.value_network_params.data.to_module(self.value_network): + self.value_network(step_td) + next_action = step_td.get(self.tensor_keys.action) + + # Use target network to compute the values + with self.target_value_network_params.to_module(self.value_network): + self.value_network(step_td_copy) + next_pred_val = step_td_copy.get(self.tensor_keys.action_value) + + if self.action_space == "categorical": + if next_action.ndim != next_pred_val.ndim: + # unsqueeze the action if it lacks on trailing singleton dim + next_action = next_action.unsqueeze(-1) + next_value = torch.gather(next_pred_val, -1, index=next_action) + else: + next_value = (next_pred_val * next_action).sum(-1, keepdim=True) + else: + next_value = None target_value = self.value_estimator.value_estimate( - td_copy, target_params=self.target_value_network_params + td_copy, + target_params=self.target_value_network_params, + next_value=next_value, ).squeeze(-1) with torch.no_grad(): @@ -369,9 +409,9 @@ class _AcceptedKeys: Defaults to ``"td_error"``. reward (NestedKey): The input tensordict key where the reward is expected. Defaults to ``"reward"``. - done (NestedKey): The input tensordict key where the the flag if a trajectory is done is expected. + done (NestedKey): The input tensordict key where the flag if a trajectory is done is expected. Defaults to ``"done"``. - terminated (NestedKey): The input tensordict key where the the flag if a trajectory is done is expected. + terminated (NestedKey): The input tensordict key where the flag if a trajectory is done is expected. Defaults to ``"terminated"``. steps_to_next_obs (NestedKey): The input tensordict key where the steps_to_next_obs is exptected. Defaults to ``"steps_to_next_obs"``. @@ -392,9 +432,19 @@ def __init__( self, value_network: Union[DistributionalQValueActor, nn.Module], gamma: float, - delay_value: bool = False, + delay_value: bool = None, priority_key: str = None, ): + if delay_value is None: + warnings.warn( + f"You did not provide a delay_value argument for {type(self)}. " + "Currently (v0.3) the default for delay_value is `False` but as of " + "v0.4 it will be `True`. Make sure to adapt your code depending " + "on your preferred configuration. " + "To remove this warning, indicate the value of delay_value in your " + "script." + ) + delay_value = False super().__init__() self._set_deprecated_ctor_keys(priority=priority_key) self.register_buffer("gamma", torch.tensor(gamma)) diff --git a/torchrl/objectives/multiagent/qmixer.py b/torchrl/objectives/multiagent/qmixer.py index 23947696c9f..38f56108784 100644 --- a/torchrl/objectives/multiagent/qmixer.py +++ b/torchrl/objectives/multiagent/qmixer.py @@ -189,11 +189,21 @@ def __init__( mixer_network: Union[TensorDictModule, nn.Module], *, loss_function: Optional[str] = "l2", - delay_value: bool = False, + delay_value: bool = None, gamma: float = None, action_space: Union[str, TensorSpec] = None, priority_key: str = None, ) -> None: + if delay_value is None: + warnings.warn( + f"You did not provide a delay_value argument for {type(self)}. " + "Currently (v0.3) the default for delay_value is `False` but as of " + "v0.4 it will be `True`. Make sure to adapt your code depending " + "on your preferred configuration. " + "To remove this warning, indicate the value of delay_value in your " + "script." + ) + delay_value = False super().__init__() self._in_keys = None self._set_deprecated_ctor_keys(priority=priority_key)