Skip to content

int_mm seems broken due to Triton upgrade #144705

Open
@cpuhrsch

Description

🐛 Describe the bug

import torch
from torch._higher_order_ops.out_dtype import out_dtype


def quantized_matmul(x_vals_int8, x_scales, w_vals_int8):
    return out_dtype(torch.ops.aten.mm.default, torch.int32, x_vals_int8, w_vals_int8) * x_scales


x_vals_int8 = torch.randn(65536, 144).to(dtype=torch.int8).cuda()
x_scales = torch.randn(65536, 1).to(dtype=torch.float32).cuda()
w_vals_int8 = torch.randn(432, 144).to(dtype=torch.int8).cuda().t()

qcm = torch.compile(quantized_matmul, mode='max-autotune-no-cudagraphs')

qcm(x_vals_int8, x_scales, w_vals_int8)

produces

python: /root/.triton/llvm/llvm-86b69c31-almalinux-x64/include/llvm/Support/Casting.h:566: decltype(auto) llvm::cast(const From &) [To = mlir::FloatAttr, From = mlir::Attribute]: Assertion `isa<To>(Val) && "cast<Ty>() argument of incompatible type!"' failed.
Aborted (core dumped)

This works on nightly20241126py312 with pytorch-triton 3.1.0+cf34004b8a. Can do more fine-grained bisection if needed.

Versions

ersions of relevant libraries:
[pip3] numpy==2.1.2
[pip3] nvidia-cublas-cu12==12.4.5.8
[pip3] nvidia-cuda-cupti-cu12==12.4.127
[pip3] nvidia-cuda-nvrtc-cu12==12.4.127
[pip3] nvidia-cuda-runtime-cu12==12.4.127
[pip3] nvidia-cudnn-cu12==9.1.0.70
[pip3] nvidia-cufft-cu12==11.2.1.3
[pip3] nvidia-curand-cu12==10.3.5.147
[pip3] nvidia-cusolver-cu12==11.6.1.9
[pip3] nvidia-cusparse-cu12==12.3.1.170
[pip3] nvidia-cusparselt-cu12==0.6.2
[pip3] nvidia-nccl-cu12==2.21.5
[pip3] nvidia-nvjitlink-cu12==12.4.127
[pip3] nvidia-nvtx-cu12==12.4.127
[pip3] pytorch-triton==3.2.0+git0d4682f0
[pip3] torch==2.7.0.dev20250113+cu124
[pip3] torchaudio==2.6.0.dev20250113+cu124
[pip3] torchvision==0.22.0.dev20250113+cu124
[conda] numpy                     2.1.2                    pypi_0    pypi
[conda] nvidia-cublas-cu12        12.4.5.8                 pypi_0    pypi
[conda] nvidia-cuda-cupti-cu12    12.4.127                 pypi_0    pypi
[conda] nvidia-cuda-nvrtc-cu12    12.4.127                 pypi_0    pypi
[conda] nvidia-cuda-runtime-cu12  12.4.127                 pypi_0    pypi
[conda] nvidia-cudnn-cu12         9.1.0.70                 pypi_0    pypi
[conda] nvidia-cufft-cu12         11.2.1.3                 pypi_0    pypi
[conda] nvidia-curand-cu12        10.3.5.147               pypi_0    pypi
[conda] nvidia-cusolver-cu12      11.6.1.9                 pypi_0    pypi
[conda] nvidia-cusparse-cu12      12.3.1.170               pypi_0    pypi
[conda] nvidia-cusparselt-cu12    0.6.2                    pypi_0    pypi
[conda] nvidia-nccl-cu12          2.21.5                   pypi_0    pypi
[conda] nvidia-nvjitlink-cu12     12.4.127                 pypi_0    pypi
[conda] nvidia-nvtx-cu12          12.4.127                 pypi_0    pypi
[conda] pytorch-triton            3.2.0+git0d4682f0          pypi_0    pypi
[conda] torch                     2.7.0.dev20250113+cu124          pypi_0    pypi
[conda] torchaudio                2.6.0.dev20250113+cu124          pypi_0    pypi
[conda] torchvision               0.22.0.dev20250113+cu124          pypi_0    pypi

Metadata

Assignees

No one assigned

    Labels

    No labels
    No labels

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions