Skip to content

Unable to export GLM models to ONNX #35021

Open
@xenova

Description

@xenova

System Info

  • transformers version: 4.47.0.dev0
  • Platform: Linux-6.5.0-1025-azure-x86_64-with-glibc2.31
  • Python version: 3.12.1
  • Huggingface_hub version: 0.26.1
  • Safetensors version: 0.4.5
  • Accelerate version: 1.0.1
  • Accelerate config: not found
  • PyTorch version (GPU?): 2.5.0+cu124 (False)
  • Tensorflow version (GPU?): not installed (NA)
  • Flax version (CPU?/GPU?/TPU?): not installed (NA)
  • Jax version: not installed
  • JaxLib version: not installed
  • Using distributed or parallel set-up in script?:

Who can help?

@ArthurZucker @Cyrilvallez

Information

  • The official example scripts
    My own modified scripts

Tasks

  • An officially supported task in the examples folder (such as GLUE/SQuAD, ...)
    My own task or dataset (give details below)

Reproduction

Run the following script:

import os
import torch
from transformers import (
    AutoProcessor,
    GlmForCausalLM,
    DynamicCache,
)

class PatchedGlmForCausalLM(GlmForCausalLM):
    def forward(self, *args):
        input_ids, attention_mask, position_ids, *past_key_values_args = args

        # Convert past_key_values list to DynamicCache
        if len(past_key_values_args) == 0:
            past_key_values = None
        else:
            past_key_values = DynamicCache(self.config.num_hidden_layers)
            for i in range(self.config.num_hidden_layers):
                key = past_key_values_args.pop(0)
                value = past_key_values_args.pop(0)
                past_key_values.update(key_states=key, value_states=value, layer_idx=i)

        o = super().forward(
            input_ids=input_ids,
            attention_mask=attention_mask,
            position_ids=position_ids,
            past_key_values=past_key_values,
        )

        flattened_past_key_values_outputs = {
            "logits": o.logits,
        }
        output_past_key_values: DynamicCache = o.past_key_values
        for i, (key, value) in enumerate(
            zip(output_past_key_values.key_cache, output_past_key_values.value_cache)
        ):
            flattened_past_key_values_outputs[f"present.{i}.key"] = key
            flattened_past_key_values_outputs[f"present.{i}.value"] = value

        return flattened_past_key_values_outputs


# Constants
OUTPUT_FOLDER = "output"
TEXT_MODEL_NAME = "model.onnx"
TEMP_MODEL_OUTPUT_FOLDER = os.path.join(OUTPUT_FOLDER, "temp")
FINAL_MODEL_OUTPUT_FOLDER = os.path.join(OUTPUT_FOLDER, "onnx")


# Load model and processor
model_id = "hf-internal-testing/tiny-random-GlmForCausalLM"
model = PatchedGlmForCausalLM.from_pretrained(model_id).eval()
processor = AutoProcessor.from_pretrained(model_id)


# Save model configs and processor
model.config.save_pretrained(OUTPUT_FOLDER)
model.generation_config.save_pretrained(OUTPUT_FOLDER)
processor.save_pretrained(OUTPUT_FOLDER)
os.makedirs(TEMP_MODEL_OUTPUT_FOLDER, exist_ok=True)


# Configuration values
## Text model
text_config = model.config
num_heads = text_config.num_attention_heads
num_key_value_heads = text_config.num_key_value_heads
head_dim = text_config.head_dim
num_layers = text_config.num_hidden_layers
hidden_size = text_config.hidden_size


# Dummy input sizes
batch_size = 2
sequence_length = 16
past_sequence_length = 0

## Text inputs
dummy_past_key_values_kwargs = {
    f"past_key_values.{i}.{key}": torch.zeros(
        batch_size,
        num_key_value_heads,
        past_sequence_length,
        head_dim,
        dtype=torch.float32,
    )
    for i in range(num_layers)
    for key in ["key", "value"]
}
input_ids = torch.randint(
    0, text_config.vocab_size,
    (batch_size, sequence_length),
)
attention_mask = torch.ones(batch_size, sequence_length + past_sequence_length, dtype=torch.int64)
position_ids = torch.ones(batch_size, sequence_length, dtype=torch.int64)

