-
Notifications
You must be signed in to change notification settings - Fork 2.7k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
JAX segment_sum is two times slower for FP16 inputs than FP32 inputs #23136
Comments
I don't know the answer to this, but maybe @jakevdp does? Some notes in the meantime. It's worth checking out the JAX microbenchmark FAQ entry because benchmarking like you're doing here can lead to incorrect conclusions since it includes the tracing and compilation overhead. Updating this doesn't seem to change the specific conclusions though! Here's how I would write the benchmark: import numpy as np
import jax
import jax.numpy as jnp
num_segments = 1700
segment_ids = np.repeat(np.arange(num_segments), np.random.randint(40,977,size=num_segments))
@jax.jit
def do_sum(data):
return jax.ops.segment_sum(data, segment_ids=segment_ids, num_segments=num_segments, indices_are_sorted=True)
key = jax.random.PRNGKey(0)
data = jax.random.uniform(key, shape=(len(segment_ids),), dtype=jnp.float32)
do_sum(data).block_until_ready() # compile
%timeit do_sum(data).block_until_ready()
key = jax.random.PRNGKey(0)
data = jax.random.uniform(key, shape=(len(segment_ids),), dtype=jnp.float16)
do_sum(data).block_until_ready() # compile
%timeit do_sum(data).block_until_ready() Regardless, I do find that the float16 version is consistently slower. Perhaps @jakevdp can lead us in the right direction! |
Interesting question! I suspect the reason for the performance difference here is that the GPU hardware is designed and tuned for float32 computation, and not for float16 computation. It would be interesting to compare this across different generations of GPU hardware. |
But I think GPU's FP16 performance shouldn't be slower than FP32 performance. For example, the A100's FP16 FLOPS is twice the FP32 FLOPS; and for NVIDIA 4090, some data shows that it has equal FP16 and FP32 performance. Is it possible that JAX somehow internally converts FP16 to FP32, performs the computation, and converts the result back to FP16? |
No, I don't think such conversions are happening – you can see exactly what operations the compiler is emitting using ahead of time lowering to output the compiled HLO. This is the output on a T4 GPU: key = jax.random.key(0)
data = jax.random.uniform(key, shape=(len(segment_ids),), dtype='float16')
print(jax.jit(lambda data: jax.ops.segment_sum(
data, segment_ids=segment_ids, num_segments=num_segments, indices_are_sorted=True
)).lower(data).compile().as_text())
|
Thanks for the clarification! What might be the problem then? I am curious about how we can debug into this issue. |
My best guess still is that the hardware you're using is not optimized for the kinds of operations you're performing (i.e. scatters) in float16, and is more optimized for float32. Appendix A of https://images.nvidia.com/aem-dam/Solutions/geforce/ada/nvidia-ada-gpu-architecture.pdf suggests that for GeForce RTX 4090, non-tensor ops are no faster in F16 than in F32, though it doesn't indicate that they should be slower. It may be that performance is worse for F16 scatters – I'm not sure. |
I write a similar benchmark on PyTorch 2.3.1 and the torch_scatter library, and now I agree that non-tensor ops are no faster in F16 than in F32 on GeForce RTX 4090. However, it seems that PyTorch's FP16 performance is 380 times faster than Jax's FP16 performance on RTX 4090. If the following benchmark code is correct, then there is still much room for improvement in Jax? PyTorch code:
PyTorch result:
Jax code:
Jax result:
|
I think these are not equivalent operations – wouldn't |
Hi, this is not According to the documentation (https://pytorch-scatter.readthedocs.io/en/latest/functions/scatter.html#torch_scatter.scatter), We can also verify this in the following code:
Output:
JAX:
Output:
|
Ah, thanks for the clarification. Looks like it is doing the same thing – I'm not sure why JAX's version is slower. |
Description
I find that JAX segment_sum is two times slower for FP16 inputs than FP32 inputs. Here is an example:
Outputs:
This happens with or without
jit()
. Why does this happen? And is there a way to optimize the computation for FP16 input?System info (python version, jaxlib version, accelerator, etc.)
The text was updated successfully, but these errors were encountered: