Open
Description
🐛 Describe the bug
import torch
from torch._higher_order_ops.out_dtype import out_dtype
def quantized_matmul(x_vals_int8, x_scales, w_vals_int8):
return out_dtype(torch.ops.aten.mm.default, torch.int32, x_vals_int8, w_vals_int8) * x_scales
x_vals_int8 = torch.randn(65536, 144).to(dtype=torch.int8).cuda()
x_scales = torch.randn(65536, 1).to(dtype=torch.float32).cuda()
w_vals_int8 = torch.randn(432, 144).to(dtype=torch.int8).cuda().t()
qcm = torch.compile(quantized_matmul, mode='max-autotune-no-cudagraphs')
qcm(x_vals_int8, x_scales, w_vals_int8)
produces
python: /root/.triton/llvm/llvm-86b69c31-almalinux-x64/include/llvm/Support/Casting.h:566: decltype(auto) llvm::cast(const From &) [To = mlir::FloatAttr, From = mlir::Attribute]: Assertion `isa<To>(Val) && "cast<Ty>() argument of incompatible type!"' failed.
Aborted (core dumped)
This works on nightly20241126py312
with pytorch-triton 3.1.0+cf34004b8a
. Can do more fine-grained bisection if needed.
Versions
ersions of relevant libraries:
[pip3] numpy==2.1.2
[pip3] nvidia-cublas-cu12==12.4.5.8
[pip3] nvidia-cuda-cupti-cu12==12.4.127
[pip3] nvidia-cuda-nvrtc-cu12==12.4.127
[pip3] nvidia-cuda-runtime-cu12==12.4.127
[pip3] nvidia-cudnn-cu12==9.1.0.70
[pip3] nvidia-cufft-cu12==11.2.1.3
[pip3] nvidia-curand-cu12==10.3.5.147
[pip3] nvidia-cusolver-cu12==11.6.1.9
[pip3] nvidia-cusparse-cu12==12.3.1.170
[pip3] nvidia-cusparselt-cu12==0.6.2
[pip3] nvidia-nccl-cu12==2.21.5
[pip3] nvidia-nvjitlink-cu12==12.4.127
[pip3] nvidia-nvtx-cu12==12.4.127
[pip3] pytorch-triton==3.2.0+git0d4682f0
[pip3] torch==2.7.0.dev20250113+cu124
[pip3] torchaudio==2.6.0.dev20250113+cu124
[pip3] torchvision==0.22.0.dev20250113+cu124
[conda] numpy 2.1.2 pypi_0 pypi
[conda] nvidia-cublas-cu12 12.4.5.8 pypi_0 pypi
[conda] nvidia-cuda-cupti-cu12 12.4.127 pypi_0 pypi
[conda] nvidia-cuda-nvrtc-cu12 12.4.127 pypi_0 pypi
[conda] nvidia-cuda-runtime-cu12 12.4.127 pypi_0 pypi
[conda] nvidia-cudnn-cu12 9.1.0.70 pypi_0 pypi
[conda] nvidia-cufft-cu12 11.2.1.3 pypi_0 pypi
[conda] nvidia-curand-cu12 10.3.5.147 pypi_0 pypi
[conda] nvidia-cusolver-cu12 11.6.1.9 pypi_0 pypi
[conda] nvidia-cusparse-cu12 12.3.1.170 pypi_0 pypi
[conda] nvidia-cusparselt-cu12 0.6.2 pypi_0 pypi
[conda] nvidia-nccl-cu12 2.21.5 pypi_0 pypi
[conda] nvidia-nvjitlink-cu12 12.4.127 pypi_0 pypi
[conda] nvidia-nvtx-cu12 12.4.127 pypi_0 pypi
[conda] pytorch-triton 3.2.0+git0d4682f0 pypi_0 pypi
[conda] torch 2.7.0.dev20250113+cu124 pypi_0 pypi
[conda] torchaudio 2.6.0.dev20250113+cu124 pypi_0 pypi
[conda] torchvision 0.22.0.dev20250113+cu124 pypi_0 pypi
Metadata
Assignees
Labels
No labels