Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Pallas Unexpected Numeric Results With tf32 on A100 GPU #23182

Closed
axelfeldmann opened this issue Aug 21, 2024 · 3 comments
Closed

Pallas Unexpected Numeric Results With tf32 on A100 GPU #23182

axelfeldmann opened this issue Aug 21, 2024 · 3 comments
Assignees
Labels
bug Something isn't working

Comments

@axelfeldmann
Copy link

axelfeldmann commented Aug 21, 2024

Description

Hi,

I'm trying to implement a fp32/tf32 matrix multiplication kernel using Pallas. However, the numeric results have more error than I hoped. Specifically, I have the following code:

import jax
from jax import numpy as jnp
from jax.experimental import pallas as pl
from jax import lax
import functools

def kernel_func(A_ref, B_ref, C_ref, N, allow_tf32):
    A_tile = pl.load(A_ref, (pl.dslice(None), pl.dslice(None))).astype(jnp.float32)
    B_tile = pl.load(B_ref, (pl.dslice(None), pl.dslice(None))).astype(jnp.float32)
    C_tile = pl.dot(A_tile, B_tile, allow_tf32=allow_tf32)
    pl.store(C_ref, (pl.dslice(None), pl.dslice(None)), C_tile.astype(C_ref.dtype))
    
def matmul(A, B, allow_tf32):
    N = A.shape[0]
    grid = (1, 1)
    in_specs = [
        pl.BlockSpec(lambda r, c: (0, 0), (N, N)),
        pl.BlockSpec(lambda r, c: (0, 0), (N, N))
    ]
    C = jax.ShapeDtypeStruct(shape=(N, N), dtype=A.dtype)

    kernel = functools.partial(
        kernel_func,
        N = N,
        allow_tf32 = allow_tf32
    )
    out, = pl.pallas_call(kernel,
        grid=grid, in_specs=in_specs,
        out_specs=[
            pl.BlockSpec(lambda r, c: (r, c), (N, N))
        ],
        out_shape=[ C ], name="matmul"
    )(A, B)
    return out

dtype = jnp.float32

N = 64
A = jax.random.uniform(jax.random.PRNGKey(0), (N, N), dtype=jnp.float32).astype(dtype)
B = jax.random.uniform(jax.random.PRNGKey(1), (N, N), dtype=jnp.float32).astype(dtype)

C_ref_no_tf32 = jnp.dot(A, B, precision="highest")
print(f"{C_ref_no_tf32[0,0] = }, {C_ref_no_tf32.dtype = }")

C_ref_tf32 = jnp.dot(A, B, precision="high")
print(f"{C_ref_tf32[0,0] = }, {C_ref_tf32.dtype = }")

C_pallas_no_tf32 = matmul(A, B, allow_tf32=False)
print(f"{C_pallas_no_tf32[0,0] = }, {C_pallas_no_tf32.dtype = }")

C_pallas_tf32 = matmul(A, B, allow_tf32=True)
print(f"{C_pallas_tf32[0,0] = }, {C_pallas_tf32.dtype = }")

And this outputs:

C_ref_no_tf32[0,0] = Array(16.450489, dtype=float32), C_ref_no_tf32.dtype = dtype('float32')
C_ref_tf32[0,0] = Array(16.451378, dtype=float32), C_ref_tf32.dtype = dtype('float32')
C_pallas_no_tf32[0,0] = Array(16.450489, dtype=float32), C_pallas_no_tf32.dtype = dtype('float32')
C_pallas_tf32[0,0] = Array(16.438375, dtype=float32), C_pallas_tf32.dtype = dtype('float32')

While the numeric difference may seem somewhat small, the difference between C_pallas_tf32 and the others becomes significant in larger applications. I am specifically curious why there is a difference between C_ref_tf32 and C_pallas_tf32. Both of them should be using tf32, so I was thinking that they should be very close to equal, much like C_ref_no_tf32 and C_pallas_no_tf32.

Two main questions:

  • do you know why this may be the case?
  • is there any way to get Pallas/Jax to dump the Pallas kernel's PTX? that way maybe at least I could inspect what it's doing.

I know that it's unreasonable to expect bitwise equality with floating point numbers, but this error does seem really hard to understand.

System info (python version, jaxlib version, accelerator, etc.)

jax:    0.4.26
jaxlib: 0.4.26
numpy:  1.26.4
python: 3.10.14 (main, Mar 21 2024, 16:24:04) [GCC 11.2.0]
jax.devices (1 total, 1 local): [cuda(id=0)]
process_count: 1
platform: uname_result(system='Linux', node='deep-chungus-8.csail.mit.edu', release='5.15.0-107-generic', version='#117-Ubuntu SMP Fri Apr 26 12:26:49 UTC 2024', machine='x86_64')


