Explicit managing the output buffer of jax.jit
function
#23084
Labels
enhancement
New feature or request
jax.jit
function
#23084
I found that for the common training pattern in Jax:
new_state, other_output = jitted_train_step_fn(old_state, other_input)
Current XLA runtime may assign different backing device memory buffer for
old_state
andnew_state.
This behavior is strange, as user may observe that their model's state is frequently changing the memory locations. It is also not perf friendly, as it will lead to command buffer update cost (because memory pointer has changed across command buffer launches), and some cache misses.
Sergei Lebedev from google also mentioned that Implicit output buffer allocation is also an issue for people using
jax.pure_callback
. There are a few bug reports where people wanted the callback to be zero copy.I think the reason for this behavior is because
jax.jit
API only has buffer donation parameters, and not having input/output aliasing parameter.The best is that user can specify that
old_state
andnew_state
is aliased through jax.jit API parameter, and XLA buffer allocation just assign the buffer ofnew_state
to buffer ofold_state
, then it is more perf friendly and semantically natural.The text was updated successfully, but these errors were encountered: