Skip to content

Commit

Permalink
[Feature] Dispatch reinforce loss module (#1252)
Browse files Browse the repository at this point in the history
  • Loading branch information
Blonck authored Jun 12, 2023
1 parent 79bb70c commit acd65f6
Show file tree
Hide file tree
Showing 2 changed files with 172 additions and 4 deletions.
56 changes: 56 additions & 0 deletions test/test_cost.py
Original file line number Diff line number Diff line change
Expand Up @@ -4374,6 +4374,8 @@ def test_a2c_notensordict(self, action_key, observation_key, reward_key, done_ke


class TestReinforce(LossModuleTestBase):
seed = 0

@pytest.mark.parametrize("delay_value", [True, False])
@pytest.mark.parametrize("gradient_mode", [True, False])
@pytest.mark.parametrize("advantage", ["gae", "td", "td_lambda", None])
Expand Down Expand Up @@ -4501,6 +4503,8 @@ def test_reinforce_tensordict_keys(self, td_est):
"value_target": "value_target",
"value": "state_value",
"sample_log_prob": "sample_log_prob",
"reward": "reward",
"done": "done",
}

self.tensordict_keys_test(
Expand All @@ -4522,9 +4526,61 @@ def test_reinforce_tensordict_keys(self, td_est):
"advantage": ("advantage", "advantage_test"),
"value_target": ("value_target", "value_target_test"),
"value": ("value", "state_value_test"),
"reward": ("reward", "reward_test"),
"done": ("done", ("done", "test")),
}
self.set_advantage_keys_through_loss_test(loss_fn, td_est, key_mapping)

@pytest.mark.parametrize("action_key", ["action", "action2"])
@pytest.mark.parametrize("observation_key", ["observation", "observation2"])
@pytest.mark.parametrize("reward_key", ["reward", "reward2"])
@pytest.mark.parametrize("done_key", ["done", "done2"])
def test_reinforce_notensordict(
self, action_key, observation_key, reward_key, done_key
):
torch.manual_seed(self.seed)
n_obs = 3
n_act = 5
batch = 4
value_net = ValueOperator(nn.Linear(n_obs, 1), in_keys=[observation_key])
net = NormalParamWrapper(nn.Linear(n_obs, 2 * n_act))
module = SafeModule(net, in_keys=[observation_key], out_keys=["loc", "scale"])
actor_net = ProbabilisticActor(
module,
distribution_class=TanhNormal,
return_log_prob=True,
in_keys=["loc", "scale"],
spec=UnboundedContinuousTensorSpec(n_act),
)
loss = ReinforceLoss(actor=actor_net, critic=value_net)
loss.set_keys(reward=reward_key, done=done_key, action=action_key)

observation = torch.randn(batch, n_obs)
action = torch.randn(batch, n_act)
next_observation = torch.randn(batch, n_obs)
next_reward = torch.randn(batch, 1)
next_observation = torch.randn(batch, n_obs)
next_done = torch.zeros(batch, 1, dtype=torch.bool)

kwargs = {
action_key: action,
observation_key: observation,
f"next_{reward_key}": next_reward,
f"next_{done_key}": next_done,
f"next_{observation_key}": next_observation,
}
td = TensorDict(kwargs, [batch]).unflatten_keys("_")

loss_val = loss(**kwargs)
loss_val_td = loss(td)
torch.testing.assert_close(loss_val_td.get("loss_actor"), loss_val[0])
torch.testing.assert_close(loss_val_td.get("loss_value"), loss_val[1])
# test select
torch.manual_seed(self.seed)
loss.select_out_keys("loss_actor")
loss_actor = loss(**kwargs)
assert loss_actor == loss_val_td["loss_actor"]


@pytest.mark.parametrize("device", get_default_devices())
class TestDreamer(LossModuleTestBase):
Expand Down
120 changes: 116 additions & 4 deletions torchrl/objectives/reinforce.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@

import torch

from tensordict.nn import ProbabilisticTensorDictSequential, TensorDictModule
from tensordict.nn import dispatch, ProbabilisticTensorDictSequential, TensorDictModule
from tensordict.tensordict import TensorDict, TensorDictBase
from tensordict.utils import NestedKey
from torchrl.objectives.common import LossModule
Expand Down Expand Up @@ -67,6 +67,79 @@ class ReinforceLoss(LossModule):
>>> data = next(datacollector)
>>> losses = reinforce_loss(data)
Examples:
>>> import torch
>>> from torch import nn
>>> from torchrl.data.tensor_specs import UnboundedContinuousTensorSpec
>>> from torchrl.modules.distributions.continuous import NormalParamWrapper, TanhNormal
>>> from torchrl.modules.tensordict_module.actors import ProbabilisticActor, ValueOperator
>>> from torchrl.modules.tensordict_module.common import SafeModule
>>> from torchrl.objectives.reinforce import ReinforceLoss
>>> from tensordict.tensordict import TensorDict
>>> n_obs, n_act = 3, 5
>>> value_net = ValueOperator(nn.Linear(n_obs, 1), in_keys=["observation"])
>>> net = NormalParamWrapper(nn.Linear(n_obs, 2 * n_act))
>>> module = SafeModule(net, in_keys=["observation"], out_keys=["loc", "scale"])
>>> actor_net = ProbabilisticActor(
... module,
... distribution_class=TanhNormal,
... return_log_prob=True,
... in_keys=["loc", "scale"],
... spec=UnboundedContinuousTensorSpec(n_act),)
>>> loss = ReinforceLoss(actor_net, value_net)
>>> batch = 2
>>> data = TensorDict({
... "observation": torch.randn(batch, n_obs),
... "next": {
... "observation": torch.randn(batch, n_obs),
... "reward": torch.randn(batch, 1),
... "done": torch.zeros(batch, 1, dtype=torch.bool),
... },
... "action": torch.randn(batch, n_act),
... }, [batch])
>>> loss(data)
TensorDict(
fields={
loss_actor: Tensor(shape=torch.Size([]), device=cpu, dtype=torch.float32, is_shared=False),
loss_value: Tensor(shape=torch.Size([]), device=cpu, dtype=torch.float32, is_shared=False)},
batch_size=torch.Size([]),
device=None,
is_shared=False)
This class is compatible with non-tensordict based modules too and can be
used without recurring to any tensordict-related primitive. In this case,
the expected keyword arguments are:
``["action", "next_reward", "next_done"]`` + in_keys of the actor and critic network
The return value is a tuple of tensors in the following order: ``["loss_actor", "loss_value"]``.
Examples:
>>> import torch
>>> from torch import nn
>>> from torchrl.data.tensor_specs import UnboundedContinuousTensorSpec
>>> from torchrl.modules.distributions.continuous import NormalParamWrapper, TanhNormal
>>> from torchrl.modules.tensordict_module.actors import ProbabilisticActor, ValueOperator
>>> from torchrl.modules.tensordict_module.common import SafeModule
>>> from torchrl.objectives.reinforce import ReinforceLoss
>>> n_obs, n_act = 3, 5
>>> value_net = ValueOperator(nn.Linear(n_obs, 1), in_keys=["observation"])
>>> net = NormalParamWrapper(nn.Linear(n_obs, 2 * n_act))
>>> module = SafeModule(net, in_keys=["observation"], out_keys=["loc", "scale"])
>>> actor_net = ProbabilisticActor(
... module,
... distribution_class=TanhNormal,
... return_log_prob=True,
... in_keys=["loc", "scale"],
... spec=UnboundedContinuousTensorSpec(n_act),)
>>> loss = ReinforceLoss(actor_net, value_net)
>>> batch = 2
>>> loss_actor, loss_value = loss(
... observation=torch.randn(batch, n_obs),
... next_observation=torch.randn(batch, n_obs),
... next_reward=torch.randn(batch, 1),
... next_done=torch.zeros(batch, 1, dtype=torch.bool),
... action=torch.randn(batch, n_act),)
>>> loss_actor.backward()
"""

@dataclass
Expand All @@ -85,15 +158,26 @@ class _AcceptedKeys:
Will be used for the underlying value estimator. Defaults to ``"state_value"``.
sample_log_prob (NestedKey): The input tensordict key where the sample log probability is expected.
Defaults to ``"sample_log_prob"``.
action (NestedKey): The input tensordict key where the action is expected.
Defaults to ``"action"``.
reward (NestedKey): The input tensordict key where the reward is expected.
Will be used for the underlying value estimator. Defaults to ``"reward"``.
done (NestedKey): The key in the input TensorDict that indicates
whether a trajectory is done. Will be used for the underlying value estimator.
Defaults to ``"done"``.
"""

advantage: NestedKey = "advantage"
value_target: NestedKey = "value_target"
value: NestedKey = "state_value"
sample_log_prob: NestedKey = "sample_log_prob"
action: NestedKey = "action"
reward: NestedKey = "reward"
done: NestedKey = "done"

default_keys = _AcceptedKeys()
default_value_estimator = ValueEstimators.GAE
out_keys = ["loss_actor", "loss_value"]

@classmethod
def __new__(cls, *args, **kwargs):
Expand All @@ -112,6 +196,7 @@ def __init__(
value_target_key: str = None,
) -> None:
super().__init__()
self.in_keys = None
self._set_deprecated_ctor_keys(
advantage=advantage_key, value_target=value_target_key
)
Expand Down Expand Up @@ -141,11 +226,36 @@ def __init__(
def _forward_value_estimator_keys(self, **kwargs) -> None:
if self._value_estimator is not None:
self._value_estimator.set_keys(
advantage=self._tensor_keys.advantage,
value_target=self._tensor_keys.value_target,
value=self._tensor_keys.value,
advantage=self.tensor_keys.advantage,
value_target=self.tensor_keys.value_target,
value=self.tensor_keys.value,
reward=self.tensor_keys.reward,
done=self.tensor_keys.done,
)
self._set_in_keys()

def _set_in_keys(self):
keys = [
self.tensor_keys.action,
("next", self.tensor_keys.reward),
("next", self.tensor_keys.done),
*self.actor_network.in_keys,
*[("next", key) for key in self.actor_network.in_keys],
*self.critic.in_keys,
]
self._in_keys = list(set(keys))

@property
def in_keys(self):
if self._in_keys is None:
self._set_in_keys()
return self._in_keys

@in_keys.setter
def in_keys(self, values):
self._in_keys = values

@dispatch
def forward(self, tensordict: TensorDictBase) -> TensorDictBase:
advantage = tensordict.get(self.tensor_keys.advantage, None)
if advantage is None:
Expand Down Expand Up @@ -217,5 +327,7 @@ def make_value_estimator(self, value_type: ValueEstimators = None, **hyperparams
"advantage": self.tensor_keys.advantage,
"value": self.tensor_keys.value,
"value_target": self.tensor_keys.value_target,
"reward": self.tensor_keys.reward,
"done": self.tensor_keys.done,
}
self._value_estimator.set_keys(**tensor_keys)

1 comment on commit acd65f6

@github-actions
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Performance Alert ⚠️

Possible performance regression was detected for benchmark 'CPU Benchmark Results'.
Benchmark result of this commit is worse than the previous benchmark result exceeding threshold 2.

Benchmark suite Current: acd65f6 Previous: 99afe8b Ratio
benchmarks/test_replaybuffer_benchmark.py::test_iterate_rb[TensorDictPrioritizedReplayBuffer-LazyMemmapStorage-None-10000] 13.533600302550024 iter/sec (stddev: 0.09182880769651193) 31.08329658928429 iter/sec (stddev: 0.0020068433890160302) 2.30

This comment was automatically generated by workflow using github-action-benchmark.

CC: @vmoens

Please sign in to comment.