Skip to content

[BUG] Documentation of BinaryDiscreteTensorSpec can be confusing #2364

Closed
@albertbou92

Description

Describe the bug

The documentation of the class BinaryDiscreteTensorSpec can be confusing, as pointed out in #2344 (reply in thread).

Here is the documentation: https://github.com/pytorch/rl/blob/main/torchrl/data/tensor_specs.py#L3228
Which says,

"""A binary discrete tensor spec.

Args:
    n (int): length of the binary vector.
    shape (torch.Size, optional): total shape of the sampled tensors.
        If provided, the last dimension must match n.
    device (str, int or torch.device, optional): device of the tensors.
    dtype (str or torch.dtype, optional): dtype of the tensors. Defaults to torch.long.

Examples:
    >>> spec = BinaryDiscreteTensorSpec(n=4, shape=(5, 4), device="cpu", dtype=torch.bool)
    >>> print(spec.zero())
"""

The n argument

At the moment, since BinaryDiscreteTensorSpec inherits from DiscreteTensorSpec, n controls the number of outputs. n=1 means only False values, n=2 allows True and False values and n>2 also allows True and False . This is not very intuitive given the explanation in the documentation. Also, is this the desired behaviour?

Additionally, n has to match the last dimension of the shape and an error is raised otherwise. Is this necessary?

Possible simplification

We could remove the n parameter (fixing it to 2) and simplify the init method signature by allowing the user to define the spec shape only with the parameter shape.

To Reproduce

Steps to reproduce the behavior.

Please try to provide a minimal example to reproduce the bug. Error messages and stack traces are also helpful.

Please use the markdown code blocks for both code and stack traces.

import torchrl
Traceback (most recent call last):
  File ... 

Expected behavior

A clear and concise description of what you expected to happen.

Screenshots

If applicable, add screenshots to help explain your problem.

System info

Describe the characteristic of your environment:

  • Describe how the library was installed (pip, source, ...)
  • Python version
  • Versions of any other relevant libraries
import torchrl, numpy, sys
print(torchrl.__version__, numpy.__version__, sys.version, sys.platform)

Additional context

Add any other context about the problem here.

Reason and Possible fixes

If you know or suspect the reason for this bug, paste the code lines and suggest modifications.

Checklist

  • I have checked that there is no similar issue in the repo (required)
  • I have read the documentation (required)
  • I have provided a minimal working example to reproduce the bug (required)

Metadata

Assignees

Labels

bugSomething isn't working

Type

No type

Projects

No projects

Milestone

No milestone

Relationships

None yet

Development

No branches or pull requests

Issue actions