Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Update replacing MultiHeadAttention with GroupQueryAttention #19882

Merged
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 5 additions & 0 deletions onnxruntime/contrib_ops/cuda/bert/group_query_attention.cc
Original file line number Diff line number Diff line change
@@ -110,6 +110,11 @@ Status GroupQueryAttention<T>::ComputeInternal(OpKernelContext* context) const {
parameters.do_rotary = do_rotary_;
parameters.rotary_interleaved = rotary_interleaved_;

if (do_rotary_ && (cos_cache == nullptr || sin_cache == nullptr)) {
return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT,
"cos_cache and sin_cache must be passed to GroupQueryAttention when do_rotary = 1");
}

TensorShapeVector output_shape(3);
output_shape[0] = static_cast<int64_t>(parameters.batch_size);
output_shape[1] = static_cast<int64_t>(sequence_length);
154 changes: 143 additions & 11 deletions onnxruntime/python/tools/transformers/convert_generation.py
Original file line number Diff line number Diff line change
@@ -1273,7 +1273,7 @@ def find_past_seq_len_usage(subg: GraphProto):


def replace_mha_with_gqa(
model: OnnxModel, attn_mask: str, kv_num_heads: int = 0, world_size: int = 1, window_size: int = 0
model: OnnxModel, attn_mask: str, kv_num_heads: int = 0, world_size: int = 1, window_size: int = -1
):
# Insert attention_mask subgraph to calculate shared inputs for all GroupQueryAttention nodes
#
@@ -1339,31 +1339,163 @@ def replace_mha_with_gqa(
)

# Replace MultiHeadAttention with GroupQueryAttention
#
# When replacing, fuse the following subgraph:
#
# root_input
# / | \
# MatMul MatMul MatMul
# | | |
# Add Add Add (optional Adds)
# | | |
# RotEmb RotEmb |
# \ | /
# MultiHeadAttention
#
# to this new subgraph:
#
# root_input
# |
# PackedMatMul (if possible)
# |
# PackedAdd (if possible)
# |
# GroupQueryAttention
#

mha_nodes = list(filter(lambda node: node.op_type == "MultiHeadAttention", model.model.graph.node))
for node in mha_nodes:
num_heads_mha = 0
for idx, node in enumerate(mha_nodes):
# Detect Q path to MHA
q_path_1 = model.match_parent_path(node, ["RotaryEmbedding", "Add", "MatMul"], [0, 0, 0])
q_path_2 = model.match_parent_path(node, ["RotaryEmbedding", "MatMul"], [0, 0])

q_rotary, q_add, q_matmul = None, None, None
if q_path_1 is not None:
q_rotary, q_add, q_matmul = q_path_1
elif q_path_2 is not None:
q_rotary, q_matmul = q_path_2

# Detect K path to MHA
k_path_1 = model.match_parent_path(node, ["RotaryEmbedding", "Add", "MatMul"], [1, 0, 0])
k_path_2 = model.match_parent_path(node, ["RotaryEmbedding", "MatMul"], [1, 0])

k_rotary, k_add, k_matmul = None, None, None
if k_path_1 is not None:
k_rotary, k_add, k_matmul = k_path_1
elif k_path_2 is not None:
k_rotary, k_matmul = k_path_2

# Detect V path to MHA
v_path_1 = model.match_parent_path(node, ["Add", "MatMul"], [2, 0])
v_path_2 = model.match_parent_path(node, ["MatMul"], [2])

v_add, v_matmul = None, None
if v_path_1 is not None:
v_add, v_matmul = v_path_1
elif v_path_2 is not None:
v_matmul = v_path_2[0]

# Get `interleaved` attribute from RotaryEmbedding
interleaved = 0
if q_rotary is not None and k_rotary is not None:
for att in q_rotary.attribute:
if att.name == "interleaved":
interleaved = att.i

# Get `num_heads` attribute from MHA
num_heads = 0
for att in node.attribute:
if att.name == "num_heads":
num_heads_mha = att.i
num_heads = att.i

# Check if root_input to Q/K/V paths is the same
root_input_is_same = q_matmul.input[0] == k_matmul.input[0] and k_matmul.input[0] == v_matmul.input[0]

# Check if Q/K/V paths all have bias or all don't have bias
all_paths_have_bias = q_add is not None and k_add is not None and v_add is not None
all_paths_have_no_bias = q_add is None and k_add is None and v_add is None

# Make PackedMatMul node if possible
q_input_to_attention, k_input_to_attention, v_input_to_attention = "", "", ""
if root_input_is_same and (all_paths_have_bias or all_paths_have_no_bias):
qw = NumpyHelper.to_array(model.get_initializer(q_matmul.input[1]))
kw = NumpyHelper.to_array(model.get_initializer(k_matmul.input[1]))
vw = NumpyHelper.to_array(model.get_initializer(v_matmul.input[1]))

