Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[BUG] Documentation of BinaryDiscreteTensorSpec can be confusing #2364

Closed
3 tasks
albertbou92 opened this issue Aug 5, 2024 · 2 comments · Fixed by #2368
Closed
3 tasks

[BUG] Documentation of BinaryDiscreteTensorSpec can be confusing #2364

albertbou92 opened this issue Aug 5, 2024 · 2 comments · Fixed by #2368
Assignees
Labels
bug Something isn't working

Comments

@albertbou92
Copy link
Contributor

albertbou92 commented Aug 5, 2024

Describe the bug

The documentation of the class BinaryDiscreteTensorSpec can be confusing, as pointed out in #2344 (reply in thread).

Here is the documentation: https://github.com/pytorch/rl/blob/main/torchrl/data/tensor_specs.py#L3228
Which says,

"""A binary discrete tensor spec.

Args:
    n (int): length of the binary vector.
    shape (torch.Size, optional): total shape of the sampled tensors.
        If provided, the last dimension must match n.
    device (str, int or torch.device, optional): device of the tensors.
    dtype (str or torch.dtype, optional): dtype of the tensors. Defaults to torch.long.

Examples:
    >>> spec = BinaryDiscreteTensorSpec(n=4, shape=(5, 4), device="cpu", dtype=torch.bool)
    >>> print(spec.zero())
"""

The n argument

At the moment, since BinaryDiscreteTensorSpec inherits from DiscreteTensorSpec, n controls the number of outputs. n=1 means only False values, n=2 allows True and False values and n>2 also allows True and False . This is not very intuitive given the explanation in the documentation. Also, is this the desired behaviour?

Additionally, n has to match the last dimension of the shape and an error is raised otherwise. Is this necessary?

Possible simplification

We could remove the n parameter (fixing it to 2) and simplify the init method signature by allowing the user to define the spec shape only with the parameter shape.

To Reproduce

Steps to reproduce the behavior.

Please try to provide a minimal example to reproduce the bug. Error messages and stack traces are also helpful.

Please use the markdown code blocks for both code and stack traces.

import torchrl
Traceback (most recent call last):
  File ... 

Expected behavior

A clear and concise description of what you expected to happen.

Screenshots

If applicable, add screenshots to help explain your problem.

System info

Describe the characteristic of your environment:

  • Describe how the library was installed (pip, source, ...)
  • Python version
  • Versions of any other relevant libraries
import torchrl, numpy, sys
print(torchrl.__version__, numpy.__version__, sys.version, sys.platform)

Additional context

Add any other context about the problem here.

Reason and Possible fixes

If you know or suspect the reason for this bug, paste the code lines and suggest modifications.

Checklist

  • I have checked that there is no similar issue in the repo (required)
  • I have read the documentation (required)
  • I have provided a minimal working example to reproduce the bug (required)
@albertbou92 albertbou92 added the bug Something isn't working label Aug 5, 2024
@albertbou92 albertbou92 changed the title [BUG] Documentation of BinaryDiscreteTensorSpec is confusing [BUG] Documentation of BinaryDiscreteTensorSpec can be confusing Aug 5, 2024
@vmoens
Copy link
Contributor

vmoens commented Aug 5, 2024

Can you check #2366 and let me know if that fixes it?

@albertbou92
Copy link
Contributor Author

Yes, the class Binary fixes it.

@vmoens vmoens linked a pull request Aug 7, 2024 that will close this issue
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
bug Something isn't working
Projects
None yet
Development

Successfully merging a pull request may close this issue.

2 participants