You signed in with another tab or window. Reload to refresh your session.You signed out in another tab or window. Reload to refresh your session.You switched accounts on another tab or window. Reload to refresh your session.Dismiss alert
Hi,
I intended to use the Multi-GPU capabilities of JAX to work with sharded arrays. However, there are two issues:
Using the array.at[].set() syntax is lowered to an all-gather operation, which leads to OOM errors since my arrays are too large to be stored on a single device. Technically, it should not be necessary to perform an all-gather operation on scatter, because the updates can be applied independently on each device.
As a workaround, i tried to implement the update using jax.lax.scatter in hopes that this is not lowered to all-gather (which actually works). However, I encountered a second bug. Using jax.lax.scatter produces a different result when used with sharded arrays in contrast to normal non-sharded arrays.
Here is an MWE:
import os
os.environ['JAX_PLATFORM_NAME'] = 'cpu'
os.environ["XLA_FLAGS"] = '--xla_force_host_platform_device_count=4 --xla_gpu_force_compilation_parallelism=4'
import jax
from jax.experimental import mesh_utils
from jax.sharding import PositionalSharding
import jax.numpy as jnp
def main():
print(f"{jax.devices()=}")
print(f"{jax.device_count()=}")
sharding = PositionalSharding(
devices=mesh_utils.create_device_mesh(
(4, 1),
devices=jax.devices(),
)
)
print(f"{sharding=}")
# does not work with sharded matrix
arr = jnp.zeros(
(8, 2),
dtype=jnp.float32,
device=sharding,
)
arr = jax.device_put(arr, sharding)
# this would work
# arr = jnp.zeros((8, 2), dtype=jnp.float32)
s = slice(0, 3)
def fn(x):
b = s.indices(x.shape[0])
indices = jnp.arange(b[0], b[1], b[2])[:, None]
updates = jnp.ones((3, x.shape[1]), dtype=jnp.float32)
res = jax.lax.scatter(
x,
indices,
updates,
jax.lax.ScatterDimensionNumbers(
update_window_dims=(1,),
inserted_window_dims=(0,),
scatter_dims_to_operand_dims=(0,)
)
)
return res
jit_fn = (
jax.jit(fn)
.lower(arr)
.compile()
)
# print("compiled")
# print(jit_fn.as_text())
r = jit_fn(arr)
print(f"{r=}")
print(f"{r.sum()=}")
if __name__ == '__main__':
main()
When executed with a sharded array, only the first index of the array is set to one, but not the second and third index as one would expect:
To see the all-gather behavior of the normal array.at.set() syntax, uncomment the print of the jitted HLO code. With scatter there is no all-gather, but the result is incorrect.
System info (python version, jaxlib version, accelerator, etc.)
jax: 0.4.31
jaxlib: 0.4.31
numpy: 1.26.4
python: 3.11.9 | packaged by conda-forge | (main, Apr 19 2024, 18:36:13) [GCC 12.3.0]
jax.devices (1 total, 1 local): [CudaDevice(id=0)]
process_count: 1
platform: uname_result(system='Linux', node='URL REDACTED FOR PRIVACY', release='5.15.0-113-generic', version='#123-Ubuntu SMP Mon Jun 10 08:16:17 UTC 2024', machine='x86_64')
$ nvidia-smi
Wed Aug 14 10:36:51 2024
+-----------------------------------------------------------------------------------------+
| NVIDIA-SMI 555.42.06 Driver Version: 555.42.06 CUDA Version: 12.5 |
|-----------------------------------------+------------------------+----------------------+
| GPU Name Persistence-M | Bus-Id Disp.A | Volatile Uncorr. ECC |
| Fan Temp Perf Pwr:Usage/Cap | Memory-Usage | GPU-Util Compute M. |
| | | MIG M. |
|=========================================+========================+======================|
| 0 NVIDIA GeForce RTX 4090 On | 00000000:2C:00.0 Off | Off |
| 0% 60C P0 47W / 450W | 396MiB / 24564MiB | 0% Default |
| | | N/A |
+-----------------------------------------+------------------------+----------------------+
+-----------------------------------------------------------------------------------------+
| Processes: |
| GPU GI CI PID Type Process name GPU Memory |
| ID ID Usage |
|=========================================================================================|
| 0 N/A N/A 3293702 C python 386MiB |
+-----------------------------------------------------------------------------------------+
The text was updated successfully, but these errors were encountered:
Description
Hi,
I intended to use the Multi-GPU capabilities of JAX to work with sharded arrays. However, there are two issues:
Using the array.at[].set() syntax is lowered to an all-gather operation, which leads to OOM errors since my arrays are too large to be stored on a single device. Technically, it should not be necessary to perform an all-gather operation on scatter, because the updates can be applied independently on each device.
As a workaround, i tried to implement the update using jax.lax.scatter in hopes that this is not lowered to all-gather (which actually works). However, I encountered a second bug. Using jax.lax.scatter produces a different result when used with sharded arrays in contrast to normal non-sharded arrays.
Here is an MWE:
When executed with a sharded array, only the first index of the array is set to one, but not the second and third index as one would expect:
Expected Output (Can be achieved by using a non-sharded array):
To see the all-gather behavior of the normal array.at.set() syntax, uncomment the print of the jitted HLO code. With scatter there is no all-gather, but the result is incorrect.
System info (python version, jaxlib version, accelerator, etc.)
The text was updated successfully, but these errors were encountered: