Skip to content

Commit

Permalink
[Doc] Document (and test) compound actor (pytorch#1673)
Browse files Browse the repository at this point in the history
  • Loading branch information
vmoens authored Nov 3, 2023
1 parent 8ca7a39 commit 97abb93
Show file tree
Hide file tree
Showing 2 changed files with 93 additions and 3 deletions.
40 changes: 37 additions & 3 deletions test/test_actors.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,12 +8,12 @@
import torch

from _utils_internal import get_default_devices

from mocking_classes import NestedCountingEnv
from tensordict import TensorDict
from tensordict.nn import TensorDictModule
from tensordict.nn import CompositeDistribution, TensorDictModule
from tensordict.nn.distributions import NormalParamExtractor
from torch import nn

from torch import distributions as dist, nn
from torchrl.data import (
BinaryDiscreteTensorSpec,
BoundedTensorSpec,
Expand Down Expand Up @@ -800,6 +800,40 @@ def test_actorcritic(device):
) == len(policy_params)


def test_compound_actor():
class Module(nn.Module):
def forward(self, x):
return x[..., :3], x[..., 3:6], x[..., 6:]

module = TensorDictModule(
Module(),
in_keys=["x"],
out_keys=[
("params", "normal", "loc"),
("params", "normal", "scale"),
("params", "categ", "logits"),
],
)
actor = ProbabilisticActor(
module,
in_keys=["params"],
distribution_class=CompositeDistribution,
distribution_kwargs={
"distribution_map": {"normal": dist.Normal, "categ": dist.Categorical}
},
)
data = TensorDict({"x": torch.rand(10)}, [])
actor(data)
assert set(data.keys(True, True)) == {
"categ",
"normal",
("params", "categ", "logits"),
("params", "normal", "loc"),
("params", "normal", "scale"),
"x",
}


@pytest.mark.skipif(not _has_transformers, reason="missing dependencies")
@pytest.mark.parametrize("device", get_default_devices())
def test_lmhead_actorvalueoperator(device):
Expand Down
56 changes: 56 additions & 0 deletions torchrl/modules/tensordict_module/actors.py
Original file line number Diff line number Diff line change
Expand Up @@ -210,6 +210,62 @@ class ProbabilisticActor(SafeProbabilisticTensorDictSequential):
device=None,
is_shared=False)
Probabilistic actors also support compound actions through the
:class:`tensordict.nn.CompositeDistribution` class. This distribution takes
a tensordict as input (typically `"params"`) and reads it as a whole: the
content of this tensordict is the input to the distributions contained in the
compound one.
Examples:
>>> from tensordict import TensorDict
>>> from tensordict.nn import CompositeDistribution, TensorDictModule
>>> from torchrl.modules import ProbabilisticActor
>>> from torch import nn, distributions as d
>>> import torch
>>>
>>> class Module(nn.Module):
... def forward(self, x):
... return x[..., :3], x[..., 3:6], x[..., 6:]
>>> module = TensorDictModule(Module(),
... in_keys=["x"],
... out_keys=[("params", "normal", "loc"),
... ("params", "normal", "scale"),
... ("params", "categ", "logits")])
>>> actor = ProbabilisticActor(module,
... in_keys=["params"],
... distribution_class=CompositeDistribution,
... distribution_kwargs={"distribution_map": {
... "normal": d.Normal, "categ": d.Categorical}}
... )
>>> data = TensorDict({"x": torch.rand(10)}, [])
>>> actor(data)
TensorDict(
fields={
categ: Tensor(shape=torch.Size([]), device=cpu, dtype=torch.int64, is_shared=False),
normal: Tensor(shape=torch.Size([3]), device=cpu, dtype=torch.float32, is_shared=False),
params: TensorDict(
fields={
categ: TensorDict(
fields={
logits: Tensor(shape=torch.Size([4]), device=cpu, dtype=torch.float32, is_shared=False)},
batch_size=torch.Size([]),
device=None,
is_shared=False),
normal: TensorDict(
fields={
loc: Tensor(shape=torch.Size([3]), device=cpu, dtype=torch.float32, is_shared=False),
scale: Tensor(shape=torch.Size([3]), device=cpu, dtype=torch.float32, is_shared=False)},
batch_size=torch.Size([]),
device=None,
is_shared=False)},
batch_size=torch.Size([]),
device=None,
is_shared=False),
x: Tensor(shape=torch.Size([10]), device=cpu, dtype=torch.float32, is_shared=False)},
batch_size=torch.Size([]),
device=None,
is_shared=False)
"""

def __init__(
Expand Down

0 comments on commit 97abb93

Please sign in to comment.