Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Feature] Make ProbabilisticActor compatible with Composite distributions #2220

Merged
merged 3 commits into from
Jun 11, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
52 changes: 52 additions & 0 deletions examples/agents/composite_actor.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,52 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
#
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.


"""
This code exemplifies how a composite actor can be built.

The actor has two components: a categorical and a normal distributions.

We use a ProbabilisticActor and explicitly pass it the key-map that we want through a 'name_map'
argument.

"""

import torch
from tensordict import TensorDict
from tensordict.nn import CompositeDistribution, TensorDictModule
from torch import distributions as d, nn

from torchrl.modules import ProbabilisticActor


class Module(nn.Module):
def forward(self, x):
return x[..., :3], x[..., 3:6], x[..., 6:]


module = TensorDictModule(
Module(),
in_keys=["x"],
out_keys=[
("params", "normal", "loc"),
("params", "normal", "scale"),
("params", "categ", "logits"),
],
)
actor = ProbabilisticActor(
module,
in_keys=["params"],
distribution_class=CompositeDistribution,
distribution_kwargs={
"distribution_map": {"normal": d.Normal, "categ": d.Categorical},
"name_map": {"normal": ("action", "normal"), "categ": ("action", "categ")},
},
)
print(actor.out_keys)

data = TensorDict({"x": torch.rand(10)}, [])
module(data)
print(actor(data))
31 changes: 25 additions & 6 deletions test/test_actors.py
Original file line number Diff line number Diff line change
Expand Up @@ -800,7 +800,8 @@ def test_actorcritic(device):
) == len(policy_params)


def test_compound_actor():
@pytest.mark.parametrize("name_map", [True, False])
def test_compound_actor(name_map):
class Module(nn.Module):
def forward(self, x):
return x[..., :3], x[..., 3:6], x[..., 6:]
Expand All @@ -814,19 +815,37 @@ def forward(self, x):
("params", "categ", "logits"),
],
)
distribution_kwargs = {
"distribution_map": {"normal": dist.Normal, "categ": dist.Categorical}
}
if name_map:
distribution_kwargs.update(
{
"name_map": {
"normal": ("action", "normal"),
"categ": ("action", "categ"),
},
}
)
actor = ProbabilisticActor(
module,
in_keys=["params"],
distribution_class=CompositeDistribution,
distribution_kwargs={
"distribution_map": {"normal": dist.Normal, "categ": dist.Categorical}
},
distribution_kwargs=distribution_kwargs,
)
if not name_map:
assert actor.out_keys == module.out_keys + ["normal", "categ"]
else:
assert actor.out_keys == module.out_keys + [
("action", "normal"),
("action", "categ"),
]

