Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Scatter on Sharded Matrices has bugs #23052

Open
ymahlau opened this issue Aug 14, 2024 · 1 comment
Open

Scatter on Sharded Matrices has bugs #23052

ymahlau opened this issue Aug 14, 2024 · 1 comment
Labels
bug Something isn't working

Comments

@ymahlau
Copy link

ymahlau commented Aug 14, 2024

Description

Hi,
I intended to use the Multi-GPU capabilities of JAX to work with sharded arrays. However, there are two issues:

  1. Using the array.at[].set() syntax is lowered to an all-gather operation, which leads to OOM errors since my arrays are too large to be stored on a single device. Technically, it should not be necessary to perform an all-gather operation on scatter, because the updates can be applied independently on each device.

  2. As a workaround, i tried to implement the update using jax.lax.scatter in hopes that this is not lowered to all-gather (which actually works). However, I encountered a second bug. Using jax.lax.scatter produces a different result when used with sharded arrays in contrast to normal non-sharded arrays.

Here is an MWE:

import os
os.environ['JAX_PLATFORM_NAME'] = 'cpu'
os.environ["XLA_FLAGS"] = '--xla_force_host_platform_device_count=4 --xla_gpu_force_compilation_parallelism=4'

import jax
from jax.experimental import mesh_utils
from jax.sharding import PositionalSharding
import jax.numpy as jnp

def main():
    
    print(f"{jax.devices()=}")
    print(f"{jax.device_count()=}")
    
    sharding = PositionalSharding(
        devices=mesh_utils.create_device_mesh(
            (4, 1),
            devices=jax.devices(),
        )
    )
    print(f"{sharding=}")
    
    # does not work with sharded matrix
    arr = jnp.zeros(
        (8, 2),
        dtype=jnp.float32,
        device=sharding,
    )
    arr = jax.device_put(arr, sharding)
    
    # this would work
    # arr = jnp.zeros((8, 2), dtype=jnp.float32)
    
    s = slice(0, 3)
    def fn(x):        
        b = s.indices(x.shape[0])
        indices = jnp.arange(b[0], b[1], b[2])[:, None]
        
        updates = jnp.ones((3, x.shape[1]), dtype=jnp.float32)
        
        res = jax.lax.scatter(
            x,
            indices,
            updates,
            jax.lax.ScatterDimensionNumbers(
                update_window_dims=(1,),
                inserted_window_dims=(0,),
                scatter_dims_to_operand_dims=(0,)
            )
        )
        return res
        
    
    jit_fn = (
        jax.jit(fn)
        .lower(arr)
        .compile()
    )
    
    # print("compiled")
    # print(jit_fn.as_text())
    
    r = jit_fn(arr)
    print(f"{r=}")
    print(f"{r.sum()=}")

if __name__ == '__main__':
    main()

When executed with a sharded array, only the first index of the array is set to one, but not the second and third index as one would expect:

jax.devices()=[CpuDevice(id=0), CpuDevice(id=1), CpuDevice(id=2), CpuDevice(id=3)]
jax.device_count()=4
sharding=PositionalSharding([[{CPU 0}]
                    [{CPU 1}]
                    [{CPU 2}]
                    [{CPU 3}]], shape=(4, 1))
r=Array([[1., 1.],
       [0., 0.],
       [0., 0.],
       [0., 0.],
       [0., 0.],
       [0., 0.],
       [0., 0.],
       [0., 0.]], dtype=float32)
r.sum()=Array(2., dtype=float32)

Expected Output (Can be achieved by using a non-sharded array):

r=Array([[1., 1.],
       [1., 1.],
       [1., 1.],
       [0., 0.],
       [0., 0.],
       [0., 0.],
       [0., 0.],
       [0., 0.]], dtype=float32)
r.sum()=Array(6., dtype=float32)

To see the all-gather behavior of the normal array.at.set() syntax, uncomment the print of the jitted HLO code. With scatter there is no all-gather, but the result is incorrect.

System info (python version, jaxlib version, accelerator, etc.)

jax:    0.4.31
jaxlib: 0.4.31
numpy:  1.26.4
python: 3.11.9 | packaged by conda-forge | (main, Apr 19 2024, 18:36:13) [GCC 12.3.0]
jax.devices (1 total, 1 local): [CudaDevice(id=0)]
process_count: 1
platform: uname_result(system='Linux', node='URL REDACTED FOR PRIVACY', release='5.15.0-113-generic', version='#123-Ubuntu SMP Mon Jun 10 08:16:17 UTC 2024', machine='x86_64')


$ nvidia-smi
Wed Aug 14 10:36:51 2024       
+-----------------------------------------------------------------------------------------+
| NVIDIA-SMI 555.42.06              Driver Version: 555.42.06      CUDA Version: 12.5     |
|-----------------------------------------+------------------------+----------------------+
| GPU  Name                 Persistence-M | Bus-Id          Disp.A | Volatile Uncorr. ECC |
| Fan  Temp   Perf          Pwr:Usage/Cap |           Memory-Usage | GPU-Util  Compute M. |
|                                         |                        |               MIG M. |
|=========================================+========================+======================|
|   0  NVIDIA GeForce RTX 4090        On  |   00000000:2C:00.0 Off |                  Off |
|  0%   60C    P0             47W /  450W |     396MiB /  24564MiB |      0%      Default |
|                                         |                        |                  N/A |
+-----------------------------------------+------------------------+----------------------+
                                                                                         
+-----------------------------------------------------------------------------------------+
| Processes:                                                                              |
|  GPU   GI   CI        PID   Type   Process name                              GPU Memory |
|        ID   ID                                                               Usage      |
|=========================================================================================|
|    0   N/A  N/A   3293702      C   python                                        386MiB |
+-----------------------------------------------------------------------------------------+
@ymahlau ymahlau added the bug Something isn't working label Aug 14, 2024
@ymahlau
Copy link
Author

ymahlau commented Aug 14, 2024

The problematic all-gather operation is also described in issue #20381

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
bug Something isn't working
Projects
None yet
Development

No branches or pull requests

1 participant