Skip to content

torch.cond + torch.non_zero does not work with torch.export.export #144691

Open
@xadupre

Description

🐛 Describe the bug

I can't export the following model after rewriting the code with torch.cond. I tried with different configurations all listed below. None worked.

import torch


class Model(torch.nn.Module):
    def forward(
        self,
        input_ids,
        image_features,
        vocab_size,
    ):
        if image_features.numel():
            input_shape = input_ids.size()
            input_ids = input_ids.view(-1, input_shape[-1])

            # positions for image tokens
            condition = (input_ids < 0) & (input_ids > -int(1e9))
            positions = torch.where(condition)
            # has_image = len(positions[0].tolist()) > 0
            input_ids = input_ids.clamp_min(0).clamp_max(vocab_size)

            return (input_ids, *positions)

        return (input_ids, *torch.where(torch.zeros((1, 1), dtype=torch.bool)))


inputs = [
    (
        (torch.arange(24) - 8).reshape((2, -1)).to(torch.int64),
        torch.arange(32).reshape((2, -1)).to(torch.float32),
        1025,
    ),
    (
        (torch.arange(24) - 8).reshape((2, -1)).to(torch.int64),
        torch.tensor([[], []], dtype=torch.float32),
        1025,
    ),
]
model = Model()
expected = [model(*inp) for inp in inputs]
assert len(expected) == 2
assert len(expected[0]) == len(expected[1]) == 3


# Rewriting with torch.cond.

class Model2(torch.nn.Module):
    def forward(self, input_ids, image_features, vocab_size):
        def then_branch(input_ids, image_features, vocab_size):
            input_shape = input_ids.size()
            input_ids = input_ids.view(-1, input_shape[-1])

            condition = (input_ids < 0) & (input_ids > -int(1e9))
            positions = torch.nonzero(condition, as_tuple=True)
            input_ids = input_ids.clamp_min(0).clamp_max(vocab_size)
            return (input_ids, positions[0], positions[1])

        def else_branch(input_ids, image_features, vocab_size):
            r = torch.where(torch.zeros((1, 1), dtype=torch.bool))
            return (input_ids, r[0], r[1])

        a, b, c = torch.cond(
            image_features.numel() > 0,
            then_branch,
            else_branch,
            [input_ids, image_features, vocab_size],
        )
        return a, b, c

# Check that it is equivalent.
model2 = Model2()
new_out = [model2(*inp) for inp in inputs]
for i in range(2):
    for j in range(3):
        torch.testing.assert_close(expected[i][j], new_out[i][j])

batch = torch.export.Dim("batch")
seq_length = torch.export.Dim("seq_length")
dynamic_shapes = ({0: batch}, {0: batch, 1: seq_length}, None)

# We try to export with (tensor, tensor, int)
# ep = torch.export.export(model2, inputs[0], dynamic_shapes=dynamic_shapes, strict=False)
# fails with Expect operands to be a tuple of possibly nested dict/list/tuple that only consists of tensor leaves, but got [FakeTensor(..., size=(s1, 12), dtype=torch.int64), FakeTensor(..., size=(s2, s3)), 1025].
# print(ep)


# We try to export with (tensor, tensor, int)
new_inputs = (*inputs[0][:2], torch.tensor([1025], dtype=torch.int64))
# ep = torch.export.export(model2, new_inputs, dynamic_shapes=dynamic_shapes, strict=False)
# torch._dynamo.exc.Unsupported: dynamic shape operator: aten.nonzero.default; to enable, set torch._dynamo.config.capture_dynamic_output_shape_ops = True
# torch._dynamo.exc.UncapturedHigherOrderOpError: Cond doesn't work unless it is captured completely with torch.compile. Scroll up to find out what causes the graph break.
# print(ep)

torch._dynamo.config.capture_dynamic_output_shape_ops = True
ep = torch.export.export(model2, new_inputs, dynamic_shapes=dynamic_shapes, strict=False)
# torch._dynamo.exc.UncapturedHigherOrderOpError: Expected true_fn_output and false_fn_output to have same metadata but found:
# pair[1] differ in 'shape: torch.Size([u0]) vs torch.Size([u1])', where lhs is FakeTensor(..., size=(u0,), dtype=torch.int64) and rhs is FakeTensor(..., size=(u1,), dtype=torch.int64)
# pair[2] differ in 'shape: torch.Size([u0]) vs torch.Size([u1])', where lhs is FakeTensor(..., size=(u0,), dtype=torch.int64) and rhs is FakeTensor(..., size=(u1,), dtype=torch.int64)
print(ep)

Versions

Collecting environment information...
PyTorch version: 2.7.0.dev20250113+cu126
Is debug build: False
CUDA used to build PyTorch: 12.6
ROCM used to build PyTorch: N/A

OS: Ubuntu 22.04.4 LTS (x86_64)
GCC version: (Ubuntu 11.4.0-1ubuntu1~22.04) 11.4.0
Clang version: Could not collect
CMake version: version 3.31.4
Libc version: glibc-2.35

