Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Remove the Nx.take backend callback #1439

Merged

Conversation

Benjamin-Philip
Copy link
Contributor

@Benjamin-Philip Benjamin-Philip commented Feb 10, 2024

This commit removes the Nx.take backend callback, and replaces it with
a Nx.gather based implementation.

Closes #1366.

This commit adds a temporary implementation of Nx.take based on
Nx.gather which only works for an axis of 0.
This commit removes:

- The Nx.take backend callback
- The BinaryBackend implementation
@Benjamin-Philip
Copy link
Contributor Author

@polvalente, I tried basing this on the EXLA implementation, but it was very cryptic to me. As far as I can understand, XLA's gather gathers slices, taking the starting indices and the slices dimension, and slicing accordingly. However, Nx's gather takes the actual indices of every element, making the XLA implementation incompatible with Nx's gather. Would basing on top of Torchx's current implementation work?

@polvalente
Copy link
Contributor

@Benjamin-Philip I believe you can experiment a bit with the :axes option from gather to replicate take. What I meant by looking into the EXLA implementation is that in this section you see that the :axis for take is passed as the :axes for gather is passed. Basically, both gather and take use XLA's gather, and so you can find the equivalence there.

I think just passing :axis as :axes might work without too much of a hassle

@Benjamin-Philip
Copy link
Contributor Author

Benjamin-Philip commented Feb 20, 2024

I think just passing :axis as :axes might work without too much of a hassle

I already tried that, and got one of the following errors:

  test/nx/vectorize_test.exs:447
     ** (ArgumentError) expected indices rank to be at least 1, got: 0
    test/nx_test.exs:2822
     ** (ArgumentError) expected the last indices dimension size (2) to be less than or equal to the tensor rank (1)
    test/nx/doctest_test.exs:4
     ** (ArgumentError) errors were found at the given arguments:

       * 3rd argument: out of range
     
     stacktrace:
       :erlang.binary_part(<<1, 0, 0, 0, 0, 0, 0, 0, 2, 0, 0, 0, 0, 0, 0, 0, 101, 0, 0, 0, 0, 0, 0, 0, 102, 0, 0, 0, 0, 0, 0, 0, 11, 0, 0, 0, 0, 0, 0, 0, 12, 0, 0, 0, 0, 0, 0, 0, 111, 0, ...>>, 56, 32)
       (nx 0.7.0-dev) lib/nx/binary_backend.ex:2004: Nx.BinaryBackend."-gather/3-lbc$^0/2-0-"/9
       (nx 0.7.0-dev) lib/nx/binary_backend.ex:1998: Nx.BinaryBackend.gather/3
       (nx 0.7.0-dev) lib/nx.ex:14478: Nx.gather/3
       (nx 0.7.0-dev) lib/nx.ex:14129: Nx.take/3
       (for doctest at) lib/nx.ex:13946: (test)

(first 2 from Nx.Shape.gather) on the following implementation:

tensor
|> gather(indices, axes: [axis])
|> reshape(inner_shape, names: inner_names)

From comparing EXLA's take and gather implementations side by side, I found a discrepancy in a single parameter offset_dims (while everything else is the same):

offset_dims = result_rank |> axes_for_rank() |> delete_slice(axis, indices_rank)

nx/exla/lib/exla/defn.ex

Lines 1830 to 1831 in 51c7122

batch_size = tensor_rank - length(axes)
offset_dims = count_up(batch_size, batch_size)

Which is why I feel they are incompatible (unless of course I misunderstood something, in which case I am wrong).


Alternatively, I could try copy pasting the old Torchx implementation, which should work out of the box.

@Benjamin-Philip
Copy link
Contributor Author

You can find the same discrepancy in Nx.Shape:

result_shape = leading_lengths ++ indices_shape ++ trailing_lengths

nx/nx/lib/nx/shape.ex

Lines 1638 to 1646 in bead39c

{outer_shape, [last_size]} = indices_shape |> Tuple.to_list() |> Enum.split(-1)
if last_size > rank do
raise ArgumentError,
"expected the last indices dimension size (#{last_size}) to be less than or equal to the tensor rank (#{rank})"
end
inner_shape = for i <- Nx.axes(shape), i not in axes, do: elem(shape, i)
shape = List.to_tuple(outer_shape ++ inner_shape)

@Benjamin-Philip
Copy link
Contributor Author

@polvalente, let me know if just copy pasting this will work, or if we should go with some other approach:

