Open
Description
System Info
transformers
version: 4.47.0.dev0- Platform: Linux-6.5.0-1025-azure-x86_64-with-glibc2.31
- Python version: 3.12.1
- Huggingface_hub version: 0.26.1
- Safetensors version: 0.4.5
- Accelerate version: 1.0.1
- Accelerate config: not found
- PyTorch version (GPU?): 2.5.0+cu124 (False)
- Tensorflow version (GPU?): not installed (NA)
- Flax version (CPU?/GPU?/TPU?): not installed (NA)
- Jax version: not installed
- JaxLib version: not installed
- Using distributed or parallel set-up in script?:
Who can help?
Information
- The official example scriptsMy own modified scripts
Tasks
- An officially supported task in the
examples
folder (such as GLUE/SQuAD, ...)My own task or dataset (give details below)
Reproduction
Run the following script:
import os
import torch
from transformers import (
AutoProcessor,
GlmForCausalLM,
DynamicCache,
)
class PatchedGlmForCausalLM(GlmForCausalLM):
def forward(self, *args):
input_ids, attention_mask, position_ids, *past_key_values_args = args
# Convert past_key_values list to DynamicCache
if len(past_key_values_args) == 0:
past_key_values = None
else:
past_key_values = DynamicCache(self.config.num_hidden_layers)
for i in range(self.config.num_hidden_layers):
key = past_key_values_args.pop(0)
value = past_key_values_args.pop(0)
past_key_values.update(key_states=key, value_states=value, layer_idx=i)
o = super().forward(
input_ids=input_ids,
attention_mask=attention_mask,
position_ids=position_ids,
past_key_values=past_key_values,
)
flattened_past_key_values_outputs = {
"logits": o.logits,
}
output_past_key_values: DynamicCache = o.past_key_values
for i, (key, value) in enumerate(
zip(output_past_key_values.key_cache, output_past_key_values.value_cache)
):
flattened_past_key_values_outputs[f"present.{i}.key"] = key
flattened_past_key_values_outputs[f"present.{i}.value"] = value
return flattened_past_key_values_outputs
# Constants
OUTPUT_FOLDER = "output"
TEXT_MODEL_NAME = "model.onnx"
TEMP_MODEL_OUTPUT_FOLDER = os.path.join(OUTPUT_FOLDER, "temp")
FINAL_MODEL_OUTPUT_FOLDER = os.path.join(OUTPUT_FOLDER, "onnx")
# Load model and processor
model_id = "hf-internal-testing/tiny-random-GlmForCausalLM"
model = PatchedGlmForCausalLM.from_pretrained(model_id).eval()
processor = AutoProcessor.from_pretrained(model_id)
# Save model configs and processor
model.config.save_pretrained(OUTPUT_FOLDER)
model.generation_config.save_pretrained(OUTPUT_FOLDER)
processor.save_pretrained(OUTPUT_FOLDER)
os.makedirs(TEMP_MODEL_OUTPUT_FOLDER, exist_ok=True)
# Configuration values
## Text model
text_config = model.config
num_heads = text_config.num_attention_heads
num_key_value_heads = text_config.num_key_value_heads
head_dim = text_config.head_dim
num_layers = text_config.num_hidden_layers
hidden_size = text_config.hidden_size
# Dummy input sizes
batch_size = 2
sequence_length = 16
past_sequence_length = 0
## Text inputs
dummy_past_key_values_kwargs = {
f"past_key_values.{i}.{key}": torch.zeros(
batch_size,
num_key_value_heads,
past_sequence_length,
head_dim,
dtype=torch.float32,
)
for i in range(num_layers)
for key in ["key", "value"]
}
input_ids = torch.randint(
0, text_config.vocab_size,
(batch_size, sequence_length),
)
attention_mask = torch.ones(batch_size, sequence_length + past_sequence_length, dtype=torch.int64)
position_ids = torch.ones(batch_size, sequence_length, dtype=torch.int64)
text_inputs = dict(
input_ids=input_ids,
attention_mask=attention_mask,
position_ids=position_ids,
**dummy_past_key_values_kwargs,
)
text_inputs_positional = tuple(text_inputs.values())
text_outputs = model.forward(*text_inputs_positional) # Test forward pass
# ONNX Exports
## Text model
TEXT_MODEL_OUTPUT_PATH=os.path.join(TEMP_MODEL_OUTPUT_FOLDER, TEXT_MODEL_NAME)
torch.onnx.export(
model,
args=text_inputs_positional,
f=TEXT_MODEL_OUTPUT_PATH,
export_params=True,
opset_version=14,
do_constant_folding=True,
input_names=list(text_inputs.keys()),
output_names=["logits"]
+ [f"present.{i}.{key}" for i in range(num_layers) for key in ["key", "value"]],
dynamic_axes={
"input_ids": {0: "batch_size", 1: "sequence_length"},
"attention_mask": {0: "batch_size", 1: "total_sequence_length"},
"position_ids": {0: "batch_size", 1: "sequence_length"},
**{
f"past_key_values.{i}.{key}": {0: "batch_size", 2: "past_sequence_length"}
for i in range(num_layers)
for key in ["key", "value"]
},
"logits": {0: "batch_size", 1: "sequence_length"},
**{
f"present.{i}.{key}": {0: "batch_size", 2: "total_sequence_length"}
for i in range(num_layers)
for key in ["key", "value"]
},
},
)
It produces this error:
Traceback (most recent call last):
File "/workspaces/glm.py", line 110, in <module>
torch.onnx.export(
File "/usr/local/python/3.12.1/lib/python3.12/site-packages/torch/onnx/__init__.py", line 375, in export
export(
File "/usr/local/python/3.12.1/lib/python3.12/site-packages/torch/onnx/utils.py", line 502, in export
_export(
File "/usr/local/python/3.12.1/lib/python3.12/site-packages/torch/onnx/utils.py", line 1564, in _export
graph, params_dict, torch_out = _model_to_graph(
^^^^^^^^^^^^^^^^
File "/usr/local/python/3.12.1/lib/python3.12/site-packages/torch/onnx/utils.py", line 1117, in _model_to_graph
graph = _optimize_graph(
^^^^^^^^^^^^^^^^
File "/usr/local/python/3.12.1/lib/python3.12/site-packages/torch/onnx/utils.py", line 639, in _optimize_graph
graph = _C._jit_pass_onnx(graph, operator_export_type)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/usr/local/python/3.12.1/lib/python3.12/site-packages/torch/onnx/utils.py", line 1836, in _run_symbolic_function
return symbolic_fn(graph_context, *inputs, **attrs)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/usr/local/python/3.12.1/lib/python3.12/site-packages/torch/onnx/symbolic_helper.py", line 369, in wrapper
return fn(g, *args, **kwargs)
^^^^^^^^^^^^^^^^^^^^^^
File "/usr/local/python/3.12.1/lib/python3.12/site-packages/torch/onnx/symbolic_opset11.py", line 519, in cat
return opset9.cat(g, tensor_list, dim)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/usr/local/python/3.12.1/lib/python3.12/site-packages/torch/onnx/symbolic_helper.py", line 281, in wrapper
return fn(g, *args, **kwargs)
^^^^^^^^^^^^^^^^^^^^^^
File "/usr/local/python/3.12.1/lib/python3.12/site-packages/torch/onnx/symbolic_opset9.py", line 575, in cat
assert all(a)
AssertionError
Expected behavior
The model should export correctly. This may in fact be an ONNX bug, but not 100% sure. Models like Gemma can export correctly, so it seems to be GLM-specific.
Activity
github-actions commentedon Dec 30, 2024
This issue has been automatically marked as stale because it has not had recent activity. If you think this still needs to be addressed please comment on this thread.
Please note that issues that do not follow the contributing guidelines are likely to be ignored.
ArthurZucker commentedon Feb 11, 2025
cc @Cyrilvallez
xenova commentedon Feb 11, 2025
I did a long deep-dive on this recently, and in fact, it appears to have been a bug with pytorch! pytorch/pytorch#145100 👀
I was able to apply a fix in Optimum by overriding the
repeat_interleave
op (huggingface/optimum#2162), so it may not necessarily be something we can (or should) fix in transformers.Cyrilvallez commentedon Feb 17, 2025
Hey @xenova! Thanks for looking deep into this. To be fair,
repeat_interleave
should be used directly in the RotaryEmbedding class, instead of theapply_rotary_pos_emb
function. I initially implemented it the way it currently is because I wanted to use modular, and this was a very simple and afficient workaround. I did not know at the time that we already had existing model using interleave instead ofcat
in the RotaryEmbedding class, and manual interleave inrotate_half
- but turns out Cohere does!I have been thinking about simplifying the implementation using inheritance from Cohere instead. Would that solve the issue if the
repeat_interleave
is located directly in the RotaryEmbedding class?xenova commentedon Feb 17, 2025
Possibly... 👀 but to be sure, if you could post code snippets, I can test them out with the torch ONNX exporter.
Cyrilvallez commentedon Feb 17, 2025
Sure, here are the class and functions with the modification: