-
Notifications
You must be signed in to change notification settings - Fork 201
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Remove the Nx.take backend callback #1439
Remove the Nx.take backend callback #1439
Conversation
This commit adds a temporary implementation of Nx.take based on Nx.gather which only works for an axis of 0.
This commit removes: - The Nx.take backend callback - The BinaryBackend implementation
@polvalente, I tried basing this on the EXLA implementation, but it was very cryptic to me. As far as I can understand, XLA's gather gathers slices, taking the starting indices and the slices dimension, and slicing accordingly. However, Nx's gather takes the actual indices of every element, making the XLA implementation incompatible with Nx's gather. Would basing on top of Torchx's current implementation work? |
@Benjamin-Philip I believe you can experiment a bit with the I think just passing |
I already tried that, and got one of the following errors:
(first 2 from tensor
|> gather(indices, axes: [axis])
|> reshape(inner_shape, names: inner_names) From comparing EXLA's Line 1765 in 51c7122
Lines 1830 to 1831 in 51c7122
Which is why I feel they are incompatible (unless of course I misunderstood something, in which case I am wrong). Alternatively, I could try copy pasting the old Torchx implementation, which should work out of the box. |
You can find the same discrepancy in Line 1367 in bead39c
Lines 1638 to 1646 in bead39c
|
@polvalente, let me know if just copy pasting this will work, or if we should go with some other approach: nx/torchx/lib/torchx/backend.ex Lines 340 to 384 in bead39c
|
That would work, but I think the best way is still to go with mimicking EXLA's gather instead. The Torchx implementation will probably go away so it delegates to the default one. Please ping me on slack with the questions, or list them in the documentation issue, you have about the |
This commit adds the take implementation. The solution was derived by comparing take's and gather's Nx.Shape functions.
ea8f65a
to
6857590
Compare
@polvalente, looking through the failed cases in EXLA, there appears to be a bug in gathering subsets with multi-dimensional tensors along the axis 1, resulting in either a Runtime Error or invalid output. I doubt that the gather call is invalid since it worked for both the |
|
||
tensor | ||
|> gather(gather_indices, axes: [axis]) | ||
|> transpose(axes: transpose_axes) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Why is this transpose needed?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I took a look at this. take
will retain the shape according to the given axis, this is also true in Python:
>>> np.take(n, [0, 1, 0, 1], axis=1)
array([[1, 2, 1, 2],
[3, 4, 3, 4]])
>>> np.take(n, [0, 1, 0, 1], axis=0)
array([[1, 2],
[3, 4],
[1, 2],
[3, 4]])
On the other hand, gather will always preserve the input shape. This means the underlying offset_dims
in the gather operation is different. So it may be that we cannot really replace take by gather (unless we want to use lose performance when axis != 0, which may be fine). Another option is to extend gather with a new option that controls the output shape but that's debatable.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I will merge this and make it optional in a follow up commit.
Co-authored-by: Paulo Valente <16843419+polvalente@users.noreply.github.com>
CI was failing due to a bug on exla, I have pushed a fix. |
💚 💙 💜 💛 ❤️ |
Sounds good.
…-- bp
On Sun, 12 May, 2024, 8:10 pm José Valim, ***@***.***> wrote:
💚 💙 💜 💛 ❤️
—
Reply to this email directly, view it on GitHub
<#1439 (comment)>, or
unsubscribe
<https://github.com/notifications/unsubscribe-auth/AQDKWBJ6NDOUH7375ZKY76DZB55MPAVCNFSM6AAAAABDCSCIRGVHI2DSMVQWIX3LMV43OSLTON2WKQ3PNVWWK3TUHMZDCMBWGI3DQOBRG4>
.
You are receiving this because you were mentioned.Message ID:
***@***.***>
|
This commit removes the Nx.take backend callback, and replaces it with
a Nx.gather based implementation.
Closes #1366.