Skip to content

Commit

Permalink
[Feature] Make ProbabilisticActor compatible with Composite distribut…
Browse files Browse the repository at this point in the history
…ions (#2220)
  • Loading branch information
vmoens authored Jun 11, 2024
1 parent 3787a9e commit 672b50e
Show file tree
Hide file tree
Showing 3 changed files with 170 additions and 7 deletions.
52 changes: 52 additions & 0 deletions examples/agents/composite_actor.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,52 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
#
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.


"""
This code exemplifies how a composite actor can be built.
The actor has two components: a categorical and a normal distributions.
We use a ProbabilisticActor and explicitly pass it the key-map that we want through a 'name_map'
argument.
"""

import torch
from tensordict import TensorDict
from tensordict.nn import CompositeDistribution, TensorDictModule
from torch import distributions as d, nn

from torchrl.modules import ProbabilisticActor


class Module(nn.Module):
def forward(self, x):
return x[..., :3], x[..., 3:6], x[..., 6:]


module = TensorDictModule(
Module(),
in_keys=["x"],
out_keys=[
("params", "normal", "loc"),
("params", "normal", "scale"),
("params", "categ", "logits"),
],
)
actor = ProbabilisticActor(
module,
in_keys=["params"],
distribution_class=CompositeDistribution,
distribution_kwargs={
"distribution_map": {"normal": d.Normal, "categ": d.Categorical},
"name_map": {"normal": ("action", "normal"), "categ": ("action", "categ")},
},
)
print(actor.out_keys)

data = TensorDict({"x": torch.rand(10)}, [])
module(data)
print(actor(data))
31 changes: 25 additions & 6 deletions test/test_actors.py
Original file line number Diff line number Diff line change
Expand Up @@ -800,7 +800,8 @@ def test_actorcritic(device):
) == len(policy_params)


def test_compound_actor():
@pytest.mark.parametrize("name_map", [True, False])
def test_compound_actor(name_map):
class Module(nn.Module):
def forward(self, x):
return x[..., :3], x[..., 3:6], x[..., 6:]
Expand All @@ -814,19 +815,37 @@ def forward(self, x):
("params", "categ", "logits"),
],
)
distribution_kwargs = {
"distribution_map": {"normal": dist.Normal, "categ": dist.Categorical}
}
if name_map:
distribution_kwargs.update(
{
"name_map": {
"normal": ("action", "normal"),
"categ": ("action", "categ"),
},
}
)
actor = ProbabilisticActor(
module,
in_keys=["params"],
distribution_class=CompositeDistribution,
distribution_kwargs={
"distribution_map": {"normal": dist.Normal, "categ": dist.Categorical}
},
distribution_kwargs=distribution_kwargs,
)
if not name_map:
assert actor.out_keys == module.out_keys + ["normal", "categ"]
else:
assert actor.out_keys == module.out_keys + [
("action", "normal"),
("action", "categ"),
]