$ nvidia-smi
Wed Aug 21 18:16:52 2024
+-----------------------------------------------------------------------------------------+
| NVIDIA-SMI 550.54.15              Driver Version: 550.54.15      CUDA Version: 12.4     |
|-----------------------------------------+------------------------+----------------------+
| GPU  Name                 Persistence-M | Bus-Id          Disp.A | Volatile Uncorr. ECC |
| Fan  Temp   Perf          Pwr:Usage/Cap |           Memory-Usage | GPU-Util  Compute M. |
|                                         |                        |               MIG M. |
|=========================================+========================+======================|
|   0  NVIDIA A100 80GB PCIe          On  |   00000000:01:00.0 Off |                    0 |
| N/A   70C    P0            136W /  300W |    9161MiB /  81920MiB |     41%      Default |
|                                         |                        |             Disabled |
+-----------------------------------------+------------------------+----------------------+
|   1  NVIDIA A100 80GB PCIe          On  |   00000000:24:00.0 Off |                    0 |
| N/A   83C    P0            212W /  300W |    9161MiB /  81920MiB |     55%      Default |
|                                         |                        |             Disabled |
+-----------------------------------------+------------------------+----------------------+
|   2  NVIDIA A100 80GB PCIe          On  |   00000000:41:00.0 Off |                    0 |
| N/A   69C    P0            295W /  300W |    9161MiB /  81920MiB |     49%      Default |
|                                         |                        |             Disabled |
+-----------------------------------------+------------------------+----------------------+
|   3  NVIDIA A100 80GB PCIe          On  |   00000000:61:00.0 Off |                    0 |
| N/A   78C    P0            335W /  300W |   59973MiB /  81920MiB |     55%      Default |
|                                         |                        |             Disabled |
+-----------------------------------------+------------------------+----------------------+
|   4  NVIDIA A100 80GB PCIe          On  |   00000000:81:00.0 Off |                    0 |
| N/A   64C    P0            340W /  300W |   38701MiB /  81920MiB |     62%      Default |
|                                         |                        |             Disabled |
+-----------------------------------------+------------------------+----------------------+
|   5  NVIDIA A100 80GB PCIe          On  |   00000000:A1:00.0 Off |                    0 |
| N/A   62C    P0             90W /  300W |    9161MiB /  81920MiB |     45%      Default |
|                                         |                        |             Disabled |
+-----------------------------------------+------------------------+----------------------+
|   6  NVIDIA A100 80GB PCIe          On  |   00000000:C1:00.0 Off |                    0 |
| N/A   40C    P0             65W /  300W |     425MiB /  81920MiB |      2%      Default |
|                                         |                        |             Disabled |
+-----------------------------------------+------------------------+----------------------+
|   7  NVIDIA A100 80GB PCIe          On  |   00000000:E1:00.0 Off |                    0 |
| N/A   73C    P0            257W /  300W |    9161MiB /  81920MiB |     40%      Default |
|                                         |                        |             Disabled |
+-----------------------------------------+------------------------+----------------------+

+-----------------------------------------------------------------------------------------+
| Processes:                                                                              |
|  GPU   GI   CI        PID   Type   Process name                              GPU Memory |
|        ID   ID                                                               Usage      |
|=========================================================================================|

@axelfeldmann axelfeldmann added the bug Something isn't working label Aug 21, 2024
@dfm
Copy link
Member

dfm commented Aug 22, 2024

Pinging @sharadmv who will know best.

@justinjfu justinjfu assigned justinjfu and unassigned sharadmv Aug 27, 2024
@justinjfu
Copy link
Collaborator

justinjfu commented Aug 27, 2024

Hi,

do you know why this may be the case?

This is the same issue encountered here: triton-lang/triton#4574. I applied the same fix recommended there and was able to get the same result for TF32 between XLA and Pallas. You can try it by pulling this branch: #23262.

is there any way to get Pallas/Jax to dump the Pallas kernel's PTX? that way maybe at least I could inspect what it's doing.

You can pass in debug=True to pallas_call and it will dump the Triton IR. But in this case you wouldn't see anything suspicious since it's due to rounding issues.

The right solution here is probably to allow inline assembly in Pallas since we don't have that functionality yet.

@axelfeldmann
Copy link
Author

Yes, the solution we got to triton-lang/triton#4574 addresses this problem. Closing the issue here.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
bug Something isn't working
Projects
None yet
Development

No branches or pull requests

4 participants