Skip to content

Commit

Permalink
[Feature] Dispatch TD3 loss module (#1254)
Browse files Browse the repository at this point in the history
  • Loading branch information
Blonck authored Jun 12, 2023
1 parent acd65f6 commit ea6f872
Show file tree
Hide file tree
Showing 2 changed files with 217 additions and 16 deletions.
94 changes: 80 additions & 14 deletions test/test_cost.py
Original file line number Diff line number Diff line change
Expand Up @@ -1011,20 +1011,34 @@ def test_ddpg_notensordict(self):
class TestTD3(LossModuleTestBase):
seed = 0

def _create_mock_actor(self, batch=2, obs_dim=3, action_dim=4, device="cpu"):
def _create_mock_actor(
self,
batch=2,
obs_dim=3,
action_dim=4,
device="cpu",
in_keys=None,
out_keys=None,
):
# Actor
action_spec = BoundedTensorSpec(
-torch.ones(action_dim), torch.ones(action_dim), (action_dim,)
)
module = nn.Linear(obs_dim, action_dim)
actor = Actor(
spec=action_spec,
module=module,
spec=action_spec, module=module, in_keys=in_keys, out_keys=out_keys
)
return actor.to(device)

def _create_mock_value(
self, batch=2, obs_dim=3, action_dim=4, device="cpu", out_keys=None
self,
batch=2,
obs_dim=3,
action_dim=4,
device="cpu",
out_keys=None,
action_key="action",
observation_key="observation",
):
# Actor
class ValueClass(nn.Module):
Expand All @@ -1038,7 +1052,7 @@ def forward(self, obs, act):
module = ValueClass()
value = ValueOperator(
module=module,
in_keys=["observation", "action"],
in_keys=[observation_key, action_key],
out_keys=out_keys,
)
return value.to(device)
Expand All @@ -1049,7 +1063,16 @@ def _create_mock_distributional_actor(
raise NotImplementedError

def _create_mock_data_td3(
self, batch=8, obs_dim=3, action_dim=4, atoms=None, device="cpu"
self,
batch=8,
obs_dim=3,
action_dim=4,
atoms=None,
device="cpu",
action_key="action",
observation_key="observation",
reward_key="reward",
done_key="done",
):
# create a tensordict
obs = torch.randn(batch, obs_dim, device=device)
Expand All @@ -1063,13 +1086,13 @@ def _create_mock_data_td3(
td = TensorDict(
batch_size=(batch,),
source={
"observation": obs,
observation_key: obs,
"next": {
"observation": next_obs,
"done": done,
"reward": reward,
observation_key: next_obs,
done_key: done,
reward_key: reward,
},
"action": action,
action_key: action,
},
device=device,
)
Expand Down Expand Up @@ -1311,6 +1334,8 @@ def test_td3_tensordict_keys(self, td_est):
"priority": "td_error",
"state_action_value": "state_action_value",
"action": "action",
"reward": "reward",
"done": "done",
}

self.tensordict_keys_test(
Expand All @@ -1320,12 +1345,16 @@ def test_td3_tensordict_keys(self, td_est):
)

value = self._create_mock_value(out_keys=["state_action_value_test"])
loss_fn = DDPGLoss(
loss_fn = TD3Loss(
actor,
value,
loss_function="l2",
action_spec=actor.spec,
)
key_mapping = {"state_action_value": ("value", "state_action_value_test")}
key_mapping = {
"state_action_value": ("value", "state_action_value_test"),
"reward": ("reward", "reward_test"),
"done": ("done", ("done", "test")),
}
self.set_advantage_keys_through_loss_test(loss_fn, td_est, key_mapping)

@pytest.mark.parametrize("spec", [True, False])
Expand Down Expand Up @@ -1353,6 +1382,43 @@ def test_constructor(self, spec, bounds):
bounds=bounds,
)

# TODO: test for action_key, atm the action key of the TD3 loss is not configurable,
# since it is used in it's constructor
@pytest.mark.parametrize("observation_key", ["observation", "observation2"])
@pytest.mark.parametrize("reward_key", ["reward", "reward2"])
@pytest.mark.parametrize("done_key", ["done", "done2"])
def test_td3_notensordict(self, observation_key, reward_key, done_key):

torch.manual_seed(self.seed)
actor = self._create_mock_actor(in_keys=[observation_key])
qvalue = self._create_mock_value(
observation_key=observation_key, out_keys=["state_action_value"]
)
td = self._create_mock_data_td3(
observation_key=observation_key, reward_key=reward_key, done_key=done_key
)
loss = TD3Loss(actor, qvalue, action_spec=actor.spec)
loss.set_keys(reward=reward_key, done=done_key)

kwargs = {
observation_key: td.get(observation_key),
f"next_{reward_key}": td.get(("next", reward_key)),
f"next_{done_key}": td.get(("next", done_key)),
f"next_{observation_key}": td.get(("next", observation_key)),
"action": td.get("action"),
}
td = TensorDict(kwargs, td.batch_size).unflatten_keys("_")

loss_val_td = loss(td)
loss_val = loss(**kwargs)
for i, key in enumerate(loss_val_td.keys()):
torch.testing.assert_close(loss_val_td.get(key), loss_val[i])
# test select
loss.select_out_keys("loss_actor", "loss_qvalue")
loss_actor, loss_qvalue = loss(**kwargs)
assert loss_actor == loss_val_td["loss_actor"]
assert loss_qvalue == loss_val_td["loss_qvalue"]


@pytest.mark.skipif(
not _has_functorch, reason=f"functorch not installed: {FUNCTORCH_ERR}"
Expand Down
139 changes: 137 additions & 2 deletions torchrl/objectives/td3.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@
from typing import Optional, Tuple

import torch
from tensordict.nn import TensorDictModule
from tensordict.nn import dispatch, TensorDictModule

from tensordict.tensordict import TensorDict, TensorDictBase
from tensordict.utils import NestedKey
Expand Down Expand Up @@ -68,6 +68,96 @@ class TD3Loss(LossModule):
delay_qvalue (bool, optional): Whether to separate the target Q value
networks from the Q value networks used
for data collection. Default is ``True``.
Examples:
>>> import torch
>>> from torch import nn
>>> from torchrl.data import BoundedTensorSpec
>>> from torchrl.modules.distributions.continuous import NormalParamWrapper, TanhNormal
>>> from torchrl.modules.tensordict_module.actors import Actor, ProbabilisticActor, ValueOperator
>>> from torchrl.modules.tensordict_module.common import SafeModule
>>> from torchrl.objectives.td3 import TD3Loss
>>> from tensordict.tensordict import TensorDict
>>> n_act, n_obs = 4, 3
>>> spec = BoundedTensorSpec(-torch.ones(n_act), torch.ones(n_act), (n_act,))
>>> module = nn.Linear(n_obs, n_act)
>>> actor = Actor(
... module=module,
... spec=spec)
>>> class ValueClass(nn.Module):
... def __init__(self):
... super().__init__()
... self.linear = nn.Linear(n_obs + n_act, 1)
... def forward(self, obs, act):
... return self.linear(torch.cat([obs, act], -1))
>>> module = ValueClass()
>>> qvalue = ValueOperator(
... module=module,
... in_keys=['observation', 'action'])
>>> loss = TD3Loss(actor, qvalue, action_spec=actor.spec)
>>> batch = [2, ]
>>> action = spec.rand(batch)
>>> data = TensorDict({
... "observation": torch.randn(*batch, n_obs),
... "action": action,
... ("next", "done"): torch.zeros(*batch, 1, dtype=torch.bool),
... ("next", "reward"): torch.randn(*batch, 1),
... ("next", "observation"): torch.randn(*batch, n_obs),
... }, batch)
>>> loss(data)
TensorDict(
fields={
loss_actor: Tensor(shape=torch.Size([]), device=cpu, dtype=torch.float32, is_shared=False),
loss_qvalue: Tensor(shape=torch.Size([]), device=cpu, dtype=torch.float32, is_shared=False),
next_state_value: Tensor(shape=torch.Size([]), device=cpu, dtype=torch.float32, is_shared=False),
pred_value: Tensor(shape=torch.Size([]), device=cpu, dtype=torch.float32, is_shared=False),
state_action_value_actor: Tensor(shape=torch.Size([]), device=cpu, dtype=torch.float32, is_shared=False),
target_value: Tensor(shape=torch.Size([]), device=cpu, dtype=torch.float32, is_shared=False)},
batch_size=torch.Size([]),
device=None,
is_shared=False)
This class is compatible with non-tensordict based modules too and can be
used without recurring to any tensordict-related primitive. In this case,
the expected keyword arguments are:
``["action", "next_reward", "next_done"]`` + in_keys of the actor and qvalue network
The return value is a tuple of tensors in the following order:
``["loss_actor", "loss_qvalue", "pred_value", "state_action_value_actor", "next_state_value", "target_value",]``.
Examples:
>>> import torch
>>> from torch import nn
>>> from torchrl.data import BoundedTensorSpec
>>> from torchrl.modules.tensordict_module.actors import Actor, ValueOperator
>>> from torchrl.objectives.td3 import TD3Loss
>>> n_act, n_obs = 4, 3
>>> spec = BoundedTensorSpec(-torch.ones(n_act), torch.ones(n_act), (n_act,))
>>> module = nn.Linear(n_obs, n_act)
>>> actor = Actor(
... module=module,
... spec=spec)
>>> class ValueClass(nn.Module):
... def __init__(self):
... super().__init__()
... self.linear = nn.Linear(n_obs + n_act, 1)
... def forward(self, obs, act):
... return self.linear(torch.cat([obs, act], -1))
>>> module = ValueClass()
>>> qvalue = ValueOperator(
... module=module,
... in_keys=['observation', 'action'])
>>> loss = TD3Loss(actor, qvalue, action_spec=actor.spec)
>>> _ = loss.select_out_keys("loss_actor", "loss_qvalue")
>>> batch = [2, ]
>>> action = spec.rand(batch)
>>> loss_actor, loss_qvalue = loss(
... observation=torch.randn(*batch, n_obs),
... action=action,
... next_done=torch.zeros(*batch, 1, dtype=torch.bool),
... next_reward=torch.randn(*batch, 1),
... next_observation=torch.randn(*batch, n_obs))
>>> loss_actor.backward()
"""

@dataclass
Expand All @@ -84,14 +174,29 @@ class _AcceptedKeys:
Will be used for the underlying value estimator. Defaults to ``"state_action_value"``.
priority (NestedKey): The input tensordict key where the target priority is written to.
Defaults to ``"td_error"``.
reward (NestedKey): The input tensordict key where the reward is expected.
Will be used for the underlying value estimator. Defaults to ``"reward"``.
done (NestedKey): The key in the input TensorDict that indicates
whether a trajectory is done. Will be used for the underlying value estimator.
Defaults to ``"done"``.
"""

action: NestedKey = "action"
state_action_value: NestedKey = "state_action_value"
priority: NestedKey = "td_error"
reward: NestedKey = "reward"
done: NestedKey = "done"

default_keys = _AcceptedKeys()
default_value_estimator = ValueEstimators.TD0
out_keys = [
"loss_actor",
"loss_qvalue",
"pred_value",
"state_action_value_actor",
"next_state_value",
"target_value",
]

def __init__(
self,
Expand All @@ -115,6 +220,7 @@ def __init__(
)

super().__init__()
self._in_keys = None
self._set_deprecated_ctor_keys(priority=priority_key)

self.delay_actor = delay_actor
Expand Down Expand Up @@ -178,8 +284,33 @@ def _forward_value_estimator_keys(self, **kwargs) -> None:
if self._value_estimator is not None:
self._value_estimator.set_keys(
value=self._tensor_keys.state_action_value,
reward=self.tensor_keys.reward,
done=self.tensor_keys.done,
)
self._set_in_keys()

def _set_in_keys(self):
keys = [
self.tensor_keys.action,
("next", self.tensor_keys.reward),
("next", self.tensor_keys.done),
*self.actor_network.in_keys,
*[("next", key) for key in self.actor_network.in_keys],
*self.qvalue_network.in_keys,
]
self._in_keys = list(set(keys))

@property
def in_keys(self):
if self._in_keys is None:
self._set_in_keys()
return self._in_keys

@in_keys.setter
def in_keys(self, values):
self._in_keys = values

@dispatch
def forward(self, tensordict: TensorDictBase) -> TensorDictBase:
obs_keys = self.actor_network.in_keys
tensordict_save = tensordict
Expand Down Expand Up @@ -333,5 +464,9 @@ def make_value_estimator(self, value_type: ValueEstimators = None, **hyperparams
else:
raise NotImplementedError(f"Unknown value type {value_type}")

tensor_keys = {"value": self.tensor_keys.state_action_value}
tensor_keys = {
"value": self.tensor_keys.state_action_value,
"reward": self.tensor_keys.reward,
"done": self.tensor_keys.done,
}
self._value_estimator.set_keys(**tensor_keys)

0 comments on commit ea6f872

Please sign in to comment.