Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Errors from interaction between shard_map with auto and lax.map #23019

Open
jeffgortmaker opened this issue Aug 12, 2024 · 0 comments
Open

Errors from interaction between shard_map with auto and lax.map #23019

jeffgortmaker opened this issue Aug 12, 2024 · 0 comments
Assignees
Labels
bug Something isn't working

Comments

@jeffgortmaker
Copy link

jeffgortmaker commented Aug 12, 2024

Description

I'm converting #23015 into an issue per @yashk2810's request.

Here's a minimal working example of what I'm trying to do. My goal in function h is to map a second function g over the first dimension a of a 2x2 array. Within g, I attempt to map a third function f over the array's second dimension b. Each map applies shard_map to lax.map; vmap isn't a good replacement because its batching rule for my code's actual f is quite expensive.

import os
import jax

print(jax.__version__)  # 0.4.31
jax.config.update('jax_platform_name', 'cpu')
os.environ['XLA_FLAGS'] = '--xla_force_host_platform_device_count=4'

from functools import partial
from jax.sharding import Mesh, PartitionSpec as P
from jax.experimental.shard_map import shard_map
from jax.experimental.mesh_utils import create_device_mesh

mesh = Mesh(create_device_mesh((2, 2)), ('a', 'b'))

def f(x):
    return x + 1

def g(x):
    return shard_map(partial(jax.lax.map, f), mesh, P('b'), P('b'))(x) + 2

def h(x):
    return shard_map(partial(jax.lax.map, g), mesh, P('a'), P('a'), check_rep=False, auto=frozenset({'b'}))(x) + 3

import numpy as np
print(g(np.zeros(2)))  # [3. 3.]
print(h(np.zeros((2, 2))))  # error

My understanding is that auto should make this work. But I get a NotImplementedError from the start of jax.experimental.shard_map._shard_map_impl. When I wrap h in a jit before calling it, I instead get the following:

F external/xla/xla/hlo/utils/hlo_sharding_util.cc:2806] Check failed: sharding.IsManualSubgroup() 

System info (python version, jaxlib version, accelerator, etc.)

jax: 0.4.31
jaxlib: 0.4.31
numpy: 2.0.1
python: 3.10.13 | packaged by conda-forge | (main, Dec 23 2023, 15:36:39) [GCC 12.3.0]
jax.devices (1 total, 1 local): [CpuDevice(id=0)]
process_count: 1
platform: uname_result(system='Linux', release='5.10.16.3-microsoft-standard-WSL2', version='1 SMP Fri Apr 2 22:23:49 UTC 2021', machine='x86_64')

@jeffgortmaker jeffgortmaker added the bug Something isn't working label Aug 12, 2024
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
bug Something isn't working
Projects
None yet
Development

No branches or pull requests

2 participants