dim = qw.shape[-1]
qkv_weight = np.stack((qw, kw, vw), axis=1).reshape(dim, 3 * dim)
qkv_weight = onnx.numpy_helper.from_array(qkv_weight, name=f"QKV_Weight_{idx}")
model.add_initializer(qkv_weight)

packed_matmul_node = onnx.helper.make_node(
"MatMul",
inputs=[q_matmul.input[0], qkv_weight.name],
outputs=[f"{qkv_weight.name}_output"],
name=model.create_node_name("MatMul"),
)
model.model.graph.node.extend([packed_matmul_node])
model.model.graph.node.remove(q_matmul)
model.model.graph.node.remove(k_matmul)
model.model.graph.node.remove(v_matmul)
q_input_to_attention = packed_matmul_node.output[0]

# Make PackedAdd node if possible
if all_paths_have_bias:
qb = NumpyHelper.to_array(model.get_initializer(q_add.input[1]))
kb = NumpyHelper.to_array(model.get_initializer(k_add.input[1]))
vb = NumpyHelper.to_array(model.get_initializer(v_add.input[1]))

dim = qb.shape[-1]
qkv_bias = np.stack((qb, kb, vb), axis=0).reshape(3 * dim)
qkv_bias = onnx.numpy_helper.from_array(qkv_bias, name=f"QKV_Bias_{idx}")
model.add_initializer(qkv_bias)
packed_add_node = onnx.helper.make_node(
"Add",
inputs=[packed_matmul_node.output[0], qkv_bias.name],
outputs=[f"{qkv_bias.name}_output"],
)
model.model.graph.node.extend([packed_add_node])
model.model.graph.node.remove(q_add)
model.model.graph.node.remove(k_add)
model.model.graph.node.remove(v_add)
q_input_to_attention = packed_add_node.output[0]

else:
q_input_to_attention = q_matmul.output[0]
k_input_to_attention = k_matmul.output[0]
v_input_to_attention = v_matmul.output[0]

# Make GQA node
gqa_node = onnx.helper.make_node(
"GroupQueryAttention",
inputs=[
node.input[0], # query
node.input[1], # key
node.input[2], # value
q_input_to_attention, # query
k_input_to_attention, # key
v_input_to_attention, # value
node.input[6], # past_key
node.input[7], # past_value
"seqlens_k", # seqlens_k (for attention_mask)
"total_seq_len", # total_seq_len (for attention_mask)
seqlen_k_cast_node.output[0], # seqlens_k (for attention mask)
total_seqlen_cast_node.output[0], # total_seq_len (for attention mask)
q_rotary.input[2] if q_rotary is not None else "", # cos_cache (for rotary embeddings)
q_rotary.input[3] if q_rotary is not None else "", # sin_cache (for rotary embeddings)
],
outputs=node.output,
name=node.name.replace("MultiHeadAttention", "GroupQueryAttention"),
domain="com.microsoft",
num_heads=num_heads_mha // world_size,
kv_num_heads=num_heads_mha // world_size if kv_num_heads == 0 else kv_num_heads // world_size,
num_heads=num_heads // world_size,
kv_num_heads=num_heads // world_size if kv_num_heads == 0 else kv_num_heads // world_size,
local_window_size=window_size,
do_rotary=int(q_rotary is not None and k_rotary is not None),
rotary_interleaved=interleaved,
)
model.model.graph.node.remove(node)
model.model.graph.node.extend([gqa_node])

if q_rotary is not None:
model.model.graph.node.remove(q_rotary)
if k_rotary is not None:
model.model.graph.node.remove(k_rotary)

return model


Original file line number Diff line number Diff line change
@@ -222,7 +222,8 @@ def get_msft_sample_inputs(
# Create past_key_values
# Each is of shape (batch_size, num_heads, past_sequence_length, head_size)
def get_past_kv_inputs(config: AutoConfig, batch_size: int, past_seq_len: int, use_fp16: bool, world_size: int = 1):
num_heads, head_size = config.num_key_value_heads // world_size, config.hidden_size // config.num_attention_heads
num_heads = config.num_key_value_heads // world_size
head_size = config.head_dim if hasattr(config, "head_dim") else config.hidden_size // config.num_attention_heads
torch_dtype = torch.float16 if use_fp16 else torch.float32
past_kv = [
(
@@ -286,7 +287,14 @@ def add_io_bindings(
):
io_binding = model.io_binding()

model_inputs = set(map(lambda i: i.name, model.get_inputs()))
for k, v in ort_inputs.items():
# Use this check to handle scenarios such as INT4 CUDA and FP16 CUDA models with
# GQA + RotaryEmbedding fusion where `position_ids` is removed as an ONNX model input
# but `position_ids` is used as a PyTorch model input
if k not in model_inputs:
continue

# Bind OrtValue inputs to device
if use_gqa and ("cache" in k or "past_key_values" in k):
if k not in kv_cache_ortvalues:
Loading
Oops, something went wrong.