Skip to content

Commit

Permalink
[Refactor] Rename specs to simpler names (#2368)
Browse files Browse the repository at this point in the history
  • Loading branch information
vmoens authored Aug 7, 2024
1 parent a41da21 commit c848a79
Show file tree
Hide file tree
Showing 110 changed files with 56,070 additions and 2,787 deletions.
20 changes: 6 additions & 14 deletions benchmarks/test_objectives_benchmarks.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@
TensorDictSequential as Seq,
)
from torch.nn import functional as F
from torchrl.data.tensor_specs import BoundedTensorSpec, UnboundedContinuousTensorSpec
from torchrl.data.tensor_specs import Bounded, Unbounded
from torchrl.modules import MLP, QValueActor, TanhNormal
from torchrl.objectives import (
A2CLoss,
Expand Down Expand Up @@ -253,9 +253,7 @@ def test_sac_speed(benchmark, n_obs=8, n_act=4, ncells=128, batch=128, n_hidden=
value = Seq(common, value_head)
value(actor(td))

loss = SACLoss(
actor, value, action_spec=UnboundedContinuousTensorSpec(shape=(n_act,))
)
loss = SACLoss(actor, value, action_spec=Unbounded(shape=(n_act,)))

loss(td)
benchmark(loss, td)
Expand Down Expand Up @@ -312,9 +310,7 @@ def test_redq_speed(benchmark, n_obs=8, n_act=4, ncells=128, batch=128, n_hidden
value = Seq(common, value_head)
value(actor(td))

loss = REDQLoss(
actor, value, action_spec=UnboundedContinuousTensorSpec(shape=(n_act,))
)
loss = REDQLoss(actor, value, action_spec=Unbounded(shape=(n_act,)))

loss(td)
benchmark(loss, td)
Expand Down Expand Up @@ -373,9 +369,7 @@ def test_redq_deprec_speed(
value = Seq(common, value_head)
value(actor(td))

loss = REDQLoss_deprecated(
actor, value, action_spec=UnboundedContinuousTensorSpec(shape=(n_act,))
)
loss = REDQLoss_deprecated(actor, value, action_spec=Unbounded(shape=(n_act,)))

loss(td)
benchmark(loss, td)
Expand Down Expand Up @@ -435,7 +429,7 @@ def test_td3_speed(benchmark, n_obs=8, n_act=4, ncells=128, batch=128, n_hidden=
loss = TD3Loss(
actor,
value,
action_spec=BoundedTensorSpec(shape=(n_act,), low=-1, high=1),
action_spec=Bounded(shape=(n_act,), low=-1, high=1),
)

loss(td)
Expand Down Expand Up @@ -490,9 +484,7 @@ def test_cql_speed(benchmark, n_obs=8, n_act=4, ncells=128, batch=128, n_hidden=
value = Seq(common, value_head)
value(actor(td))

loss = CQLLoss(
actor, value, action_spec=UnboundedContinuousTensorSpec(shape=(n_act,))
)
loss = CQLLoss(actor, value, action_spec=Unbounded(shape=(n_act,)))

loss(td)
benchmark(loss, td)
Expand Down
82 changes: 75 additions & 7 deletions docs/source/reference/data.rst
Original file line number Diff line number Diff line change
Expand Up @@ -877,11 +877,58 @@ TensorSpec

.. _ref_specs:

The `TensorSpec` parent class and subclasses define the basic properties of observations and actions in TorchRL, such
as shape, device, dtype and domain.
The :class:`~torchrl.data.TensorSpec` parent class and subclasses define the basic properties of state, observations
actions, rewards and done status in TorchRL, such as their shape, device, dtype and domain.

It is important that your environment specs match the input and output that it sends and receives, as
:obj:`ParallelEnv` will create buffers from these specs to communicate with the spawn processes.
Check the :obj:`torchrl.envs.utils.check_env_specs` method for a sanity check.
:class:`~torchrl.envs.ParallelEnv` will create buffers from these specs to communicate with the spawn processes.
Check the :func:`torchrl.envs.utils.check_env_specs` method for a sanity check.

If needed, specs can be automatially generated from data using the :func:`~torchrl.envs.utils.make_composite_from_td`
function.

Specs fall in two main categories, numerical and categorical.

.. table:: Numerical TensorSpec subclasses.

+-------------------------------------------------------------------------------+
| Numerical |
+=====================================+=========================================+
| Bounded | Unbounded |
+-----------------+-------------------+-------------------+---------------------+
| BoundedDiscrete | BoundedContinuous | UnboundedDiscrete | UnboundedContinuous |
+-----------------+-------------------+-------------------+---------------------+

Whenever a :class:`~torchrl.data.Bounded` instance is created, its domain (defined either implicitly by its dtype or
explicitly by the `"domain"` keyword argument) will determine if the instantiated class will be of :class:`~torchrl.data.BoundedContinuous`
or :class:`~torchrl.data.BoundedDiscrete` type. The same applies to the :class:`~torchrl.data.Unbounded` class.
See these classes for further information.

.. table:: Categorical TensorSpec subclasses.

+------------------------------------------------------------------+
| Categorical |
+========+=============+=============+==================+==========+
| OneHot | MultiOneHot | Categorical | MultiCategorical | Binary |
+--------+-------------+-------------+------------------+----------+

Unlike ``gymnasium``, TorchRL does not have the concept of an arbitrary list of specs. If multiple specs have to be
combined together, TorchRL assumes that the data will be presented as dictionaries (more specifically, as
:class:`~tensordict.TensorDict` or related formats). The corresponding :class:`~torchrl.data.TensorSpec` class in these
cases is the :class:`~torchrl.data.Composite` spec.

Nevertheless, specs can be stacked together using :func:`~torch.stack`: if they are identical, their shape will be
expanded accordingly.
Otherwise, a lazy stack will be created through the :class:`~torchrl.data.Stacked` class.

Similarly, ``TensorSpecs`` possess some common behavior with :class:`~torch.Tensor` and
:class:`~tensordict.TensorDict`: they can be reshaped, indexed, squeezed, unsqueezed, moved to another device (``to``)
or unbound (``unbind``) as regular :class:`~torch.Tensor` instances would be.

Specs where some dimensions are ``-1`` are said to be "dynamic" and the negative dimensions indicate that the corresponding
data has an inconsistent shape. When seen by an optimizer or an environment (e.g., batched environment such as
:class:`~torchrl.envs.ParallelEnv`), these negative shapes tell TorchRL to avoid using buffers as the tensor shapes are
not predictable.

.. currentmodule:: torchrl.data

Expand All @@ -890,19 +937,40 @@ Check the :obj:`torchrl.envs.utils.check_env_specs` method for a sanity check.
:template: rl_template.rst

TensorSpec
Binary
Bounded
Categorical
Composite
MultiCategorical
MultiOneHot
NonTensor
OneHotDiscrete
Stacked
StackedComposite
Unbounded
UnboundedContinuous
UnboundedDiscrete

The following classes are deprecated and just point to the classes above:

.. currentmodule:: torchrl.data

.. autosummary::
:toctree: generated/
:template: rl_template.rst

BinaryDiscreteTensorSpec
BoundedTensorSpec
CompositeSpec
DiscreteTensorSpec
LazyStackedCompositeSpec
LazyStackedTensorSpec
MultiDiscreteTensorSpec
MultiOneHotDiscreteTensorSpec
NonTensorSpec
OneHotDiscreteTensorSpec
UnboundedContinuousTensorSpec
UnboundedDiscreteTensorSpec
LazyStackedTensorSpec
LazyStackedCompositeSpec
NonTensorSpec

Reinforcement Learning From Human Feedback (RLHF)
-------------------------------------------------
Expand Down
Loading

0 comments on commit c848a79

Please sign in to comment.