Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

erf_inv not supported in Pallas GPU #23173

Closed
ayaka14732 opened this issue Aug 21, 2024 · 0 comments · Fixed by #23192
Closed

erf_inv not supported in Pallas GPU #23173

ayaka14732 opened this issue Aug 21, 2024 · 0 comments · Fixed by #23192
Assignees
Labels
bug Something isn't working

Comments

@ayaka14732
Copy link
Member

Description

def test_erf_inv():
  @jax.jit
  @functools.partial(
      pl.pallas_call,
      out_shape=jax.ShapeDtypeStruct((8, 128), jnp.float32),
  )
  def kernel(x_ref, o_ref):
    o_ref[...] = lax.erf_inv(x_ref[...])

  val = -3.2
  x = jnp.full((8, 128), val)
  out = kernel(x)
  expected = lax.erf_inv(x)
  np.testing.assert_array_equal(out, expected)

Error:

NotImplementedError: Unimplemented primitive in Pallas GPU lowering: erf_inv. Please file an issue on https://github.com/google/jax/issues.

System info (python version, jaxlib version, accelerator, etc.)

jax:    0.4.32
jaxlib: 0.4.32
numpy:  1.26.3
python: 3.11.8 (stable, redacted, redacted) [Clang google3-trunk (a0a9bf5152507beacd2a72dda42d054391494c4a)]
jax.devices (1 total, 1 local): [CudaDevice(id=0)]
process_count: 1
platform: uname_result(system='Linux', node='nonet5.prod.google.com', release='5.10.0-smp-1101.34.0.0', version='#1 [v5.10.0-1101.34.0.0] SMP @1712273364', machine='x86_64')


$ nvidia-smi
Wed Aug 21 09:45:21 2024       
+-----------------------------------------------------------------------------+
| NVIDIA-SMI 525.125.06   Driver Version: 525.125.06   CUDA Version: 12.0     |
|-------------------------------+----------------------+----------------------+
| GPU  Name        Persistence-M| Bus-Id        Disp.A | Volatile Uncorr. ECC |
| Fan  Temp  Perf  Pwr:Usage/Cap|         Memory-Usage | GPU-Util  Compute M. |
|                               |                      |               MIG M. |
|===============================+======================+======================|
|   0  NVIDIA A100-SXM...  On   | 00000000:DB:00.0 Off |                    0 |
| N/A   37C    P0    63W / 400W |    419MiB / 40960MiB |      0%      Default |
|                               |                      |             Disabled |
+-------------------------------+----------------------+----------------------+
                                                                               
+-----------------------------------------------------------------------------+
| Processes:                                                                  |
|  GPU   GI   CI        PID   Type   Process name                  GPU Memory |
|        ID   ID                                                   Usage      |
|=============================================================================|
+-----------------------------------------------------------------------------+
@ayaka14732 ayaka14732 added the bug Something isn't working label Aug 21, 2024
@ayaka14732 ayaka14732 self-assigned this Aug 21, 2024
shahid45754 added a commit to shahid45754/jax that referenced this issue Aug 22, 2024
…ving pallas_call -Removed pallas_call decorator from kernel function due to unsupported erf_inv operation in pallas. -Simplified 'kernal function to directly return the result of lax.erf_inv, eliminated the need for o_ref -This Changes addresses the NotImplementedError and improves code stability on both CPU and GPU.
@ayaka14732 ayaka14732 linked a pull request Aug 22, 2024 that will close this issue
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
bug Something isn't working
Projects
None yet
Development

Successfully merging a pull request may close this issue.

1 participant