Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Explicit managing the output buffer of jax.jit function #23084

Open
shawnwang18 opened this issue Aug 15, 2024 · 2 comments
Open

Explicit managing the output buffer of jax.jit function #23084

shawnwang18 opened this issue Aug 15, 2024 · 2 comments
Assignees
Labels
enhancement New feature or request

Comments

@shawnwang18
Copy link

I found that for the common training pattern in Jax:

new_state, other_output = jitted_train_step_fn(old_state, other_input)

Current XLA runtime may assign different backing device memory buffer for old_state and new_state.

This behavior is strange, as user may observe that their model's state is frequently changing the memory locations. It is also not perf friendly, as it will lead to command buffer update cost (because memory pointer has changed across command buffer launches), and some cache misses.

Sergei Lebedev from google also mentioned that Implicit output buffer allocation is also an issue for people using jax.pure_callback. There are a few bug reports where people wanted the callback to be zero copy.

I think the reason for this behavior is because jax.jit API only has buffer donation parameters, and not having input/output aliasing parameter.

The best is that user can specify that old_state and new_state is aliased through jax.jit API parameter, and XLA buffer allocation just assign the buffer of new_state to buffer of old_state, then it is more perf friendly and semantically natural.

@shawnwang18 shawnwang18 added the enhancement New feature or request label Aug 15, 2024
@jakevdp
Copy link
Collaborator

jakevdp commented Aug 16, 2024

You may be able to do what you want via the donate_argnums/donate_argnames parameter of jax.jit; see the jax.jit documentation for a description.

For example:

from functools import partial
import jax
import jax.numpy as jnp

state = (jnp.zeros(4), jnp.arange(4))

@partial(jax.jit, donate_argnums=0)
def f(state):
  x, y = state
  x += 1
  y *= 2
  return (x, y)

pointer0 = state[0].unsafe_buffer_pointer()
pointer1 = state[1].unsafe_buffer_pointer()

state = f(state)

assert state[0].unsafe_buffer_pointer() == pointer0  # same memory
assert state[1].unsafe_buffer_pointer() == pointer1  # same memory

@jakevdp jakevdp self-assigned this Aug 16, 2024
@Gattocrucco
Copy link
Contributor

I think the reason for this behavior is because jax.jit API only has buffer donation parameters, and not having input/output aliasing parameter.

In my experience (not tested on latest jax though) jit with donate_argnums does re-use the input state variable as output state if the arrays are large enough, and doesn't if they are small.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
enhancement New feature or request
Projects
None yet
Development

No branches or pull requests

3 participants