diff --git a/test/test_actors.py b/test/test_actors.py index 8b432e9ac21..ddefcea274c 100644 --- a/test/test_actors.py +++ b/test/test_actors.py @@ -8,12 +8,12 @@ import torch from _utils_internal import get_default_devices - from mocking_classes import NestedCountingEnv from tensordict import TensorDict -from tensordict.nn import TensorDictModule +from tensordict.nn import CompositeDistribution, TensorDictModule from tensordict.nn.distributions import NormalParamExtractor -from torch import nn + +from torch import distributions as dist, nn from torchrl.data import ( BinaryDiscreteTensorSpec, BoundedTensorSpec, @@ -800,6 +800,40 @@ def test_actorcritic(device): ) == len(policy_params) +def test_compound_actor(): + class Module(nn.Module): + def forward(self, x): + return x[..., :3], x[..., 3:6], x[..., 6:] + + module = TensorDictModule( + Module(), + in_keys=["x"], + out_keys=[ + ("params", "normal", "loc"), + ("params", "normal", "scale"), + ("params", "categ", "logits"), + ], + ) + actor = ProbabilisticActor( + module, + in_keys=["params"], + distribution_class=CompositeDistribution, + distribution_kwargs={ + "distribution_map": {"normal": dist.Normal, "categ": dist.Categorical} + }, + ) + data = TensorDict({"x": torch.rand(10)}, []) + actor(data) + assert set(data.keys(True, True)) == { + "categ", + "normal", + ("params", "categ", "logits"), + ("params", "normal", "loc"), + ("params", "normal", "scale"), + "x", + } + + @pytest.mark.skipif(not _has_transformers, reason="missing dependencies") @pytest.mark.parametrize("device", get_default_devices()) def test_lmhead_actorvalueoperator(device): diff --git a/torchrl/modules/tensordict_module/actors.py b/torchrl/modules/tensordict_module/actors.py index a5b5051bab5..1e5a557546a 100644 --- a/torchrl/modules/tensordict_module/actors.py +++ b/torchrl/modules/tensordict_module/actors.py @@ -210,6 +210,62 @@ class ProbabilisticActor(SafeProbabilisticTensorDictSequential): device=None, is_shared=False) + Probabilistic actors also support compound actions through the + :class:`tensordict.nn.CompositeDistribution` class. This distribution takes + a tensordict as input (typically `"params"`) and reads it as a whole: the + content of this tensordict is the input to the distributions contained in the + compound one. + + Examples: + >>> from tensordict import TensorDict + >>> from tensordict.nn import CompositeDistribution, TensorDictModule + >>> from torchrl.modules import ProbabilisticActor + >>> from torch import nn, distributions as d + >>> import torch + >>> + >>> class Module(nn.Module): + ... def forward(self, x): + ... return x[..., :3], x[..., 3:6], x[..., 6:] + >>> module = TensorDictModule(Module(), + ... in_keys=["x"], + ... out_keys=[("params", "normal", "loc"), + ... ("params", "normal", "scale"), + ... ("params", "categ", "logits")]) + >>> actor = ProbabilisticActor(module, + ... in_keys=["params"], + ... distribution_class=CompositeDistribution, + ... distribution_kwargs={"distribution_map": { + ... "normal": d.Normal, "categ": d.Categorical}} + ... ) + >>> data = TensorDict({"x": torch.rand(10)}, []) + >>> actor(data) + TensorDict( + fields={ + categ: Tensor(shape=torch.Size([]), device=cpu, dtype=torch.int64, is_shared=False), + normal: Tensor(shape=torch.Size([3]), device=cpu, dtype=torch.float32, is_shared=False), + params: TensorDict( + fields={ + categ: TensorDict( + fields={ + logits: Tensor(shape=torch.Size([4]), device=cpu, dtype=torch.float32, is_shared=False)}, + batch_size=torch.Size([]), + device=None, + is_shared=False), + normal: TensorDict( + fields={ + loc: Tensor(shape=torch.Size([3]), device=cpu, dtype=torch.float32, is_shared=False), + scale: Tensor(shape=torch.Size([3]), device=cpu, dtype=torch.float32, is_shared=False)}, + batch_size=torch.Size([]), + device=None, + is_shared=False)}, + batch_size=torch.Size([]), + device=None, + is_shared=False), + x: Tensor(shape=torch.Size([10]), device=cpu, dtype=torch.float32, is_shared=False)}, + batch_size=torch.Size([]), + device=None, + is_shared=False) + """ def __init__(