data = TensorDict({"x": torch.rand(10)}, [])
actor(data)
assert set(data.keys(True, True)) == {
"categ",
"normal",
"categ" if not name_map else ("action", "categ"),
"normal" if not name_map else ("action", "normal"),
("params", "categ", "logits"),
("params", "normal", "loc"),
("params", "normal", "scale"),
Expand Down
94 changes: 93 additions & 1 deletion torchrl/modules/tensordict_module/actors.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@

from tensordict import TensorDictBase, unravel_key
from tensordict.nn import (
CompositeDistribution,
dispatch,
TensorDictModule,
TensorDictModuleBase,
Expand Down Expand Up @@ -174,6 +175,14 @@ class ProbabilisticActor(SafeProbabilisticTensorDictSequential):
A :class:`torch.distributions.Distribution` class to
be used for sampling.
Default is :class:`tensordict.nn.distributions.Delta`.

.. note:: if ``distribution_class`` is of type :class:`~tensordict.nn.distributions.CompositeDistribution`,
the keys will be inferred from the ``distribution_map`` / ``name_map`` keyword arguments of that
distribution. If this distribution is used with another constructor (e.g., partial or lambda function)
then the out_keys will need to be provided explicitly.
Note also that actions will __not__ be prefixed with an ``"action"`` key, see the example below
on how this can be achieved with a ``ProbabilisticActor``.

distribution_kwargs (dict, optional): keyword-only argument.
Keyword-argument pairs to be passed to the distribution.
return_log_prob (bool, optional): keyword-only argument.
Expand Down Expand Up @@ -276,6 +285,75 @@ class ProbabilisticActor(SafeProbabilisticTensorDictSequential):
device=None,
is_shared=False)

Using a probabilistic actor with a composite distribution can be achieved using the following
example code:

Examples:
>>> import torch
>>> from tensordict import TensorDict
>>> from tensordict.nn import CompositeDistribution
>>> from tensordict.nn import TensorDictModule
>>> from torch import distributions as d
>>> from torch import nn
>>>
>>> from torchrl.modules import ProbabilisticActor
>>>
>>>
>>> class Module(nn.Module):
... def forward(self, x):
... return x[..., :3], x[..., 3:6], x[..., 6:]
...
>>>
>>> module = TensorDictModule(Module(),
... in_keys=["x"],
... out_keys=[
... ("params", "normal", "loc"), ("params", "normal", "scale"), ("params", "categ", "logits")
... ])
>>> actor = ProbabilisticActor(module,
... in_keys=["params"],
... distribution_class=CompositeDistribution,
... distribution_kwargs={"distribution_map": {"normal": d.Normal, "categ": d.Categorical},
... "name_map": {"normal": ("action", "normal"),
... "categ": ("action", "categ")}}
... )
>>> print(actor.out_keys)
[('params', 'normal', 'loc'), ('params', 'normal', 'scale'), ('params', 'categ', 'logits'), ('action', 'normal'), ('action', 'categ')]
>>>
>>> data = TensorDict({"x": torch.rand(10)}, [])
>>> module(data)
>>> print(actor(data))
TensorDict(
fields={
action: TensorDict(
fields={
categ: Tensor(shape=torch.Size([]), device=cpu, dtype=torch.int64, is_shared=False),
normal: Tensor(shape=torch.Size([3]), device=cpu, dtype=torch.float32, is_shared=False)},
batch_size=torch.Size([]),
device=None,
is_shared=False),
params: TensorDict(
fields={
categ: TensorDict(
fields={
logits: Tensor(shape=torch.Size([4]), device=cpu, dtype=torch.float32, is_shared=False)},
batch_size=torch.Size([]),
device=None,
is_shared=False),
normal: TensorDict(
fields={
loc: Tensor(shape=torch.Size([3]), device=cpu, dtype=torch.float32, is_shared=False),
scale: Tensor(shape=torch.Size([3]), device=cpu, dtype=torch.float32, is_shared=False)},
batch_size=torch.Size([]),
device=None,
is_shared=False)},
batch_size=torch.Size([]),
device=None,
is_shared=False),
x: Tensor(shape=torch.Size([10]), device=cpu, dtype=torch.float32, is_shared=False)},
batch_size=torch.Size([]),
device=None,
is_shared=False)

"""

def __init__(
Expand All @@ -287,8 +365,22 @@ def __init__(
spec: Optional[TensorSpec] = None,
**kwargs,
):
distribution_class = kwargs.get("distribution_class")
if out_keys is None:
out_keys = ["action"]
if distribution_class is CompositeDistribution:
if "distribution_map" not in kwargs.get("distribution_kwargs", {}):
raise KeyError(
"'distribution_map' must be provided within "
"distribution_kwargs whenever the distribution is of type CompositeDistribution."
)
distribution_map = kwargs["distribution_kwargs"]["distribution_map"]
name_map = kwargs["distribution_kwargs"].get("name_map", None)
if name_map is not None:
out_keys = list(name_map.values())
else:
out_keys = list(distribution_map.keys())
else:
out_keys = ["action"]
if (
len(out_keys) == 1
and spec is not None
Expand Down
Loading