Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

jax-metal: reduce window not supported #21387

Open
jonatanklosko opened this issue May 23, 2024 · 0 comments
Open

jax-metal: reduce window not supported #21387

jonatanklosko opened this issue May 23, 2024 · 0 comments
Assignees
Labels
Apple GPU (Metal) plugin bug Something isn't working

Comments

@jonatanklosko
Copy link

jonatanklosko commented May 23, 2024

Description

import jax
import jax.numpy as jnp

def f(x):
  return jax.lax.reduce_window(
            x, 0, jnp.add, window_dimensions=(2,), window_strides=(1,),
            padding=[(0, 0)], base_dilation=(1,), window_dilation=(1,))

x = jnp.array([1, 2, 4])

# Print lowered HLO
print(jax.jit(f).lower(x).as_text())
print(jax.jit(f)(x))
HLO
module @jit_f attributes {mhlo.num_partitions = 1 : i32, mhlo.num_replicas = 1 : i32} {
  func.func public @main(%arg0: tensor<3xi32> {mhlo.layout_mode = "default"}) -> (tensor<4xi32> {jax.result_info = "", mhlo.layout_mode = "default"}) {
    %c = stablehlo.constant dense<0> : tensor<i32>
    %0 = "stablehlo.reduce_window"(%arg0, %c) <{base_dilations = array<i64: 1>, padding = dense<1> : tensor<1x2xi64>, window_dilations = array<i64: 1>, window_dimensions = array<i64: 2>, window_strides = array<i64: 1>}> ({
    ^bb0(%arg1: tensor<i32>, %arg2: tensor<i32>):
      %1 = stablehlo.add %arg1, %arg2 : tensor<i32>
      stablehlo.return %1 : tensor<i32>
    }) : (tensor<3xi32>, tensor<i32>) -> tensor<4xi32>
    return %0 : tensor<4xi32>
  }
}

fails with

Traceback (most recent call last):
  File "/Users/jonatanklosko/tmp/jax_mlir.py", line 71, in <module>
    print(jax.jit(f)(x))
jaxlib.xla_extension.XlaRuntimeError: UNKNOWN: /Users/jonatanklosko/tmp/jax_mlir.py:63:0: error: failed to legalize operation 'mhlo.reduce_window'
/Users/jonatanklosko/tmp/jax_mlir.py:70:0: note: called from
/Users/jonatanklosko/tmp/jax_mlir.py:63:0: note: see current operation:
%2 = "mhlo.reduce_window"(%arg0, %1) ({
^bb0(%arg1: tensor<si32>, %arg2: tensor<si32>):
  %3 = "mhlo.add"(%arg1, %arg2) : (tensor<si32>, tensor<si32>) -> tensor<si32>
  "mhlo.return"(%3) : (tensor<si32>) -> ()
}) {base_dilations = dense<1> : tensor<1xi64>, padding = dense<0> : tensor<1x2xi64>, window_dilations = dense<1> : tensor<1xi64>, window_dimensions = dense<2> : tensor<1xi64>, window_strides = dense<1> : tensor<1xi64>} : (tensor<3xsi32>, tensor<si32>) -> tensor<2xsi32>

System info (python version, jaxlib version, accelerator, etc.)

jax:    0.4.26
jaxlib: 0.4.26
numpy:  1.26.4
python: 3.10.8 (main, Nov 16 2022, 12:45:33) [Clang 14.0.0 (clang-1400.0.29.202)]
jax.devices (1 total, 1 local): [METAL(id=0)]
process_count: 1
platform: uname_result(system='Darwin', node='chonker', release='23.5.0', version='Darwin Kernel Version 23.5.0: Wed May  1 20:12:58 PDT 2024; root:xnu-10063.121.3~5/RELEASE_ARM64_T6000', machine='arm64')

jax-metal 0.0.7

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
Apple GPU (Metal) plugin bug Something isn't working
Projects
None yet
Development

No branches or pull requests

4 participants