Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Refactor] Rename specs to simpler names #2368

Merged
merged 23 commits into from
Aug 7, 2024
Merged
Prev Previous commit
Next Next commit
amend
  • Loading branch information
vmoens committed Aug 6, 2024
commit 89f846760b70cfee34dc164d0005ca7a871ccc91
50 changes: 46 additions & 4 deletions docs/source/reference/data.rst
Original file line number Diff line number Diff line change
Expand Up @@ -877,11 +877,53 @@ TensorSpec

.. _ref_specs:

The `TensorSpec` parent class and subclasses define the basic properties of observations and actions in TorchRL, such
as shape, device, dtype and domain.
The :class:`~torchrl.data.TensorSpec` parent class and subclasses define the basic properties of state, observations
actions, rewards and done status in TorchRL, such as their shape, device, dtype and domain.

It is important that your environment specs match the input and output that it sends and receives, as
:obj:`ParallelEnv` will create buffers from these specs to communicate with the spawn processes.
Check the :obj:`torchrl.envs.utils.check_env_specs` method for a sanity check.
:class:`~torchrl.envs.ParallelEnv` will create buffers from these specs to communicate with the spawn processes.
Check the :func:`torchrl.envs.utils.check_env_specs` method for a sanity check.

If needed, specs can be automacially generated from data using the :func:`~torchrl.envs.utils.make_composite_from_td`
function.

I am writing docs in sphinx. I want to write a table.

Here is the approximate content, help me clean it up and make it appear nicely:

