-
Notifications
You must be signed in to change notification settings - Fork 328
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[Feature] Indexing specs #1105
[Feature] Indexing specs #1105
Conversation
…orSpec, UnboundedContinuousTensorSpec & UnboundedDiscreteTensorSpec
…orSpec, UnboundedContinuousTensorSpec & UnboundedDiscreteTensorSpec
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Fantastic work, thanks a mil for this!
@@ -135,7 +135,12 @@ def _slice_indexing(shape: list[int], idx: slice): | |||
return [n_items] + shape[1:] | |||
|
|||
|
|||
def _shape_indexing(shape: list[int], idx: SHAPE_INDEX_TYPING): | |||
def _shape_indexing( |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
this seems pretty crucial and we may refactor that at some point (e.g. using fake tensors)
Can we add a short docstring to say what it is about?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
we have a similar function in tensordict called _getitem_batch_size
, is it inspired by that?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Docstring added! _shape_indexing
is definitely similar to _getitem_batch_size
. I couldn't use the latter though as it doesn't perform some indexing checks, which are only executed when actually indexing the tensor. I assumed we wanted those checks, hence the reimplementation
from tensordict.utils import _getitem_batch_size
from torchrl.data.tensor_specs import _shape_indexing
_getitem_batch_size(torch.Size((3, 2)), 5) # torch.Size([2])
_shape_indexing(torch.Size((3, 2)), 5) # IndexError: index 5 is out of bounds for axis 0 with size 3
If we can work with fake_mode to have both fast shape indexing and such checks, that would definitely be the best of both worlds!
Description
Add indexing support to remaining specs:
Already supported by previous PR #1081 :
Note: although indexing & tests have been implemented for BoundedTensorSpec & MultiDiscreteTensorSpec, a NotImplementedError is currently set to prevent a different behavior with the other specs until pytorch/pytorch#100080 is addressed.
Motivation and Context
Address feature request of adding indexing to specs: #1051.
Types of changes
What types of changes does your code introduce? Remove all that do not apply:
Checklist
Go over all the following points, and put an
x
in all the boxes that apply.If you are unsure about any of these, don't hesitate to ask. We are here to help!
cc @matteobettini