Skip to content

Commit

Permalink
[Feature] Distpatch IQL loss module (#1230)
Browse files Browse the repository at this point in the history
  • Loading branch information
Blonck authored Jun 4, 2023
1 parent 3aab983 commit 340d47e
Show file tree
Hide file tree
Showing 2 changed files with 229 additions and 14 deletions.
106 changes: 93 additions & 13 deletions test/test_cost.py
Original file line number Diff line number Diff line change
Expand Up @@ -4256,13 +4256,20 @@ def test_dreamer_value_tensordict_keys(self, device):
class TestIQL(LossModuleTestBase):
seed = 0

def _create_mock_actor(self, batch=2, obs_dim=3, action_dim=4, device="cpu"):
def _create_mock_actor(
self,
batch=2,
obs_dim=3,
action_dim=4,
device="cpu",
observation_key="observation",
):
# Actor
action_spec = BoundedTensorSpec(
-torch.ones(action_dim), torch.ones(action_dim), (action_dim,)
)
net = NormalParamWrapper(nn.Linear(obs_dim, 2 * action_dim))
module = SafeModule(net, in_keys=["observation"], out_keys=["loc", "scale"])
module = SafeModule(net, in_keys=[observation_key], out_keys=["loc", "scale"])
actor = ProbabilisticActor(
module=module,
in_keys=["loc", "scale"],
Expand All @@ -4271,7 +4278,16 @@ def _create_mock_actor(self, batch=2, obs_dim=3, action_dim=4, device="cpu"):
)
return actor.to(device)

def _create_mock_qvalue(self, batch=2, obs_dim=3, action_dim=4, device="cpu"):
def _create_mock_qvalue(
self,
batch=2,
obs_dim=3,
action_dim=4,
device="cpu",
out_keys=None,
observation_key="observation",
action_key="action",
):
class ValueClass(nn.Module):
def __init__(self):
super().__init__()
Expand All @@ -4283,15 +4299,24 @@ def forward(self, obs, act):
module = ValueClass()
qvalue = ValueOperator(
module=module,
in_keys=["observation", "action"],
in_keys=[observation_key, action_key],
out_keys=out_keys,
)
return qvalue.to(device)

def _create_mock_value(
self, batch=2, obs_dim=3, action_dim=4, device="cpu", out_keys=None
self,
batch=2,
obs_dim=3,
action_dim=4,
device="cpu",
out_keys=None,
observation_key="observation",
):
module = nn.Linear(obs_dim, 1)
value = ValueOperator(module=module, in_keys=["observation"], out_keys=out_keys)
value = ValueOperator(
module=module, in_keys=[observation_key], out_keys=out_keys
)
return value.to(device)

def _create_mock_distributional_actor(
Expand All @@ -4300,7 +4325,16 @@ def _create_mock_distributional_actor(
raise NotImplementedError

def _create_mock_data_iql(
self, batch=16, obs_dim=3, action_dim=4, atoms=None, device="cpu"
self,
batch=16,
obs_dim=3,
action_dim=4,
atoms=None,
device="cpu",
observation_key="observation",
action_key="action",
done_key="done",
reward_key="reward",
):
# create a tensordict
obs = torch.randn(batch, obs_dim, device=device)
Expand All @@ -4314,13 +4348,13 @@ def _create_mock_data_iql(
td = TensorDict(
batch_size=(batch,),
source={
"observation": obs,
observation_key: obs,
"next": {
"observation": next_obs,
"done": done,
"reward": reward,
observation_key: next_obs,
done_key: done,
reward_key: reward,
},
"action": action,
action_key: action,
},
device=device,
)
Expand Down Expand Up @@ -4587,6 +4621,8 @@ def test_iql_tensordict_keys(self, td_est):
"action": "action",
"state_action_value": "state_action_value",
"value": "state_value",
"reward": "reward",
"done": "done",
}

self.tensordict_keys_test(
Expand All @@ -4603,9 +4639,53 @@ def test_iql_tensordict_keys(self, td_est):
loss_function="l2",
)

key_mapping = {"value": ("value", "value_test")}
key_mapping = {
"value": ("value", "value_test"),
"done": ("done", "done_test"),
"reward": ("reward", ("reward", "test")),
}
self.set_advantage_keys_through_loss_test(loss_fn, td_est, key_mapping)

@pytest.mark.parametrize("action_key", ["action", "action2"])
@pytest.mark.parametrize("observation_key", ["observation", "observation2"])
@pytest.mark.parametrize("reward_key", ["reward", "reward2"])
@pytest.mark.parametrize("done_key", ["done", "done2"])
def test_iql_notensordict(self, action_key, observation_key, reward_key, done_key):
torch.manual_seed(self.seed)
td = self._create_mock_data_iql(
action_key=action_key,
observation_key=observation_key,
reward_key=reward_key,
done_key=done_key,
)

actor = self._create_mock_actor(observation_key=observation_key)
qvalue = self._create_mock_qvalue(
observation_key=observation_key,
action_key=action_key,
out_keys=["state_action_value"],
)
value = self._create_mock_value(observation_key=observation_key)

loss = IQLLoss(actor_network=actor, qvalue_network=qvalue, value_network=value)
loss.set_keys(action=action_key, reward=reward_key, done=done_key)

kwargs = {
action_key: td.get(action_key),
observation_key: td.get(observation_key),
f"next_{reward_key}": td.get(("next", reward_key)),
f"next_{done_key}": td.get(("next", done_key)),
}
td = TensorDict(kwargs, td.batch_size).unflatten_keys("_")

loss_val = loss(**kwargs)
loss_val_td = loss(td)
assert len(loss_val) == 4
torch.testing.assert_close(loss_val_td.get("loss_actor"), loss_val[0])
torch.testing.assert_close(loss_val_td.get("loss_qvalue"), loss_val[1])
torch.testing.assert_close(loss_val_td.get("loss_value"), loss_val[2])
torch.testing.assert_close(loss_val_td.get("entropy"), loss_val[3])


def test_hold_out():
net = torch.nn.Linear(3, 4)
Expand Down
137 changes: 136 additions & 1 deletion torchrl/objectives/iql.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@
from typing import Optional, Tuple

import torch
from tensordict.nn import TensorDictModule
from tensordict.nn import dispatch, TensorDictModule
from tensordict.tensordict import TensorDict, TensorDictBase
from tensordict.utils import NestedKey
from torch import Tensor
Expand Down Expand Up @@ -57,6 +57,109 @@ class IQLLoss(LossModule):
priority_key (str, optional): [Deprecated, use .set_keys(priority_key=priority_key) instead]
tensordict key where to write the priority (for prioritized replay
buffer usage). Default is `"td_error"`.
Examples:
>>> import torch
>>> from torch import nn
>>> from torchrl.data import BoundedTensorSpec
>>> from torchrl.modules.distributions.continuous import NormalParamWrapper, TanhNormal
>>> from torchrl.modules.tensordict_module.actors import ProbabilisticActor, ValueOperator
>>> from torchrl.modules.tensordict_module.common import SafeModule
>>> from torchrl.objectives.iql import IQLLoss
>>> from tensordict.tensordict import TensorDict
>>> n_act, n_obs = 4, 3
>>> spec = BoundedTensorSpec(-torch.ones(n_act), torch.ones(n_act), (n_act,))
>>> net = NormalParamWrapper(nn.Linear(n_obs, 2 * n_act))
>>> module = SafeModule(net, in_keys=["observation"], out_keys=["loc", "scale"])
>>> actor = ProbabilisticActor(
... module=module,
... in_keys=["loc", "scale"],
... spec=spec,
... distribution_class=TanhNormal)
>>> class ValueClass(nn.Module):
... def __init__(self):
... super().__init__()
... self.linear = nn.Linear(n_obs + n_act, 1)
... def forward(self, obs, act):
... return self.linear(torch.cat([obs, act], -1))
>>> module = ValueClass()
>>> qvalue = ValueOperator(
... module=module,
... in_keys=['observation', 'action'])
>>> module = nn.Linear(n_obs, 1)
>>> value = ValueOperator(
... module=module,
... in_keys=["observation"])
>>> loss = IQLLoss(actor, qvalue, value)
>>> batch = [2, ]
>>> action = spec.rand(batch)
>>> data = TensorDict({
... "observation": torch.randn(*batch, n_obs),
... "action": action,
... ("next", "done"): torch.zeros(*batch, 1, dtype=torch.bool),
... ("next", "reward"): torch.randn(*batch, 1),
... }, batch)
>>> loss(data)
TensorDict(
fields={
entropy: Tensor(shape=torch.Size([]), device=cpu, dtype=torch.float32, is_shared=False),
loss_actor: Tensor(shape=torch.Size([]), device=cpu, dtype=torch.float32, is_shared=False),
loss_qvalue: Tensor(shape=torch.Size([]), device=cpu, dtype=torch.float32, is_shared=False),
loss_value: Tensor(shape=torch.Size([]), device=cpu, dtype=torch.float32, is_shared=False)},
batch_size=torch.Size([]),
device=None,
is_shared=False)
This class is compatible with non-tensordict based modules too and can be
used without recurring to any tensordict-related primitive. In this case,
the expected keyword arguments are:
``["action", "next_reward", "next_done"]`` + in_keys of the actor, value, and qvalue network
The return value is a tuple of tensors in the following order:
``["loss_actor", "loss_qvalue", "loss_value", "entropy"]``.
Examples:
>>> import torch
>>> from torch import nn
>>> from torchrl.data import BoundedTensorSpec
>>> from torchrl.modules.distributions.continuous import NormalParamWrapper, TanhNormal
>>> from torchrl.modules.tensordict_module.actors import ProbabilisticActor, ValueOperator
>>> from torchrl.modules.tensordict_module.common import SafeModule
>>> from torchrl.objectives.iql import IQLLoss
>>> _ = torch.manual_seed(42)
>>> n_act, n_obs = 4, 3
>>> spec = BoundedTensorSpec(-torch.ones(n_act), torch.ones(n_act), (n_act,))
>>> net = NormalParamWrapper(nn.Linear(n_obs, 2 * n_act))
>>> module = SafeModule(net, in_keys=["observation"], out_keys=["loc", "scale"])
>>> actor = ProbabilisticActor(
... module=module,
... in_keys=["loc", "scale"],
... spec=spec,
... distribution_class=TanhNormal)
>>> class ValueClass(nn.Module):
... def __init__(self):
... super().__init__()
... self.linear = nn.Linear(n_obs + n_act, 1)
... def forward(self, obs, act):
... return self.linear(torch.cat([obs, act], -1))
>>> module = ValueClass()
>>> qvalue = ValueOperator(
... module=module,
... in_keys=['observation', 'action'])
>>> module = nn.Linear(n_obs, 1)
>>> value = ValueOperator(
... module=module,
... in_keys=["observation"])
>>> loss = IQLLoss(actor, qvalue, value)
>>> batch = [2, ]
>>> action = spec.rand(batch)
>>> loss_val = loss(
... observation=torch.randn(*batch, n_obs),
... action=action,
... next_done=torch.zeros(*batch, 1, dtype=torch.bool),
... next_reward=torch.randn(*batch, 1))
>>> loss_val
(tensor(1.4535, grad_fn=<MeanBackward0>), tensor(0.8389, grad_fn=<MeanBackward0>), tensor(0.3406, grad_fn=<MeanBackward0>), tensor(3.3441))
"""

@dataclass
Expand All @@ -78,13 +181,20 @@ class _AcceptedKeys:
state_action_value (NestedKey): The input tensordict key where the
state action value is expected. Will be used for the underlying
value estimator as value key. Defaults to ``"state_action_value"``.
reward (NestedKey): The input tensordict key where the reward is expected.
Will be used for the underlying value estimator. Defaults to ``"reward"``.
done (NestedKey): The key in the input TensorDict that indicates
whether a trajectory is done. Will be used for the underlying value estimator.
Defaults to ``"done"``.
"""

value: NestedKey = "state_value"
action: NestedKey = "action"
log_prob: NestedKey = "_log_prob"
priority: NestedKey = "td_error"
state_action_value: NestedKey = "state_action_value"
reward: NestedKey = "reward"
done: NestedKey = "done"

default_keys = _AcceptedKeys()
default_value_estimator = ValueEstimators.TD0
Expand Down Expand Up @@ -153,6 +263,19 @@ def device(self) -> torch.device:
"At least one of the networks of SACLoss must have trainable " "parameters."
)

@property
def in_keys(self):
keys = [
self.tensor_keys.action,
("next", self.tensor_keys.reward),
("next", self.tensor_keys.done),
]
keys.extend(self.actor_network.in_keys)
keys.extend(self.qvalue_network.in_keys)
keys.extend(self.value_network.in_keys)

return list(set(keys))

@staticmethod
def loss_value_diff(diff, expectile=0.8):
"""Loss function for iql expectile value difference."""
Expand All @@ -163,8 +286,18 @@ def _forward_value_estimator_keys(self, **kwargs) -> None:
if self._value_estimator is not None:
self._value_estimator.set_keys(
value=self._tensor_keys.value,
reward=self.tensor_keys.reward,
done=self.tensor_keys.done,
)

@dispatch(
dest=[
"loss_actor",
"loss_qvalue",
"loss_value",
"entropy",
]
)
def forward(self, tensordict: TensorDictBase) -> TensorDictBase:
shape = None
if tensordict.ndimension() > 1:
Expand Down Expand Up @@ -317,5 +450,7 @@ def make_value_estimator(self, value_type: ValueEstimators = None, **hyperparams
tensor_keys = {
"value_target": "value_target",
"value": self.tensor_keys.value,
"reward": self.tensor_keys.reward,
"done": self.tensor_keys.done,
}
self._value_estimator.set_keys(**tensor_keys)

0 comments on commit 340d47e

Please sign in to comment.