Skip to content

Commit

Permalink
[BugFix] DQN loss dispatch respect configured tensordict keys (#1285)
Browse files Browse the repository at this point in the history
  • Loading branch information
Blonck authored Jun 15, 2023
1 parent 1faea14 commit 2dbdec9
Show file tree
Hide file tree
Showing 2 changed files with 103 additions and 37 deletions.
73 changes: 50 additions & 23 deletions test/test_cost.py
Original file line number Diff line number Diff line change
Expand Up @@ -480,20 +480,46 @@ def test_dqn_batcher(self, n, delay_value, device, action_spec_type, gamma=0.9):
p.data += torch.randn_like(p)
assert all((p1 != p2).all() for p1, p2 in zip(parameters, actor.parameters()))

def test_dqn_tensordict_keys(self):
@pytest.mark.parametrize(
"td_est", [ValueEstimators.TD1, ValueEstimators.TD0, ValueEstimators.TDLambda]
)
def test_dqn_tensordict_keys(self, td_est):
torch.manual_seed(self.seed)
action_spec_type = "one_hot"
actor = self._create_mock_actor(action_spec_type=action_spec_type)
loss_fn = DQNLoss(actor)

default_keys = {
"advantage": "advantage",
"value_target": "value_target",
"value": "chosen_action_value",
"priority": "td_error",
"action_value": "action_value",
"action": "action",
"reward": "reward",
"done": "done",
}

self.tensordict_keys_test(loss_fn, default_keys=default_keys)

loss_fn = DQNLoss(actor)
key_mapping = {
"advantage": ("advantage", "advantage_2"),
"value_target": ("value_target", ("value_target", "nested")),
"reward": ("reward", "reward_test"),
"done": ("done", ("done", "test")),
}
self.set_advantage_keys_through_loss_test(loss_fn, td_est, key_mapping)

actor = self._create_mock_actor(
action_spec_type=action_spec_type, action_value_key="chosen_action_value_2"
)
loss_fn = DQNLoss(actor)
key_mapping = {
"value": ("value", "chosen_action_value_2"),
}
self.set_advantage_keys_through_loss_test(loss_fn, td_est, key_mapping)

@pytest.mark.parametrize("action_spec_type", ("categorical", "one_hot"))
@pytest.mark.parametrize(
"td_est", [ValueEstimators.TD1, ValueEstimators.TD0, ValueEstimators.TDLambda]
Expand Down Expand Up @@ -577,37 +603,38 @@ def test_distributional_dqn(
p.data += torch.randn_like(p)
assert all((p1 != p2).all() for p1, p2 in zip(parameters, actor.parameters()))

def test_dqn_notensordict(self):
@pytest.mark.parametrize("observation_key", ["observation", "observation2"])
@pytest.mark.parametrize("reward_key", ["reward", "reward2"])
@pytest.mark.parametrize("done_key", ["done", "done2"])
def test_dqn_notensordict(self, observation_key, reward_key, done_key):
n_obs = 3
n_action = 4
action_spec = OneHotDiscreteTensorSpec(n_action)
value_network = nn.Linear(n_obs, n_action) # a simple value model
dqn_loss = DQNLoss(value_network, action_space=action_spec)
module = nn.Linear(n_obs, n_action) # a simple value model
actor = QValueActor(
spec=action_spec,
action_space="one_hot",
module=module,
in_keys=[observation_key],
)
dqn_loss = DQNLoss(actor)
dqn_loss.set_keys(reward=reward_key, done=done_key)
# define data
observation = torch.randn(n_obs)
next_observation = torch.randn(n_obs)
action = action_spec.rand()
next_reward = torch.randn(1)
next_done = torch.zeros(1, dtype=torch.bool)
loss_val = dqn_loss(
observation=observation,
next_observation=next_observation,
next_reward=next_reward,
next_done=next_done,
action=action,
)
loss_val_td = dqn_loss(
TensorDict(
{
"observation": observation,
"next_observation": next_observation,
"next_reward": next_reward,
"next_done": next_done,
"action": action,
},
[],
).unflatten_keys("_")
)
kwargs = {
observation_key: observation,
f"next_{observation_key}": next_observation,
f"next_{reward_key}": next_reward,
f"next_{done_key}": next_done,
"action": action,
}
td = TensorDict(kwargs, []).unflatten_keys("_")
loss_val = dqn_loss(**kwargs)
loss_val_td = dqn_loss(td)
torch.testing.assert_close(loss_val_td.get("loss"), loss_val)

def test_distributional_dqn_tensordict_keys(self):
Expand Down
67 changes: 53 additions & 14 deletions torchrl/objectives/dqn.py
Original file line number Diff line number Diff line change
Expand Up @@ -119,20 +119,37 @@ class _AcceptedKeys:
default values.
Attributes:
advantage (NestedKey): The input tensordict key where the advantage is expected.
Will be used for the underlying value estimator. Defaults to ``"advantage"``.
value_target (NestedKey): The input tensordict key where the target state value is expected.
Will be used for the underlying value estimator Defaults to ``"value_target"``.
value (NestedKey): The input tensordict key where the state value is expected.
Will be used for the underlying value estimator. Defaults to ``"state_value"``.
state_action_value (NestedKey): The input tensordict key where the state action value is expected.
Defaults to ``"state_action_value"``.
action (NestedKey): The input tensordict key where the action is expected.
Defaults to ``"action"``.
priority (NestedKey): The input tensordict key where the target priority is written to.
Defaults to ``"td_error"``.
reward (NestedKey): The input tensordict key where the reward is expected.
Will be used for the underlying value estimator. Defaults to ``"reward"``.
done (NestedKey): The key in the input TensorDict that indicates
whether a trajectory is done. Will be used for the underlying value estimator.
Defaults to ``"done"``.
"""

advantage: NestedKey = "advantage"
value_target: NestedKey = "value_target"
value: NestedKey = "chosen_action_value"
action_value: NestedKey = "action_value"
action: NestedKey = "action"
priority: NestedKey = "td_error"
reward: NestedKey = "reward"
done: NestedKey = "done"

default_keys = _AcceptedKeys()
default_value_estimator = ValueEstimators.TD0
out_keys = ["loss"]

def __init__(
self,
Expand All @@ -146,6 +163,7 @@ def __init__(
) -> None:

super().__init__()
self._in_keys = None
self._set_deprecated_ctor_keys(priority=priority_key)
self.delay_value = delay_value
value_network = ensure_tensordict_compatible(
Expand Down Expand Up @@ -187,7 +205,35 @@ def __init__(
self.gamma = gamma

def _forward_value_estimator_keys(self, **kwargs) -> None:
pass
if self._value_estimator is not None:
self._value_estimator.set_keys(
advantage=self.tensor_keys.advantage,
value_target=self.tensor_keys.value_target,
value=self.tensor_keys.value,
reward=self.tensor_keys.reward,
done=self.tensor_keys.done,
)
self._set_in_keys()

def _set_in_keys(self):
keys = [
self.tensor_keys.action,
("next", self.tensor_keys.reward),
("next", self.tensor_keys.done),
*self.value_network.in_keys,
*[("next", key) for key in self.value_network.in_keys],
]
self._in_keys = list(set(keys))

@property
def in_keys(self):
if self._in_keys is None:
self._set_in_keys()
return self._in_keys

@in_keys.setter
def in_keys(self, values):
self._in_keys = values

def make_value_estimator(self, value_type: ValueEstimators = None, **hyperparams):
if value_type is None:
Expand All @@ -213,22 +259,15 @@ def make_value_estimator(self, value_type: ValueEstimators = None, **hyperparams
raise NotImplementedError(f"Unknown value type {value_type}")

tensor_keys = {
"advantage": "advantage",
"value_target": "value_target",
"value": "chosen_action_value",
"advantage": self.tensor_keys.advantage,
"value_target": self.tensor_keys.value_target,
"value": self.tensor_keys.value,
"reward": self.tensor_keys.reward,
"done": self.tensor_keys.done,
}
self._value_estimator.set_keys(**tensor_keys)

@dispatch(
source=[
"observation",
("next", "observation"),
"action",
("next", "reward"),
("next", "done"),
],
dest=["loss"],
)
@dispatch
def forward(self, tensordict: TensorDictBase) -> TensorDict:
"""Computes the DQN loss given a tensordict sampled from the replay buffer.
Expand Down

0 comments on commit 2dbdec9

Please sign in to comment.