Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Refactor] Rename specs to simpler names #2368

Merged
merged 23 commits into from
Aug 7, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 2 additions & 1 deletion .github/unittest/linux/scripts/environment.yml
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,8 @@ dependencies:
- tensorboard
- imageio==2.26.0
- wandb
- dm_control
- dm_control<1.0.21
- mujoco<3.2.1
- mlflow
- av
- coverage
Expand Down
2 changes: 1 addition & 1 deletion .github/unittest/linux/scripts/run_all.sh
Original file line number Diff line number Diff line change
Expand Up @@ -91,7 +91,7 @@ echo "installing gymnasium"
pip3 install "gymnasium"
pip3 install ale_py
pip3 install mo-gymnasium[mujoco] # requires here bc needs mujoco-py
pip3 install mujoco -U
pip3 install "mujoco<3.2.1" -U

# sanity check: remove?
python3 -c """
Expand Down
3 changes: 2 additions & 1 deletion .github/unittest/linux_distributed/scripts/environment.yml
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,8 @@ dependencies:
- tensorboard
- imageio==2.26.0
- wandb
- dm_control
- dm_control<1.0.21
- mujoco<3.2.1
- mlflow
- av
- coverage
Expand Down
3 changes: 2 additions & 1 deletion .github/unittest/linux_examples/scripts/environment.yml
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,8 @@ dependencies:
- scipy
- hydra-core
- imageio==2.26.0
- dm_control
- dm_control<1.0.21
- mujoco<3.2.1
- mlflow
- av
- coverage
Expand Down
3 changes: 2 additions & 1 deletion .github/unittest/linux_libs/scripts_envpool/environment.yml
Original file line number Diff line number Diff line change
Expand Up @@ -18,5 +18,6 @@ dependencies:
- expecttest
- pyyaml
- scipy
- dm_control
- dm_control<1.0.21
- mujoco<3.2.1
- coverage
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@ dependencies:
- scipy
- hydra-core
- dm_control -e git+https://github.com/deepmind/dm_control.git@c053360edea6170acfd9c8f65446703307d9d352#egg={dm_control}
- mujoco<3.2.1
- patchelf
- pyopengl==3.1.4
- ray
Expand Down
4 changes: 2 additions & 2 deletions .github/workflows/benchmarks.yml
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,7 @@ jobs:
python3 setup.py develop
python3 -m pip install pytest pytest-benchmark
python3 -m pip install "gym[accept-rom-license,atari]"
python3 -m pip install dm_control
python3 -m pip install "dm_control<1.0.21" "mujoco<3.2.1"
export TD_GET_DEFAULTS_TO_NONE=1
- name: Run benchmarks
run: |
Expand Down Expand Up @@ -97,7 +97,7 @@ jobs:
python3 setup.py develop
python3 -m pip install pytest pytest-benchmark
python3 -m pip install "gym[accept-rom-license,atari]"
python3 -m pip install dm_control
python3 -m pip install "dm_control<1.0.21" "mujoco<3.2.1"
export TD_GET_DEFAULTS_TO_NONE=1
- name: check GPU presence
run: |
Expand Down
4 changes: 2 additions & 2 deletions .github/workflows/benchmarks_pr.yml
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,7 @@ jobs:
python3 setup.py develop
python3 -m pip install pytest pytest-benchmark
python3 -m pip install "gym[accept-rom-license,atari]"
python3 -m pip install dm_control
python3 -m pip install "dm_control<1.0.21" "mujoco<3.2.1"
export TD_GET_DEFAULTS_TO_NONE=1
- name: Setup benchmarks
run: |
Expand Down Expand Up @@ -108,7 +108,7 @@ jobs:
python3 setup.py develop
python3 -m pip install pytest pytest-benchmark
python3 -m pip install "gym[accept-rom-license,atari]"
python3 -m pip install dm_control
python3 -m pip install "dm_control<1.0.21" "mujoco<3.2.1"
export TD_GET_DEFAULTS_TO_NONE=1
- name: check GPU presence
run: |
Expand Down
20 changes: 6 additions & 14 deletions benchmarks/test_objectives_benchmarks.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@
TensorDictSequential as Seq,
)
from torch.nn import functional as F
from torchrl.data.tensor_specs import BoundedTensorSpec, UnboundedContinuousTensorSpec
from torchrl.data.tensor_specs import Bounded, Unbounded
from torchrl.modules import MLP, QValueActor, TanhNormal
from torchrl.objectives import (
A2CLoss,
Expand Down Expand Up @@ -253,9 +253,7 @@ def test_sac_speed(benchmark, n_obs=8, n_act=4, ncells=128, batch=128, n_hidden=
value = Seq(common, value_head)
value(actor(td))

loss = SACLoss(
actor, value, action_spec=UnboundedContinuousTensorSpec(shape=(n_act,))
)
loss = SACLoss(actor, value, action_spec=Unbounded(shape=(n_act,)))

loss(td)
benchmark(loss, td)
Expand Down Expand Up @@ -312,9 +310,7 @@ def test_redq_speed(benchmark, n_obs=8, n_act=4, ncells=128, batch=128, n_hidden
value = Seq(common, value_head)
value(actor(td))

loss = REDQLoss(
actor, value, action_spec=UnboundedContinuousTensorSpec(shape=(n_act,))
)
loss = REDQLoss(actor, value, action_spec=Unbounded(shape=(n_act,)))

loss(td)
benchmark(loss, td)
Expand Down Expand Up @@ -373,9 +369,7 @@ def test_redq_deprec_speed(
value = Seq(common, value_head)
value(actor(td))

loss = REDQLoss_deprecated(
actor, value, action_spec=UnboundedContinuousTensorSpec(shape=(n_act,))
)
loss = REDQLoss_deprecated(actor, value, action_spec=Unbounded(shape=(n_act,)))

loss(td)
benchmark(loss, td)
Expand Down Expand Up @@ -435,7 +429,7 @@ def test_td3_speed(benchmark, n_obs=8, n_act=4, ncells=128, batch=128, n_hidden=
loss = TD3Loss(
actor,
value,
action_spec=BoundedTensorSpec(shape=(n_act,), low=-1, high=1),
action_spec=Bounded(shape=(n_act,), low=-1, high=1),
)

loss(td)
Expand Down Expand Up @@ -490,9 +484,7 @@ def test_cql_speed(benchmark, n_obs=8, n_act=4, ncells=128, batch=128, n_hidden=
value = Seq(common, value_head)
value(actor(td))

loss = CQLLoss(
actor, value, action_spec=UnboundedContinuousTensorSpec(shape=(n_act,))
)
loss = CQLLoss(actor, value, action_spec=Unbounded(shape=(n_act,)))

loss(td)
benchmark(loss, td)
Expand Down
3 changes: 2 additions & 1 deletion docs/requirements.txt
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,8 @@ docutils
sphinx_design

torchvision
dm_control
dm_control<1.0.21
mujoco<3.2.1
atari-py
ale-py
gym[classic_control,accept-rom-license]
Expand Down
82 changes: 75 additions & 7 deletions docs/source/reference/data.rst
Original file line number Diff line number Diff line change
Expand Up @@ -877,11 +877,58 @@ TensorSpec

.. _ref_specs:

The `TensorSpec` parent class and subclasses define the basic properties of observations and actions in TorchRL, such
as shape, device, dtype and domain.
The :class:`~torchrl.data.TensorSpec` parent class and subclasses define the basic properties of state, observations
actions, rewards and done status in TorchRL, such as their shape, device, dtype and domain.

It is important that your environment specs match the input and output that it sends and receives, as
:obj:`ParallelEnv` will create buffers from these specs to communicate with the spawn processes.
Check the :obj:`torchrl.envs.utils.check_env_specs` method for a sanity check.
:class:`~torchrl.envs.ParallelEnv` will create buffers from these specs to communicate with the spawn processes.
Check the :func:`torchrl.envs.utils.check_env_specs` method for a sanity check.

If needed, specs can be automatially generated from data using the :func:`~torchrl.envs.utils.make_composite_from_td`
function.

Specs fall in two main categories, numerical and categorical.

.. table:: Numerical TensorSpec subclasses.

+-------------------------------------------------------------------------------+
| Numerical |
+=====================================+=========================================+
| Bounded | Unbounded |
+-----------------+-------------------+-------------------+---------------------+
| BoundedDiscrete | BoundedContinuous | UnboundedDiscrete | UnboundedContinuous |
+-----------------+-------------------+-------------------+---------------------+

Whenever a :class:`~torchrl.data.Bounded` instance is created, its domain (defined either implicitly by its dtype or
explicitly by the `"domain"` keyword argument) will determine if the instantiated class will be of :class:`~torchrl.data.BoundedContinuous`
or :class:`~torchrl.data.BoundedDiscrete` type. The same applies to the :class:`~torchrl.data.Unbounded` class.
See these classes for further information.

.. table:: Categorical TensorSpec subclasses.

+------------------------------------------------------------------+
| Categorical |
+========+=============+=============+==================+==========+
| OneHot | MultiOneHot | Categorical | MultiCategorical | Binary |
+--------+-------------+-------------+------------------+----------+

Unlike ``gymnasium``, TorchRL does not have the concept of an arbitrary list of specs. If multiple specs have to be
combined together, TorchRL assumes that the data will be presented as dictionaries (more specifically, as
:class:`~tensordict.TensorDict` or related formats). The corresponding :class:`~torchrl.data.TensorSpec` class in these
cases is the :class:`~torchrl.data.Composite` spec.

Nevertheless, specs can be stacked together using :func:`~torch.stack`: if they are identical, their shape will be
expanded accordingly.
Otherwise, a lazy stack will be created through the :class:`~torchrl.data.Stacked` class.

Similarly, ``TensorSpecs`` possess some common behavior with :class:`~torch.Tensor` and
:class:`~tensordict.TensorDict`: they can be reshaped, indexed, squeezed, unsqueezed, moved to another device (``to``)
or unbound (``unbind``) as regular :class:`~torch.Tensor` instances would be.

Specs where some dimensions are ``-1`` are said to be "dynamic" and the negative dimensions indicate that the corresponding
data has an inconsistent shape. When seen by an optimizer or an environment (e.g., batched environment such as
:class:`~torchrl.envs.ParallelEnv`), these negative shapes tell TorchRL to avoid using buffers as the tensor shapes are
not predictable.

.. currentmodule:: torchrl.data

Expand All @@ -890,19 +937,40 @@ Check the :obj:`torchrl.envs.utils.check_env_specs` method for a sanity check.
:template: rl_template.rst

TensorSpec
Binary
Bounded
Categorical
Composite
MultiCategorical
MultiOneHot
NonTensor
OneHotDiscrete
Stacked
StackedComposite
Unbounded
UnboundedContinuous
UnboundedDiscrete

The following classes are deprecated and just point to the classes above:

.. currentmodule:: torchrl.data

.. autosummary::
:toctree: generated/
:template: rl_template.rst

BinaryDiscreteTensorSpec
BoundedTensorSpec
CompositeSpec
DiscreteTensorSpec
LazyStackedCompositeSpec
LazyStackedTensorSpec
MultiDiscreteTensorSpec
MultiOneHotDiscreteTensorSpec
NonTensorSpec
OneHotDiscreteTensorSpec
UnboundedContinuousTensorSpec
UnboundedDiscreteTensorSpec
LazyStackedTensorSpec
LazyStackedCompositeSpec
NonTensorSpec

Reinforcement Learning From Human Feedback (RLHF)
-------------------------------------------------
Expand Down
Loading
Loading