Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[BugFix] DQN loss dispatch respect configured tensordict keys #1285

Merged
merged 1 commit into from
Jun 15, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
73 changes: 50 additions & 23 deletions test/test_cost.py
Original file line number Diff line number Diff line change
Expand Up @@ -480,20 +480,46 @@ def test_dqn_batcher(self, n, delay_value, device, action_spec_type, gamma=0.9):
p.data += torch.randn_like(p)
assert all((p1 != p2).all() for p1, p2 in zip(parameters, actor.parameters()))

def test_dqn_tensordict_keys(self):
@pytest.mark.parametrize(
"td_est", [ValueEstimators.TD1, ValueEstimators.TD0, ValueEstimators.TDLambda]
)
def test_dqn_tensordict_keys(self, td_est):
torch.manual_seed(self.seed)
action_spec_type = "one_hot"
actor = self._create_mock_actor(action_spec_type=action_spec_type)
loss_fn = DQNLoss(actor)

default_keys = {
"advantage": "advantage",
"value_target": "value_target",
"value": "chosen_action_value",
"priority": "td_error",
"action_value": "action_value",
"action": "action",
"reward": "reward",
"done": "done",
}

self.tensordict_keys_test(loss_fn, default_keys=default_keys)

loss_fn = DQNLoss(actor)
key_mapping = {
"advantage": ("advantage", "advantage_2"),
"value_target": ("value_target", ("value_target", "nested")),
"reward": ("reward", "reward_test"),
"done": ("done", ("done", "test")),
}
self.set_advantage_keys_through_loss_test(loss_fn, td_est, key_mapping)

actor = self._create_mock_actor(
action_spec_type=action_spec_type, action_value_key="chosen_action_value_2"
)
loss_fn = DQNLoss(actor)
key_mapping = {
"value": ("value", "chosen_action_value_2"),
}
self.set_advantage_keys_through_loss_test(loss_fn, td_est, key_mapping)

