Skip to content

Commit

Permalink
[Feature] Dispatch REDQ loss module (pytorch#1251)
Browse files Browse the repository at this point in the history
  • Loading branch information
Blonck authored Jun 9, 2023
1 parent a25a878 commit 0d67d39
Show file tree
Hide file tree
Showing 2 changed files with 257 additions and 14 deletions.
116 changes: 104 additions & 12 deletions test/test_cost.py
Original file line number Diff line number Diff line change
Expand Up @@ -2383,13 +2383,20 @@ def test_discrete_sac_notensordict(
class TestREDQ(LossModuleTestBase):
seed = 0

def _create_mock_actor(self, batch=2, obs_dim=3, action_dim=4, device="cpu"):
def _create_mock_actor(
self,
batch=2,
obs_dim=3,
action_dim=4,
device="cpu",
observation_key="observation",
):
# Actor
action_spec = BoundedTensorSpec(
-torch.ones(action_dim), torch.ones(action_dim), (action_dim,)
)
net = NormalParamWrapper(nn.Linear(obs_dim, 2 * action_dim))
module = SafeModule(net, in_keys=["observation"], out_keys=["loc", "scale"])
module = SafeModule(net, in_keys=[observation_key], out_keys=["loc", "scale"])
actor = ProbabilisticActor(
module=module,
in_keys=["loc", "scale"],
Expand All @@ -2399,7 +2406,16 @@ def _create_mock_actor(self, batch=2, obs_dim=3, action_dim=4, device="cpu"):
)
return actor.to(device)

def _create_mock_qvalue(self, batch=2, obs_dim=3, action_dim=4, device="cpu"):
def _create_mock_qvalue(
self,
batch=2,
obs_dim=3,
action_dim=4,
device="cpu",
observation_key="observation",
action_key="action",
out_keys=None,
):
class ValueClass(nn.Module):
def __init__(self):
super().__init__()
Expand All @@ -2410,8 +2426,7 @@ def forward(self, obs, act):

module = ValueClass()
qvalue = ValueOperator(
module=module,
in_keys=["observation", "action"],
module=module, in_keys=[observation_key, action_key], out_keys=out_keys
)
return qvalue.to(device)

Expand Down Expand Up @@ -2454,7 +2469,16 @@ def forward(self, hidden, act):
return model.to(device)

def _create_mock_data_redq(
self, batch=16, obs_dim=3, action_dim=4, atoms=None, device="cpu"
self,
batch=16,
obs_dim=3,
action_dim=4,
atoms=None,
device="cpu",
observation_key="observation",
action_key="action",
reward_key="reward",
done_key="done",
):
# create a tensordict
obs = torch.randn(batch, obs_dim, device=device)
Expand All @@ -2468,13 +2492,13 @@ def _create_mock_data_redq(
td = TensorDict(
batch_size=(batch,),
source={
"observation": obs,
observation_key: obs,
"next": {
"observation": next_obs,
"done": done,
"reward": reward,
observation_key: next_obs,
done_key: done,
reward_key: reward,
},
"action": action,
action_key: action,
},
device=device,
)
Expand Down Expand Up @@ -2869,6 +2893,8 @@ def test_redq_tensordict_keys(self, td_est):
"value": "state_value",
"sample_log_prob": "sample_log_prob",
"state_action_value": "state_action_value",
"reward": "reward",
"done": "done",
}
self.tensordict_keys_test(
loss_fn,
Expand All @@ -2883,9 +2909,75 @@ def test_redq_tensordict_keys(self, td_est):
loss_function="l2",
)

key_mapping = {"value": ("value", "state_value_test")}
key_mapping = {
"value": ("value", "state_value_test"),
"reward": ("reward", "reward_test"),
"done": ("done", ("done", "test")),
}
self.set_advantage_keys_through_loss_test(loss_fn, td_est, key_mapping)

@pytest.mark.parametrize("action_key", ["action", "action2"])
@pytest.mark.parametrize("observation_key", ["observation", "observation2"])
@pytest.mark.parametrize("reward_key", ["reward", "reward2"])
@pytest.mark.parametrize("done_key", ["done", "done2"])
def test_redq_notensordict(self, action_key, observation_key, reward_key, done_key):
torch.manual_seed(self.seed)
td = self._create_mock_data_redq(
action_key=action_key,
observation_key=observation_key,
reward_key=reward_key,
done_key=done_key,
)

actor = self._create_mock_actor(
observation_key=observation_key,
)
qvalue = self._create_mock_qvalue(
observation_key=observation_key,
action_key=action_key,
out_keys=["state_action_value"],
)

loss = REDQLoss(
actor_network=actor,
qvalue_network=qvalue,
)
loss.set_keys(action=action_key, reward=reward_key, done=done_key)

kwargs = {
action_key: td.get(action_key),
observation_key: td.get(observation_key),
f"next_{reward_key}": td.get(("next", reward_key)),
f"next_{done_key}": td.get(("next", done_key)),
f"next_{observation_key}": td.get(("next", observation_key)),
}
td = TensorDict(kwargs, td.batch_size).unflatten_keys("_")

torch.manual_seed(self.seed)
loss_val = loss(**kwargs)
torch.manual_seed(self.seed)
loss_val_td = loss(td)

torch.testing.assert_close(loss_val_td.get("loss_actor"), loss_val[0])
torch.testing.assert_close(loss_val_td.get("loss_qvalue"), loss_val[1])
torch.testing.assert_close(loss_val_td.get("loss_alpha"), loss_val[2])
torch.testing.assert_close(loss_val_td.get("alpha"), loss_val[3])
torch.testing.assert_close(loss_val_td.get("entropy"), loss_val[4])
torch.testing.assert_close(
loss_val_td.get("state_action_value_actor"), loss_val[5]
)
torch.testing.assert_close(
loss_val_td.get("action_log_prob_actor"), loss_val[6]
)
torch.testing.assert_close(loss_val_td.get("next.state_value"), loss_val[7])
torch.testing.assert_close(loss_val_td.get("target_value"), loss_val[8])
# test select
torch.manual_seed(self.seed)
loss.select_out_keys("loss_actor", "loss_alpha")
loss_actor, loss_alpha = loss(**kwargs)
assert loss_actor == loss_val_td["loss_actor"]
assert loss_alpha == loss_val_td["loss_alpha"]


class TestCQL(LossModuleTestBase):
seed = 0
Expand Down
155 changes: 153 additions & 2 deletions torchrl/objectives/redq.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@
import numpy as np
import torch

from tensordict.nn import TensorDictModule, TensorDictSequential
from tensordict.nn import dispatch, TensorDictModule, TensorDictSequential
from tensordict.tensordict import TensorDict, TensorDictBase
from tensordict.utils import NestedKey
from torch import Tensor
Expand Down Expand Up @@ -78,6 +78,108 @@ class REDQLoss(LossModule):
for prioritized replay buffers. Default is
``"td_error"``.
Examples:
>>> import torch
>>> from torch import nn
>>> from torchrl.data import BoundedTensorSpec
>>> from torchrl.modules.distributions.continuous import NormalParamWrapper, TanhNormal
>>> from torchrl.modules.tensordict_module.actors import ProbabilisticActor, ValueOperator
>>> from torchrl.modules.tensordict_module.common import SafeModule
>>> from torchrl.objectives.redq import REDQLoss
>>> from tensordict.tensordict import TensorDict
>>> n_act, n_obs = 4, 3
>>> spec = BoundedTensorSpec(-torch.ones(n_act), torch.ones(n_act), (n_act,))
>>> net = NormalParamWrapper(nn.Linear(n_obs, 2 * n_act))
>>> module = SafeModule(net, in_keys=["observation"], out_keys=["loc", "scale"])
>>> actor = ProbabilisticActor(
... module=module,
... in_keys=["loc", "scale"],
... spec=spec,
... distribution_class=TanhNormal)
>>> class ValueClass(nn.Module):
... def __init__(self):
... super().__init__()
... self.linear = nn.Linear(n_obs + n_act, 1)
... def forward(self, obs, act):
... return self.linear(torch.cat([obs, act], -1))
>>> module = ValueClass()
>>> qvalue = ValueOperator(
... module=module,
... in_keys=['observation', 'action'])
>>> loss = REDQLoss(actor, qvalue)
>>> batch = [2, ]
>>> action = spec.rand(batch)
>>> data = TensorDict({
... "observation": torch.randn(*batch, n_obs),
... "action": action,
... ("next", "done"): torch.zeros(*batch, 1, dtype=torch.bool),
... ("next", "reward"): torch.randn(*batch, 1),
... ("next", "observation"): torch.randn(*batch, n_obs),
... }, batch)
>>> loss(data)
TensorDict(
fields={
action_log_prob_actor: Tensor(shape=torch.Size([]), device=cpu, dtype=torch.float32, is_shared=False),
alpha: Tensor(shape=torch.Size([]), device=cpu, dtype=torch.float32, is_shared=False),
entropy: Tensor(shape=torch.Size([]), device=cpu, dtype=torch.float32, is_shared=False),
loss_actor: Tensor(shape=torch.Size([]), device=cpu, dtype=torch.float32, is_shared=False),
loss_alpha: Tensor(shape=torch.Size([]), device=cpu, dtype=torch.float32, is_shared=False),
loss_qvalue: Tensor(shape=torch.Size([]), device=cpu, dtype=torch.float32, is_shared=False),
next.state_value: Tensor(shape=torch.Size([]), device=cpu, dtype=torch.float32, is_shared=False),
state_action_value_actor: Tensor(shape=torch.Size([]), device=cpu, dtype=torch.float32, is_shared=False),
target_value: Tensor(shape=torch.Size([]), device=cpu, dtype=torch.float32, is_shared=False)},
batch_size=torch.Size([]),
device=None,
is_shared=False)
This class is compatible with non-tensordict based modules too and can be
used without recurring to any tensordict-related primitive. In this case,
the expected keyword arguments are:
``["action", "next_reward", "next_done"]`` + in_keys of the actor and qvalue network
The return value is a tuple of tensors in the following order:
``["loss_actor", "loss_qvalue", "loss_alpha", "alpha", "entropy",
"state_action_value_actor", "action_log_prob_actor", "next.state_value", "target_value",]``.
Examples:
>>> import torch
>>> from torch import nn
>>> from torchrl.data import BoundedTensorSpec
>>> from torchrl.modules.distributions.continuous import NormalParamWrapper, TanhNormal
>>> from torchrl.modules.tensordict_module.actors import ProbabilisticActor, ValueOperator
>>> from torchrl.modules.tensordict_module.common import SafeModule
>>> from torchrl.objectives.redq import REDQLoss
>>> n_act, n_obs = 4, 3
>>> spec = BoundedTensorSpec(-torch.ones(n_act), torch.ones(n_act), (n_act,))
>>> net = NormalParamWrapper(nn.Linear(n_obs, 2 * n_act))
>>> module = SafeModule(net, in_keys=["observation"], out_keys=["loc", "scale"])
>>> actor = ProbabilisticActor(
... module=module,
... in_keys=["loc", "scale"],
... spec=spec,
... distribution_class=TanhNormal)
>>> class ValueClass(nn.Module):
... def __init__(self):
... super().__init__()
... self.linear = nn.Linear(n_obs + n_act, 1)
... def forward(self, obs, act):
... return self.linear(torch.cat([obs, act], -1))
>>> module = ValueClass()
>>> qvalue = ValueOperator(
... module=module,
... in_keys=['observation', 'action'])
>>> loss = REDQLoss(actor, qvalue)
>>> batch = [2, ]
>>> action = spec.rand(batch)
>>> # filter output keys to "loss_actor", and "loss_qvalue"
>>> _ = loss.select_out_keys("loss_actor", "loss_qvalue")
>>> loss_actor, loss_qvalue = loss(
... observation=torch.randn(*batch, n_obs),
... action=action,
... next_done=torch.zeros(*batch, 1, dtype=torch.bool),
... next_reward=torch.randn(*batch, 1),
... next_observation=torch.randn(*batch, n_obs))
>>> loss_actor.backward()
"""

@dataclass
Expand All @@ -97,17 +199,35 @@ class _AcceptedKeys:
priority is written to. Defaults to ``"td_error"``.
state_action_value (NestedKey): The input tensordict key where the
state action value is expected. Defaults to ``"state_action_value"``.
reward (NestedKey): The input tensordict key where the reward is expected.
Will be used for the underlying value estimator. Defaults to ``"reward"``.
done (NestedKey): The key in the input TensorDict that indicates
whether a trajectory is done. Will be used for the underlying value estimator.
Defaults to ``"done"``.
"""

action: NestedKey = "action"
value: NestedKey = "state_value"
sample_log_prob: NestedKey = "sample_log_prob"
priority: NestedKey = "td_error"
state_action_value: NestedKey = "state_action_value"
reward: NestedKey = "reward"
done: NestedKey = "done"

default_keys = _AcceptedKeys()
delay_actor: bool = False
default_value_estimator = ValueEstimators.TD0
out_keys = [
"loss_actor",
"loss_qvalue",
"loss_alpha",
"alpha",
"entropy",
"state_action_value_actor",
"action_log_prob_actor",
"next.state_value",
"target_value",
]

def __init__(
self,
Expand All @@ -131,6 +251,7 @@ def __init__(
raise ImportError("Failed to import functorch.") from FUNCTORCH_ERR

super().__init__()
self._in_keys = None
self._set_deprecated_ctor_keys(priority_key=priority_key)

self.convert_to_functional(
Expand Down Expand Up @@ -200,7 +321,10 @@ def _forward_value_estimator_keys(self, **kwargs) -> None:
if self._value_estimator is not None:
self._value_estimator.set_keys(
value=self._tensor_keys.value,
reward=self.tensor_keys.reward,
done=self.tensor_keys.done,
)
self._set_in_keys()

@property
def alpha(self):
Expand All @@ -209,6 +333,29 @@ def alpha(self):
alpha = self.log_alpha.exp()
return alpha

def _set_in_keys(self):
keys = [
self.tensor_keys.action,
self.tensor_keys.sample_log_prob,
("next", self.tensor_keys.reward),
("next", self.tensor_keys.done),
*self.actor_network.in_keys,
*[("next", key) for key in self.actor_network.in_keys],
*self.qvalue_network.in_keys,
]
self._in_keys = list(set(keys))

@property
def in_keys(self):
if self._in_keys is None:
self._set_in_keys()
return self._in_keys

@in_keys.setter
def in_keys(self, values):
self._in_keys = values

@dispatch
def forward(self, tensordict: TensorDictBase) -> TensorDictBase:
obs_keys = self.actor_network.in_keys
tensordict_select = tensordict.clone(False).select(
Expand Down Expand Up @@ -395,5 +542,9 @@ def make_value_estimator(self, value_type: ValueEstimators = None, **hyperparams
else:
raise NotImplementedError(f"Unknown value type {value_type}")

tensor_keys = {"value": self.tensor_keys.value}
tensor_keys = {
"value": self.tensor_keys.value,
"reward": self.tensor_keys.reward,
"done": self.tensor_keys.done,
}
self._value_estimator.set_keys(**tensor_keys)

0 comments on commit 0d67d39

Please sign in to comment.