Skip to content

Commit

Permalink
[BugFix, Feature] Fix DDQN implementation (#1737)
Browse files Browse the repository at this point in the history
  • Loading branch information
vmoens authored Dec 7, 2023
1 parent a772a43 commit ee89728
Show file tree
Hide file tree
Showing 5 changed files with 78 additions and 14 deletions.
10 changes: 7 additions & 3 deletions test/test_cost.py
Original file line number Diff line number Diff line change
Expand Up @@ -426,19 +426,23 @@ def _create_seq_mock_data_dqn(
)
return td

@pytest.mark.parametrize("delay_value", (False, True))
@pytest.mark.parametrize(
"delay_value,double_dqn", ([False, False], [True, False], [True, True])
)
@pytest.mark.parametrize("device", get_default_devices())
@pytest.mark.parametrize("action_spec_type", ("one_hot", "categorical"))
@pytest.mark.parametrize("td_est", list(ValueEstimators) + [None])
def test_dqn(self, delay_value, device, action_spec_type, td_est):
def test_dqn(self, delay_value, double_dqn, device, action_spec_type, td_est):
torch.manual_seed(self.seed)
actor = self._create_mock_actor(
action_spec_type=action_spec_type, device=device
)
td = self._create_mock_data_dqn(
action_spec_type=action_spec_type, device=device
)
loss_fn = DQNLoss(actor, loss_function="l2", delay_value=delay_value)
loss_fn = DQNLoss(
actor, loss_function="l2", delay_value=delay_value, double_dqn=double_dqn
)
if td_est in (ValueEstimators.GAE, ValueEstimators.VTrace):
with pytest.raises(NotImplementedError):
loss_fn.make_value_estimator(td_est)
Expand Down
2 changes: 1 addition & 1 deletion torchrl/envs/libs/robohive.py
Original file line number Diff line number Diff line change
Expand Up @@ -68,7 +68,7 @@ class RoboHiveEnv(GymEnv, metaclass=_RoboHiveBuild):
Args:
env_name (str): the environment name to build.
read_info (bool, optional): whether the the info should be parsed.
read_info (bool, optional): whether the info should be parsed.
Defaults to ``True``.
device (torch.device, optional): the device on which the input/output
are expected. Defaults to torch default device.
Expand Down
4 changes: 2 additions & 2 deletions torchrl/modules/distributions/discrete.py
Original file line number Diff line number Diff line change
Expand Up @@ -163,7 +163,7 @@ class MaskedCategorical(D.Categorical):
must be taken into account. Exclusive with ``mask``.
neg_inf (float, optional): The log-probability value allocated to
invalid (out-of-mask) indices. Defaults to -inf.
padding_value: The padding value in the then mask tensor when
padding_value: The padding value in the mask tensor. When
sparse_mask == True, the padding_value will be ignored.
>>> torch.manual_seed(0)
Expand Down Expand Up @@ -314,7 +314,7 @@ class MaskedOneHotCategorical(MaskedCategorical):
must be taken into account. Exclusive with ``mask``.
neg_inf (float, optional): The log-probability value allocated to
invalid (out-of-mask) indices. Defaults to -inf.
padding_value: The padding value in the then mask tensor when
padding_value: The padding value in then mask tensor when
sparse_mask == True, the padding_value will be ignored.
grad_method (ReparamGradientStrategy, optional): strategy to gather
reparameterized samples.
Expand Down
64 changes: 57 additions & 7 deletions torchrl/objectives/dqn.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,7 +44,9 @@ class DQNLoss(LossModule):
Defaults to "l2".
delay_value (bool, optional): whether to duplicate the value network
into a new target value network to
create a double DQN. Default is ``False``.
create a DQN with a target network. Default is ``False``.
double_dqn (bool, optional): whether or not to use Double DQN, as described in
https://arxiv.org/abs/1509.06461. Defaults to ``False``.
action_space (str or TensorSpec, optional): Action space. Must be one of
``"one-hot"``, ``"mult_one_hot"``, ``"binary"`` or ``"categorical"``,
or an instance of the corresponding specs (:class:`torchrl.data.OneHotDiscreteTensorSpec`,
Expand Down Expand Up @@ -164,13 +166,27 @@ def __init__(
value_network: Union[QValueActor, nn.Module],
*,
loss_function: Optional[str] = "l2",
delay_value: bool = False,
delay_value: bool = None,
double_dqn: bool = False,
gamma: float = None,
action_space: Union[str, TensorSpec] = None,
priority_key: str = None,
) -> None:
if delay_value is None:
warnings.warn(
f"You did not provide a delay_value argument for {type(self)}. "
"Currently (v0.3) the default for delay_value is `False` but as of "
"v0.4 it will be `True`. Make sure to adapt your code depending "
"on your preferred configuration. "
"To remove this warning, indicate the value of delay_value in your "
"script."
)
delay_value = False
super().__init__()
self._in_keys = None
if double_dqn and not delay_value:
raise ValueError("double_dqn=True requires delay_value=True.")
self.double_dqn = double_dqn
self._set_deprecated_ctor_keys(priority=priority_key)
self.delay_value = delay_value
value_network = ensure_tensordict_compatible(
Expand Down Expand Up @@ -296,16 +312,40 @@ def forward(self, tensordict: TensorDictBase) -> TensorDict:
pred_val = td_copy.get(self.tensor_keys.action_value)

if self.action_space == "categorical":
if action.shape != pred_val.shape:
if action.ndim != pred_val.ndim:
# unsqueeze the action if it lacks on trailing singleton dim
action = action.unsqueeze(-1)
pred_val_index = torch.gather(pred_val, -1, index=action).squeeze(-1)
else:
action = action.to(torch.float)
pred_val_index = (pred_val * action).sum(-1)

if self.double_dqn:
step_td = step_mdp(td_copy, keep_other=False)
step_td_copy = step_td.clone(False)
# Use online network to compute the action
with self.value_network_params.data.to_module(self.value_network):
self.value_network(step_td)
next_action = step_td.get(self.tensor_keys.action)

# Use target network to compute the values
with self.target_value_network_params.to_module(self.value_network):
self.value_network(step_td_copy)
next_pred_val = step_td_copy.get(self.tensor_keys.action_value)

if self.action_space == "categorical":
if next_action.ndim != next_pred_val.ndim:
# unsqueeze the action if it lacks on trailing singleton dim
next_action = next_action.unsqueeze(-1)
next_value = torch.gather(next_pred_val, -1, index=next_action)
else:
next_value = (next_pred_val * next_action).sum(-1, keepdim=True)
else:
next_value = None
target_value = self.value_estimator.value_estimate(
td_copy, target_params=self.target_value_network_params
td_copy,
target_params=self.target_value_network_params,
next_value=next_value,
).squeeze(-1)

with torch.no_grad():
Expand Down Expand Up @@ -369,9 +409,9 @@ class _AcceptedKeys:
Defaults to ``"td_error"``.
reward (NestedKey): The input tensordict key where the reward is expected.
Defaults to ``"reward"``.
done (NestedKey): The input tensordict key where the the flag if a trajectory is done is expected.
done (NestedKey): The input tensordict key where the flag if a trajectory is done is expected.
Defaults to ``"done"``.
terminated (NestedKey): The input tensordict key where the the flag if a trajectory is done is expected.
terminated (NestedKey): The input tensordict key where the flag if a trajectory is done is expected.
Defaults to ``"terminated"``.
steps_to_next_obs (NestedKey): The input tensordict key where the steps_to_next_obs is exptected.
Defaults to ``"steps_to_next_obs"``.
Expand All @@ -392,9 +432,19 @@ def __init__(
self,
value_network: Union[DistributionalQValueActor, nn.Module],
gamma: float,
delay_value: bool = False,
delay_value: bool = None,
priority_key: str = None,
):
if delay_value is None:
warnings.warn(
f"You did not provide a delay_value argument for {type(self)}. "
"Currently (v0.3) the default for delay_value is `False` but as of "
"v0.4 it will be `True`. Make sure to adapt your code depending "
"on your preferred configuration. "
"To remove this warning, indicate the value of delay_value in your "
"script."
)
delay_value = False
super().__init__()
self._set_deprecated_ctor_keys(priority=priority_key)
self.register_buffer("gamma", torch.tensor(gamma))
Expand Down
12 changes: 11 additions & 1 deletion torchrl/objectives/multiagent/qmixer.py
Original file line number Diff line number Diff line change
Expand Up @@ -189,11 +189,21 @@ def __init__(
mixer_network: Union[TensorDictModule, nn.Module],
*,
loss_function: Optional[str] = "l2",
delay_value: bool = False,
delay_value: bool = None,
gamma: float = None,
action_space: Union[str, TensorSpec] = None,
priority_key: str = None,
) -> None:
if delay_value is None:
warnings.warn(
f"You did not provide a delay_value argument for {type(self)}. "
"Currently (v0.3) the default for delay_value is `False` but as of "
"v0.4 it will be `True`. Make sure to adapt your code depending "
"on your preferred configuration. "
"To remove this warning, indicate the value of delay_value in your "
"script."
)
delay_value = False
super().__init__()
self._in_keys = None
self._set_deprecated_ctor_keys(priority=priority_key)
Expand Down

0 comments on commit ee89728

Please sign in to comment.