```
Specs fall in two main categories, numerical and categorical:

+-------------------------------------------------------------------------------+------------------------------------------------------------------+
| Numerical | Categorical | |
+=====================================+=========================================+========+=============+=============+==================+==========+
| Bounded | Unbounded | OneHot | MultiOneHot | Categorical | MultiCategorical | Discrete |
+-----------------+-------------------+-------------------+---------------------+--------+-------------+-------------+------------------+----------+
| BoundedDiscrete | BoundedContinuous | UnboundedDiscrete | UnboundedContinuous | |
+-----------------+-------------------+-------------------+---------------------+------------------------------------------------------------------+

Whenever a :class:`~torchrl.data.Bounded` instance is created, its domain (defined either implicitly by its dtype or
explicitly by the `"domain"` keyword argument) will determine if the instantiated class will be of :class:`~torchrl.data.BoundedContinuous`
or :class:`~torchrl.data.BoundedDiscrete` type. The same applies to the :class:`~torchrl.data.Unbounded` class.
See these classes for further information.

Unlike ``gymnasium``, TorchRL does not have the concept of an arbitrary list of specs. If multiple specs have to be
combined together, TorchRL assumes that the data will be presented as dictionaries (more specifically, as
:class:`~tensordict.TensorDict` or related formats). The corresponding :class:`~torchrl.data.TensorSpec` class in these
cases is the :class:`~torchrl.data.Composite` spec.

Nevertheless, specs can be stacked together using :func:`~torch.stack`: if they are identical, their shape will be
expanded accordingly.
Otherwise, a lazy stack will be created through the :class:`~torchrl.data.Stacked` class.

Similarly, ``TensorSpecs`` possess some common behavior with :class:`~torch.Tensor` and
:class:`~tensordict.TensorDict`: they can be reshaped, indexed, squeezed, unsqueezed, moved to another device (``to``)
or unbound (``unbind``) as regular :class:`~torch.Tensor` instances would be.

Specs where some dimensions are ``-1`` are said to be "dynamic" and the negative dimensions indicate that the corresponding
data has an inconsistent shape. When seen by an optimizer or an environment (e.g., batched environment such as
:class:`~torchrl.envs.ParallelEnv`), these negative shapes tell TorchRL to avoid using buffers as the tensor shapes are
not predictable.

.. currentmodule:: torchrl.data

Expand Down
4 changes: 2 additions & 2 deletions test/test_actors.py
Original file line number Diff line number Diff line change
Expand Up @@ -187,7 +187,7 @@ def test_distributional_qvalue_hook_conflicting_spec(self):
):
_process_action_space_spec(OneHot(3), spec)
with pytest.raises(
ValueError, match="action_space cannot be of type CompositeSpec"
ValueError, match="action_space cannot be of type Composite"
):
_process_action_space_spec(Composite(), spec)
with pytest.raises(KeyError, match="action could not be found in the spec"):
Expand Down Expand Up @@ -240,7 +240,7 @@ def test_nested_keys(self, nested_action, batch_size, nested_dim=5):
_process_action_space_spec(Binary(n=1), action_spec)
_process_action_space_spec(Binary(n=1), leaf_action_spec)
with pytest.raises(
ValueError, match="action_space cannot be of type CompositeSpec"
ValueError, match="action_space cannot be of type Composite"
):
_process_action_space_spec(action_spec, None)

Expand Down
17 changes: 8 additions & 9 deletions test/test_specs.py
Original file line number Diff line number Diff line change
Expand Up @@ -414,7 +414,7 @@ def test_setitem_matches_device(self, shape, is_complete, device, dtype, dest):
contextlib.nullcontext()
if (device == dest) or (device is None)
else pytest.raises(
RuntimeError, match="All devices of CompositeSpec must match"
RuntimeError, match="All devices of Composite must match"
)
)
with cm:
Expand Down Expand Up @@ -725,15 +725,15 @@ def test_lock(recurse):
assert not spec.locked
spec.lock_(recurse=recurse)
assert spec.locked
with pytest.raises(RuntimeError, match="Cannot modify a locked CompositeSpec."):
with pytest.raises(RuntimeError, match="Cannot modify a locked Composite."):
spec["a"] = spec["a"].clone()
with pytest.raises(RuntimeError, match="Cannot modify a locked CompositeSpec."):
with pytest.raises(RuntimeError, match="Cannot modify a locked Composite."):
spec.set("a", spec["a"].clone())
if recurse:
assert spec["a"].locked
with pytest.raises(RuntimeError, match="Cannot modify a locked CompositeSpec."):
with pytest.raises(RuntimeError, match="Cannot modify a locked Composite."):
spec["a"].set("b", spec["a", "b"].clone())
with pytest.raises(RuntimeError, match="Cannot modify a locked CompositeSpec."):
with pytest.raises(RuntimeError, match="Cannot modify a locked Composite."):
spec["a", "b"] = spec["a", "b"].clone()
else:
assert not spec["a"].locked
Expand Down Expand Up @@ -1830,7 +1830,7 @@ def test_composite_encode_err(self):
2,
),
)
with pytest.raises(KeyError, match="The CompositeSpec instance with keys"):
with pytest.raises(KeyError, match="The Composite instance with keys"):
c.encode({"c": 0})
with pytest.raises(
RuntimeError, match="raised a RuntimeError. Scroll up to know more"
Expand Down Expand Up @@ -2374,15 +2374,15 @@ def test_malformed_stack(self, shape, stack_dim):
torch.stack([c1, c2], stack_dim)


class TestDenseStackedCompositeSpecs:
class TestDenseStackedComposite:
def test_stack(self):
c1 = Composite(a=Unbounded())
c2 = c1.clone()
c = torch.stack([c1, c2], 0)
assert isinstance(c, Composite)


class TestLazyStackedCompositeSpecs:
class TestLazyStackedComposite:
def _get_heterogeneous_specs(
self,
batch_size=(),
Expand Down Expand Up @@ -3283,7 +3283,6 @@ def test_valid_indexing(spec_class):
assert spec_3d[None, 1, ..., None].shape == torch.Size([1, 3, 1, 4])
assert spec_4d[:, None, ..., None, :].shape == torch.Size([5, 1, 3, 1, 4, 6])

# BoundedTensorSpec, DiscreteTensorSpec, UnboundedContinuousTensorSpec, UnboundedDiscreteTensorSpec, CompositeSpec
else:
# Integers
assert spec_2d[0, 1].shape == torch.Size([])
Expand Down
2 changes: 1 addition & 1 deletion torchrl/data/tensor_specs.py
Original file line number Diff line number Diff line change
Expand Up @@ -533,7 +533,7 @@ class TensorSpec:
TensorSpecs are dataclasses that always share the following fields: `shape`, `space, `dtype` and `device`.

As such, TensorSpecs possess some common behavior with :class:`~torch.Tensor` and :class:`~tensordict.TensorDict`:
they can be reshaped, indexed, squeezed, unqueezed, moved to another device etc.
they can be reshaped, indexed, squeezed, unsqueezed, moved to another device etc.

Args:
shape (torch.Size): size of the tensor. The shape includes the batch dimensions as well as the feature
Expand Down
Loading
Loading