Python version: 3.12.8 (main, Dec 4 2024, 08:54:12) [GCC 11.4.0] (64-bit runtime)
Python platform: Linux-5.15.167.4-microsoft-standard-WSL2-x86_64-with-glibc2.35
Is CUDA available: True
CUDA runtime version: 12.6.68
CUDA_MODULE_LOADING set to: LAZY
GPU models and configuration: GPU 0: NVIDIA GeForce RTX 4060 Laptop GPU
Nvidia driver version: 538.92
cuDNN version: Probably one of the following:
/usr/lib/x86_64-linux-gnu/libcudnn.so.9.3.0
/usr/lib/x86_64-linux-gnu/libcudnn_adv.so.9.3.0
/usr/lib/x86_64-linux-gnu/libcudnn_cnn.so.9.3.0
/usr/lib/x86_64-linux-gnu/libcudnn_engines_precompiled.so.9.3.0
/usr/lib/x86_64-linux-gnu/libcudnn_engines_runtime_compiled.so.9.3.0
/usr/lib/x86_64-linux-gnu/libcudnn_graph.so.9.3.0
/usr/lib/x86_64-linux-gnu/libcudnn_heuristic.so.9.3.0
/usr/lib/x86_64-linux-gnu/libcudnn_ops.so.9.3.0
HIP runtime version: N/A
MIOpen runtime version: N/A
Is XNNPACK available: True

CPU:
Architecture: x86_64
CPU op-mode(s): 32-bit, 64-bit
Address sizes: 46 bits physical, 48 bits virtual
Byte Order: Little Endian
CPU(s): 20
On-line CPU(s) list: 0-19
Vendor ID: GenuineIntel
Model name: 13th Gen Intel(R) Core(TM) i7-13800H
CPU family: 6
Model: 186
Thread(s) per core: 2
Core(s) per socket: 10
Socket(s): 1
Stepping: 2
BogoMIPS: 5836.79
Flags: fpu vme de pse tsc msr pae mce cx8 apic sep mtrr pge mca cmov pat pse36 clflush mmx fxsr sse sse2 ss ht syscall nx pdpe1gb rdtscp lm constant_tsc rep_good nopl xtopology tsc_reliable nonstop_tsc cpuid pni pclmulqdq vmx ssse3 fma cx16 pcid sse4_1 sse4_2 x2apic movbe popcnt tsc_deadline_timer aes xsave avx f16c rdrand hypervisor lahf_lm abm 3dnowprefetch invpcid_single ssbd ibrs ibpb stibp ibrs_enhanced tpr_shadow vnmi ept vpid ept_ad fsgsbase tsc_adjust bmi1 avx2 smep bmi2 erms invpcid rdseed adx smap clflushopt clwb sha_ni xsaveopt xsavec xgetbv1 xsaves avx_vnni umip waitpkg gfni vaes vpclmulqdq rdpid movdiri movdir64b fsrm md_clear serialize flush_l1d arch_capabilities
Virtualization: VT-x
Hypervisor vendor: Microsoft
Virtualization type: full
L1d cache: 480 KiB (10 instances)
L1i cache: 320 KiB (10 instances)
L2 cache: 12.5 MiB (10 instances)
L3 cache: 24 MiB (1 instance)
Vulnerability Gather data sampling: Not affected
Vulnerability Itlb multihit: Not affected
Vulnerability L1tf: Not affected
Vulnerability Mds: Not affected
Vulnerability Meltdown: Not affected
Vulnerability Mmio stale data: Not affected
Vulnerability Reg file data sampling: Mitigation; Clear Register File
Vulnerability Retbleed: Mitigation; Enhanced IBRS
Vulnerability Spec rstack overflow: Not affected
Vulnerability Spec store bypass: Mitigation; Speculative Store Bypass disabled via prctl and seccomp
Vulnerability Spectre v1: Mitigation; usercopy/swapgs barriers and __user pointer sanitization
Vulnerability Spectre v2: Mitigation; Enhanced / Automatic IBRS; IBPB conditional; RSB filling; PBRSB-eIBRS SW sequence; BHI BHI_DIS_S
Vulnerability Srbds: Not affected
Vulnerability Tsx async abort: Not affected

Versions of relevant libraries:
[pip3] mypy-extensions==1.0.0
[pip3] numpy==2.2.1
[pip3] nvidia-cublas-cu12==12.6.4.1
[pip3] nvidia-cuda-cupti-cu12==12.6.80
[pip3] nvidia-cuda-nvrtc-cu12==12.6.77
[pip3] nvidia-cuda-runtime-cu12==12.6.77
[pip3] nvidia-cudnn-cu12==9.5.1.17
[pip3] nvidia-cufft-cu12==11.3.0.4
[pip3] nvidia-curand-cu12==10.3.7.77
[pip3] nvidia-cusolver-cu12==11.7.1.2
[pip3] nvidia-cusparse-cu12==12.5.4.2
[pip3] nvidia-cusparselt-cu12==0.6.3
[pip3] nvidia-nccl-cu12==2.21.5
[pip3] nvidia-nvjitlink-cu12==12.6.85
[pip3] nvidia-nvtx-cu12==12.6.77
[pip3] onnx==1.18.0
[pip3] onnx-extended==0.3.0
[pip3] onnxruntime_extensions==0.13.0
[pip3] onnxruntime-training==1.21.0+cu126
[pip3] pytorch-triton==3.2.0+git0d4682f0
[pip3] torch==2.7.0.dev20250113+cu126
[pip3] torch_geometric==2.4.0
[pip3] torchaudio==2.6.0.dev20250113+cu126
[pip3] torchvision==0.22.0.dev20250113+cu126
[conda] Could not collect

cc @chauhang @penguinwu @avikchaudhuri @gmagogsfm @zhxchen17 @tugsbayasgalan @angelayi @suo @ydwu4

Metadata

Assignees

Type

No type

Projects

No projects

Milestone

No milestone

Relationships

None yet

Development

No branches or pull requests

Issue actions