text_inputs = dict(
    input_ids=input_ids,
    attention_mask=attention_mask,
    position_ids=position_ids,
    **dummy_past_key_values_kwargs,
)
text_inputs_positional = tuple(text_inputs.values())
text_outputs = model.forward(*text_inputs_positional)  # Test forward pass


# ONNX Exports
## Text model
TEXT_MODEL_OUTPUT_PATH=os.path.join(TEMP_MODEL_OUTPUT_FOLDER, TEXT_MODEL_NAME)
torch.onnx.export(
    model,
    args=text_inputs_positional,
    f=TEXT_MODEL_OUTPUT_PATH,
    export_params=True,
    opset_version=14,
    do_constant_folding=True,
    input_names=list(text_inputs.keys()),
    output_names=["logits"]
    + [f"present.{i}.{key}" for i in range(num_layers) for key in ["key", "value"]],
    dynamic_axes={
        "input_ids": {0: "batch_size", 1: "sequence_length"},
        "attention_mask": {0: "batch_size", 1: "total_sequence_length"},
        "position_ids": {0: "batch_size", 1: "sequence_length"},
        **{
            f"past_key_values.{i}.{key}": {0: "batch_size", 2: "past_sequence_length"}
            for i in range(num_layers)
            for key in ["key", "value"]
        },
        "logits": {0: "batch_size", 1: "sequence_length"},
        **{
            f"present.{i}.{key}": {0: "batch_size", 2: "total_sequence_length"}
            for i in range(num_layers)
            for key in ["key", "value"]
        },
    },
)

It produces this error:

