[BUG] Documentation of BinaryDiscreteTensorSpec can be confusing #2364
Description
Describe the bug
The documentation of the class BinaryDiscreteTensorSpec
can be confusing, as pointed out in #2344 (reply in thread).
Here is the documentation: https://github.com/pytorch/rl/blob/main/torchrl/data/tensor_specs.py#L3228
Which says,
"""A binary discrete tensor spec.
Args:
n (int): length of the binary vector.
shape (torch.Size, optional): total shape of the sampled tensors.
If provided, the last dimension must match n.
device (str, int or torch.device, optional): device of the tensors.
dtype (str or torch.dtype, optional): dtype of the tensors. Defaults to torch.long.
Examples:
>>> spec = BinaryDiscreteTensorSpec(n=4, shape=(5, 4), device="cpu", dtype=torch.bool)
>>> print(spec.zero())
"""
The n
argument
At the moment, since BinaryDiscreteTensorSpec
inherits from DiscreteTensorSpec
, n
controls the number of outputs. n=1 means only False
values, n=2 allows True
and False
values and n>2 also allows True
and False
. This is not very intuitive given the explanation in the documentation. Also, is this the desired behaviour?
Additionally, n
has to match the last dimension of the shape and an error is raised otherwise. Is this necessary?
Possible simplification
We could remove the n
parameter (fixing it to 2) and simplify the init method signature by allowing the user to define the spec shape only with the parameter shape
.
To Reproduce
Steps to reproduce the behavior.
Please try to provide a minimal example to reproduce the bug. Error messages and stack traces are also helpful.
Please use the markdown code blocks for both code and stack traces.
import torchrl
Traceback (most recent call last):
File ...
Expected behavior
A clear and concise description of what you expected to happen.
Screenshots
If applicable, add screenshots to help explain your problem.
System info
Describe the characteristic of your environment:
- Describe how the library was installed (pip, source, ...)
- Python version
- Versions of any other relevant libraries
import torchrl, numpy, sys
print(torchrl.__version__, numpy.__version__, sys.version, sys.platform)
Additional context
Add any other context about the problem here.
Reason and Possible fixes
If you know or suspect the reason for this bug, paste the code lines and suggest modifications.
Checklist
- I have checked that there is no similar issue in the repo (required)
- I have read the documentation (required)
- I have provided a minimal working example to reproduce the bug (required)