@impl true
def take(out, t, i, axis) do
axes = Nx.axes(t)
indices_shape =
axes
|> Enum.map(fn
^axis -> Tuple.product(i.shape)
_ -> 1
end)
|> List.to_tuple()
idx_tiling =
t.shape
|> Tuple.to_list()
|> Enum.with_index(fn
_x, ^axis -> 1
x, _ -> x
end)
indices_for_axis =
i
|> Nx.reshape(indices_shape)
|> Nx.tile(idx_tiling)
num_elements = Tuple.product(indices_for_axis.shape)
indices =
axes
|> Enum.map(fn
^axis ->
Nx.reshape(indices_for_axis, {num_elements, 1})
current ->
# current when current < axis ->
indices_for_axis
|> Nx.shape()
|> Nx.iota(axis: current, backend: __MODULE__)
|> Nx.reshape({num_elements, 1})
end)
|> Nx.concatenate(axis: 1)
# TODO: maybe rewrite it as gather now behaves differently
gather(out, t, indices, [])
end

@polvalente
Copy link
Contributor

@polvalente, let me know if just copy pasting this will work, or if we should go with some other approach:

@impl true
def take(out, t, i, axis) do
axes = Nx.axes(t)
indices_shape =
axes
|> Enum.map(fn
^axis -> Tuple.product(i.shape)
_ -> 1
end)
|> List.to_tuple()
idx_tiling =
t.shape
|> Tuple.to_list()
|> Enum.with_index(fn
_x, ^axis -> 1
x, _ -> x
end)
indices_for_axis =
i
|> Nx.reshape(indices_shape)
|> Nx.tile(idx_tiling)
num_elements = Tuple.product(indices_for_axis.shape)
indices =
axes
|> Enum.map(fn
^axis ->
Nx.reshape(indices_for_axis, {num_elements, 1})
current ->
# current when current < axis ->
indices_for_axis
|> Nx.shape()
|> Nx.iota(axis: current, backend: __MODULE__)
|> Nx.reshape({num_elements, 1})
end)
|> Nx.concatenate(axis: 1)
# TODO: maybe rewrite it as gather now behaves differently
gather(out, t, indices, [])
end

That would work, but I think the best way is still to go with mimicking EXLA's gather instead. The Torchx implementation will probably go away so it delegates to the default one.

Please ping me on slack with the questions, or list them in the documentation issue, you have about the :axes option so I can write better documentation for it :)

This commit adds the take implementation.

The solution was derived by comparing take's and gather's Nx.Shape functions.
@Benjamin-Philip Benjamin-Philip force-pushed the bp-remove-take-backend-callback branch 2 times, most recently from ea8f65a to 6857590 Compare May 11, 2024 16:40
@Benjamin-Philip Benjamin-Philip marked this pull request as ready for review May 11, 2024 17:23
@Benjamin-Philip
Copy link
Contributor Author

@polvalente, looking through the failed cases in EXLA, there appears to be a bug in gathering subsets with multi-dimensional tensors along the axis 1, resulting in either a Runtime Error or invalid output.

I doubt that the gather call is invalid since it worked for both the BinaryBackend and Torchx.


tensor
|> gather(gather_indices, axes: [axis])
|> transpose(axes: transpose_axes)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why is this transpose needed?

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I took a look at this. take will retain the shape according to the given axis, this is also true in Python:

>>> np.take(n, [0, 1, 0, 1], axis=1)
array([[1, 2, 1, 2],
       [3, 4, 3, 4]])
>>> np.take(n, [0, 1, 0, 1], axis=0)
array([[1, 2],
       [3, 4],
       [1, 2],
       [3, 4]])

On the other hand, gather will always preserve the input shape. This means the underlying offset_dims in the gather operation is different. So it may be that we cannot really replace take by gather (unless we want to use lose performance when axis != 0, which may be fine). Another option is to extend gather with a new option that controls the output shape but that's debatable.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I will merge this and make it optional in a follow up commit.

nx/lib/nx.ex Outdated Show resolved Hide resolved
Co-authored-by: Paulo Valente <16843419+polvalente@users.noreply.github.com>
@josevalim
Copy link
Collaborator

CI was failing due to a bug on exla, I have pushed a fix.

@josevalim josevalim closed this May 12, 2024
@josevalim josevalim reopened this May 12, 2024
@josevalim josevalim merged commit 24ed6fa into elixir-nx:main May 12, 2024
13 of 16 checks passed
@josevalim
Copy link
Collaborator

💚 💙 💜 💛 ❤️

@Benjamin-Philip
Copy link
Contributor Author

Benjamin-Philip commented May 12, 2024 via email

@Benjamin-Philip Benjamin-Philip deleted the bp-remove-take-backend-callback branch May 13, 2024 11:23
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

Remove take_along_axis and take
3 participants