data = TensorDict({"x": torch.rand(10)}, [])
actor(data)
assert set(data.keys(True, True)) == {
"categ",
"normal",
"categ" if not name_map else ("action", "categ"),
"normal" if not name_map else ("action", "normal"),
("params", "categ", "logits"),
("params", "normal", "loc"),
("params", "normal", "scale"),
Expand Down
94 changes: 93 additions & 1 deletion torchrl/modules/tensordict_module/actors.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@

from tensordict import TensorDictBase, unravel_key
from tensordict.nn import (
CompositeDistribution,
dispatch,
TensorDictModule,
TensorDictModuleBase,
Expand Down Expand Up @@ -174,6 +175,14 @@ class ProbabilisticActor(SafeProbabilisticTensorDictSequential):
A :class:`torch.distributions.Distribution` class to
be used for sampling.
Default is :class:`tensordict.nn.distributions.Delta`.
.. note:: if ``distribution_class`` is of type :class:`~tensordict.nn.distributions.CompositeDistribution`,
the keys will be inferred from the ``distribution_map`` / ``name_map`` keyword arguments of that
distribution. If this distribution is used with another constructor (e.g., partial or lambda function)
then the out_keys will need to be provided explicitly.
Note also that actions will __not__ be prefixed with an ``"action"`` key, see the example below
on how this can be achieved with a ``ProbabilisticActor``.
distribution_kwargs (dict, optional): keyword-only argument.
Keyword-argument pairs to be passed to the distribution.
return_log_prob (bool, optional): keyword-only argument.
Expand Down Expand Up @@ -276,6 +285,75 @@ class ProbabilisticActor(SafeProbabilisticTensorDictSequential):
device=None,
is_shared=False)
Using a probabilistic actor with a composite distribution can be achieved using the following
example code:
Examples:
>>> import torch
>>> from tensordict import TensorDict
>>> from tensordict.nn import CompositeDistribution
>>> from tensordict.nn import TensorDictModule
>>> from torch import distributions as d
>>> from torch import nn
>>>
>>> from torchrl.modules import ProbabilisticActor
>>>
>>>
>>> class Module(nn.Module):
... def forward(self, x):
... return x[..., :3], x[..., 3:6], x[..., 6:]
...
>>>
>>> module = TensorDictModule(Module(),
... in_keys=["x"],
... out_keys=[
... ("params", "normal", "loc"), ("params", "normal", "scale"), ("params", "categ", "logits")
... ])
>>> actor = ProbabilisticActor(module,
... in_keys=["params"],
... distribution_class=CompositeDistribution,
... distribution_kwargs={"distribution_map": {"normal": d.Normal, "categ": d.Categorical},
... "name_map": {"normal": ("action", "normal"),
... "categ": ("action", "categ")}}
... )
>>> print(actor.out_keys)
[('params', 'normal', 'loc'), ('params', 'normal', 'scale'), ('params', 'categ', 'logits'), ('action', 'normal'), ('action', 'categ')]
>>>
>>> data = TensorDict({"x": torch.rand(10)}, [])
>>> module(data)
>>> print(actor(data))
TensorDict(
fields={
action: TensorDict(
fields={
categ: Tensor(shape=torch.Size([]), device=cpu, dtype=torch.int64, is_shared=False),
normal: Tensor(shape=torch.Size([3]), device=cpu, dtype=torch.float32, is_shared=False)},
batch_size=torch.Size([]),
device=None,
is_shared=False),
params: TensorDict(
fields={
categ: TensorDict(
fields={
logits: Tensor(shape=torch.Size([4]), device=cpu, dtype=torch.float32, is_shared=False)},
batch_size=torch.Size([]),
device=None,
is_shared=False),
normal: TensorDict(
fields={
loc: Tensor(shape=torch.Size([3]), device=cpu, dtype=torch.float32, is_shared=False),
scale: Tensor(shape=torch.Size([3]), device=cpu, dtype=torch.float32, is_shared=False)},
batch_size=torch.Size([]),
device=None,
is_shared=False)},
batch_size=torch.Size([]),
device=None,
is_shared=False),
x: Tensor(shape=torch.Size([10]), device=cpu, dtype=torch.float32, is_shared=False)},
batch_size=torch.Size([]),
device=None,
is_shared=False)
"""

def __init__(
Expand All @@ -287,8 +365,22 @@ def __init__(
spec: Optional[TensorSpec] = None,
**kwargs,
):
distribution_class = kwargs.get("distribution_class")
if out_keys is None:
out_keys = ["action"]
if distribution_class is CompositeDistribution:
if "distribution_map" not in kwargs.get("distribution_kwargs", {}):
raise KeyError(
"'distribution_map' must be provided within "
"distribution_kwargs whenever the distribution is of type CompositeDistribution."
)
distribution_map = kwargs["distribution_kwargs"]["distribution_map"]
name_map = kwargs["distribution_kwargs"].get("name_map", None)
if name_map is not None:
out_keys = list(name_map.values())
else:
out_keys = list(distribution_map.keys())
else:
out_keys = ["action"]
if (
len(out_keys) == 1
and spec is not None
Expand Down

0 comments on commit 672b50e

Please sign in to comment.