Skip to content

Commit

Permalink
[Refactor] Rename specs to simpler names (pytorch#2368)
Browse files Browse the repository at this point in the history
  • Loading branch information
vmoens committed Aug 7, 2024
1 parent a41da21 commit 607db8b
Show file tree
Hide file tree
Showing 113 changed files with 3,403 additions and 2,879 deletions.
20 changes: 6 additions & 14 deletions benchmarks/test_objectives_benchmarks.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@
TensorDictSequential as Seq,
)
from torch.nn import functional as F
from torchrl.data.tensor_specs import BoundedTensorSpec, UnboundedContinuousTensorSpec
from torchrl.data.tensor_specs import Bounded, Unbounded
from torchrl.modules import MLP, QValueActor, TanhNormal
from torchrl.objectives import (
A2CLoss,
Expand Down Expand Up @@ -253,9 +253,7 @@ def test_sac_speed(benchmark, n_obs=8, n_act=4, ncells=128, batch=128, n_hidden=
value = Seq(common, value_head)
value(actor(td))

loss = SACLoss(
actor, value, action_spec=UnboundedContinuousTensorSpec(shape=(n_act,))
)
loss = SACLoss(actor, value, action_spec=Unbounded(shape=(n_act,)))

loss(td)
benchmark(loss, td)
Expand Down Expand Up @@ -312,9 +310,7 @@ def test_redq_speed(benchmark, n_obs=8, n_act=4, ncells=128, batch=128, n_hidden
value = Seq(common, value_head)
value(actor(td))

loss = REDQLoss(
actor, value, action_spec=UnboundedContinuousTensorSpec(shape=(n_act,))
)
loss = REDQLoss(actor, value, action_spec=Unbounded(shape=(n_act,)))

loss(td)
benchmark(loss, td)
Expand Down Expand Up @@ -373,9 +369,7 @@ def test_redq_deprec_speed(
value = Seq(common, value_head)
value(actor(td))

loss = REDQLoss_deprecated(
actor, value, action_spec=UnboundedContinuousTensorSpec(shape=(n_act,))
)
loss = REDQLoss_deprecated(actor, value, action_spec=Unbounded(shape=(n_act,)))

loss(td)
benchmark(loss, td)
Expand Down Expand Up @@ -435,7 +429,7 @@ def test_td3_speed(benchmark, n_obs=8, n_act=4, ncells=128, batch=128, n_hidden=
loss = TD3Loss(
actor,
value,
action_spec=BoundedTensorSpec(shape=(n_act,), low=-1, high=1),
action_spec=Bounded(shape=(n_act,), low=-1, high=1),
)

loss(td)
Expand Down Expand Up @@ -490,9 +484,7 @@ def test_cql_speed(benchmark, n_obs=8, n_act=4, ncells=128, batch=128, n_hidden=
value = Seq(common, value_head)
value(actor(td))

loss = CQLLoss(
actor, value, action_spec=UnboundedContinuousTensorSpec(shape=(n_act,))
)
loss = CQLLoss(actor, value, action_spec=Unbounded(shape=(n_act,)))

loss(td)
benchmark(loss, td)
Expand Down
82 changes: 75 additions & 7 deletions docs/source/reference/data.rst
Original file line number Diff line number Diff line change
Expand Up @@ -877,11 +877,58 @@ TensorSpec

.. _ref_specs:

The `TensorSpec` parent class and subclasses define the basic properties of observations and actions in TorchRL, such
as shape, device, dtype and domain.
The :class:`~torchrl.data.TensorSpec` parent class and subclasses define the basic properties of state, observations
actions, rewards and done status in TorchRL, such as their shape, device, dtype and domain.

It is important that your environment specs match the input and output that it sends and receives, as
:obj:`ParallelEnv` will create buffers from these specs to communicate with the spawn processes.
Check the :obj:`torchrl.envs.utils.check_env_specs` method for a sanity check.
:class:`~torchrl.envs.ParallelEnv` will create buffers from these specs to communicate with the spawn processes.
Check the :func:`torchrl.envs.utils.check_env_specs` method for a sanity check.

If needed, specs can be automatially generated from data using the :func:`~torchrl.envs.utils.make_composite_from_td`
function.

Specs fall in two main categories, numerical and categorical.

.. table:: Numerical TensorSpec subclasses.

+-------------------------------------------------------------------------------+
| Numerical |
+=====================================+=========================================+
| Bounded | Unbounded |
+-----------------+-------------------+-------------------+---------------------+
| BoundedDiscrete | BoundedContinuous | UnboundedDiscrete | UnboundedContinuous |
+-----------------+-------------------+-------------------+---------------------+

Whenever a :class:`~torchrl.data.Bounded` instance is created, its domain (defined either implicitly by its dtype or
explicitly by the `"domain"` keyword argument) will determine if the instantiated class will be of :class:`~torchrl.data.BoundedContinuous`
or :class:`~torchrl.data.BoundedDiscrete` type. The same applies to the :class:`~torchrl.data.Unbounded` class.
See these classes for further information.

.. table:: Categorical TensorSpec subclasses.

+------------------------------------------------------------------+
| Categorical |
+========+=============+=============+==================+==========+
| OneHot | MultiOneHot | Categorical | MultiCategorical | Binary |
+--------+-------------+-------------+------------------+----------+

Unlike ``gymnasium``, TorchRL does not have the concept of an arbitrary list of specs. If multiple specs have to be
combined together, TorchRL assumes that the data will be presented as dictionaries (more specifically, as
:class:`~tensordict.TensorDict` or related formats). The corresponding :class:`~torchrl.data.TensorSpec` class in these
cases is the :class:`~torchrl.data.Composite` spec.

Nevertheless, specs can be stacked together using :func:`~torch.stack`: if they are identical, their shape will be
expanded accordingly.
Otherwise, a lazy stack will be created through the :class:`~torchrl.data.Stacked` class.

Similarly, ``TensorSpecs`` possess some common behavior with :class:`~torch.Tensor` and
:class:`~tensordict.TensorDict`: they can be reshaped, indexed, squeezed, unsqueezed, moved to another device (``to``)
or unbound (``unbind``) as regular :class:`~torch.Tensor` instances would be.

Specs where some dimensions are ``-1`` are said to be "dynamic" and the negative dimensions indicate that the corresponding
data has an inconsistent shape. When seen by an optimizer or an environment (e.g., batched environment such as
:class:`~torchrl.envs.ParallelEnv`), these negative shapes tell TorchRL to avoid using buffers as the tensor shapes are
not predictable.

.. currentmodule:: torchrl.data

Expand All @@ -890,19 +937,40 @@ Check the :obj:`torchrl.envs.utils.check_env_specs` method for a sanity check.
:template: rl_template.rst

TensorSpec
Binary
Bounded
Categorical
Composite
MultiCategorical
MultiOneHot
NonTensor
OneHotDiscrete
Stacked
StackedComposite
Unbounded
UnboundedContinuous
UnboundedDiscrete

The following classes are deprecated and just point to the classes above:

.. currentmodule:: torchrl.data

.. autosummary::
:toctree: generated/
:template: rl_template.rst

BinaryDiscreteTensorSpec
BoundedTensorSpec
CompositeSpec
DiscreteTensorSpec
LazyStackedCompositeSpec
LazyStackedTensorSpec
MultiDiscreteTensorSpec
MultiOneHotDiscreteTensorSpec
NonTensorSpec
OneHotDiscreteTensorSpec
UnboundedContinuousTensorSpec
UnboundedDiscreteTensorSpec
LazyStackedTensorSpec
LazyStackedCompositeSpec
NonTensorSpec

Reinforcement Learning From Human Feedback (RLHF)
-------------------------------------------------
Expand Down
40 changes: 20 additions & 20 deletions docs/source/reference/envs.rst
Original file line number Diff line number Diff line change
Expand Up @@ -28,9 +28,9 @@ Each env will have the following attributes:
This is especially useful for transforms (see below). For parametric environments (e.g.
model-based environments), the device does represent the hardware that will be used to
compute the operations.
- :obj:`env.observation_spec`: a :class:`~torchrl.data.CompositeSpec` object
- :obj:`env.observation_spec`: a :class:`~torchrl.data.Composite` object
containing all the observation key-spec pairs.
- :obj:`env.state_spec`: a :class:`~torchrl.data.CompositeSpec` object
- :obj:`env.state_spec`: a :class:`~torchrl.data.Composite` object
containing all the input key-spec pairs (except action). For most stateful
environments, this container will be empty.
- :obj:`env.action_spec`: a :class:`~torchrl.data.TensorSpec` object
Expand All @@ -39,10 +39,10 @@ Each env will have the following attributes:
the reward spec.
- :obj:`env.done_spec`: a :class:`~torchrl.data.TensorSpec` object representing
the done-flag spec. See the section on trajectory termination below.
- :obj:`env.input_spec`: a :class:`~torchrl.data.CompositeSpec` object containing
- :obj:`env.input_spec`: a :class:`~torchrl.data.Composite` object containing
all the input keys (:obj:`"full_action_spec"` and :obj:`"full_state_spec"`).
It is locked and should not be modified directly.
- :obj:`env.output_spec`: a :class:`~torchrl.data.CompositeSpec` object containing
- :obj:`env.output_spec`: a :class:`~torchrl.data.Composite` object containing
all the output keys (:obj:`"full_observation_spec"`, :obj:`"full_reward_spec"` and :obj:`"full_done_spec"`).
It is locked and should not be modified directly.

Expand Down Expand Up @@ -433,28 +433,28 @@ only the done flag is shared across agents (as in VMAS):
... action_specs.append(agent_i_action_spec)
... reward_specs.append(agent_i_reward_spec)
... observation_specs.append(agent_i_observation_spec)
>>> env.action_spec = CompositeSpec(
>>> env.action_spec = Composite(
... {
... "agents": CompositeSpec(
... "agents": Composite(
... {"action": torch.stack(action_specs)}, shape=(env.n_agents,)
... )
... }
...)
>>> env.reward_spec = CompositeSpec(
>>> env.reward_spec = Composite(
... {
... "agents": CompositeSpec(
... "agents": Composite(
... {"reward": torch.stack(reward_specs)}, shape=(env.n_agents,)
... )
... }
...)
>>> env.observation_spec = CompositeSpec(
>>> env.observation_spec = Composite(
... {
... "agents": CompositeSpec(
... "agents": Composite(
... {"observation": torch.stack(observation_specs)}, shape=(env.n_agents,)
... )
... }
...)
>>> env.done_spec = DiscreteTensorSpec(
>>> env.done_spec = Categorical(
... n=2,
... shape=torch.Size((1,)),
... dtype=torch.bool,
Expand Down Expand Up @@ -582,23 +582,23 @@ the ``return_contiguous=False`` argument.
Here is a working example:

>>> from torchrl.envs import EnvBase
>>> from torchrl.data import UnboundedContinuousTensorSpec, CompositeSpec, BoundedTensorSpec, BinaryDiscreteTensorSpec
>>> from torchrl.data import Unbounded, Composite, Bounded, Binary
>>> import torch
>>> from tensordict import TensorDict, TensorDictBase
>>>
>>> class EnvWithDynamicSpec(EnvBase):
... def __init__(self, max_count=5):
... super().__init__(batch_size=())
... self.observation_spec = CompositeSpec(
... observation=UnboundedContinuousTensorSpec(shape=(3, -1, 2)),
... self.observation_spec = Composite(
... observation=Unbounded(shape=(3, -1, 2)),
... )
... self.action_spec = BoundedTensorSpec(low=-1, high=1, shape=(2,))
... self.full_done_spec = CompositeSpec(
... done=BinaryDiscreteTensorSpec(1, shape=(1,), dtype=torch.bool),
... terminated=BinaryDiscreteTensorSpec(1, shape=(1,), dtype=torch.bool),
... truncated=BinaryDiscreteTensorSpec(1, shape=(1,), dtype=torch.bool),
... self.action_spec = Bounded(low=-1, high=1, shape=(2,))
... self.full_done_spec = Composite(
... done=Binary(1, shape=(1,), dtype=torch.bool),
... terminated=Binary(1, shape=(1,), dtype=torch.bool),
... truncated=Binary(1, shape=(1,), dtype=torch.bool),
... )
... self.reward_spec = UnboundedContinuousTensorSpec((1,), dtype=torch.float)
... self.reward_spec = Unbounded((1,), dtype=torch.float)
... self.count = 0
... self.max_count = max_count
...
Expand Down
Loading

0 comments on commit 607db8b

Please sign in to comment.