Traceback (most recent call last):
  File "/workspaces/glm.py", line 110, in <module>
    torch.onnx.export(
  File "/usr/local/python/3.12.1/lib/python3.12/site-packages/torch/onnx/__init__.py", line 375, in export
    export(
  File "/usr/local/python/3.12.1/lib/python3.12/site-packages/torch/onnx/utils.py", line 502, in export
    _export(
  File "/usr/local/python/3.12.1/lib/python3.12/site-packages/torch/onnx/utils.py", line 1564, in _export
    graph, params_dict, torch_out = _model_to_graph(
                                    ^^^^^^^^^^^^^^^^
  File "/usr/local/python/3.12.1/lib/python3.12/site-packages/torch/onnx/utils.py", line 1117, in _model_to_graph
    graph = _optimize_graph(
            ^^^^^^^^^^^^^^^^
  File "/usr/local/python/3.12.1/lib/python3.12/site-packages/torch/onnx/utils.py", line 639, in _optimize_graph
    graph = _C._jit_pass_onnx(graph, operator_export_type)
            ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/usr/local/python/3.12.1/lib/python3.12/site-packages/torch/onnx/utils.py", line 1836, in _run_symbolic_function
    return symbolic_fn(graph_context, *inputs, **attrs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/usr/local/python/3.12.1/lib/python3.12/site-packages/torch/onnx/symbolic_helper.py", line 369, in wrapper
    return fn(g, *args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^
  File "/usr/local/python/3.12.1/lib/python3.12/site-packages/torch/onnx/symbolic_opset11.py", line 519, in cat
    return opset9.cat(g, tensor_list, dim)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/usr/local/python/3.12.1/lib/python3.12/site-packages/torch/onnx/symbolic_helper.py", line 281, in wrapper
    return fn(g, *args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^
  File "/usr/local/python/3.12.1/lib/python3.12/site-packages/torch/onnx/symbolic_opset9.py", line 575, in cat
    assert all(a)
AssertionError

Expected behavior

The model should export correctly. This may in fact be an ONNX bug, but not 100% sure. Models like Gemma can export correctly, so it seems to be GLM-specific.

Activity

github-actions

github-actions commented on Dec 30, 2024

@github-actions

This issue has been automatically marked as stale because it has not had recent activity. If you think this still needs to be addressed please comment on this thread.

Please note that issues that do not follow the contributing guidelines are likely to be ignored.

ArthurZucker

ArthurZucker commented on Feb 11, 2025

@ArthurZucker
Collaborator
xenova

xenova commented on Feb 11, 2025

@xenova
ContributorAuthor

I did a long deep-dive on this recently, and in fact, it appears to have been a bug with pytorch! pytorch/pytorch#145100 👀

I was able to apply a fix in Optimum by overriding the repeat_interleave op (huggingface/optimum#2162), so it may not necessarily be something we can (or should) fix in transformers.

Cyrilvallez

Cyrilvallez commented on Feb 17, 2025

@Cyrilvallez
Member

Hey @xenova! Thanks for looking deep into this. To be fair, repeat_interleave should be used directly in the RotaryEmbedding class, instead of the apply_rotary_pos_emb function. I initially implemented it the way it currently is because I wanted to use modular, and this was a very simple and afficient workaround. I did not know at the time that we already had existing model using interleave instead of cat in the RotaryEmbedding class, and manual interleave in rotate_half - but turns out Cohere does!
I have been thinking about simplifying the implementation using inheritance from Cohere instead. Would that solve the issue if the repeat_interleave is located directly in the RotaryEmbedding class?

xenova

xenova commented on Feb 17, 2025

@xenova
ContributorAuthor

Hey @xenova! Thanks for looking deep into this. To be fair, repeat_interleave should be used directly in the RotaryEmbedding class, instead of the apply_rotary_pos_emb function. I initially implemented it the way it currently is because I wanted to use modular, and this was a very simple and afficient workaround. I did not know at the time that we already had existing model using interleave instead of cat in the RotaryEmbedding class, and manual interleave in rotate_half - but turns out Cohere does! I have been thinking about simplifying the implementation using inheritance from Cohere instead. Would that solve the issue if the repeat_interleave is located directly in the RotaryEmbedding class?

Possibly... 👀 but to be sure, if you could post code snippets, I can test them out with the torch ONNX exporter.

Cyrilvallez

Cyrilvallez commented on Feb 17, 2025

@Cyrilvallez
Member

Sure, here are the class and functions with the modification:

class GlmRotaryEmbedding(nn.Module):
    def __init__(self, config: CohereConfig, device=None):
        super().__init__()
        # BC: "rope_type" was originally "type"
        if hasattr(config, "rope_scaling") and config.rope_scaling is not None:
            self.rope_type = config.rope_scaling.get("rope_type", config.rope_scaling.get("type"))
        else:
            self.rope_type = "default"
        self.max_seq_len_cached = config.max_position_embeddings
        self.original_max_seq_len = config.max_position_embeddings

        self.config = config
        self.rope_init_fn = ROPE_INIT_FUNCTIONS[self.rope_type]

        inv_freq, self.attention_scaling = self.rope_init_fn(self.config, device)
        self.register_buffer("inv_freq", inv_freq, persistent=False)
        self.original_inv_freq = self.inv_freq

    def _dynamic_frequency_update(self, position_ids, device):
        """
        dynamic RoPE layers should recompute `inv_freq` in the following situations:
        1 - growing beyond the cached sequence length (allow scaling)
        2 - the current sequence length is in the original scale (avoid losing precision with small sequences)
        """
        seq_len = torch.max(position_ids) + 1
        if seq_len > self.max_seq_len_cached:  # growth
            inv_freq, self.attention_scaling = self.rope_init_fn(self.config, device, seq_len=seq_len)
            self.register_buffer("inv_freq", inv_freq, persistent=False)  # TODO joao: may break with compilation
            self.max_seq_len_cached = seq_len

        if seq_len < self.original_max_seq_len and self.max_seq_len_cached > self.original_max_seq_len:  # reset
            # This .to() is needed if the model has been moved to a device after being initialized (because
            # the buffer is automatically moved, but not the original copy)
            self.original_inv_freq = self.original_inv_freq.to(device)
            self.register_buffer("inv_freq", self.original_inv_freq, persistent=False)
            self.max_seq_len_cached = self.original_max_seq_len

    @torch.no_grad()
    def forward(self, x, position_ids):
        if "dynamic" in self.rope_type:
            self._dynamic_frequency_update(position_ids, device=x.device)

        # Core RoPE block
        inv_freq_expanded = self.inv_freq[None, :, None].float().expand(position_ids.shape[0], -1, 1)
        position_ids_expanded = position_ids[:, None, :].float()
        # Force float32 (see https://github.com/huggingface/transformers/pull/29285)
        device_type = x.device.type
        device_type = device_type if isinstance(device_type, str) and device_type != "mps" else "cpu"
        with torch.autocast(device_type=device_type, enabled=False):
            freqs = (inv_freq_expanded.float() @ position_ids_expanded.float()).transpose(1, 2)
            emb = torch.repeat_interleave(freqs, 2, dim=-1)  # diff from Llama: we interleave() instead of cat()
            cos = emb.cos()
            sin = emb.sin()

        # Advanced RoPE types (e.g. yarn) apply a post-processing scaling factor, equivalent to scaling attention
        cos = cos * self.attention_scaling
        sin = sin * self.attention_scaling

        return cos.to(dtype=x.dtype), sin.to(dtype=x.dtype)


def rotate_half(x):
    """Rotates half the hidden dims of the input."""
    x1 = x[..., 0::2]
    x2 = x[..., 1::2]
    return torch.stack((-x2, x1), dim=-1).flatten(-2)


def apply_rotary_pos_emb(q, k, cos, sin, position_ids=None, unsqueeze_dim=1):
    """Applies Rotary Position Embedding to the query and key tensors.

    Args:
        q (`torch.Tensor`): The query tensor.
        k (`torch.Tensor`): The key tensor.
        cos (`torch.Tensor`): The cosine part of the rotary embedding.
        sin (`torch.Tensor`): The sine part of the rotary embedding.
        position_ids (`torch.Tensor`, *optional*):
            Deprecated and unused.
        unsqueeze_dim (`int`, *optional*, defaults to 1):
            The 'unsqueeze_dim' argument specifies the dimension along which to unsqueeze cos[position_ids] and
            sin[position_ids] so that they can be properly broadcasted to the dimensions of q and k. For example, note
            that cos[position_ids] and sin[position_ids] have the shape [batch_size, seq_len, head_dim]. Then, if q and
            k have the shape [batch_size, heads, seq_len, head_dim], then setting unsqueeze_dim=1 makes
            cos[position_ids] and sin[position_ids] broadcastable to the shapes of q and k. Similarly, if q and k have
            the shape [batch_size, seq_len, heads, head_dim], then set unsqueeze_dim=2.
    Returns:
        `tuple(torch.Tensor)` comprising of the query and key tensors rotated using the Rotary Position Embedding.
    """
    cos = cos.unsqueeze(unsqueeze_dim)
    sin = sin.unsqueeze(unsqueeze_dim)

    # Keep half or full tensor for later concatenation
    rotary_dim = cos.shape[-1]
    q_rot, q_pass = q[..., :rotary_dim], q[..., rotary_dim:]
    k_rot, k_pass = k[..., :rotary_dim], k[..., rotary_dim:]

    # Apply rotary embeddings on the first half or full tensor
    q_embed = (q_rot * cos) + (rotate_half(q_rot) * sin)
    k_embed = (k_rot * cos) + (rotate_half(k_rot) * sin)

    # Concatenate back to full shape
    q_embed = torch.cat([q_embed, q_pass], dim=-1)
    k_embed = torch.cat([k_embed, k_pass], dim=-1)
    return q_embed, k_embed
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Metadata

Metadata

Assignees

No one assigned

    Labels

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions

      Unable to export GLM models to ONNX · Issue #35021 · huggingface/transformers