@pytest.mark.parametrize("action_spec_type", ("categorical", "one_hot"))
@pytest.mark.parametrize(
"td_est", [ValueEstimators.TD1, ValueEstimators.TD0, ValueEstimators.TDLambda]
Expand Down Expand Up @@ -577,37 +603,38 @@ def test_distributional_dqn(
p.data += torch.randn_like(p)
assert all((p1 != p2).all() for p1, p2 in zip(parameters, actor.parameters()))

def test_dqn_notensordict(self):
@pytest.mark.parametrize("observation_key", ["observation", "observation2"])
@pytest.mark.parametrize("reward_key", ["reward", "reward2"])
@pytest.mark.parametrize("done_key", ["done", "done2"])
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

do we skip the action key by design?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

No it is intended.
Atm, the action key cannot really be configured since it is used in the constructor of the DQN loss via _find_action_space(action_space).

Either I need to remove the action key from the configurable keys or the following part of the constructor must be moved until .set_keys() is called:

        if action_space is None:
            # infer from value net
            try:
                action_space = value_network.spec
            except AttributeError:
                # let's try with action_space then
                try:
                    action_space = value_network.action_space
                except AttributeError:
                    raise ValueError(self.ACTION_SPEC_ERROR)
        if action_space is None:
            warnings.warn(
                "action_space was not specified. DQNLoss will default to 'one-hot'."
                "This behaviour will be deprecated soon and a space will have to be passed."
                "Check the DQNLoss documentation to see how to pass the action space. "
            )
            action_space = "one-hot"
        self.action_space = _find_action_space(action_space)

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Got it thanks!

def test_dqn_notensordict(self, observation_key, reward_key, done_key):
n_obs = 3
n_action = 4
action_spec = OneHotDiscreteTensorSpec(n_action)
value_network = nn.Linear(n_obs, n_action) # a simple value model
dqn_loss = DQNLoss(value_network, action_space=action_spec)
module = nn.Linear(n_obs, n_action) # a simple value model
actor = QValueActor(
spec=action_spec,
action_space="one_hot",
module=module,
in_keys=[observation_key],
)
dqn_loss = DQNLoss(actor)
dqn_loss.set_keys(reward=reward_key, done=done_key)
# define data
observation = torch.randn(n_obs)
next_observation = torch.randn(n_obs)
action = action_spec.rand()
next_reward = torch.randn(1)
next_done = torch.zeros(1, dtype=torch.bool)
loss_val = dqn_loss(
observation=observation,
next_observation=next_observation,
next_reward=next_reward,
next_done=next_done,
action=action,
)
loss_val_td = dqn_loss(
TensorDict(
{
"observation": observation,
"next_observation": next_observation,
"next_reward": next_reward,
"next_done": next_done,
"action": action,
},
[],
).unflatten_keys("_")
)
kwargs = {
observation_key: observation,
f"next_{observation_key}": next_observation,
f"next_{reward_key}": next_reward,
f"next_{done_key}": next_done,
"action": action,
}
td = TensorDict(kwargs, []).unflatten_keys("_")
loss_val = dqn_loss(**kwargs)
loss_val_td = dqn_loss(td)
torch.testing.assert_close(loss_val_td.get("loss"), loss_val)

def test_distributional_dqn_tensordict_keys(self):
Expand Down
67 changes: 53 additions & 14 deletions torchrl/objectives/dqn.py
Original file line number Diff line number Diff line change
Expand Up @@ -119,20 +119,37 @@ class _AcceptedKeys:
default values.

Attributes:
advantage (NestedKey): The input tensordict key where the advantage is expected.
Will be used for the underlying value estimator. Defaults to ``"advantage"``.
value_target (NestedKey): The input tensordict key where the target state value is expected.
Will be used for the underlying value estimator Defaults to ``"value_target"``.
value (NestedKey): The input tensordict key where the state value is expected.
Will be used for the underlying value estimator. Defaults to ``"state_value"``.
state_action_value (NestedKey): The input tensordict key where the state action value is expected.
Defaults to ``"state_action_value"``.
action (NestedKey): The input tensordict key where the action is expected.
Defaults to ``"action"``.
priority (NestedKey): The input tensordict key where the target priority is written to.
Defaults to ``"td_error"``.
reward (NestedKey): The input tensordict key where the reward is expected.
Will be used for the underlying value estimator. Defaults to ``"reward"``.
done (NestedKey): The key in the input TensorDict that indicates
whether a trajectory is done. Will be used for the underlying value estimator.
Defaults to ``"done"``.
"""

advantage: NestedKey = "advantage"
value_target: NestedKey = "value_target"
value: NestedKey = "chosen_action_value"
action_value: NestedKey = "action_value"
action: NestedKey = "action"
priority: NestedKey = "td_error"
reward: NestedKey = "reward"
done: NestedKey = "done"

default_keys = _AcceptedKeys()
default_value_estimator = ValueEstimators.TD0
out_keys = ["loss"]

def __init__(
self,
Expand All @@ -146,6 +163,7 @@ def __init__(
) -> None:

super().__init__()
self._in_keys = None
self._set_deprecated_ctor_keys(priority=priority_key)
self.delay_value = delay_value
value_network = ensure_tensordict_compatible(
Expand Down Expand Up @@ -187,7 +205,35 @@ def __init__(
self.gamma = gamma

def _forward_value_estimator_keys(self, **kwargs) -> None:
pass
if self._value_estimator is not None:
self._value_estimator.set_keys(
advantage=self.tensor_keys.advantage,
value_target=self.tensor_keys.value_target,
value=self.tensor_keys.value,
reward=self.tensor_keys.reward,
done=self.tensor_keys.done,
)
self._set_in_keys()

def _set_in_keys(self):
keys = [
self.tensor_keys.action,
("next", self.tensor_keys.reward),
("next", self.tensor_keys.done),
*self.value_network.in_keys,
*[("next", key) for key in self.value_network.in_keys],
]
self._in_keys = list(set(keys))

@property
def in_keys(self):
if self._in_keys is None:
self._set_in_keys()
return self._in_keys

@in_keys.setter
def in_keys(self, values):
self._in_keys = values

def make_value_estimator(self, value_type: ValueEstimators = None, **hyperparams):
if value_type is None:
Expand All @@ -213,22 +259,15 @@ def make_value_estimator(self, value_type: ValueEstimators = None, **hyperparams
raise NotImplementedError(f"Unknown value type {value_type}")

tensor_keys = {
"advantage": "advantage",
"value_target": "value_target",
"value": "chosen_action_value",
"advantage": self.tensor_keys.advantage,
"value_target": self.tensor_keys.value_target,
"value": self.tensor_keys.value,
"reward": self.tensor_keys.reward,
"done": self.tensor_keys.done,
}
self._value_estimator.set_keys(**tensor_keys)

@dispatch(
source=[
"observation",
("next", "observation"),
"action",
("next", "reward"),
("next", "done"),
],
dest=["loss"],
)
@dispatch
def forward(self, tensordict: TensorDictBase) -> TensorDict:
"""Computes the DQN loss given a tensordict sampled from the replay buffer.

Expand Down