From 672b50e40892a20b33f1bddd56b43e5c744df3ec Mon Sep 17 00:00:00 2001 From: Vincent Moens Date: Tue, 11 Jun 2024 09:32:10 +0100 Subject: [PATCH] [Feature] Make ProbabilisticActor compatible with Composite distributions (#2220) --- examples/agents/composite_actor.py | 52 ++++++++++++ test/test_actors.py | 31 +++++-- torchrl/modules/tensordict_module/actors.py | 94 ++++++++++++++++++++- 3 files changed, 170 insertions(+), 7 deletions(-) create mode 100644 examples/agents/composite_actor.py diff --git a/examples/agents/composite_actor.py b/examples/agents/composite_actor.py new file mode 100644 index 00000000000..ae08062e084 --- /dev/null +++ b/examples/agents/composite_actor.py @@ -0,0 +1,52 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# +# This source code is licensed under the MIT license found in the +# LICENSE file in the root directory of this source tree. + + +""" +This code exemplifies how a composite actor can be built. + +The actor has two components: a categorical and a normal distributions. + +We use a ProbabilisticActor and explicitly pass it the key-map that we want through a 'name_map' +argument. + +""" + +import torch +from tensordict import TensorDict +from tensordict.nn import CompositeDistribution, TensorDictModule +from torch import distributions as d, nn + +from torchrl.modules import ProbabilisticActor + + +class Module(nn.Module): + def forward(self, x): + return x[..., :3], x[..., 3:6], x[..., 6:] + + +module = TensorDictModule( + Module(), + in_keys=["x"], + out_keys=[ + ("params", "normal", "loc"), + ("params", "normal", "scale"), + ("params", "categ", "logits"), + ], +) +actor = ProbabilisticActor( + module, + in_keys=["params"], + distribution_class=CompositeDistribution, + distribution_kwargs={ + "distribution_map": {"normal": d.Normal, "categ": d.Categorical}, + "name_map": {"normal": ("action", "normal"), "categ": ("action", "categ")}, + }, +) +print(actor.out_keys) + +data = TensorDict({"x": torch.rand(10)}, []) +module(data) +print(actor(data)) diff --git a/test/test_actors.py b/test/test_actors.py index 560566286ae..388120a4ba7 100644 --- a/test/test_actors.py +++ b/test/test_actors.py @@ -800,7 +800,8 @@ def test_actorcritic(device): ) == len(policy_params) -def test_compound_actor(): +@pytest.mark.parametrize("name_map", [True, False]) +def test_compound_actor(name_map): class Module(nn.Module): def forward(self, x): return x[..., :3], x[..., 3:6], x[..., 6:] @@ -814,19 +815,37 @@ def forward(self, x): ("params", "categ", "logits"), ], ) + distribution_kwargs = { + "distribution_map": {"normal": dist.Normal, "categ": dist.Categorical} + } + if name_map: + distribution_kwargs.update( + { + "name_map": { + "normal": ("action", "normal"), + "categ": ("action", "categ"), + }, + } + ) actor = ProbabilisticActor( module, in_keys=["params"], distribution_class=CompositeDistribution, - distribution_kwargs={ - "distribution_map": {"normal": dist.Normal, "categ": dist.Categorical} - }, + distribution_kwargs=distribution_kwargs, ) + if not name_map: + assert actor.out_keys == module.out_keys + ["normal", "categ"] + else: + assert actor.out_keys == module.out_keys + [ + ("action", "normal"), + ("action", "categ"), + ] + data = TensorDict({"x": torch.rand(10)}, []) actor(data) assert set(data.keys(True, True)) == { - "categ", - "normal", + "categ" if not name_map else ("action", "categ"), + "normal" if not name_map else ("action", "normal"), ("params", "categ", "logits"), ("params", "normal", "loc"), ("params", "normal", "scale"), diff --git a/torchrl/modules/tensordict_module/actors.py b/torchrl/modules/tensordict_module/actors.py index 7c712bc4630..d2ea7af0af8 100644 --- a/torchrl/modules/tensordict_module/actors.py +++ b/torchrl/modules/tensordict_module/actors.py @@ -11,6 +11,7 @@ from tensordict import TensorDictBase, unravel_key from tensordict.nn import ( + CompositeDistribution, dispatch, TensorDictModule, TensorDictModuleBase, @@ -174,6 +175,14 @@ class ProbabilisticActor(SafeProbabilisticTensorDictSequential): A :class:`torch.distributions.Distribution` class to be used for sampling. Default is :class:`tensordict.nn.distributions.Delta`. + + .. note:: if ``distribution_class`` is of type :class:`~tensordict.nn.distributions.CompositeDistribution`, + the keys will be inferred from the ``distribution_map`` / ``name_map`` keyword arguments of that + distribution. If this distribution is used with another constructor (e.g., partial or lambda function) + then the out_keys will need to be provided explicitly. + Note also that actions will __not__ be prefixed with an ``"action"`` key, see the example below + on how this can be achieved with a ``ProbabilisticActor``. + distribution_kwargs (dict, optional): keyword-only argument. Keyword-argument pairs to be passed to the distribution. return_log_prob (bool, optional): keyword-only argument. @@ -276,6 +285,75 @@ class ProbabilisticActor(SafeProbabilisticTensorDictSequential): device=None, is_shared=False) + Using a probabilistic actor with a composite distribution can be achieved using the following + example code: + + Examples: + >>> import torch + >>> from tensordict import TensorDict + >>> from tensordict.nn import CompositeDistribution + >>> from tensordict.nn import TensorDictModule + >>> from torch import distributions as d + >>> from torch import nn + >>> + >>> from torchrl.modules import ProbabilisticActor + >>> + >>> + >>> class Module(nn.Module): + ... def forward(self, x): + ... return x[..., :3], x[..., 3:6], x[..., 6:] + ... + >>> + >>> module = TensorDictModule(Module(), + ... in_keys=["x"], + ... out_keys=[ + ... ("params", "normal", "loc"), ("params", "normal", "scale"), ("params", "categ", "logits") + ... ]) + >>> actor = ProbabilisticActor(module, + ... in_keys=["params"], + ... distribution_class=CompositeDistribution, + ... distribution_kwargs={"distribution_map": {"normal": d.Normal, "categ": d.Categorical}, + ... "name_map": {"normal": ("action", "normal"), + ... "categ": ("action", "categ")}} + ... ) + >>> print(actor.out_keys) + [('params', 'normal', 'loc'), ('params', 'normal', 'scale'), ('params', 'categ', 'logits'), ('action', 'normal'), ('action', 'categ')] + >>> + >>> data = TensorDict({"x": torch.rand(10)}, []) + >>> module(data) + >>> print(actor(data)) + TensorDict( + fields={ + action: TensorDict( + fields={ + categ: Tensor(shape=torch.Size([]), device=cpu, dtype=torch.int64, is_shared=False), + normal: Tensor(shape=torch.Size([3]), device=cpu, dtype=torch.float32, is_shared=False)}, + batch_size=torch.Size([]), + device=None, + is_shared=False), + params: TensorDict( + fields={ + categ: TensorDict( + fields={ + logits: Tensor(shape=torch.Size([4]), device=cpu, dtype=torch.float32, is_shared=False)}, + batch_size=torch.Size([]), + device=None, + is_shared=False), + normal: TensorDict( + fields={ + loc: Tensor(shape=torch.Size([3]), device=cpu, dtype=torch.float32, is_shared=False), + scale: Tensor(shape=torch.Size([3]), device=cpu, dtype=torch.float32, is_shared=False)}, + batch_size=torch.Size([]), + device=None, + is_shared=False)}, + batch_size=torch.Size([]), + device=None, + is_shared=False), + x: Tensor(shape=torch.Size([10]), device=cpu, dtype=torch.float32, is_shared=False)}, + batch_size=torch.Size([]), + device=None, + is_shared=False) + """ def __init__( @@ -287,8 +365,22 @@ def __init__( spec: Optional[TensorSpec] = None, **kwargs, ): + distribution_class = kwargs.get("distribution_class") if out_keys is None: - out_keys = ["action"] + if distribution_class is CompositeDistribution: + if "distribution_map" not in kwargs.get("distribution_kwargs", {}): + raise KeyError( + "'distribution_map' must be provided within " + "distribution_kwargs whenever the distribution is of type CompositeDistribution." + ) + distribution_map = kwargs["distribution_kwargs"]["distribution_map"] + name_map = kwargs["distribution_kwargs"].get("name_map", None) + if name_map is not None: + out_keys = list(name_map.values()) + else: + out_keys = list(distribution_map.keys()) + else: + out_keys = ["action"] if ( len(out_keys) == 1 and spec is not None