From 7cea166f00b1a028908d854a2425b692ba54d384 Mon Sep 17 00:00:00 2001 From: Tianlei Wu Date: Wed, 31 Jul 2024 09:01:05 -0700 Subject: [PATCH] [CUDA] Fix MultiHeadAttention thread safe and bias support (#21498) ### Description #### Issues Fixed (1) **TRT cross attention not thread safe**. [Core changes like this](https://github.com/microsoft/onnxruntime/commit/6fd7aba3d4f27089de1a4ead86a2ae7e667c18b6) are used to make it thread-safe: * Add an once_flag to CumulatedSequenceLengthCache to make sure it is only initialized once; and change the cache to be read only after initialization. Previously, the content is not read-only so it might be changed by other thread and potentially cause buffer overrun. * The kernel initialization is not guarded (Although the factory of kernel loading has static mutex to guard multiple threading), so the mutable variable might be set by two different threads at the same time. Add an once_flag to avoid that. This requires need some workspace computation change as well. So I did not create a separated pull request. (2) **Bias for cross attention** That scenario has assumption that only query has bias, but not for key and value. However, such assumption is not verified in runtime and there was no comment of assumption, and there was no test case so the support of scenario was disabled by mistake. Actually, the scenario is used in whisper model (TODO: we shall add tests for whisper to CI pipeline, and also update fusion script to verify such assumptions if needed.) CUDA/CPU kernels supports bias for cross attention as long as bias is zero for key and value. I updated the check to support the scenario and added comments wherever there is such assumption. (3) **Fallback support** Previously, unfused kernel did not support packed qkv and packed kv formats. That means some case might fail since there is no fallback. I added new AddBiasTranpose cuda kernels for them to support fallback, so that all supported cases will not fail. #### Improvements (4) **QKV workspace size**. The logic for no_qkv_workspace could be easily out of sync since related code are scattered in different source files. I refactor the code to move all related code to one file (attention_prepare_qkv.cu) and add asserts, so that the logic can be in sync. (5) **Remove confusing concept of pass past in kv** parameters.pass_past_in_kv is confusing since the k/v in cross attention is not past state. Remove it and use parameters.qkv_format == Q_K_V_BSNH_BNSH_BNSH instead. New code does not use past_key/past_value for cross attention, so the logic is more clear. (6) **More coverage and less workspace and less transpose of flash and efficient attention** Previously, there is one condition does not run flash or efficient attention: ``` bool past_no_bias = (pass_key_value_as_past || past_key != nullptr || present_key != nullptr) && bias == nullptr; ``` After this change, we can use flash and efficient attention for the case, and also less workspace. For example, cross attention with bias, the original code uses two additional workspaces: ``` transpose: past_key (BxNxSxH) => temp_k_workspace (BxSxNxH), past_value (BxNxSxH_v) => temp_v_workspace (BxSxNxH_v) add bias: query => q, temp_k_workspace => k, temp_v_workspace => v ``` New logic is like ``` if (has bias) Add bias to query, key, value, and store in q, k, v workspace else Use query, key and value directly as q, k and v in kernel ``` We can see that, we do not need allocate temp_k_workspace and temp_v_workspace so use less memory. New code saved two transposes in this case. Flash and efficient attention supports BSNH or BNSH formats for k and v. In old code, k/v are also converted to BSNH format. Some is not necessary. I do some change to convert k/v to BSNH or BNSH case by case. So that there are more cases can be covered by flash or efficient attention to improve performance. (6) **Debugging support** Previously, there is less debug info. In this change, I add a flag for debug info in the AttentionData. So that we can output debug info during the processing. Also add functions to consolidate the dumping of inputs, QKV processing and outputs; Add an environment variable `ORT_ENABLE_GPU_DUMP` to allow disable dumping from cuda kernel. #### Summary of changes (1) Refactoring the CheckInputs, and pass in operator type. (2) Refactoring the PrepareQKV to support fallback for packed qkv or packed kv inputs. (3) Change a few case of PrepareQKV to allow more case covered by flash and efficient attention. (4) use parameters.qkv_format == Q_K_V_BSNH_BNSH_BNSH to replace parameters.pass_past_in_kv (5) Allow bias input for Q_K_V_BSNH_BNSH_BNSH, and add comments of assumption that key/value has no bias in this case. (6) Fix thread-safe issue in CumulatedSequenceLengthCache handling. (7) Add test cases to cover all supported scenarios. Current support scenarios for MultiHeadAttention for CUDA/CPU: | Q | K | V | pastK| pastV | presentK| presentV | Bias | Op desc | ---- | ---- | ---- | ------ | ----- | --------- | -------- | -----|--------- | BSNH | BLNH| BLNH| - | - | - | - | QKV | not packed | BLN3H| - | - | - | - | - | - | QKV | qkv packed
not support in CPU | BSNH | BLN2H| - | - | - | - | - | --- | kv packed
not support in CPU | BSNH | BNLH| BNLH| - | - | - | - | Q-- | cross attention
bias for Q only | BSNH | BLNH | BLNH | - | - | BNTH | BNTH | QKV | no past
only present | BSNH | BLNH | BLNH | BNPH | BNPH | BNTH | BNTH | QKV | past and present
(not share buffer) ### Motivation and Context https://github.com/microsoft/onnxruntime/issues/18854 --- .../contrib_ops/cpu/bert/attention_base.cc | 1 - .../contrib_ops/cpu/bert/attention_common.h | 13 +- .../cpu/bert/multihead_attention.cc | 15 +- .../cpu/bert/multihead_attention_helper.h | 574 +++++++----- .../cuda/bert/add_bias_transpose.cu | 114 +++ .../cuda/bert/add_bias_transpose.h | 54 +- .../contrib_ops/cuda/bert/attention.cc | 4 +- .../contrib_ops/cuda/bert/attention_impl.cu | 202 ++-- .../contrib_ops/cuda/bert/attention_impl.h | 55 +- .../cuda/bert/attention_kernel_options.h | 1 + .../cuda/bert/attention_kv_cache.cu | 73 +- .../cuda/bert/attention_prepare_qkv.cu | 864 ++++++++++++------ .../cuda/bert/attention_transpose.cu | 6 + .../decoder_masked_multihead_attention.cc | 13 +- ...decoder_masked_multihead_attention_impl.cu | 3 + .../cuda/bert/multihead_attention.cc | 221 ++--- .../cuda/bert/multihead_attention.h | 6 + .../cuda/bert/packed_attention_impl.cu | 22 +- .../bert/packed_multihead_attention_impl.cu | 58 +- .../quantization/attention_quantization.cc | 4 +- .../cuda/utils/dump_cuda_tensor.cc | 9 + .../contrib_ops/cuda/utils/dump_cuda_tensor.h | 2 +- .../contrib_ops/rocm/bert/attention_impl.cu | 10 +- .../contrib_ops/rocm/bert/attention_impl.h | 6 - .../rocm/bert/multihead_attention.cu | 8 +- .../tools/transformers/fusion_attention.py | 3 + .../test/python/transformers/benchmark_mha.py | 99 +- .../test/python/transformers/test_mha.py | 346 ++++--- 28 files changed, 1729 insertions(+), 1057 deletions(-) diff --git a/onnxruntime/contrib_ops/cpu/bert/attention_base.cc b/onnxruntime/contrib_ops/cpu/bert/attention_base.cc index 515a967aa2386..f7d8fedc734e4 100644 --- a/onnxruntime/contrib_ops/cpu/bert/attention_base.cc +++ b/onnxruntime/contrib_ops/cpu/bert/attention_base.cc @@ -258,7 +258,6 @@ Status AttentionBase::CheckInputs(const TensorShape& input_shape, output_parameters->scale = scale_; output_parameters->mask_type = mask_type; output_parameters->broadcast_res_pos_bias = broadcast_res_pos_bias; - output_parameters->pass_past_in_kv = false; output_parameters->qkv_format = Q_K_V_BNSH; } diff --git a/onnxruntime/contrib_ops/cpu/bert/attention_common.h b/onnxruntime/contrib_ops/cpu/bert/attention_common.h index 55292b35e1e38..88127387d08ea 100644 --- a/onnxruntime/contrib_ops/cpu/bert/attention_common.h +++ b/onnxruntime/contrib_ops/cpu/bert/attention_common.h @@ -6,6 +6,12 @@ namespace onnxruntime { namespace contrib { +enum AttentionType { + kAttention, + kMultiHeadAttention, + kDecoderMaskedMultiHeadAttention, +}; + enum AttentionMaskType { MASK_NONE, // No mask MASK_1D_KEY_SEQ_LEN, // [batch_size], key sequence length @@ -24,10 +30,12 @@ enum AttentionQkvFormat { UNKNOWN, // enum value not set, or depends on qkv projection implementation details Q_K_V_BNSH, // for non-packed qkv, permuted Q_K_V_BSNH, // for non-packed qkv, not permuted, used by memory efficient attention or MultiHeadAttention - QKV_BSN3H, // for TRT fused attention, qkv are packed + Q_K_V_BSNH_BNSH_BNSH, // for cross attention, k and v are permuted Q_K_V_BNSH_QKV_BS3NH, // for TRT fused causal attention, data has two formats (qkv is 3BNSH, gemm_buffer is BS3NH) - Q_KV_BSNH_BSN2H, // for TRT fused cross attention, kv are packed Q_K_V_TNH, // for memory efficient attention, qkv are not packed, and paddings are removed. + Q_KV_BSNH_BSN2H, // for TRT fused cross attention, kv are packed + QKV_BSN3H, // for TRT fused attention, qkv are packed + QKV_BS3NH, // for DecoderMaskedMultiHeadAttention, qkv are packed QKV_TN3H, // for TRT fused attention, qkv are packed and paddings are removed }; @@ -61,7 +69,6 @@ struct AttentionParameters { bool past_present_share_buffer; bool do_rotary; bool broadcast_res_pos_bias; - bool pass_past_in_kv; float mask_filter_value; float scale; bool use_tf32; diff --git a/onnxruntime/contrib_ops/cpu/bert/multihead_attention.cc b/onnxruntime/contrib_ops/cpu/bert/multihead_attention.cc index 9677c30f22d8a..0d77376779230 100644 --- a/onnxruntime/contrib_ops/cpu/bert/multihead_attention.cc +++ b/onnxruntime/contrib_ops/cpu/bert/multihead_attention.cc @@ -85,7 +85,7 @@ Status MultiHeadAttention::Compute(OpKernelContext* context) const { scale_, is_unidirectional_, past_present_share_buffer, - false)); + kMultiHeadAttention)); const int batch_size = parameters.batch_size; const int q_sequence_length = parameters.sequence_length; @@ -121,20 +121,13 @@ Status MultiHeadAttention::Compute(OpKernelContext* context) const { AllocatorPtr allocator; ORT_RETURN_IF_ERROR(context->GetTempSpaceAllocator(&allocator)); - // For each of Q/K/V, there are multiple scenarios: - // 1) Combined QKV bias is null - // a) Q/K/V is (B, S, D) - // b) Q/K/V is (B, S, N, H) - // 2) No packed QKV in Q - // a) Q/K/V has seq_len = 1 - // b) Q/K/V has seq_len > 1 - OrtValue Q; ORT_RETURN_IF_ERROR(MaybeTransposeToBNSHAndAddBias( context, allocator, batch_size, num_heads_, q_sequence_length, qk_head_size, query, bias, q_bias_offset, Q)); - if (parameters.pass_past_in_kv) { // key and value in BNSH format - assert(bias == nullptr); + if (parameters.qkv_format == Q_K_V_BSNH_BNSH_BNSH) { + // For cross attention with k and v in BNSH format, we assume that bias for key and value are zeros. + // So we don't need to add bias for key and value here. assert(past_key == nullptr); assert(past_value == nullptr); return ApplyAttention(Q.GetMutable()->MutableData(), diff --git a/onnxruntime/contrib_ops/cpu/bert/multihead_attention_helper.h b/onnxruntime/contrib_ops/cpu/bert/multihead_attention_helper.h index bd7ab09659170..cfb8d36843777 100644 --- a/onnxruntime/contrib_ops/cpu/bert/multihead_attention_helper.h +++ b/onnxruntime/contrib_ops/cpu/bert/multihead_attention_helper.h @@ -11,6 +11,232 @@ namespace onnxruntime { namespace contrib { namespace multihead_attention_helper { +template +Status Check_QKV(const T* packed_qkv, AttentionQkvFormat& qkv_format) { + const auto& query_dims = packed_qkv->Shape().GetDims(); + if (query_dims.size() == 3) { + // Packed qkv used by DecoderMaskedMultiHeadAttention. Query shape is (B, S, 3D), no key and value. + qkv_format = AttentionQkvFormat::QKV_BS3NH; + } else { + assert(query_dims.size() == 5); + if (static_cast(query_dims[3]) != 3) { + return ORT_MAKE_STATUS( + ONNXRUNTIME, INVALID_ARGUMENT, + "Expect 'query' shape (batch_size, sequence_length, num_heads, 3, head_size) for packed qkv"); + } + + qkv_format = AttentionQkvFormat::QKV_BSN3H; + } + + return Status::OK(); +} + +template +Status Check_Q_KV(const T* query, const T* packed_kv, int num_heads, int head_size, + AttentionQkvFormat& qkv_format, int& kv_sequence_length) { + const auto& query_dims = query->Shape().GetDims(); + const auto& key_dims = packed_kv->Shape().GetDims(); + if (query_dims.size() != 3) { + return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, + "Expect rank of query be 3 for packed kv"); + } + + if (key_dims.size() != 5) { + return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, + "Expect rank of key be 5 for packed kv"); + } + + if (key_dims[0] != query_dims[0] || + static_cast(key_dims[2]) != num_heads || + static_cast(key_dims[3]) != 2 || + static_cast(key_dims[4]) != head_size) { + return ORT_MAKE_STATUS( + ONNXRUNTIME, INVALID_ARGUMENT, + "Expect 'key' shape (batch_size, kv_sequence_length, num_heads, 2, head_size) for packed kv"); + } + + qkv_format = AttentionQkvFormat::Q_KV_BSNH_BSN2H; + kv_sequence_length = static_cast(key_dims[1]); + return Status::OK(); +} + +template +Status Check_Q_K_V(const T* query, const T* key, const T* value, int num_heads, int head_size, + AttentionQkvFormat& qkv_format, int& kv_sequence_length, int& v_hidden_size) { + const auto& query_dims = query->Shape().GetDims(); + const auto& key_dims = key->Shape().GetDims(); + const auto& value_dims = value->Shape().GetDims(); + if (query_dims.size() != 3) { + return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, + "Expect rank of query be 3 for packed kv"); + } + + if (key_dims.size() != value_dims.size() || (key_dims.size() != 3 && value_dims.size() != 4)) { + return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, + "Expect rank of key and value be same, and either 3 or 4"); + } + + if (key_dims[0] != query_dims[0] || value_dims[0] != query_dims[0]) { + return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, + "Input 'query', 'key' and 'value' shall have same dim 0 (batch_size)"); + } + + if (key_dims.size() == 3) { + if (key_dims[2] != query_dims[2]) { + return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, + "Input 'query' and 'key' shall have same dim 2 (hidden_size)"); + } + + if (key_dims[1] != value_dims[1]) { + return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, + "Input 'key' and 'value' shall have same dim 1 (kv_sequence_length)"); + } + + qkv_format = AttentionQkvFormat::Q_K_V_BSNH; + kv_sequence_length = static_cast(key_dims[1]); + v_hidden_size = static_cast(value_dims[2]); + } else { // key_dims.size() == 4 + if (value->Shape() != key->Shape() || + static_cast(key_dims[1]) != num_heads || + static_cast(key_dims[3]) != head_size) { + return ORT_MAKE_STATUS( + ONNXRUNTIME, INVALID_ARGUMENT, + "Input 'key' and 'value' shall have same shape (batch_size, num_heads, kv_sequence_length, head_size)"); + } + + qkv_format = AttentionQkvFormat::Q_K_V_BSNH_BNSH_BNSH; + kv_sequence_length = static_cast(key_dims[2]); + v_hidden_size = static_cast(value_dims[1]) * static_cast(value_dims[3]); + } + + return Status::OK(); +} + +template +Status CheckPast(const T* past_key, const T* past_value, const T* past_seq_len, + int batch_size, int num_heads, int head_size, bool past_present_share_buffer, + int& past_sequence_length, int& max_sequence_length) { + const auto& past_key_dims = past_key->Shape().GetDims(); + const auto& past_value_dims = past_value->Shape().GetDims(); + + if (past_key_dims.size() != 4) { + return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, + "Input 'past_key' is expected to have 4 dimensions, got ", + past_key_dims.size()); + } + if (past_value_dims.size() != 4) { + return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, + "Input 'past_value' is expected to have 4 dimensions, got ", + past_value_dims.size()); + } + + if (past_key_dims[0] != batch_size) { + return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, + "Input 'past_key' dimension 0 should be batch_size, got ", + past_key_dims[0]); + } + if (past_value_dims[0] != batch_size) { + return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, + "Input 'past_value' dimension 0 should be batch_size, got ", + past_value_dims[0]); + } + + if (past_key_dims[1] != num_heads) { + return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, + "Input 'past_key' dimension 1 should be same as number of heads, got ", + past_key_dims[1]); + } + if (past_value_dims[1] != num_heads) { + return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, + "Input 'past_value' dimension 1 should be same as number of heads, got ", + past_value_dims[1]); + } + if (past_key_dims[2] != past_value_dims[2]) { + return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, + "Input 'past_key' and 'past_value' shall have same dim 2 (past_sequence_length). ", + past_key_dims[2], " vs ", past_value_dims[2]); + } + if (past_key_dims[3] != head_size) { + return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, + "Input 'past_key' dimension 3 should be same as head_size, got ", + past_key_dims[3]); + } + if (past_value_dims[3] != head_size) { + return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, + "Input 'past_value' dimension 3 should be same as head_size, got ", + past_value_dims[3]); + } + past_sequence_length = static_cast(past_key_dims[2]); + if (past_present_share_buffer) { + max_sequence_length = static_cast(past_key_dims[2]); + if (past_seq_len == nullptr || !onnxruntime::IsScalarOr1ElementVector(past_seq_len)) { + return ORT_MAKE_STATUS( + ONNXRUNTIME, INVALID_ARGUMENT, + "past_sequence_length tensor must be of one element when past_present_share_buffer is set"); + } + past_sequence_length = *((*past_seq_len).template Data()); + } + return Status::OK(); +} + +template +Status CheckRelativePositionBias( + const T* relative_position_bias, int batch_size, int num_heads, int sequence_length, int total_sequence_length, + bool& broadcast_res_pos_bias) { + const auto& relative_position_bias_dims = relative_position_bias->Shape().GetDims(); + + if (relative_position_bias_dims.size() != 4) { + return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, + "Input 'relative_position_bias' is expected to have 4 dimensions, got ", + relative_position_bias_dims.size()); + } + if (relative_position_bias_dims[0] != batch_size && relative_position_bias_dims[0] != 1) { + return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, + "Input 'relative_position_bias' dimension 0 should be batch_size or 1, got ", + relative_position_bias_dims[0]); + } + if (relative_position_bias_dims[0] == 1) { + broadcast_res_pos_bias = true; + } + if (relative_position_bias_dims[1] != num_heads) { + return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, + "Input 'relative_position_bias' dimension 1 should be same as number of heads, got ", + relative_position_bias_dims[1]); + } + if (relative_position_bias_dims[2] != sequence_length) { + return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, + "Input 'relative_position_bias' dimension 2 should be same as sequence_length, got ", + relative_position_bias_dims[2]); + } + if (relative_position_bias_dims[3] != total_sequence_length) { + return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, + "Input 'relative_position_bias' dimension 3 should be same as total_sequence_length, got ", + relative_position_bias_dims[3]); + } + return Status::OK(); +} + +template +AttentionMaskType GetMaskType(const T* key_padding_mask, int batch_size, int sequence_length, int total_sequence_length) { + AttentionMaskType mask_type = AttentionMaskType::MASK_UNKNOWN; + const auto& mask_dims = key_padding_mask->Shape().GetDims(); + if (mask_dims.size() == 1) { + if (mask_dims[0] == static_cast(batch_size)) { + mask_type = AttentionMaskType::MASK_1D_KEY_SEQ_LEN; + } else if (mask_dims[0] == static_cast(3) * static_cast(batch_size) + static_cast(2)) { + mask_type = AttentionMaskType::MASK_1D_KEY_SEQ_LEN_START; + } + } else if (mask_dims.size() == 2 && mask_dims[0] == static_cast(batch_size) && + mask_dims[1] == static_cast(total_sequence_length)) { + mask_type = AttentionMaskType::MASK_2D_KEY_PADDING; + } else if (mask_dims.size() == 3 && mask_dims[0] == static_cast(batch_size) && + mask_dims[1] == static_cast(sequence_length) && + mask_dims[2] == static_cast(total_sequence_length)) { + mask_type = AttentionMaskType::MASK_3D_ATTENTION; + } + return mask_type; +} + template Status CheckInputs(const T* query, const T* key, @@ -27,176 +253,128 @@ Status CheckInputs(const T* query, float scale, bool is_unidirectional, bool past_present_share_buffer, - bool dmmha_packing) { - // key_padding_mask (K/V) : (B) or (2*B + 1) or (B, L) or None - // relative_position_bias : (B, 1, S, L) - // past_key : (B, N, S*, H) - // past_value : (B, N, S*, H) - // When no packing for q/k/v: + AttentionType operator_type) { + // --------------------------------------------------------------- + // Notations: + // B: batch_size + // N: num_heads + // H: head_size (V might have different head size than Q and K) + // D: hidden_size = N * H + // S: q_sequence_length + // P: past_sequence_length + // L: kv_sequence_length + // T: total_sequence_length = P + L + // M: max_sequence_length + // --------------------------------------------------------------- + // MultiHeadAttention inputs: + // --------------------------------------------------------------- + // Q_K_V_BSNH - no packing: // query (Q) : (B, S, D) - // key (K) : (B, L, D) or (B, N, S*, H) - // value (V) : (B, L, D_v) or (B, N, S*, H) - // bias (Q/K/V) : (D + D + D_v) - // When packed kv is used: + // key (K) : (B, L, D) + // value (V) : (B, L, D_v) + // Q_K_V_BSNH_BNSH_BNSH - cross attention (kv cache is not used, L == T, D == D_v): // query (Q) : (B, S, D) - // key (K) : (B, L, N, 2, H) - // value (V) : None - // bias (Q/K/V) : None - // When packed qkv is used: - // query (Q) : (B, L, N, 3, H) or (B, S, 3*D) + // key (K) : (B, N, L, H) + // value (V) : (B, N, L, H) + // Q_KV_BSNH_BSN2H - packed kv (kv cache is not used, bias is not allowed for packed kv): + // query (Q) : (B, S, D) + // key (K/V) : (B, L, N, 2, H) + // value : None + // QKV_BSN3H - packed qkv (kv cache is not used, S == L, D == D_v): + // query (Q/K/V) : (B, S, N, 3, H) + // key : None + // value : None + // + // Other inputs: + // bias (Q/K/V) : None or (D + D + D_v) + // key_padding_mask (K/V) : (B) or (3 * B + 2) or (B, T) or (B, S, T) + // relative_position_bias : (B, N, S, T) or (1, N, S, T) + // past_key : (B, N, P, H) or None. Past state is only allowed for Q_K_V_BSNH. + // past_value : (B, N, P, H) or None. Past state is only allowed for Q_K_V_BSNH. + // --------------------------------------------------------------- + // DecoderMaskedMultiHeadAttention inputs (S == 1, D == D_v): + // --------------------------------------------------------------- + // Q_K_V_BSNH - no packing: + // query (Q) : (B, S, D) + // key (K) : (B, L, D) + // value (V) : (B, L, D) + // Q_K_V_BSNH_BNSH_BNSH - cross attention (kv cache and relative_position_bias are not used. L == T): + // query (Q) : (B, S, D) + // key (K) : (B, N, L, H) + // value (V) : (B, N, L, H) + // QKV_BS3NH - packed qkv (S == L): + // query (Q) : (B, S, 3 * D) // key (K) : None // value (V) : None - // bias (Q/K/V) : None or (D + D + D_v) - - AttentionQkvFormat qkv_format; + // + // Other inputs: + // bias (Q/K/V) : None or (3 * D) + // key_padding_mask (K/V) : None or (B, T) + // relative_position_bias : (1, N, S, T), or (B, N, S, T) where only 1 x N x S x T data is used in CUDA. + // + // The following inputs are not used in cross attention (so they are None for cross attention): + // past_key : (B, N, P, H), or (B, N, M, H) when past_present_share_buffer is True. + // For CUDA, past_present_share_buffer is always True. ROCm supports both. + // past_value : (B, N, P, H), or (B, N, M, H) when past_present_share_buffer is True. + // For CUDA, past_present_share_buffer is always True. ROCm supports both. + // past_sequence_length : scalar (1) when past_present_share_buffer is True. + // CUDA version has extra inputs (beam_width, cache_indirection) that are not checked in the class. + // For ROCm, see contrib_ops/rocm/bert/batched_gemm_softmax_gemm_permute_pipelines.cuh for more details. + // --------------------------------------------------------------- + AttentionQkvFormat qkv_format = UNKNOWN; const auto& query_dims = query->Shape().GetDims(); - if (query_dims.size() != 3 && query_dims.size() != 5) { + + int query_rank = static_cast(query_dims.size()); + if (query_rank != 3 && query_rank != 5) { return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, "Input 'query' is expected to have 3 or 5 dimensions, got ", - query_dims.size()); + query_rank); } int batch_size = static_cast(query_dims[0]); int sequence_length = static_cast(query_dims[1]); - int hidden_size = (query_dims.size() == 3) + bool dmmha_packing = operator_type == kDecoderMaskedMultiHeadAttention && key == nullptr && value == nullptr; + int hidden_size = (query_rank == 3) ? (dmmha_packing ? (static_cast(query_dims[2]) / 3) : static_cast(query_dims[2])) : (num_heads * static_cast(query_dims[4])); int head_size = static_cast(hidden_size) / num_heads; int kv_sequence_length = sequence_length; + int v_hidden_size = hidden_size; + if (key != nullptr) { + if (value == nullptr) { + ORT_RETURN_IF_ERROR(Check_Q_KV(query, key, num_heads, head_size, qkv_format, kv_sequence_length)); + } else { + ORT_RETURN_IF_ERROR(Check_Q_K_V(query, key, value, num_heads, head_size, + qkv_format, kv_sequence_length, v_hidden_size)); + } + } else if (value == nullptr) { // no key and value + ORT_RETURN_IF_ERROR(Check_QKV(query, qkv_format)); + } else { + return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, + "Input 'value' shall absent when 'key' is absent"); + } + int past_sequence_length = 0; int max_sequence_length = 0; if (past_key != nullptr && past_value != nullptr) { - const auto& past_key_dims = past_key->Shape().GetDims(); - const auto& past_value_dims = past_value->Shape().GetDims(); - - if (past_key_dims.size() != 4) { - return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, - "Input 'past_key' is expected to have 4 dimensions, got ", - past_key_dims.size()); - } - if (past_value_dims.size() != 4) { - return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, - "Input 'past_value' is expected to have 4 dimensions, got ", - past_value_dims.size()); - } - - if (past_key_dims[0] != batch_size) { - return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, - "Input 'past_key' dimension 0 should be batch_size, got ", - past_key_dims[0]); - } - if (past_value_dims[0] != batch_size) { - return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, - "Input 'past_value' dimension 0 should be batch_size, got ", - past_value_dims[0]); - } - - if (past_key_dims[1] != num_heads) { - return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, - "Input 'past_key' dimension 1 should be same as number of heads, got ", - past_key_dims[1]); - } - if (past_value_dims[1] != num_heads) { - return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, - "Input 'past_value' dimension 1 should be same as number of heads, got ", - past_value_dims[1]); - } - if (past_key_dims[2] != past_value_dims[2]) { - return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, - "Input 'past_key' and 'past_value' shall have same dim 2 (past_sequence_length). ", - past_key_dims[2], " vs ", past_value_dims[2]); - } - if (past_key_dims[3] != head_size) { - return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, - "Input 'past_key' dimension 3 should be same as head_size, got ", - past_key_dims[3]); - } - if (past_value_dims[3] != head_size) { - return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, - "Input 'past_value' dimension 3 should be same as head_size, got ", - past_value_dims[3]); - } - past_sequence_length = static_cast(past_key_dims[2]); - max_sequence_length = static_cast(past_key_dims[2]); - if (past_present_share_buffer) { - if (past_seq_len == nullptr || !onnxruntime::IsScalarOr1ElementVector(past_seq_len)) { - return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, - "past_sequence_length tensor must be of one element when past_present_share_buffer is set"); - } - past_sequence_length = *((*past_seq_len).template Data()); - } + ORT_RETURN_IF_ERROR(CheckPast(past_key, past_value, past_seq_len, + batch_size, num_heads, head_size, past_present_share_buffer, + past_sequence_length, max_sequence_length)); } else if (past_key != nullptr || past_value != nullptr) { return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, "Input 'past_key' and 'past_value' shall be both present or both absent"); } - if (key != nullptr) { - if (query_dims.size() != 3) { - return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, "Input 'query' is expected to have 3 dimensions when key is given, got ", - query_dims.size()); - } - - const auto& key_dims = key->Shape().GetDims(); - if (key_dims.size() != 3 && key_dims.size() != 4 && key_dims.size() != 5) { - return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, "Input 'key' is expected to have 3, 4, or 5 dimensions, got ", - key_dims.size()); - } - if (query_dims[0] != key_dims[0]) { + if (operator_type == kMultiHeadAttention) { + if (qkv_format == AttentionQkvFormat::QKV_BS3NH) { return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, - "Input 'query' and 'key' shall have same dim 0 (batch size)"); + "Packed qkv of 3D BS3NH format is not support by MultiHeadAttention"); } - if (key_dims.size() == 3) { - if (key_dims[2] != query_dims[2]) { - return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, - "Input 'query' and 'key' shall have same dim 2 (hidden_size)"); - } - - qkv_format = Q_K_V_BSNH; - kv_sequence_length = static_cast(key_dims[1]); - } else if (key_dims.size() == 5) { - if (static_cast(key_dims[2]) != num_heads || static_cast(key_dims[3]) != 2 || static_cast(key_dims[4]) != head_size) { - return ORT_MAKE_STATUS( - ONNXRUNTIME, INVALID_ARGUMENT, - "Expect 'key' shape (batch_size, kv_sequence_length, num_heads, 2, head_size) for packed kv"); - } - if (value != nullptr) { - return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, "Expect 'value' be none when 'key' has packed kv format."); - } - - qkv_format = Q_KV_BSNH_BSN2H; - kv_sequence_length = static_cast(key_dims[1]); - } else { // key_dims.size() == 4 (cross-attention with past_key) - if (static_cast(key_dims[1]) != num_heads || static_cast(key_dims[3]) != head_size) { - return ORT_MAKE_STATUS( - ONNXRUNTIME, INVALID_ARGUMENT, - "Expect 'key' shape (batch_size, num_heads, kv_sequence_length, head_size)"); - } - - if (value == nullptr || value->Shape().GetDims().size() != 4) { - return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, "Input 'value' shall be 4D when 'key' is 4D"); - } - - if (bias != nullptr) { - return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, "Input 'bias' shall be empty when 'key' is 4D"); - } - - qkv_format = UNKNOWN; - kv_sequence_length = static_cast(key_dims[2]); - } - } else { // packed QKV - if (query_dims.size() != 3 && query_dims.size() != 5) { - return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, "Input 'query' is expected to have 3 or 5 dimensions when key is empty, got ", - query_dims.size()); - } - if (query_dims.size() == 5 && (static_cast(query_dims[2]) != num_heads || static_cast(query_dims[3]) != 3)) { - return ORT_MAKE_STATUS( - ONNXRUNTIME, INVALID_ARGUMENT, - "Expect 'query' shape (batch_size, kv_sequence_length, num_heads, 3, head_size) for packed kv"); + if (qkv_format == AttentionQkvFormat::Q_KV_BSNH_BSN2H && bias != nullptr) { + return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, "Input 'bias' shall be empty when packed kv is used"); } - - qkv_format = QKV_BSN3H; } if (bias != nullptr) { @@ -206,116 +384,31 @@ Status CheckInputs(const T* query, bias_dims.size()); } - if (value == nullptr) { - // Currently, bias is not allowed for packed KV. This constraint can be removed later. - // Here we assume that fusion tool will not include bias for packed KV. - if (query_dims.size() == 5 && query_dims[3] == 2) { - return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, "'bias' is not allowed for packed kv. "); - } + int expected_bias_length = 2 * hidden_size + v_hidden_size; + if (bias_dims[0] != static_cast(expected_bias_length)) { + return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, "Input 'bias' length is expected to be 2 * hidden_size + hidden_size_v, got ", + bias_dims.size()); } } int total_sequence_length = past_sequence_length + kv_sequence_length; AttentionMaskType mask_type = AttentionMaskType::MASK_NONE; if (key_padding_mask != nullptr) { - mask_type = AttentionMaskType::MASK_UNKNOWN; - const auto& mask_dims = key_padding_mask->Shape().GetDims(); - if (mask_dims.size() == 1) { - if (mask_dims[0] == static_cast(batch_size)) { - mask_type = AttentionMaskType::MASK_1D_KEY_SEQ_LEN; - } else if (mask_dims[0] == static_cast(3) * static_cast(batch_size) + static_cast(2)) { - mask_type = AttentionMaskType::MASK_1D_KEY_SEQ_LEN_START; - } - } else if (mask_dims.size() == 2 && mask_dims[0] == static_cast(batch_size) && - mask_dims[1] == static_cast(kv_sequence_length)) { - mask_type = AttentionMaskType::MASK_2D_KEY_PADDING; - } else if (mask_dims.size() == 2 && mask_dims[0] == static_cast(batch_size) && - mask_dims[1] == static_cast(total_sequence_length)) { - mask_type = AttentionMaskType::MASK_2D_KEY_PADDING; - } else if (mask_dims.size() == 3 && mask_dims[0] == static_cast(batch_size) && - mask_dims[1] == static_cast(sequence_length) && - mask_dims[2] == static_cast(total_sequence_length)) { - mask_type = AttentionMaskType::MASK_3D_ATTENTION; - } - + mask_type = GetMaskType(key_padding_mask, batch_size, sequence_length, total_sequence_length); if (mask_type == AttentionMaskType::MASK_UNKNOWN) { return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, - "Input 'key_padding_mask' shape shall be 1D, 2D, or 3D"); - } - } - - // NOTE: In Cross-Attention, we pass the past key and value to 'key' and 'value' instead of 'past_key' and 'past_value'. - bool pass_past_in_kv = false; - int v_hidden_size = hidden_size; - if (value != nullptr) { - const auto& value_dims = value->Shape().GetDims(); - if (value_dims.size() != 3 && value_dims.size() != 4) { - return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, "Input 'value' is expected to have 3 or 4 dimensions, got ", - value_dims.size()); - } - - if (query_dims[0] != value_dims[0]) { - return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, - "Input 'query' and 'value' shall have same dim 0 (batch_size)"); - } - - if (value_dims.size() == 3) { - if (static_cast(kv_sequence_length) != value_dims[1]) { - return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, - "Input 'key' and 'value' shall have the same dim 1 (kv_sequence_length)"); - } - v_hidden_size = static_cast(value_dims[2]); - } else { // value_dims.size() == 4 - if (static_cast(kv_sequence_length) != value_dims[2]) { - return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, - "Input 'key' and 'value' shall have the same dim 2 (kv_sequence_length)"); - } - - if (past_key != nullptr || past_value != nullptr) { - return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, - "Input 'past_key' and 'past_value' shall be empty when 'value' is 4D"); - } - - v_hidden_size = static_cast(value_dims[1]) * static_cast(value_dims[3]); - pass_past_in_kv = true; + "Input 'key_padding_mask' shape is not expected."); } } bool broadcast_res_pos_bias = false; if (relative_position_bias != nullptr) { - const auto& relative_position_bias_dims = relative_position_bias->Shape().GetDims(); - - if (relative_position_bias_dims.size() != 4) { - return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, - "Input 'relative_position_bias' is expected to have 4 dimensions, got ", - relative_position_bias_dims.size()); - } - if (relative_position_bias_dims[0] != batch_size && relative_position_bias_dims[0] != 1) { - return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, - "Input 'relative_position_bias' dimension 0 should be batch_size or 1, got ", - relative_position_bias_dims[0]); - } - if (relative_position_bias_dims[0] == 1) { - broadcast_res_pos_bias = true; - } - if (relative_position_bias_dims[1] != num_heads) { - return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, - "Input 'relative_position_bias' dimension 1 should be same as number of heads, got ", - relative_position_bias_dims[1]); - } - if (relative_position_bias_dims[2] != sequence_length) { - return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, - "Input 'relative_position_bias' dimension 2 should be same as sequence_length, got ", - relative_position_bias_dims[2]); - } - if (relative_position_bias_dims[3] != total_sequence_length) { - return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, - "Input 'relative_position_bias' dimension 3 should be same as total_sequence_length, got ", - relative_position_bias_dims[3]); - } + ORT_RETURN_IF_ERROR(CheckRelativePositionBias( + relative_position_bias, batch_size, num_heads, sequence_length, total_sequence_length, broadcast_res_pos_bias)); } - // TODO: ORT_RETURN_IF(qkv_format == UNKNOWN, "Unrecognized QKV format"); + assert(qkv_format != UNKNOWN); + if (parameters != nullptr) { AttentionParameters* output_parameters = reinterpret_cast(parameters); output_parameters->batch_size = batch_size; @@ -323,7 +416,7 @@ Status CheckInputs(const T* query, output_parameters->past_sequence_length = past_sequence_length; output_parameters->kv_sequence_length = kv_sequence_length; output_parameters->total_sequence_length = total_sequence_length; - output_parameters->max_sequence_length = max_sequence_length; + output_parameters->max_sequence_length = past_present_share_buffer ? max_sequence_length : total_sequence_length; output_parameters->input_hidden_size = 0; output_parameters->hidden_size = hidden_size; output_parameters->v_hidden_size = v_hidden_size; @@ -336,7 +429,6 @@ Status CheckInputs(const T* query, output_parameters->mask_type = mask_type; output_parameters->scale = scale; output_parameters->broadcast_res_pos_bias = broadcast_res_pos_bias; - output_parameters->pass_past_in_kv = pass_past_in_kv; output_parameters->qkv_format = qkv_format; } @@ -359,7 +451,7 @@ Status CheckInputs(const T* query, float scale, bool is_unidirectional, bool past_present_share_buffer, - bool dmmha_packing, + AttentionType operator_type, int max_threads_per_block) { if (max_threads_per_block > 0 && num_heads > max_threads_per_block) { return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, "num_heads should be no larger than ", max_threads_per_block); @@ -367,7 +459,7 @@ Status CheckInputs(const T* query, return CheckInputs(query, key, value, bias, key_padding_mask, relative_position_bias, past_key, past_value, past_seq_len, parameters, num_heads, mask_filter_value, scale, is_unidirectional, - past_present_share_buffer, dmmha_packing); + past_present_share_buffer, operator_type); } } // namespace multihead_attention_helper diff --git a/onnxruntime/contrib_ops/cuda/bert/add_bias_transpose.cu b/onnxruntime/contrib_ops/cuda/bert/add_bias_transpose.cu index 9e6752b451868..62d6a723bf32c 100644 --- a/onnxruntime/contrib_ops/cuda/bert/add_bias_transpose.cu +++ b/onnxruntime/contrib_ops/cuda/bert/add_bias_transpose.cu @@ -520,6 +520,39 @@ __global__ void AddBiasUnpack(int M, const T* input, const T* biases, T* output) } } +template +__global__ void AddBiasTransposeUnpack(int M, const T* input, const T* biases, T* output) { + // Format 5 to unpack TRT packed input format to BNSH for unfused attention. + // Input: BxSxNxMxH + // Output: MxBxNxSxH + // B is batch_size, S is sequence_length, M is number of matrices, N is num_heads, H is head_size + int n = threadIdx.y; + int s = blockIdx.x; + int b = blockIdx.y; + int m = blockIdx.z; // matrix id + + const int head_size = blockDim.x; + const int num_heads = blockDim.y; + + const int sequence_length = gridDim.x; + const int batch_size = gridDim.y; + const int H = head_size; + const int NH = num_heads * head_size; + const int NHS = NH * sequence_length; + + int in_offset = m * head_size + n * M * H + (s * NH + b * NHS) * M; + const int out_offset = (s + n * sequence_length) * head_size + (b + m * batch_size) * NHS; + + const int h = threadIdx.x; + if (h < head_size) { + if (biases != nullptr) { + output[out_offset + h] = input[in_offset + h] + biases[m * NH + n * H + h]; + } else { + output[out_offset + h] = input[in_offset + h]; + } + } +} + template __global__ void AddBiasTransposeCutlass(int M, const T* input, const T* biases, T* output) { // Format 3 for cutlass memory efficient attention @@ -692,6 +725,8 @@ void InvokeAddBiasTranspose( } } else if (format == 4) { // format == 4 AddBiasUnpack<<>>(total_matrix_count, input, biases, output); + } else if (format == 5) { // format == 5 + AddBiasTransposeUnpack<<>>(total_matrix_count, input, biases, output); } else { // format == 0 AddBiasTranspose<<>>(input, biases, output); } @@ -716,6 +751,8 @@ void InvokeAddBiasTranspose( } } else if (format == 4) { // format == 4 ORT_THROW("AddBiasTranspose (format 4) not implemented for hidden_size > max_threads_per_block"); + } else if (format == 5) { // format == 5 + ORT_THROW("AddBiasTranspose (format 5) not implemented for hidden_size > max_threads_per_block"); } else { // format 0 AddBiasTransposeLarge<<>>(qk_head_size, input, biases, output); } @@ -904,6 +941,7 @@ void InvokeAddBias( AddBiasTransposeTrtLarge<<>>(head_size, query, biases, q); } } + // K { const dim3 grid(kv_sequence_length, batch_size, num_matrices); @@ -1011,6 +1049,82 @@ void LaunchAddBias( } } +template +void InvokeAddBias( + cudaStream_t stream, const int max_threads_per_block, + const int batch_size, const int sequence_length, + const int num_heads, const int head_size, + const T* biases, const T* query, T* q) { + assert(num_heads <= max_threads_per_block); + constexpr int num_matrices = 1; + const dim3 grid(sequence_length, batch_size, num_matrices); + if (head_size * num_heads <= max_threads_per_block) { + const dim3 block(head_size, num_heads, 1); + AddBiasTransposeTrt<<>>(query, biases, q); + } else { + const dim3 block(max_threads_per_block / num_heads, num_heads, 1); + AddBiasTransposeTrtLarge<<>>(head_size, query, biases, q); + } +} + +template <> +void LaunchAddBias( + cudaStream_t stream, const int max_threads_per_block, + const int batch_size, const int sequence_length, + const int num_heads, const int head_size, + const float* biases, const float* query, float* q) { + if (0 == (head_size % 4)) { + const int H = head_size / 4; + const float4* query2 = reinterpret_cast(query); + const float4* biases2 = reinterpret_cast(biases); + float4* q2 = reinterpret_cast(q); + InvokeAddBias(stream, max_threads_per_block, + batch_size, sequence_length, num_heads, H, + biases2, query2, q2); + } else if (0 == (head_size & 1)) { + const int H = head_size / 2; + const float2* query2 = reinterpret_cast(query); + const float2* biases2 = reinterpret_cast(biases); + float2* q2 = reinterpret_cast(q); + InvokeAddBias(stream, max_threads_per_block, + batch_size, sequence_length, num_heads, H, + biases2, query2, q2); + } else { + InvokeAddBias(stream, max_threads_per_block, + batch_size, sequence_length, num_heads, head_size, + biases, query, q); + } +} + +template <> +void LaunchAddBias( + cudaStream_t stream, const int max_threads_per_block, + const int batch_size, const int sequence_length, + const int num_heads, const int head_size, + const half* biases, const half* query, half* q) { + if (0 == (head_size % 4)) { + const int H = head_size / 4; + const Half4* query2 = reinterpret_cast(query); + const Half4* biases2 = reinterpret_cast(biases); + Half4* q2 = reinterpret_cast(q); + InvokeAddBias(stream, max_threads_per_block, + batch_size, sequence_length, num_heads, H, + biases2, query2, q2); + } else if (0 == (head_size & 1)) { + const int H = head_size / 2; + const half2* query2 = reinterpret_cast(query); + const half2* biases2 = reinterpret_cast(biases); + half2* q2 = reinterpret_cast(q); + InvokeAddBias(stream, max_threads_per_block, + batch_size, sequence_length, num_heads, H, + biases2, query2, q2); + } else { + InvokeAddBias(stream, max_threads_per_block, + batch_size, sequence_length, num_heads, head_size, + biases, query, q); + } +} + } // namespace cuda } // namespace contrib } // namespace onnxruntime diff --git a/onnxruntime/contrib_ops/cuda/bert/add_bias_transpose.h b/onnxruntime/contrib_ops/cuda/bert/add_bias_transpose.h index efc31db43bcdb..bd4e123a272bc 100644 --- a/onnxruntime/contrib_ops/cuda/bert/add_bias_transpose.h +++ b/onnxruntime/contrib_ops/cuda/bert/add_bias_transpose.h @@ -3,14 +3,15 @@ #pragma once #include "core/providers/cuda/shared_inc/cuda_utils.h" +#include "contrib_ops/cpu/bert/attention_common.h" namespace onnxruntime { namespace contrib { namespace cuda { -// Fused kernel of Add (bias) and Transpose. +// Fused kernel of Add bias (optional, can be None) and Transpose. // Shape of inputs and outputs: -// biases: (num_matrices, num_heads * head_size) +// biases: (num_matrices, num_heads * head_size) or None // format 0: (requires sequence_length = kv_sequence_length and qk_head_size = v_head_size when num_matrices == 3) // input: (num_matrices, batch_size, sequence_length, num_heads, head_size) // output: (num_matrices, batch_size, num_heads, sequence_length, head_size) @@ -24,9 +25,12 @@ namespace cuda { // format 3: (requires sequence_length = kv_sequence_length and qk_head_size = v_head_size when num_matrices == 3) // input: (batch_size, sequence_length, num_matrices, num_heads, head_size) // output: (num_matrices, batch_size, sequence_length, num_heads, head_size) -// format 4: (requires qk_head_size = v_head_size) +// format 4: (requires qk_head_size == v_head_size) // input: (batch_size, sequence_length, num_heads, num_matrices, head_size) // output: (num_matrices, batch_size, sequence_length, num_heads, head_size) +// format 5: (requires qk_head_size == v_head_size) +// input: (batch_size, sequence_length, num_heads, num_matrices, head_size) +// output: (num_matrices, batch_size, num_heads, sequence_length, head_size) template void LaunchAddBiasTranspose( @@ -35,7 +39,7 @@ void LaunchAddBiasTranspose( const T* input, const T* biases, T* output, bool enable_half4, const int v_head_size, T* qkv_add_bias = nullptr, int total_matrix_count = -1, bool do_rotary = false, int rotary_embedding = 0, int past_sequence_length = 0); -// Add (bias) and Transpose for separated inputs of Q, K and V, and output Trt format. +// Add bias (optional, can be None) and Transpose for separated inputs of Q, K and V, and output Trt format. // For self attention: // output: (batch_size, sequence_length, num_heads, 3, head_size) // It assumes sequence_length == kv_sequence_length and head_size == v_head_size. @@ -50,7 +54,7 @@ void LaunchAddBiasTransposeTrt( const T* biases, const T* query, const T* key, const T* value, T* output, bool is_cross_attention, int kv_sequence_length = -1); -// Add (bias) for separated inputs of Q, K and V. +// Add bias (required) for separated inputs of Q, K and V. // Q: (batch_size, sequence_length, num_heads, head_size) // K: (batch_size, kv_sequence_length, num_heads, head_size) // V: (batch_size, kv_sequence_length, num_heads, v_head_size) @@ -61,6 +65,46 @@ void LaunchAddBias( const int num_heads, const int head_size, const int v_head_size, const T* biases, const T* query, const T* key, const T* value, T* q, T* k, T* v); +// Add bias (required) for Q: (batch_size, sequence_length, num_heads, head_size) +template +void LaunchAddBias( + cudaStream_t stream, const int max_threads_per_block, + const int batch_size, const int sequence_length, + const int num_heads, const int head_size, + const T* biases, const T* query, T* q); + +// Add bias (optional, can be None) transpose kernel defined in packed_multihead_attention_impl.cu. +// Support the following format transforms (for float and half only). +// source_format => target_format: +// Q_K_V_TNH => Q_K_V_BNSH (requires token_offset) +// Q_K_V_TNH => Q_K_V_TNH +// Q_K_V_TNH => QKV_TN3H +// QKV_TN3H => Q_K_V_BNSH (requires token_offset) +// QKV_TN3H => Q_K_V_TNH +// QKV_TN3H => QKV_TN3H +template +void AddBiasTransposePacked( + const T* query, const T* key, const T* value, const T* bias, T* output, + const int batch_size, const int sequence_length, + const int num_heads, const int qk_head_size, const int v_head_size, + AttentionQkvFormat source_format, AttentionQkvFormat target_format, + const int32_t* token_offset, int32_t token_count, + cudaStream_t stream); + +// Add bias (required) transpose kernel defined in packed_attention_impl.cu. +// Support the following format transforms (for float and half only): +// format transform +// Q_K_V_BNSH: Tx3xNxH => 3xBxNxSxH (requires token_offset) +// Q_K_V_BSNH: Tx3xNxH => 3xTxNxH +// QKV_BSN3H: Tx3xNxH => TxNx3xH +template +void AddBiasTransposePacked( + const T* input, const T* biases, T* output, + const int batch_size, const int sequence_length, + const int num_heads, const int qk_head_size, const int v_head_size, + AttentionQkvFormat format, const int32_t* token_offset, int32_t token_count, + cudaStream_t stream); + } // namespace cuda } // namespace contrib } // namespace onnxruntime diff --git a/onnxruntime/contrib_ops/cuda/bert/attention.cc b/onnxruntime/contrib_ops/cuda/bert/attention.cc index 3b7f980ba1881..5c0989bced70c 100644 --- a/onnxruntime/contrib_ops/cuda/bert/attention.cc +++ b/onnxruntime/contrib_ops/cuda/bert/attention.cc @@ -260,7 +260,8 @@ Status Attention::ComputeInternal(OpKernelContext* context) const { fused_runner, use_flash_attention, use_fused_cross_attention, - use_memory_efficient_attention); + use_memory_efficient_attention, + false); IAllocatorUniquePtr work_space = IAllocator::MakeUniquePtr(allocator, workSpaceSize, false, context->GetComputeStream()); typedef typename ToCudaType::MappedType CudaT; @@ -281,6 +282,7 @@ Status Attention::ComputeInternal(OpKernelContext* context) const { } data.has_qkv_workspace = true; data.workspace = reinterpret_cast(work_space.get()); + data.workspace_bytes = workSpaceSize; data.output = reinterpret_cast(output->MutableData()); if (nullptr != present) { data.present = reinterpret_cast(present->MutableData()); diff --git a/onnxruntime/contrib_ops/cuda/bert/attention_impl.cu b/onnxruntime/contrib_ops/cuda/bert/attention_impl.cu index 997493acd9cb7..f9eabe27d97e4 100644 --- a/onnxruntime/contrib_ops/cuda/bert/attention_impl.cu +++ b/onnxruntime/contrib_ops/cuda/bert/attention_impl.cu @@ -58,31 +58,25 @@ size_t AlignSize(size_t bytes) { return bytesAligned; } -void CumulatedSequenceLengthCache::Initialize(int32_t seq_length, cudaStream_t stream) { - if (this->sequence_length != seq_length) { - ORT_ENFORCE(buffer.get() != nullptr && this->max_batch_size > 0); - LaunchTrtSequenceOffset(reinterpret_cast(buffer.get()), nullptr, - this->max_batch_size, seq_length, stream); - this->sequence_length = seq_length; +const int32_t* CumulatedSequenceLengthCache::TryGet(int batch_size, int32_t seq_len, cudaStream_t stream) { + if (this->sequence_length == 0 && seq_len > 0) { + // Initialize only once with sequence length in the first request. + std::call_once(init_once_flag_, [&]() { + ORT_ENFORCE(buffer.get() != nullptr && this->max_batch_size > 0); + LaunchTrtSequenceOffset(reinterpret_cast(buffer.get()), nullptr, + this->max_batch_size, seq_len, stream); + // Syncronize to ensure thread-safe since other thread will not wait for the above kernel finish. + // Otherwise, the data might be consumed by other threads before it is ready and causes data race issue. + cudaStreamSynchronize(stream); + this->sequence_length = seq_len; + }); } -} -int* GetCumulatedSequenceLength(CumulatedSequenceLengthCache* cache, - const int* mask_index, - int batch_size, - int sequence_length, - cudaStream_t stream, - void* scratch_buffer) { - if (mask_index == nullptr && cache != nullptr) { - if (batch_size <= cache->max_batch_size) { - cache->Initialize(sequence_length, stream); - return reinterpret_cast(cache->buffer.get()); - } + if (this->sequence_length == seq_len && batch_size <= this->max_batch_size) { + return reinterpret_cast(buffer.get()); } - int* sequence_offset = reinterpret_cast(scratch_buffer); - LaunchTrtSequenceOffset(sequence_offset, mask_index, batch_size, sequence_length, stream); - return sequence_offset; + return nullptr; } size_t GetAttentionScratchSize( @@ -114,10 +108,12 @@ size_t GetAttentionWorkspaceSize( void* fused_runner, bool use_flash_attention, bool use_fused_cross_attention, - bool use_memory_efficient_attention) { + bool use_memory_efficient_attention, + bool no_qkv_workspace) { // Note that q, k and v might need alignment for fused attention kernels. - const size_t qkv_bytes = element_size * batch_size * num_heads * - ((sequence_length + kv_sequence_length) * qk_head_size + kv_sequence_length * v_head_size); + const size_t qkv_size = element_size * batch_size * num_heads * + ((sequence_length + kv_sequence_length) * qk_head_size + kv_sequence_length * v_head_size); + const size_t qkv_bytes = no_qkv_workspace ? 0 : qkv_size; #if USE_FLASH_ATTENTION if (use_flash_attention) { @@ -162,39 +158,44 @@ Status FusedTrtCrossAttention( // We only enable fused cross attention when there is no key padding mask. // Otherwise, key have effective batch size 2 * batch_size, which is different from batch_size of query. assert(data.mask_index == nullptr); - + assert(data.scratch != nullptr); + assert(data.q != nullptr); + assert(data.k != nullptr); + +#ifndef NDEBUG + char* scratch_end = reinterpret_cast(data.scratch) + 2 * GetSequenceOffsetSize(parameters.batch_size, false); + char* buffer_end = reinterpret_cast(data.workspace) + data.workspace_bytes; + assert(scratch_end <= buffer_end); +#endif const int batch_size = parameters.batch_size; const int sequence_length = parameters.sequence_length; - int* q_sequence_offset = GetCumulatedSequenceLength(data.cumulated_sequence_length_q_cache, - data.mask_index, batch_size, - sequence_length, stream, - data.scratch); + int32_t* q_sequence_offset = const_cast(data.cumulated_sequence_length_q_cache); + if (q_sequence_offset == nullptr) { + q_sequence_offset = reinterpret_cast(data.scratch); + LaunchTrtSequenceOffset(q_sequence_offset, data.mask_index, batch_size, sequence_length, stream); + } + + CUDA_RETURN_IF_ERROR(cudaGetLastError()); DUMP_TENSOR_INIT(); DUMP_TENSOR_D("q_sequence_offset", q_sequence_offset, 1, batch_size + 1); - int* kv_sequence_offset = q_sequence_offset + (GetSequenceOffsetSize(batch_size, false) / sizeof(int)); - kv_sequence_offset = GetCumulatedSequenceLength(data.cumulated_sequence_length_kv_cache, - data.mask_index, batch_size, parameters.kv_sequence_length, stream, - kv_sequence_offset); - CUDA_RETURN_IF_ERROR(cudaGetLastError()); + int32_t* kv_sequence_offset = const_cast(data.cumulated_sequence_length_kv_cache); + if (kv_sequence_offset == nullptr) { + int* scratch = reinterpret_cast(data.scratch) + (GetSequenceOffsetSize(batch_size, false) / sizeof(int)); + kv_sequence_offset = reinterpret_cast(scratch); + LaunchTrtSequenceOffset(kv_sequence_offset, data.mask_index, batch_size, parameters.kv_sequence_length, stream); + } + CUDA_RETURN_IF_ERROR(cudaGetLastError()); DUMP_TENSOR_D("kv_sequence_offset", kv_sequence_offset, 1, batch_size + 1); FusedMultiHeadCrossAttentionKernel const* cross_attention_kernel = reinterpret_cast(data.fused_cross_attention_kernel); - // When there is no bias, we can directly use q and packed kv from inputs. - void const* query = data.q; - void const* packed_kv = data.k; - if (data.value == nullptr && data.bias == nullptr) { - query = data.query; - packed_kv = data.key; - } - run_fused_cross_attention( - query, // Q - packed_kv, // packed KV + data.q, // Q + data.k, // packed KV q_sequence_offset, // cumulated sequence length of Q kv_sequence_offset, // cumulated sequence length of KV data.output, // output @@ -206,8 +207,6 @@ Status FusedTrtCrossAttention( parameters.kv_sequence_length, // sequence length of KV stream); - DUMP_TENSOR("trt cross output", data.output, - batch_size, sequence_length, parameters.num_heads, parameters.v_head_size); return Status::OK(); } @@ -225,24 +224,33 @@ Status FusedTrtSelfAttention( cudaStream_t stream, contrib::AttentionParameters& parameters, AttentionData& data) { + assert(data.scratch != nullptr); +#ifndef NDEBUG + char* scratch_end = reinterpret_cast(data.scratch) + GetSequenceOffsetSize(parameters.batch_size, false); + char* buffer_end = reinterpret_cast(data.workspace) + data.workspace_bytes; + assert(scratch_end <= buffer_end); +#endif + const int batch_size = parameters.batch_size; const int sequence_length = parameters.sequence_length; const bool causal = parameters.is_unidirectional; - int* sequence_offset = reinterpret_cast(data.scratch); - - DUMP_TENSOR_INIT(); + const int32_t* sequence_offset = data.cumulated_sequence_length_q_cache; if (parameters.mask_type == AttentionMaskType::MASK_2D_KEY_PADDING) { - DUMP_TENSOR_D("mask", reinterpret_cast(data.mask_index), batch_size, sequence_length); - LaunchTrtSequenceOffset2d(sequence_offset, data.mask_index, batch_size, sequence_length, stream); + LaunchTrtSequenceOffset2d(reinterpret_cast(data.scratch), data.mask_index, batch_size, sequence_length, stream); + sequence_offset = reinterpret_cast(data.scratch); } else { - sequence_offset = GetCumulatedSequenceLength(data.cumulated_sequence_length_q_cache, - data.mask_index, batch_size, sequence_length, stream, - sequence_offset); + if (sequence_offset == nullptr) { + LaunchTrtSequenceOffset(reinterpret_cast(data.scratch), data.mask_index, batch_size, sequence_length, stream); + sequence_offset = reinterpret_cast(data.scratch); + } } - DUMP_TENSOR_D("sequence_offset", sequence_offset, 1, (data.mask_index != nullptr ? 2 : 1) * batch_size + 1); + CUDA_RETURN_IF_ERROR(cudaGetLastError()); + DUMP_TENSOR_INIT(); + DUMP_TENSOR_D("sequence_offset", sequence_offset, 1, (data.mask_index != nullptr ? 2 : 1) * batch_size + 1); + FusedMHARunnerFP16v2* fused_fp16_runner = reinterpret_cast(data.fused_runner); const int s = causal ? sequence_length : fused_fp16_runner->NormalizeSequenceLength(sequence_length); @@ -252,22 +260,12 @@ Status FusedTrtSelfAttention( if (!causal) { assert(data.qkv_format == AttentionQkvFormat::QKV_BSN3H); - - // When there is no bias, we can directly use packed qkv from inputs. - void const* packed_qkv = data.q; - if (data.query != nullptr && data.key == nullptr && data.bias == nullptr) { - packed_qkv = data.query; - } - - fused_fp16_runner->Run(b, s, packed_qkv, sequence_offset, data.output, stream); - DUMP_TENSOR("fused output", data.output, - batch_size, sequence_length, parameters.num_heads, parameters.v_head_size); + fused_fp16_runner->Run(b, s, data.q, sequence_offset, data.output, stream); } else { assert(data.qkv_format == AttentionQkvFormat::Q_K_V_BNSH_QKV_BS3NH); fused_fp16_runner->Run(b, s, data.gemm_buffer, sequence_offset, data.output, stream); - DUMP_TENSOR("fused causal output", data.output, - batch_size, sequence_length, parameters.num_heads, parameters.v_head_size); } + return Status::OK(); } @@ -289,38 +287,19 @@ Status FlashAttention( contrib::AttentionParameters& parameters, AttentionData& data, float scale) { - assert(data.qkv_format == AttentionQkvFormat::Q_K_V_BSNH); + assert(data.qkv_format == AttentionQkvFormat::Q_K_V_BSNH || + data.qkv_format == AttentionQkvFormat::Q_K_V_BSNH_BNSH_BNSH); assert(nullptr == data.mask_index); assert(nullptr == data.relative_position_bias); assert(parameters.head_size == parameters.v_head_size); - void* query = reinterpret_cast(data.q); - void* key = reinterpret_cast(data.k); - void* value = reinterpret_cast(data.v); - // For packed KV, we can use query input directly. - if (data.gemm_buffer == nullptr && data.key != nullptr && data.value == nullptr && data.bias == nullptr) { - query = reinterpret_cast(const_cast(data.query)); - } - - DUMP_TENSOR_INIT(); - DUMP_TENSOR_D("q(BSNH)", reinterpret_cast(query), - parameters.batch_size, parameters.sequence_length, parameters.num_heads, parameters.head_size); - DUMP_TENSOR_D("k(BSNH)", data.k, - parameters.batch_size, parameters.total_sequence_length, parameters.num_heads, parameters.head_size); - DUMP_TENSOR_D("v(BSNH)", data.v, - parameters.batch_size, parameters.total_sequence_length, - parameters.num_heads, parameters.v_head_size); - - bool is_bf16 = false; + constexpr bool is_bf16 = false; ORT_RETURN_IF_ERROR(onnxruntime::flash::mha_fwd( - device_prop, stream, query, key, value, data.output, reinterpret_cast(data.scratch), + device_prop, stream, data.q, data.k, data.v, data.output, reinterpret_cast(data.scratch), parameters.batch_size, parameters.num_heads, parameters.num_heads, parameters.head_size, parameters.sequence_length, parameters.total_sequence_length, scale, parameters.is_unidirectional, is_bf16, parameters.num_splits, reinterpret_cast(data.softmax_lse_accum), reinterpret_cast(data.out_accum), - true)); - - DUMP_TENSOR("flash attention output", data.output, - parameters.batch_size, parameters.sequence_length, parameters.num_heads, parameters.v_head_size); + data.qkv_format == AttentionQkvFormat::Q_K_V_BSNH)); return Status::OK(); } @@ -351,25 +330,8 @@ Status EfficientAttention( float scale) { // We only enable fused cross attention when there is no key padding mask. // Otherwise, key have effective batch size 2 * batch_size, which is different from batch_size of query. - assert(data.qkv_format == AttentionQkvFormat::Q_K_V_BSNH); - - const void* query = data.q; - const void* key = data.k; - const void* value = data.v; - // For packed KV, we can use query input directly. - if (data.gemm_buffer == nullptr && data.key != nullptr && data.value == nullptr) { - assert(data.bias == nullptr); - query = data.query; - } - - DUMP_TENSOR_INIT(); - DUMP_TENSOR_D("q(BSNH)", reinterpret_cast(query), - parameters.batch_size, parameters.sequence_length, parameters.num_heads, parameters.head_size); - DUMP_TENSOR_D("k(BSNH)", data.k, - parameters.batch_size, parameters.total_sequence_length, parameters.num_heads, parameters.head_size); - DUMP_TENSOR_D("v(BSNH)", data.v, - parameters.batch_size, parameters.total_sequence_length, - parameters.num_heads, parameters.v_head_size); + assert(data.qkv_format == AttentionQkvFormat::Q_K_V_BSNH || + data.qkv_format == AttentionQkvFormat::Q_K_V_BSNH_BNSH_BNSH); MemoryEfficientAttentionParams p; p.sm = device_prop.major * 10 + device_prop.minor; @@ -394,21 +356,19 @@ Status EfficientAttention( ? nullptr : const_cast(reinterpret_cast( data.mask_index + 2 * parameters.batch_size + 1)); - p.query = query; - p.key = key; - p.value = value; + p.query = data.q; + p.key = data.k; + p.value = data.v; p.attn_bias = nullptr == data.relative_position_bias ? nullptr : data.relative_position_bias; p.is_attn_bias_batched = !parameters.broadcast_res_pos_bias; p.output = data.output; - p.is_kv_bsnh = true; + p.is_kv_bsnh = data.qkv_format == AttentionQkvFormat::Q_K_V_BSNH; p.workspace = MemoryEfficientAttentionParams::need_workspace(parameters.v_head_size, sizeof(T) == sizeof(float)) ? data.scratch : nullptr; p.stream = stream; p.has_custom_right_padding = false; run_memory_efficient_attention(p); - DUMP_TENSOR("efficient attention output", data.output, - parameters.batch_size, parameters.sequence_length, parameters.num_heads, parameters.v_head_size); return Status::OK(); } @@ -449,10 +409,6 @@ Status UnfusedAttention( cublasSetStream(cublas, stream); - DUMP_TENSOR_INIT(); - DUMP_TENSOR_D("q[BNSH]", data.q, batch_size, num_heads, sequence_length, qk_head_size); - DUMP_TENSOR_D("k[BNSH]", data.k, batch_size, num_heads, total_sequence_length, qk_head_size); - const int present_sequence_length = parameters.past_present_share_buffer ? parameters.max_sequence_length : total_sequence_length; @@ -467,8 +423,7 @@ Status UnfusedAttention( &zero, data.scratch, total_sequence_length, sequence_length * total_sequence_length, batches, device_prop, parameters.use_tf32)); - DUMP_TENSOR_D("Q", data.q, batch_size, num_heads, sequence_length, qk_head_size); - DUMP_TENSOR_D("K", data.k, batch_size, num_heads, qk_head_size, sequence_length); + DUMP_TENSOR_INIT(); DUMP_TENSOR_D("QK", data.scratch, batch_size, num_heads, sequence_length, total_sequence_length); constexpr size_t element_size = sizeof(T); @@ -523,7 +478,6 @@ Status UnfusedAttention( // Temp_output is BxNxSxH_v, transpose to output BxSxNxH_v Status result = LaunchTransCtx(stream, sequence_length, batch_size, v_head_size, num_heads, device_prop.maxThreadsPerBlock, false, temp_output, data.output); - DUMP_TENSOR("unfused output", data.output, batch_size, sequence_length, num_heads, v_head_size); return result; } @@ -554,7 +508,7 @@ Status QkvToContext( if (!parameters.past_present_share_buffer) { ORT_RETURN_IF_ERROR(ConcatPastToPresent(batch_size, num_heads, qk_head_size, v_head_size, - sequence_length, total_sequence_length, parameters.pass_past_in_kv, + sequence_length, total_sequence_length, stream, max_threads_per_block, data)); } else { // past_present_share_buffer diff --git a/onnxruntime/contrib_ops/cuda/bert/attention_impl.h b/onnxruntime/contrib_ops/cuda/bert/attention_impl.h index 56836bdda197c..fad353dcfeb07 100644 --- a/onnxruntime/contrib_ops/cuda/bert/attention_impl.h +++ b/onnxruntime/contrib_ops/cuda/bert/attention_impl.h @@ -6,6 +6,8 @@ #include #include #include +#include +#include #include "core/framework/allocator.h" #include "contrib_ops/cpu/bert/attention_common.h" @@ -15,13 +17,18 @@ namespace cuda { constexpr int kCumulatedSequenceLengthCacheMaxBatchSize = 128; +// A cache for cumulated sequence length. It will be initialized in the first request, then become read-only after that. struct CumulatedSequenceLengthCache { onnxruntime::IAllocatorUniquePtr buffer; int32_t max_batch_size; int32_t sequence_length; - CumulatedSequenceLengthCache() : max_batch_size(0), sequence_length(0) {} - void Initialize(int32_t sequence_length, cudaStream_t stream); + CumulatedSequenceLengthCache() : max_batch_size(kCumulatedSequenceLengthCacheMaxBatchSize), sequence_length(0) {} + + const int32_t* TryGet(int batch_size, int32_t sequence_length, cudaStream_t stream); + + // Use this flag to guard the initializaton only once in multi-threading. + mutable std::once_flag init_once_flag_; }; size_t @@ -46,7 +53,8 @@ size_t GetAttentionWorkspaceSize( void* fused_runner, bool use_flash_attention, bool use_fused_cross_attention, - bool use_memory_efficient_attention); + bool use_memory_efficient_attention, + bool no_qkv_workspace); template struct AttentionData { @@ -65,8 +73,6 @@ struct AttentionData { bool has_qkv_workspace = false; T* workspace = nullptr; - T* temp_k_workspace = nullptr; - T* temp_v_workspace = nullptr; T* output = nullptr; T* present = nullptr; @@ -79,22 +85,50 @@ struct AttentionData { bool use_flash_attention = false; bool use_memory_efficient_attention = false; - mutable CumulatedSequenceLengthCache* cumulated_sequence_length_q_cache = nullptr; - mutable CumulatedSequenceLengthCache* cumulated_sequence_length_kv_cache = nullptr; + const int32_t* cumulated_sequence_length_q_cache = nullptr; + const int32_t* cumulated_sequence_length_kv_cache = nullptr; // Intermediate data T* q = nullptr; T* k = nullptr; T* v = nullptr; T* scratch = nullptr; - AttentionQkvFormat qkv_format = AttentionQkvFormat::Q_K_V_BSNH; + AttentionQkvFormat qkv_format = AttentionQkvFormat::UNKNOWN; // Flash buffers T* softmax_lse = nullptr; T* softmax_lse_accum = nullptr; T* out_accum = nullptr; + + // For Debugging + size_t workspace_bytes = 0; + bool allow_debug_info = false; + + bool IsUnfused() const { + return !use_flash_attention && !use_memory_efficient_attention && + (fused_runner == nullptr) && (fused_cross_attention_kernel == nullptr); + } + + void PrintDebugInfo() const { + std::cout << "flash=" << use_flash_attention + << ", efficient=" << use_memory_efficient_attention + << ", fused_runner=" << (fused_runner != nullptr) + << ", fused_cross=" << (fused_cross_attention_kernel != nullptr) + << ", bias=" << (bias != nullptr) + << ", attn_bias=" << (relative_position_bias != nullptr) + << ", mask_dims=" << mask_index_dims.size() + << ", has_qkv_workspace=" << has_qkv_workspace + << ", workspace=" << workspace_bytes + << ", past=" << (past != nullptr ? 1 : (past_key != nullptr ? 2 : 0)) + << ", present=" << (present != nullptr ? 1 : (present_key != nullptr ? 2 : 0)) + << std::endl; + } }; +// Return true if it does not need qkv workspace, false otherwise. +template +bool NoQkvWorkspace(contrib::AttentionParameters& parameters, AttentionData& data); + template Status PrepareQkv(contrib::AttentionParameters& parameters, AttentionData& data, @@ -129,6 +163,9 @@ Status LaunchTransQkv(cudaStream_t stream, const int matrix_num, const int max_threads_per_block, const bool reversed_bs, const half* input, half* output, int total_matrix_count = -1); +Status Transpose_BSNH_to_BNSH(const int batch_size, const int sequence_length, const int num_heads, const int head_size, + const float* input, float* output, cudaStream_t stream, const int max_threads_per_block); + Status Transpose_BSNH_to_BNSH(const int batch_size, const int sequence_length, const int num_heads, const int head_size, const half* input, half* output, cudaStream_t stream, const int max_threads_per_block); @@ -158,7 +195,7 @@ Status LaunchConcatTensorToTensor(cudaStream_t stream, template Status ConcatPastToPresent(int batch_size, int num_heads, int qk_head_size, int v_head_size, - int sequence_length, int total_sequence_length, bool pass_past_in_kv, + int sequence_length, int total_sequence_length, cudaStream_t stream, int max_threads_per_block, AttentionData& data); diff --git a/onnxruntime/contrib_ops/cuda/bert/attention_kernel_options.h b/onnxruntime/contrib_ops/cuda/bert/attention_kernel_options.h index bd7df5f490c76..aba1e01bfd91b 100644 --- a/onnxruntime/contrib_ops/cuda/bert/attention_kernel_options.h +++ b/onnxruntime/contrib_ops/cuda/bert/attention_kernel_options.h @@ -50,6 +50,7 @@ class AttentionKernelOptions { bool use_unfused_{true}; bool use_trt_flash_attention_{true}; + bool use_trt_cross_attention_{true}; // Causal attention is disabled by default in #14732. diff --git a/onnxruntime/contrib_ops/cuda/bert/attention_kv_cache.cu b/onnxruntime/contrib_ops/cuda/bert/attention_kv_cache.cu index 89be0f1115f41..9f0f49348c225 100644 --- a/onnxruntime/contrib_ops/cuda/bert/attention_kv_cache.cu +++ b/onnxruntime/contrib_ops/cuda/bert/attention_kv_cache.cu @@ -249,16 +249,15 @@ Status LaunchConcatPastToPresent(cudaStream_t stream, template Status ConcatPastToPresent(int batch_size, int num_heads, int qk_head_size, int v_head_size, - int sequence_length, int total_sequence_length, bool pass_past_in_kv, - cudaStream_t stream, - int max_threads_per_block, + int sequence_length, int total_sequence_length, + cudaStream_t stream, int max_threads_per_block, AttentionData& data) { // Concat past key value to present (2xBxNxLxH), where L is kv_sequence_length and T is total_sequence_length. // past_k (BxNxPxH) + k (BxNxLxH) => present_k (BxNxTxH) // past_v (BxNxPxH) + v (BxNxLxH) => present_v (BxNxTxH) // When there is past state, the head size for Q/K/V shall be same: H == H_v. - if (nullptr != data.present) { + if (nullptr != data.present) { // Attention op assert(data.qkv_format == AttentionQkvFormat::Q_K_V_BNSH || data.qkv_format == AttentionQkvFormat::Q_K_V_BNSH_QKV_BS3NH); @@ -270,58 +269,52 @@ Status ConcatPastToPresent(int batch_size, int num_heads, int qk_head_size, int // Update pointers to present_k and present_v. data.k = data.present; data.v = data.present + batch_size * num_heads * total_sequence_length * qk_head_size; - } else if (nullptr != data.past_key || nullptr != data.present_key) { - if (nullptr != data.past_key && nullptr == data.present_key) { - data.k = const_cast(data.past_key); - data.v = const_cast(data.past_value); - } else if (nullptr == data.past_key && nullptr != data.present_key) { - if (data.qkv_format == AttentionQkvFormat::Q_K_V_BNSH) { + } else { // MultiHeadAttention op + if (nullptr != data.present_key) { + ORT_ENFORCE(data.qkv_format == AttentionQkvFormat::Q_K_V_BNSH || + data.qkv_format == AttentionQkvFormat::Q_K_V_BSNH_BNSH_BNSH); + if (nullptr != data.past_key) { + assert(data.past_key != data.k); + assert(data.past_value != data.v); + + ORT_RETURN_IF_ERROR( + LaunchConcatTensorToTensor(stream, total_sequence_length, sequence_length, + batch_size, qk_head_size, num_heads, + max_threads_per_block, 1, data.past_key, data.k, data.present_key)); + ORT_RETURN_IF_ERROR( + LaunchConcatTensorToTensor(stream, total_sequence_length, sequence_length, + batch_size, v_head_size, num_heads, + max_threads_per_block, 1, data.past_value, data.v, data.present_value)); + // Update pointers to present_k and present_v. data.k = data.present_key; data.v = data.present_value; - } else { - assert(data.qkv_format == AttentionQkvFormat::Q_K_V_BSNH); - data.k = data.temp_k_workspace; - data.v = data.temp_v_workspace; + } else { // nullptr == data.past_key && nullptr != data.present_key + if (data.k != data.present_key) { + int64_t k_size = (int64_t)batch_size * num_heads * total_sequence_length * qk_head_size; + cudaMemcpyAsync(data.present_key, data.k, k_size * sizeof(T), cudaMemcpyDeviceToDevice, stream); + } + + if (data.v != data.present_value) { + int64_t v_size = (int64_t)batch_size * num_heads * total_sequence_length * v_head_size; + cudaMemcpyAsync(data.present_value, data.v, v_size * sizeof(T), cudaMemcpyDeviceToDevice, stream); + } } - } else if (pass_past_in_kv) { - // past_key and past_value are used directly as key and value in attention computations - data.k = const_cast(data.past_key); - data.v = const_cast(data.past_value); - - // This path has a memory copy from past_key and past_value to present_key and present_value - // Avoid this path since the memory copy is unnecessary because past_key == present_key and - // past_value == present_value - int64_t k_size = (int64_t)batch_size * num_heads * total_sequence_length * qk_head_size; - int64_t v_size = (int64_t)batch_size * num_heads * total_sequence_length * v_head_size; - cudaMemcpyAsync(data.present_key, data.past_key, k_size * sizeof(T), cudaMemcpyDeviceToDevice, stream); - cudaMemcpyAsync(data.present_value, data.past_value, v_size * sizeof(T), cudaMemcpyDeviceToDevice, stream); - } else { - ORT_RETURN_IF_ERROR( - LaunchConcatTensorToTensor(stream, total_sequence_length, sequence_length, - batch_size, qk_head_size, num_heads, - max_threads_per_block, 1, data.past_key, data.k, data.present_key)); - ORT_RETURN_IF_ERROR( - LaunchConcatTensorToTensor(stream, total_sequence_length, sequence_length, - batch_size, v_head_size, num_heads, - max_threads_per_block, 1, data.past_value, data.v, data.present_value)); - // Update pointers to present_k and present_v. - data.k = data.present_key; - data.v = data.present_value; } } + return CUDA_CALL(cudaGetLastError()); } // Template Instantiation template Status ConcatPastToPresent(int batch_size, int num_heads, int qk_head_size, int v_head_size, - int sequence_length, int total_sequence_length, bool pass_past_in_kv, + int sequence_length, int total_sequence_length, cudaStream_t stream, int max_threads_per_block, AttentionData& data); template Status ConcatPastToPresent(int batch_size, int num_heads, int qk_head_size, int v_head_size, - int sequence_length, int total_sequence_length, bool pass_past_in_kv, + int sequence_length, int total_sequence_length, cudaStream_t stream, int max_threads_per_block, AttentionData& data); diff --git a/onnxruntime/contrib_ops/cuda/bert/attention_prepare_qkv.cu b/onnxruntime/contrib_ops/cuda/bert/attention_prepare_qkv.cu index 040d6124e7456..05c592ec61059 100644 --- a/onnxruntime/contrib_ops/cuda/bert/attention_prepare_qkv.cu +++ b/onnxruntime/contrib_ops/cuda/bert/attention_prepare_qkv.cu @@ -12,12 +12,101 @@ namespace onnxruntime { namespace contrib { namespace cuda { +#if DEBUG_TENSOR_LEVEL > 1 +// Dump the workspace for Q, K, V after processing QKV data. +template +void DumpQkv(AttentionData& data) { + const int batch_size = parameters.batch_size; + const int sequence_length = parameters.sequence_length; + const int kv_sequence_length = parameters.kv_sequence_length; + const int num_heads = parameters.num_heads; + const int qk_head_size = parameters.head_size; + const int v_head_size = parameters.v_head_size; + + DUMP_TENSOR_INIT(); + if (data.qkv_format == AttentionQkvFormat::Q_K_V_BNSH) { + DUMP_TENSOR_D("q(BNSH)", data.q, batch_size, num_heads, sequence_length, qk_head_size); + DUMP_TENSOR_D("k(BNSH)", data.k, batch_size, num_heads, kv_sequence_length, qk_head_size); + DUMP_TENSOR_D("v(BNSH)", data.v, batch_size, num_heads, kv_sequence_length, v_head_size); + } else if (data.qkv_format == AttentionQkvFormat::Q_K_V_BSNH) { + DUMP_TENSOR_D("q(BSNH)", data.q, batch_size, sequence_length, num_heads, qk_head_size); + DUMP_TENSOR_D("k(BSNH)", data.k, batch_size, kv_sequence_length, num_heads, qk_head_size); + DUMP_TENSOR_D("v(BSNH)", data.v, batch_size, kv_sequence_length, num_heads, v_head_size); + } else if (data.qkv_format == AttentionQkvFormat::Q_K_V_BSNH_BNSH_BNSH) { + DUMP_TENSOR_D("q(BNSH)", data.q, batch_size, num_heads, sequence_length, qk_head_size); + DUMP_TENSOR_D("k(BNSH)", data.k, batch_size, num_heads, kv_sequence_length, qk_head_size); + DUMP_TENSOR_D("v(BNSH)", data.v, batch_size, num_heads, kv_sequence_length, v_head_size); + } else if (data.qkv_format == AttentionQkvFormat::QKV_BSN3H) { + DUMP_TENSOR_D("q(BSN3H)", data.q, batch_size, sequence_length, num_heads * 3, qk_head_size); + } +} + +// Dump the inputs before processing QKV data. +template +void DumpInputs(contrib::AttentionParameters& parameters, AttentionData& data) { + const int batch_size = parameters.batch_size; + const int sequence_length = parameters.sequence_length; + const int kv_sequence_length = parameters.kv_sequence_length; + const int num_heads = parameters.num_heads; + const int qk_head_size = parameters.head_size; + const int v_head_size = parameters.v_head_size; + + DUMP_TENSOR_INIT(); + if (parameters.qkv_format == AttentionQkvFormat::Q_K_V_BSNH) { + DUMP_TENSOR_D("Query(BNSH)", data.query, batch_size, num_heads, sequence_length, qk_head_size); + DUMP_TENSOR_D("Key(BNSH)", data.key, batch_size, num_heads, kv_sequence_length, qk_head_size); + DUMP_TENSOR_D("Value(BNSH)", data.value, batch_size, num_heads, kv_sequence_length, v_head_size); + } else if (data.qkv_format == AttentionQkvFormat::Q_K_V_BSNH) { + DUMP_TENSOR_D("Query(BSNH)", data.query, batch_size, sequence_length, num_heads, qk_head_size); + DUMP_TENSOR_D("Key(BSNH)", data.key, batch_size, kv_sequence_length, num_heads, qk_head_size); + DUMP_TENSOR_D("Value(BSNH)", data.value, batch_size, kv_sequence_length, num_heads, v_head_size); + } else if (data.qkv_format == AttentionQkvFormat::Q_K_V_BSNH_BNSH_BNSH) { + DUMP_TENSOR_D("Query(BNSH)", data.query, batch_size, num_heads, sequence_length, qk_head_size); + DUMP_TENSOR_D("Key(BNSH)", data.key, batch_size, num_heads, kv_sequence_length, qk_head_size); + DUMP_TENSOR_D("Value(BNSH)", data.value, batch_size, num_heads, kv_sequence_length, v_head_size); + } else if (data.qkv_format == AttentionQkvFormat::QKV_BSN3H) { + DUMP_TENSOR_D("Query(BSN3H)", data.query, batch_size, sequence_length, num_heads * 3, qk_head_size); + } else if (data.qkv_format == AttentionQkvFormat::Q_KV_BSNH_BSN2H) { + DUMP_TENSOR_D("Query(BNSH)", data.query, batch_size, num_heads, sequence_length, qk_head_size); + DUMP_TENSOR_D("Value(BSN2H)", data.value, batch_size, sequence_length, num_heads * 2, qk_head_size); + } + + if (data.bias != nullptr) { + DUMP_TENSOR_D("Q_bias", data.bias, num_heads, qk_head_size); + DUMP_TENSOR_D("K_bias", data.bias + num_heads * qk_head_size, num_heads, qk_head_size); + DUMP_TENSOR_D("V_bias", data.bias + 2 * num_heads * qk_head_size, num_heads, v_head_size); + } + + if (data.relative_position_bias != nullptr) { + DUMP_TENSOR_D("relative_position_bias", data.relative_position_bias, + parameters.broadcast_res_pos_bias ? 1 : batch_size, + num_heads, sequence_length, kv_sequence_length); + } + + if (data.mask_index != nullptr) { + if (parameters.mask_type == AttentionMaskType::MASK_2D_KEY_PADDING) { + DUMP_TENSOR_D("mask", data.mask_index, batch_size, parameters.total_sequence_length); + } + if (parameters.mask_type == AttentionMaskType::MASK_1D_KEY_SEQ_LEN_START) { + DUMP_TENSOR_D("mask", data.mask_index, 3 * batch_size + 2, 1); + } + } +} + +// Dump the kernel outputs +template +void DumpOutputs(AttentionData& data) { + DUMP_TENSOR_INIT(); + DUMP_TENSOR("output", data.output, + parameters.batch_size, parameters.sequence_length, parameters.num_heads, parameters.v_head_size); +} +#endif + template Status PrepareQkv_Attention(contrib::AttentionParameters& parameters, AttentionData& data, cudaStream_t stream, - int max_threads_per_block, - AttentionQkvFormat& qkv_format) { + int max_threads_per_block) { const int batch_size = parameters.batch_size; const int sequence_length = parameters.sequence_length; const int num_heads = parameters.num_heads; @@ -40,7 +129,7 @@ Status PrepareQkv_Attention(contrib::AttentionParameters& parameters, int matrix_to_trans = (past_present_share_buffer ? 1 : 3); ORT_RETURN_IF_ERROR(LaunchTransQkv(stream, matrix_to_trans, sequence_length, batch_size, qk_head_size, num_heads, max_threads_per_block, false, data.gemm_buffer, qkv, 3)); - qkv_format = AttentionQkvFormat::Q_K_V_BNSH; + data.qkv_format = AttentionQkvFormat::Q_K_V_BNSH; } else { // For fused TRT attention, transpose qkv to BxSxNx3xH (format 2) // For flash or memory efficient attention, transpose to 3xBxSxNxH (format 3) @@ -48,13 +137,13 @@ Status PrepareQkv_Attention(contrib::AttentionParameters& parameters, // For fused causal kernel, use format 1 since we need have K and V to update present state, // at the same time, we update gemm_buffer BxSx3xNxH with bias which is used as input for fused causal kernel. const int format = (use_fused_kernel ? 2 : (use_flash_or_efficient_attention ? 3 : 1)); - qkv_format = use_fused_kernel - ? AttentionQkvFormat::QKV_BSN3H - : (use_flash_or_efficient_attention - ? AttentionQkvFormat::Q_K_V_BSNH - : (use_fused_causal - ? AttentionQkvFormat::Q_K_V_BNSH_QKV_BS3NH - : AttentionQkvFormat::Q_K_V_BNSH)); + data.qkv_format = use_fused_kernel + ? AttentionQkvFormat::QKV_BSN3H + : (use_flash_or_efficient_attention + ? AttentionQkvFormat::Q_K_V_BSNH + : (use_fused_causal + ? AttentionQkvFormat::Q_K_V_BNSH_QKV_BS3NH + : AttentionQkvFormat::Q_K_V_BNSH)); // For fused causal, we will update gemm_buffer with bias directly. T* qkv_add_bias = use_fused_causal ? data.gemm_buffer : nullptr; @@ -71,367 +160,526 @@ Status PrepareQkv_Attention(contrib::AttentionParameters& parameters, return Status::OK(); } -// For MultiHeadAttention with past state +// Return true if the workspace is not needed for Q, K, V inputs, false otherwise. +// This shall be in sync with the following function PrepareQkv_MHA_Cross. template -Status PrepareQkv_MHA_WithPast(contrib::AttentionParameters& parameters, - AttentionData& data, - cudaStream_t stream, - int max_threads_per_block, - T* q, T* k, T* v, AttentionQkvFormat& qkv_format) { +bool NoQkvWorkspace_MHA_Cross(AttentionData& data) { + // query, key and value are passed as Q, K and V for the following conditions. + return (data.use_memory_efficient_attention || data.use_flash_attention) && (data.bias == nullptr); +} + +// For MultiHeadAttention with cross attention (Q_K_V_BSNH_BNSH_BNSH format) +template +Status PrepareQkv_MHA_Cross(contrib::AttentionParameters& parameters, + AttentionData& data, + cudaStream_t stream, + int max_threads_per_block) { + assert(parameters.qkv_format == AttentionQkvFormat::Q_K_V_BSNH_BNSH_BNSH); + // past_key or past_value is not supported for cross attention + // present_key and present_value can be supported in theory, although we do not allow the senario for now. + assert(data.past_key == nullptr); + assert(data.past_value == nullptr); + assert(data.has_qkv_workspace == !NoQkvWorkspace_MHA_Cross(data)); + const int batch_size = parameters.batch_size; const int sequence_length = parameters.sequence_length; - const int kv_sequence_length = parameters.kv_sequence_length; const int num_heads = parameters.num_heads; const int qk_head_size = parameters.head_size; - const int v_head_size = parameters.v_head_size; - - DUMP_TENSOR_INIT(); - if (data.bias == nullptr) { - // Below logic does not support fused attention with past without bias - // When there is past state, the format shall be BxNxSxH, so we disable fused attention when there is past. - - // cross attention with past state - if (data.past_key != nullptr && data.present_key == nullptr) { - assert(data.past_value != nullptr); - assert(data.query != nullptr); - assert(data.key == nullptr); - assert(data.value == nullptr); - ORT_RETURN_IF_ERROR(LaunchTransQkv(stream, 1, sequence_length, batch_size, qk_head_size, num_heads, - max_threads_per_block, false, data.query, q)); +#if USE_MEMORY_EFFICIENT_ATTENTION || USE_FLASH_ATTENTION + if (data.use_memory_efficient_attention || data.use_flash_attention) { + // Add bias for Q + if (data.bias != nullptr) { + LaunchAddBias(stream, max_threads_per_block, batch_size, sequence_length, num_heads, qk_head_size, + data.bias, data.query, data.q); + } else { + data.q = const_cast(data.query); } - // cross attention with present state or self attention with present state - else if (data.past_key == nullptr && data.present_key != nullptr) { - assert(data.past_value == nullptr); - assert(data.present_value != nullptr); - assert(data.query != nullptr); - assert(data.key != nullptr); - assert(data.value != nullptr); - - // TODO: supporting packed qkv for self attention may benefit performance - ORT_RETURN_IF_ERROR(LaunchTransQkv(stream, 1, sequence_length, batch_size, qk_head_size, num_heads, - max_threads_per_block, false, data.query, q)); - // TODO: supporting packed kv for cross attention may benefit performance - ORT_RETURN_IF_ERROR(LaunchTransQkv(stream, 1, kv_sequence_length, batch_size, qk_head_size, num_heads, - max_threads_per_block, false, data.key, data.present_key)); - ORT_RETURN_IF_ERROR(LaunchTransQkv(stream, 1, kv_sequence_length, batch_size, v_head_size, num_heads, - max_threads_per_block, false, data.value, data.present_value)); - } - // self attention with past and present state - else { - assert(data.past_key != nullptr); - assert(data.past_value != nullptr); - assert(data.present_key != nullptr); - assert(data.present_value != nullptr); - assert(data.query != nullptr); - assert(data.key != nullptr); - assert(data.value != nullptr); - // TODO: supporting packed qkv for self attention may benefit performance + // Here we have assumption that there is no bias for key and value when they are in BNSH format. + data.k = const_cast(data.key); + data.v = const_cast(data.value); + data.qkv_format = AttentionQkvFormat::Q_K_V_BSNH_BNSH_BNSH; + } else +#endif + { // unfused kernel + assert(data.IsUnfused()); + if (data.bias == nullptr) { + // Transpose query from BSNH to BNSH ORT_RETURN_IF_ERROR(LaunchTransQkv(stream, 1, sequence_length, batch_size, qk_head_size, num_heads, - max_threads_per_block, false, data.query, q)); - ORT_RETURN_IF_ERROR(LaunchTransQkv(stream, 1, kv_sequence_length, batch_size, qk_head_size, num_heads, - max_threads_per_block, false, data.key, k)); - ORT_RETURN_IF_ERROR(LaunchTransQkv(stream, 1, kv_sequence_length, batch_size, v_head_size, num_heads, - max_threads_per_block, false, data.value, v)); + max_threads_per_block, false, data.query, data.q)); + } else { + // Add bias to query, and transpose it: Query (BxSxNxH) => Q (BxNxSxH) + constexpr int format = 0; + LaunchAddBiasTranspose(stream, 1, format, max_threads_per_block, + batch_size, sequence_length, num_heads, qk_head_size, + data.query, data.bias, data.q, + true, -1); } - qkv_format = AttentionQkvFormat::Q_K_V_BNSH; + + // Here we have assumption that there is no bias for key and value when they are in BNSH format. + // So we do not need to add bias for key and value. Just use the key and value directly. + data.k = const_cast(data.key); + data.v = const_cast(data.value); + data.qkv_format = AttentionQkvFormat::Q_K_V_BNSH; + } + return Status::OK(); +} + +template +bool NoQkvWorkspace_MHA_NoPast(AttentionData& data) { + // query, key and value are passed as Q, K and V for the following conditions. + return (data.use_memory_efficient_attention || data.use_flash_attention) && data.bias == nullptr; +} + +// For MultiHeadAttention without past state, with Q, K and V inputs +template +Status PrepareQkv_MHA_NoPast(contrib::AttentionParameters& parameters, + AttentionData& data, + cudaStream_t stream, + int max_threads_per_block) { + assert(parameters.qkv_format == AttentionQkvFormat::Q_K_V_BSNH); + assert(data.query != nullptr); + assert(data.key != nullptr); + assert(data.value != nullptr); + assert(data.past_key == nullptr); + assert(data.past_value == nullptr); + assert(data.present_key == nullptr); + assert(data.present_value == nullptr); + assert(!parameters.is_unidirectional); + assert(data.has_qkv_workspace == !NoQkvWorkspace_MHA_NoPast(data)); + + const int batch_size = parameters.batch_size; + const int sequence_length = parameters.sequence_length; + const int kv_sequence_length = parameters.kv_sequence_length; + const int num_heads = parameters.num_heads; + const int qk_head_size = parameters.head_size; + const int v_head_size = parameters.v_head_size; + + if (data.fused_cross_attention_kernel != nullptr) { + assert(qk_head_size == v_head_size); + assert(data.relative_position_bias == nullptr); + assert(data.mask_index == nullptr); + assert(parameters.hidden_size == parameters.v_hidden_size); + + // For fused cross attention, besides adding bias, K and V needed to be packed: + // Key (BxSxNxH), Value (BxSxNxH) => Q (BxSxNxH), K (BxSxNx2xH) + LaunchAddBiasTransposeTrt( + stream, max_threads_per_block, + batch_size, sequence_length, + num_heads, qk_head_size, + data.bias, data.query, data.key, data.value, data.q, true, kv_sequence_length); + data.v = nullptr; + data.qkv_format = AttentionQkvFormat::Q_KV_BSNH_BSN2H; } #if USE_MEMORY_EFFICIENT_ATTENTION || USE_FLASH_ATTENTION - // When past_key/past_value are inputted directly as key/value and there is no present_key/present_value - else if ((data.use_memory_efficient_attention || data.use_flash_attention) && - data.past_key != nullptr && - data.past_value != nullptr && - parameters.pass_past_in_kv) { - // Transpose past_key and past_value to use memory efficient attention - - // past_key (BxNxSxH) => temp_k_workspace (BxSxNxH) - ORT_RETURN_IF_ERROR(LaunchTransCtx(stream, kv_sequence_length, batch_size, qk_head_size, num_heads, - max_threads_per_block, false, data.past_key, data.temp_k_workspace)); - // past_value (BxNxSxH_v) => temp_v_workspace (BxSxNxH_v) - ORT_RETURN_IF_ERROR(LaunchTransCtx(stream, kv_sequence_length, batch_size, qk_head_size, num_heads, - max_threads_per_block, false, data.past_value, data.temp_v_workspace)); - - // query => q, temp_k_workspace => k, temp_v_workspace => v - LaunchAddBias(stream, max_threads_per_block, - batch_size, sequence_length, kv_sequence_length, - num_heads, qk_head_size, v_head_size, - data.bias, data.query, data.temp_k_workspace, data.temp_v_workspace, q, k, v); - - DUMP_TENSOR_D("q(BSNH)", q, batch_size, sequence_length, num_heads, qk_head_size); - DUMP_TENSOR_D("k(BSNH)", k, batch_size, kv_sequence_length, num_heads, qk_head_size); - DUMP_TENSOR_D("v(BSNH)", v, batch_size, kv_sequence_length, num_heads, v_head_size); - qkv_format = AttentionQkvFormat::Q_K_V_BSNH; - - data.past_key = nullptr; - data.past_value = nullptr; + else if (data.use_memory_efficient_attention || data.use_flash_attention) { + if (data.bias != nullptr) { + LaunchAddBias(stream, max_threads_per_block, + batch_size, sequence_length, kv_sequence_length, + num_heads, qk_head_size, v_head_size, + data.bias, data.query, data.key, data.value, data.q, data.k, data.v); + } else { + data.q = const_cast(data.query); + data.k = const_cast(data.key); + data.v = const_cast(data.value); + } + + data.qkv_format = AttentionQkvFormat::Q_K_V_BSNH; } - // When there is no past_key/past_value and there is present_key/present_value - // (e.g. get initial kv to use as past_kv in the next iteration) - else if ((data.use_memory_efficient_attention || data.use_flash_attention) && - data.present_key != nullptr && - data.present_value != nullptr) { - // Use memory efficient attention kernel - LaunchAddBias(stream, max_threads_per_block, - batch_size, sequence_length, kv_sequence_length, - num_heads, qk_head_size, v_head_size, - data.bias, data.query, data.key, data.value, q, data.temp_k_workspace, data.temp_v_workspace); - - // temp_k_workspace (BxSxNxH) => present_k (BxNxSxH) - ORT_RETURN_IF_ERROR(LaunchTransQkv(stream, 1, kv_sequence_length, batch_size, qk_head_size, num_heads, - max_threads_per_block, false, data.temp_k_workspace, data.present_key)); +#endif + else if (data.fused_runner != nullptr) { + assert(qk_head_size == v_head_size); + assert(data.relative_position_bias == nullptr); + + // Query (BxSxNxH), Key (BxSxNxH), Value (BxSxNxH) => Q: BxSxNx(H + H + H) + LaunchAddBiasTransposeTrt( + stream, max_threads_per_block, + batch_size, sequence_length, + num_heads, qk_head_size, + data.bias, data.query, data.key, data.value, data.q, false, kv_sequence_length); + data.k = nullptr; + data.v = nullptr; + + data.qkv_format = AttentionQkvFormat::QKV_BSN3H; + } else { // unfused kernel + assert(data.IsUnfused()); + // Query (BxSxNxH) => Q (BxNxSxH) + constexpr int format = 0; + LaunchAddBiasTranspose( + stream, 1, format, max_threads_per_block, + batch_size, sequence_length, num_heads, qk_head_size, + data.query, data.bias, data.q, + true, -1); + + // Key (BxLxNxH) => K (BxNxLxH) + LaunchAddBiasTranspose( + stream, 1, format, max_threads_per_block, + batch_size, kv_sequence_length, num_heads, qk_head_size, + data.key, nullptr == data.bias ? nullptr : data.bias + num_heads * qk_head_size, data.k, + true, -1); - // temp_v_workspace (BxSxNxH_v) => present_v (BxNxSxH_v) + // Value (BxLxNxH_v) => K (BxNxLxH_v) + LaunchAddBiasTranspose( + stream, 1, format, max_threads_per_block, + batch_size, kv_sequence_length, num_heads, v_head_size, + data.value, nullptr == data.bias ? nullptr : data.bias + 2 * num_heads * qk_head_size, data.v, + true, -1); + + data.qkv_format = AttentionQkvFormat::Q_K_V_BNSH; + } + + return Status::OK(); +} + +template +bool NoQkvWorkspace_MHA_WithPast_NoBias(AttentionData& data) { + if (data.use_memory_efficient_attention || data.use_flash_attention) { + // Q, K and V redirects to query, present_k and present_v, so we do not need extra workspace for QKV. + return data.past_key == nullptr && data.present_key != nullptr; + } + return false; +} + +// For MultiHeadAttention with kv cache (past or present), but no bias +template +Status PrepareQkv_MHA_WithPast_NoBias(contrib::AttentionParameters& parameters, + AttentionData& data, + cudaStream_t stream, + int max_threads_per_block) { + assert(parameters.qkv_format == AttentionQkvFormat::Q_K_V_BSNH); + assert(data.query != nullptr); + assert(data.key != nullptr); + assert(data.value != nullptr); + assert(data.bias == nullptr); + assert(data.fused_runner == nullptr); + assert(data.fused_cross_attention_kernel == nullptr); + assert(data.present_key != nullptr); + assert(data.present_value != nullptr); + assert(data.past_key == nullptr && data.past_value == nullptr || + data.past_key != nullptr && data.past_value != nullptr); + assert(data.has_qkv_workspace == !NoQkvWorkspace_MHA_WithPast_NoBias(data)); + + const int batch_size = parameters.batch_size; + const int sequence_length = parameters.sequence_length; + const int kv_sequence_length = parameters.kv_sequence_length; + const int num_heads = parameters.num_heads; + const int qk_head_size = parameters.head_size; + const int v_head_size = parameters.v_head_size; + + // When there is no past state and there is present state, we output K and V directly to present state. + if (data.past_key == nullptr && data.present_key != nullptr) { + data.k = data.present_key; + data.v = data.present_value; + } + +#if USE_MEMORY_EFFICIENT_ATTENTION || USE_FLASH_ATTENTION + if (data.use_memory_efficient_attention || data.use_flash_attention) { + // Use oiginal Query (BSNH) since there is no bias. + data.q = const_cast(data.query); + + // Key (BxLxNxH) => K (BxNxLxH) + ORT_RETURN_IF_ERROR(LaunchTransQkv(stream, 1, kv_sequence_length, batch_size, qk_head_size, num_heads, + max_threads_per_block, false, data.key, data.k)); + // Value (BxLxNxH) => V (BxNxLxH) ORT_RETURN_IF_ERROR(LaunchTransQkv(stream, 1, kv_sequence_length, batch_size, v_head_size, num_heads, - max_threads_per_block, false, data.temp_v_workspace, data.present_value)); + max_threads_per_block, false, data.value, data.v)); + data.qkv_format = AttentionQkvFormat::Q_K_V_BSNH_BNSH_BNSH; + } else +#endif + { // unfused kernel + assert(data.IsUnfused()); + ORT_RETURN_IF_ERROR(LaunchTransQkv(stream, 1, sequence_length, batch_size, qk_head_size, num_heads, + max_threads_per_block, false, data.query, data.q)); + ORT_RETURN_IF_ERROR(LaunchTransQkv(stream, 1, kv_sequence_length, batch_size, qk_head_size, num_heads, + max_threads_per_block, false, data.key, data.k)); + ORT_RETURN_IF_ERROR(LaunchTransQkv(stream, 1, kv_sequence_length, batch_size, v_head_size, num_heads, + max_threads_per_block, false, data.value, data.v)); + data.qkv_format = AttentionQkvFormat::Q_K_V_BNSH; + } + + return Status::OK(); +} - DUMP_TENSOR_D("q(BSNH)", q, batch_size, sequence_length, num_heads, qk_head_size); - DUMP_TENSOR_D("k(BSNH)", data.temp_k_workspace, batch_size, kv_sequence_length, num_heads, qk_head_size); - DUMP_TENSOR_D("v(BSNH)", data.temp_v_workspace, batch_size, kv_sequence_length, num_heads, v_head_size); - qkv_format = AttentionQkvFormat::Q_K_V_BSNH; +template +constexpr bool NoQkvWorkspace_MHA_WithPast_Bias(AttentionData& /*data*/) { + return false; +} + +// For MultiHeadAttention with both kv cache (past or present) and bias +template +Status PrepareQkv_MHA_WithPast_Bias(contrib::AttentionParameters& parameters, + AttentionData& data, + cudaStream_t stream, + int max_threads_per_block) { + assert(parameters.qkv_format == AttentionQkvFormat::Q_K_V_BSNH); + assert(data.bias != nullptr); + assert(!(data.past_key != nullptr && data.present_key == nullptr)); + assert(data.fused_runner == nullptr); + assert(data.fused_cross_attention_kernel == nullptr); + assert(data.present_key != nullptr); + assert(data.present_value != nullptr); + assert(data.past_key == nullptr && data.past_value == nullptr || + data.past_key != nullptr && data.past_value != nullptr); + assert(data.has_qkv_workspace == !NoQkvWorkspace_MHA_WithPast_Bias(data)); + + const int batch_size = parameters.batch_size; + const int sequence_length = parameters.sequence_length; + const int kv_sequence_length = parameters.kv_sequence_length; + const int num_heads = parameters.num_heads; + const int qk_head_size = parameters.head_size; + const int v_head_size = parameters.v_head_size; + + // When there is no past state and there is present state, we output K and V directly to present state. + if (data.past_key == nullptr && data.present_key != nullptr) { + data.k = data.present_key; + data.v = data.present_value; } + +#if USE_MEMORY_EFFICIENT_ATTENTION || USE_FLASH_ATTENTION + if (data.use_memory_efficient_attention || data.use_flash_attention) { + // Query(BxSxNxH) + Bias_Q => Q (BxSxNxH) + LaunchAddBias(stream, max_threads_per_block, batch_size, sequence_length, num_heads, qk_head_size, + data.bias, data.query, data.q); + + // Key (BxLxNxH) + Bias_K => K (BxNxLxH) + constexpr int format = 0; + LaunchAddBiasTranspose( + stream, 1, format, max_threads_per_block, + batch_size, kv_sequence_length, num_heads, qk_head_size, + data.key, data.bias + num_heads * qk_head_size, data.k, true, -1); + + // Key (BxLxNxH) + Bias_K => K (BxNxLxH) + LaunchAddBiasTranspose( + stream, 1, format, max_threads_per_block, + batch_size, kv_sequence_length, num_heads, v_head_size, + data.value, data.bias + 2 * num_heads * qk_head_size, data.v, true, -1); + + data.qkv_format = AttentionQkvFormat::Q_K_V_BSNH_BNSH_BNSH; + } else #endif - else { - // Use unfused kernel for Q, use unfused kernel for K and V if needed + { // unfused kernel + assert(data.IsUnfused()); + constexpr int format = 0; // Query (BxSxNxH) => Q (BxNxSxH) LaunchAddBiasTranspose(stream, 1, format, max_threads_per_block, batch_size, sequence_length, num_heads, qk_head_size, - data.query, data.bias, q, + data.query, data.bias, data.q, true, -1); - if (!parameters.pass_past_in_kv) { - T* k_dest = (data.past_key == nullptr && data.present_key != nullptr) ? data.present_key : k; - T* v_dest = (data.past_value == nullptr && data.present_value != nullptr) ? data.present_value : v; - - // Key (BxLxNxH) => K (BxNxLxH) - LaunchAddBiasTranspose(stream, 1, format, max_threads_per_block, - batch_size, kv_sequence_length, num_heads, qk_head_size, - data.key, data.bias + num_heads * qk_head_size, k_dest, - true, -1); + // Key (BxLxNxH) => K (BxNxLxH) + LaunchAddBiasTranspose(stream, 1, format, max_threads_per_block, + batch_size, kv_sequence_length, num_heads, qk_head_size, + data.key, data.bias + num_heads * qk_head_size, data.k, + true, -1); - // Value (BxLxNxH_v) => V (BxNxLxH_v) - LaunchAddBiasTranspose(stream, 1, format, max_threads_per_block, - batch_size, kv_sequence_length, num_heads, v_head_size, - data.value, data.bias + 2 * num_heads * qk_head_size, v_dest, - true, -1); + // Value (BxLxNxH_v) => V (BxNxLxH_v) + LaunchAddBiasTranspose(stream, 1, format, max_threads_per_block, + batch_size, kv_sequence_length, num_heads, v_head_size, + data.value, data.bias + 2 * num_heads * qk_head_size, data.v, + true, -1); - DUMP_TENSOR_D("q(BNSH)", q, batch_size, num_heads, sequence_length, qk_head_size); - DUMP_TENSOR_D("k(BNSH)", k_dest, batch_size, num_heads, kv_sequence_length, qk_head_size); - DUMP_TENSOR_D("v(BNSH)", v_dest, batch_size, num_heads, kv_sequence_length, v_head_size); - } - qkv_format = AttentionQkvFormat::Q_K_V_BNSH; + data.qkv_format = AttentionQkvFormat::Q_K_V_BNSH; } + return Status::OK(); } +template +bool NoQkvWorkspace_MHA_PackedQKV(AttentionData& data) { + // query, key and value are passed as Q, K and V for the following conditions. + return nullptr != data.fused_runner && data.bias == nullptr; +} + // For MultiHeadAttention without past state, with packed QKV inputs template Status PrepareQkv_MHA_PackedQKV(contrib::AttentionParameters& parameters, AttentionData& data, cudaStream_t stream, - int max_threads_per_block, - AttentionQkvFormat& qkv_format) { + int max_threads_per_block) { + assert(parameters.qkv_format == AttentionQkvFormat::QKV_BSN3H); + assert(data.past_key == nullptr); + assert(data.past_value == nullptr); + assert(data.present_key == nullptr); + assert(data.present_value == nullptr); + assert(parameters.head_size == parameters.v_head_size); + assert(data.fused_cross_attention_kernel == nullptr); + assert(!parameters.is_unidirectional); + assert(data.has_qkv_workspace == !NoQkvWorkspace_MHA_PackedQKV(data)); + const int batch_size = parameters.batch_size; const int sequence_length = parameters.sequence_length; const int num_heads = parameters.num_heads; const int qk_head_size = parameters.head_size; const int v_head_size = parameters.v_head_size; - void* fused_runner = data.fused_runner; - - T* qkv = data.workspace; - - bool use_fused_kernel = (nullptr != fused_runner && !parameters.is_unidirectional); - - assert(data.bias == nullptr); - assert(qk_head_size == v_head_size); - - DUMP_TENSOR_INIT(); - DUMP_TENSOR_D("packed_qkv", data.query, batch_size * sequence_length, num_heads, 3, qk_head_size); if (data.use_memory_efficient_attention || data.use_flash_attention) { - // unpack qkv to BSNH. Note that there is no bias so we need not output query to q. + // unpack qkv to BSNH. constexpr int format = 4; T* qkv_add_bias = nullptr; LaunchAddBiasTranspose(stream, 3, format, max_threads_per_block, batch_size, sequence_length, num_heads, qk_head_size, - data.query, data.bias, qkv, + data.query, data.bias, data.q, true, v_head_size, qkv_add_bias, 3); - DUMP_TENSOR_D("q(BSNH)", data.q, batch_size, sequence_length, num_heads, qk_head_size); - DUMP_TENSOR_D("k(BSNH)", data.k, batch_size, sequence_length, num_heads, qk_head_size); - DUMP_TENSOR_D("v(BSNH)", data.v, batch_size, sequence_length, num_heads, v_head_size); - qkv_format = AttentionQkvFormat::Q_K_V_BSNH; - } else { - if (!use_fused_kernel) { - return ORT_MAKE_STATUS( - ONNXRUNTIME, NOT_IMPLEMENTED, - "packed QKV format is not implemented for current GPU. Please disable it in fusion options."); + data.qkv_format = AttentionQkvFormat::Q_K_V_BSNH; + } else if (nullptr != data.fused_runner) { + assert(nullptr == data.relative_position_bias); + if (data.bias == nullptr) { + // When there is no bias, we can directly use the original packed QKV input. + // Need revisit this when we add support for causal. + data.q = const_cast(data.query); + data.k = nullptr; + data.v = nullptr; + } else { // data.bias != nullptr + AddBiasTransposePacked( + data.query, data.key, data.value, data.bias, data.q, + batch_size, sequence_length, + num_heads, qk_head_size, v_head_size, + AttentionQkvFormat::QKV_TN3H, AttentionQkvFormat::QKV_TN3H, + nullptr, batch_size * sequence_length, + stream); } - qkv_format = AttentionQkvFormat::QKV_BSN3H; + data.qkv_format = AttentionQkvFormat::QKV_BSN3H; + } else { // unfused kernel + assert(data.IsUnfused()); + // unpack qkv to BNSH + constexpr int format = 5; + T* qkv_add_bias = nullptr; + LaunchAddBiasTranspose(stream, 3, format, max_threads_per_block, + batch_size, sequence_length, num_heads, qk_head_size, + data.query, data.bias, data.q, + true, v_head_size, qkv_add_bias, 3); + + data.qkv_format = AttentionQkvFormat::Q_K_V_BNSH; } + return Status::OK(); } +// This shall be in sync with the following function PrepareQkv_MHA_PackedQKV. +template +bool NoQkvWorkspace_MHA_PackedKV(AttentionData& data) { + return data.fused_cross_attention_kernel != nullptr; +} + // For MultiHeadAttention without past state, with packed KV inputs template Status PrepareQkv_MHA_PackedKV(contrib::AttentionParameters& parameters, AttentionData& data, cudaStream_t stream, - int max_threads_per_block, - AttentionQkvFormat& qkv_format) { + int max_threads_per_block) { + assert(parameters.qkv_format == AttentionQkvFormat::Q_KV_BSNH_BSN2H); + assert(data.bias == nullptr); + assert(data.past_key == nullptr); + assert(data.past_value == nullptr); + assert(data.present_key == nullptr); + assert(data.present_value == nullptr); + assert(parameters.head_size == parameters.v_head_size); + assert(data.fused_runner == nullptr); + assert(data.has_qkv_workspace == !NoQkvWorkspace_MHA_PackedKV(data)); + const int batch_size = parameters.batch_size; const int kv_sequence_length = parameters.kv_sequence_length; const int num_heads = parameters.num_heads; const int qk_head_size = parameters.head_size; const int v_head_size = parameters.v_head_size; - // TODO: unpack kv to BNSH for unfused kernel so that we can remove the following constraint. - // CheckInputs verified this constraint. - assert(data.bias == nullptr); - assert(qk_head_size == v_head_size); - - DUMP_TENSOR_INIT(); - DUMP_TENSOR_D("packed_kv", data.key, batch_size * kv_sequence_length, num_heads, 2, qk_head_size); - if (data.use_memory_efficient_attention || data.use_flash_attention) { - // unpack kv to BSNH. Note that there is no bias so we need not output query to q. + // Note that there is no bias so we need not output query to q. + data.q = const_cast(data.query); + // Unpack kv to BSNH. constexpr int format = 4; T* qkv_add_bias = nullptr; const T* kv_bias = (data.bias == nullptr ? data.bias : data.bias + parameters.hidden_size); LaunchAddBiasTranspose(stream, 2, format, max_threads_per_block, batch_size, kv_sequence_length, num_heads, qk_head_size, data.key, kv_bias, data.k, - true, v_head_size, qkv_add_bias, 2); - DUMP_TENSOR_D("k(BSNH)", data.k, batch_size, kv_sequence_length, num_heads, qk_head_size); - DUMP_TENSOR_D("v(BSNH)", data.v, batch_size, kv_sequence_length, num_heads, v_head_size); - qkv_format = AttentionQkvFormat::Q_K_V_BSNH; - } else { - if (data.fused_cross_attention_kernel == nullptr) { - return ORT_MAKE_STATUS( - ONNXRUNTIME, NOT_IMPLEMENTED, - "packed KV format is not implemented for current GPU. Please disable packed kv in fusion options."); - } + true, v_head_size, qkv_add_bias); + data.qkv_format = AttentionQkvFormat::Q_K_V_BSNH; + } else if (data.fused_cross_attention_kernel != nullptr) { + data.qkv_format = AttentionQkvFormat::Q_KV_BSNH_BSN2H; + data.q = const_cast(data.query); + data.k = const_cast(data.key); + data.v = nullptr; + } else { // unfused kernel + assert(data.IsUnfused()); + // Transpose q from BSNH to BNSH. Note that there is no bias. + ORT_RETURN_IF_ERROR(Transpose_BSNH_to_BNSH(batch_size, parameters.sequence_length, num_heads, qk_head_size, + data.query, data.q, stream, max_threads_per_block)); - qkv_format = AttentionQkvFormat::Q_KV_BSNH_BSN2H; + // Unpack kv to BNSH. + constexpr int format = 5; + T* qkv_add_bias = nullptr; + const T* kv_bias = (data.bias == nullptr ? data.bias : data.bias + parameters.hidden_size); + LaunchAddBiasTranspose(stream, 2, format, max_threads_per_block, + batch_size, kv_sequence_length, num_heads, qk_head_size, + data.key, kv_bias, data.k, + true, v_head_size, qkv_add_bias, 2); + data.qkv_format = AttentionQkvFormat::Q_K_V_BNSH; } + return Status::OK(); } -// For MultiHeadAttention without past state, with Q, K and V inputs +// Prepare Q, K and V for MultiHeadAttention operator. template -Status PrepareQkv_MHA_NotPacked(contrib::AttentionParameters& parameters, - AttentionData& data, - cudaStream_t stream, - int max_threads_per_block, - T* q, T* k, T* v, AttentionQkvFormat& qkv_format) { - const int batch_size = parameters.batch_size; - const int sequence_length = parameters.sequence_length; - const int kv_sequence_length = parameters.kv_sequence_length; - const int num_heads = parameters.num_heads; - const int qk_head_size = parameters.head_size; - const int v_head_size = parameters.v_head_size; - void* fused_runner = data.fused_runner; - - T* qkv = data.workspace; - - bool use_fused_kernel = (nullptr != fused_runner && !parameters.is_unidirectional); - bool use_fused_causal = (nullptr != fused_runner && parameters.is_unidirectional); - - // gemm_buffer == nullptr and not packed - assert(data.query != nullptr && data.key != nullptr && data.value != nullptr); - - DUMP_TENSOR_INIT(); - DUMP_TENSOR_D("query", data.query, batch_size, sequence_length, num_heads, qk_head_size); - DUMP_TENSOR_D("key", data.key, batch_size, kv_sequence_length, num_heads, qk_head_size); - DUMP_TENSOR_D("value", data.value, batch_size, kv_sequence_length, num_heads, v_head_size); - -#if DUMP_TENSOR_LEVEL > 1 - if (data.bias != nullptr) { - DUMP_TENSOR_D("query_bias", data.bias, num_heads, qk_head_size); - DUMP_TENSOR_D("key_bias", data.bias + num_heads * qk_head_size, num_heads, qk_head_size); - DUMP_TENSOR_D("value_bias", data.bias + 2 * num_heads * qk_head_size, num_heads, v_head_size); - } -#endif - - if (data.relative_position_bias != nullptr && parameters.broadcast_res_pos_bias) { - DUMP_TENSOR_D("relative_position_bias", data.relative_position_bias, - num_heads, sequence_length, kv_sequence_length); - } - - if (data.mask_index != nullptr && parameters.mask_type == AttentionMaskType::MASK_1D_KEY_SEQ_LEN_START) { - DUMP_TENSOR_D("mask_index", data.mask_index, 3 * batch_size + 2, 1); - } - - if (data.fused_cross_attention_kernel != nullptr) { - assert(qk_head_size == v_head_size); - - // For fused cross attention, besides adding bias, K and V needed to be packed: - // K (BxSxNxH), V (BxSxNxH) => BxSxNx2xH - LaunchAddBiasTransposeTrt( - stream, max_threads_per_block, - batch_size, sequence_length, - num_heads, qk_head_size, - data.bias, data.query, data.key, data.value, qkv, true, kv_sequence_length); - - qkv_format = AttentionQkvFormat::Q_KV_BSNH_BSN2H; - } -#if USE_MEMORY_EFFICIENT_ATTENTION || USE_FLASH_ATTENTION - else if (data.use_memory_efficient_attention || data.use_flash_attention) { - LaunchAddBias(stream, max_threads_per_block, - batch_size, sequence_length, kv_sequence_length, - num_heads, qk_head_size, v_head_size, - data.bias, data.query, data.key, data.value, q, k, v); - - DUMP_TENSOR_D("q(BSNH)", q, batch_size, sequence_length, num_heads, qk_head_size); - DUMP_TENSOR_D("k(BSNH)", k, batch_size, kv_sequence_length, num_heads, qk_head_size); - DUMP_TENSOR_D("v(BSNH)", v, batch_size, kv_sequence_length, num_heads, v_head_size); - qkv_format = AttentionQkvFormat::Q_K_V_BSNH; +Status PrepareQkv_MultiHeadAttention(contrib::AttentionParameters& parameters, + AttentionData& data, + cudaStream_t stream, + int max_threads_per_block) { + switch (parameters.qkv_format) { + case AttentionQkvFormat::Q_K_V_BSNH_BNSH_BNSH: + ORT_RETURN_IF_ERROR(PrepareQkv_MHA_Cross(parameters, data, stream, max_threads_per_block)); + break; + case AttentionQkvFormat::Q_KV_BSNH_BSN2H: + ORT_RETURN_IF_ERROR(PrepareQkv_MHA_PackedKV(parameters, data, stream, max_threads_per_block)); + break; + case AttentionQkvFormat::QKV_BSN3H: + ORT_RETURN_IF_ERROR(PrepareQkv_MHA_PackedQKV(parameters, data, stream, max_threads_per_block)); + break; + case AttentionQkvFormat::Q_K_V_BSNH: + if (data.past_key != nullptr || data.present_key != nullptr) { + if (data.bias == nullptr) { + ORT_RETURN_IF_ERROR(PrepareQkv_MHA_WithPast_NoBias(parameters, data, stream, max_threads_per_block)); + } else { + ORT_RETURN_IF_ERROR(PrepareQkv_MHA_WithPast_Bias(parameters, data, stream, max_threads_per_block)); + } + } else { // no past state + ORT_RETURN_IF_ERROR(PrepareQkv_MHA_NoPast(parameters, data, stream, max_threads_per_block)); + } + break; + default: + ORT_THROW("Unsupported QKV format: ", parameters.qkv_format); } -#endif - else if (use_fused_kernel) { - assert(qk_head_size == v_head_size); - - // Q (BxSxNxH), K (BxSxNxH), V (BxSxNxH) => BxSxNx(H + H + H) - LaunchAddBiasTransposeTrt( - stream, max_threads_per_block, - batch_size, sequence_length, - num_heads, qk_head_size, - data.bias, data.query, data.key, data.value, qkv, false, kv_sequence_length); - DUMP_TENSOR_D("qkv(BSN3H)", qkv, batch_size, sequence_length, num_heads, 2 * qk_head_size + v_head_size); - - qkv_format = AttentionQkvFormat::QKV_BSN3H; - } else { // unfused kernel - ORT_ENFORCE(!use_fused_causal, "MultiHeadAttention has not enabled fused causal"); - - // Query (BxSxNxH) => Q (BxNxSxH) - constexpr int format = 0; - LaunchAddBiasTranspose( - stream, 1, format, max_threads_per_block, - batch_size, sequence_length, num_heads, qk_head_size, - data.query, data.bias, q, - true, -1); - - // Key (BxLxNxH) => K (BxNxLxH) - LaunchAddBiasTranspose( - stream, 1, format, max_threads_per_block, - batch_size, kv_sequence_length, num_heads, qk_head_size, - data.key, nullptr == data.bias ? nullptr : data.bias + num_heads * qk_head_size, k, - true, -1); - - // Value (BxLxNxH_v) => K (BxNxLxH_v) - LaunchAddBiasTranspose( - stream, 1, format, max_threads_per_block, - batch_size, kv_sequence_length, num_heads, v_head_size, - data.value, nullptr == data.bias ? nullptr : data.bias + 2 * num_heads * qk_head_size, v, - true, -1); + return Status::OK(); +} - DUMP_TENSOR_D("q(BNSH)", q, batch_size, num_heads, sequence_length, qk_head_size); - DUMP_TENSOR_D("k(BNSH)", k, batch_size, num_heads, kv_sequence_length, qk_head_size); - DUMP_TENSOR_D("v(BNSH)", v, batch_size, num_heads, kv_sequence_length, v_head_size); - qkv_format = AttentionQkvFormat::Q_K_V_BNSH; +// Check whether there is no needed to have workspace for Q, K and V for MultiHeadAttention operator. +// Please make it in sync with PrepareQkv_MultiHeadAttention. +template +bool NoQkvWorkspace(contrib::AttentionParameters& parameters, AttentionData& data) { + switch (parameters.qkv_format) { + case AttentionQkvFormat::Q_K_V_BSNH_BNSH_BNSH: + return NoQkvWorkspace_MHA_Cross(data); + case AttentionQkvFormat::Q_KV_BSNH_BSN2H: + return NoQkvWorkspace_MHA_PackedKV(data); + case AttentionQkvFormat::QKV_BSN3H: + return NoQkvWorkspace_MHA_PackedQKV(data); + case AttentionQkvFormat::Q_K_V_BSNH: + if (data.past_key != nullptr || data.present_key != nullptr) { + if (data.bias == nullptr) { + return NoQkvWorkspace_MHA_WithPast_NoBias(data); + } else { + return NoQkvWorkspace_MHA_WithPast_Bias(data); + } + } else { // no past state + return NoQkvWorkspace_MHA_NoPast(data); + } + default: + ORT_THROW("Unsupported QKV format: ", parameters.qkv_format); } - return Status::OK(); } template @@ -439,7 +687,6 @@ Status PrepareQkv(contrib::AttentionParameters& parameters, AttentionData& data, cudaStream_t stream, int max_threads_per_block) { - data.scratch = data.workspace; if (data.has_qkv_workspace) { const int size_per_batch_q = parameters.sequence_length * parameters.head_size; const int size_per_batch_k = parameters.kv_sequence_length * parameters.head_size; @@ -452,28 +699,37 @@ Status PrepareQkv(contrib::AttentionParameters& parameters, data.k = data.workspace + elements_q; data.v = data.k + elements_k; data.scratch = data.v + elements_v; + } else { + data.q = nullptr; + data.k = nullptr; + data.v = nullptr; + data.scratch = data.workspace; } +#if DEBUG_TENSOR_LEVEL > 1 + DumpInputs(parameters, data); +#endif + if (nullptr != data.gemm_buffer) { // Attention operator - ORT_RETURN_IF_ERROR(PrepareQkv_Attention(parameters, data, stream, max_threads_per_block, - data.qkv_format)); - } else if (data.past_key != nullptr || data.present_key != nullptr) { // mha operator with past/present state - ORT_RETURN_IF_ERROR(PrepareQkv_MHA_WithPast(parameters, data, stream, max_threads_per_block, - data.q, data.k, data.v, data.qkv_format)); - } else if (data.key == nullptr) { // multihead attention operator, no past, packed qkv - ORT_RETURN_IF_ERROR(PrepareQkv_MHA_PackedQKV(parameters, data, stream, max_threads_per_block, data.qkv_format)); - } else if (data.value == nullptr) { // multihead attention operator, no past, packed kv - ORT_RETURN_IF_ERROR(PrepareQkv_MHA_PackedKV(parameters, data, stream, max_threads_per_block, data.qkv_format)); - } else { // multihead attention operator, no past, separated Q/K/V inputs - ORT_RETURN_IF_ERROR(PrepareQkv_MHA_NotPacked(parameters, data, stream, max_threads_per_block, - data.q, data.k, data.v, data.qkv_format)); + ORT_RETURN_IF_ERROR(PrepareQkv_Attention(parameters, data, stream, max_threads_per_block)); + } else { // MultiHeadAttention operator + ORT_RETURN_IF_ERROR(PrepareQkv_MultiHeadAttention(parameters, data, stream, max_threads_per_block)); } + assert(data.qkv_format != AttentionQkvFormat::UNKNOWN); + +#if DEBUG_TENSOR_LEVEL > 1 + DumpQkv(data); +#endif + CUDA_RETURN_IF_ERROR(cudaGetLastError()); return Status::OK(); } // Template Instantiation +template bool NoQkvWorkspace(contrib::AttentionParameters& parameters, AttentionData& data); +template bool NoQkvWorkspace(contrib::AttentionParameters& parameters, AttentionData& data); + template Status PrepareQkv( contrib::AttentionParameters& parameters, AttentionData& data, diff --git a/onnxruntime/contrib_ops/cuda/bert/attention_transpose.cu b/onnxruntime/contrib_ops/cuda/bert/attention_transpose.cu index bd38a21aadfcb..9f3e396b7f949 100644 --- a/onnxruntime/contrib_ops/cuda/bert/attention_transpose.cu +++ b/onnxruntime/contrib_ops/cuda/bert/attention_transpose.cu @@ -304,6 +304,12 @@ Status Transpose_BSNH_to_BNSH(const int batch_size, const int sequence_length, c max_threads_per_block, false, input, output); } +Status Transpose_BSNH_to_BNSH(const int batch_size, const int sequence_length, const int num_heads, const int head_size, + const float* input, float* output, cudaStream_t stream, const int max_threads_per_block) { + return LaunchTransQkv(stream, 1, sequence_length, batch_size, head_size, num_heads, + max_threads_per_block, false, input, output); +} + } // namespace cuda } // namespace contrib } // namespace onnxruntime diff --git a/onnxruntime/contrib_ops/cuda/bert/decoder_masked_multihead_attention.cc b/onnxruntime/contrib_ops/cuda/bert/decoder_masked_multihead_attention.cc index 66c0aceaed1e7..037a4fdf3d9a0 100644 --- a/onnxruntime/contrib_ops/cuda/bert/decoder_masked_multihead_attention.cc +++ b/onnxruntime/contrib_ops/cuda/bert/decoder_masked_multihead_attention.cc @@ -75,7 +75,6 @@ Status DecoderMaskedMultiHeadAttention::ComputeInternal(OpKernelContext* attention::kDecoderMaskedAttentionLoadKVDataInFlight, false); bool is_unidirectional = false; - bool is_dmmha_packing = (key == nullptr && value == nullptr); ORT_RETURN_IF_ERROR(multihead_attention_helper::CheckInputs(query, key, value, @@ -91,7 +90,7 @@ Status DecoderMaskedMultiHeadAttention::ComputeInternal(OpKernelContext* scale_, is_unidirectional, past_present_share_buffer_, - is_dmmha_packing, // dmmha_packing + kDecoderMaskedMultiHeadAttention, device_prop.maxThreadsPerBlock)); if (bias) { @@ -157,7 +156,7 @@ Status DecoderMaskedMultiHeadAttention::ComputeInternal(OpKernelContext* parameters.is_cross_attention = true; parameters.total_sequence_length = parameters.kv_sequence_length; parameters.max_sequence_length = parameters.kv_sequence_length; - // parameters.k and paraneters.v are nullptr + // parameters.k and parameters.v are nullptr parameters.k_cache = const_cast(key->Data()); parameters.v_cache = const_cast(value->Data()); parameters.k_bias = nullptr; @@ -188,12 +187,14 @@ Status DecoderMaskedMultiHeadAttention::ComputeInternal(OpKernelContext* } parameters.is_cross_attention = false; - parameters.is_packed_qkv = is_dmmha_packing; - parameters.k = is_dmmha_packing + bool is_packed_qkv = (key == nullptr && value == nullptr); + parameters.is_packed_qkv = is_packed_qkv; + + parameters.k = is_packed_qkv ? const_cast(query->Data() + parameters.hidden_size) : const_cast(key->Data()); - parameters.v = is_dmmha_packing + parameters.v = is_packed_qkv ? const_cast(query->Data() + 2 * static_cast(parameters.hidden_size)) : const_cast(value->Data()); parameters.k_cache = present_key_data; diff --git a/onnxruntime/contrib_ops/cuda/bert/fastertransformer_decoder_attention/decoder_masked_multihead_attention_impl.cu b/onnxruntime/contrib_ops/cuda/bert/fastertransformer_decoder_attention/decoder_masked_multihead_attention_impl.cu index 9efb6f08e8e99..2f8d277cb7342 100644 --- a/onnxruntime/contrib_ops/cuda/bert/fastertransformer_decoder_attention/decoder_masked_multihead_attention_impl.cu +++ b/onnxruntime/contrib_ops/cuda/bert/fastertransformer_decoder_attention/decoder_masked_multihead_attention_impl.cu @@ -183,6 +183,7 @@ __global__ void masked_multihead_attention_kernel(DecoderMaskedMultiHeadAttentio *reinterpret_cast(&q_smem[tidx * QK_VEC_SIZE]) = q; } + // This has assumption that key and value does not have bias for cross attention when they are in BNSH format. if (!params.is_cross_attention) { Qk_vec_k k; @@ -580,6 +581,8 @@ __global__ void masked_multihead_attention_kernel(DecoderMaskedMultiHeadAttentio // One group of threads computes the product(s) for the current timestep. V_vec_k v_bias; + + // This has assumption that key and value does not have bias for cross attention when they are in BNSH format. if (params.v_bias && !params.is_cross_attention) { zero(v_bias); diff --git a/onnxruntime/contrib_ops/cuda/bert/multihead_attention.cc b/onnxruntime/contrib_ops/cuda/bert/multihead_attention.cc index 663bd020ddac7..c36abc8e1d624 100644 --- a/onnxruntime/contrib_ops/cuda/bert/multihead_attention.cc +++ b/onnxruntime/contrib_ops/cuda/bert/multihead_attention.cc @@ -7,6 +7,7 @@ #include "contrib_ops/cpu/bert/multihead_attention_helper.h" #include "contrib_ops/cuda/bert/cutlass_fmha/memory_efficient_attention.h" #include "contrib_ops/cuda/bert/flash_attention/flash_api.h" +#include "contrib_ops/cuda/utils/dump_cuda_tensor.h" using namespace onnxruntime::cuda; using namespace ::onnxruntime::common; @@ -44,7 +45,8 @@ MultiHeadAttention::MultiHeadAttention(const OpKernelInfo& info) scale_ = info.GetAttrOrDefault("scale", 0.0f); is_unidirectional_ = info.GetAttrOrDefault("unidirectional", 0) == 1; - ORT_ENFORCE(!is_unidirectional_, "Unidirectional MHA does not support CUDA kernel. Consider using Attention or GQA instead."); + ORT_ENFORCE(!is_unidirectional_, + "MHA support CUDA kernel does not Unidirectional. Consider using Attention or GQA instead."); kernel_options_ = this->GetAttentionKernelOptions(); @@ -95,7 +97,7 @@ Status MultiHeadAttention::ComputeInternal(OpKernelContext* context) const { scale_, is_unidirectional_, false, // past_present_share_buffer - false, // dmmha_packing + kMultiHeadAttention, device_prop.maxThreadsPerBlock)); int sequence_length = parameters.sequence_length; @@ -111,25 +113,43 @@ Status MultiHeadAttention::ComputeInternal(OpKernelContext* context) const { Tensor* present_key = context->Output(1, present_shape); Tensor* present_value = context->Output(2, present_shape); - MHARunner* fused_runner = nullptr; + int num_past = static_cast(past_key != nullptr) + static_cast(past_value != nullptr); + int num_present = static_cast(present_key != nullptr) + static_cast(present_value != nullptr); + if (num_past == 0 && num_present == 0) { + // It is valid case without past state. + } else if ((num_past == 2 && num_present == 2) || (num_past == 0 && num_present == 2)) { + if (parameters.qkv_format == AttentionQkvFormat::QKV_BSN3H) { + return ORT_MAKE_STATUS( + ONNXRUNTIME, INVALID_ARGUMENT, + "Inputs 'past_key', 'past_value', 'present_key' and 'present_value' shall be empty for packed QKV format"); + } + + if (parameters.qkv_format == AttentionQkvFormat::Q_KV_BSNH_BSN2H) { + return ORT_MAKE_STATUS( + ONNXRUNTIME, INVALID_ARGUMENT, + "Inputs 'past_key', 'past_value', 'present_key' and 'present_value' shall be empty for packed KV format"); + } + + if (parameters.qkv_format == AttentionQkvFormat::Q_K_V_BSNH_BNSH_BNSH) { + return ORT_MAKE_STATUS( + ONNXRUNTIME, INVALID_ARGUMENT, + "Inputs 'past_key', 'past_value', 'present_key' and 'present_value' shall be empty for cross attention"); + } + } else { + return ORT_MAKE_STATUS( + ONNXRUNTIME, INVALID_ARGUMENT, + "Inputs 'past_key', 'past_value', 'present_key' and 'present_value' shall be all provided, " + "or all empty, or only present_key and present_value are provided"); + } + MHARunner* fused_runner = nullptr; const FusedMultiHeadCrossAttentionKernel* fused_cross_attention_kernel = nullptr; // Check whether we can use fused kernel int sm = device_prop.major * 10 + device_prop.minor; - bool is_mask_1d_seq_len = parameters.mask_type == AttentionMaskType::MASK_1D_KEY_SEQ_LEN; - - const bool pass_key_value_as_past = (parameters.pass_past_in_kv && nullptr != key && nullptr != value); - -#if USE_FLASH_ATTENTION || USE_MEMORY_EFFICIENT_ATTENTION - // Exclude this case since PrepareQkv will convert the format to BNSH. - bool past_no_bias = (pass_key_value_as_past || past_key != nullptr || present_key != nullptr) && bias == nullptr; -#endif - #if USE_FLASH_ATTENTION bool use_flash_attention = !disable_flash_attention_ && - !past_no_bias && nullptr == relative_position_bias && nullptr == key_padding_mask && parameters.head_size == parameters.v_head_size && @@ -138,7 +158,7 @@ Status MultiHeadAttention::ComputeInternal(OpKernelContext* context) const { parameters.num_heads, parameters.num_heads); // When input is packed QKV format, TensorRT kernel might be faster than flash attention when sequence length <= 512. - if (use_flash_attention && key == nullptr && value == nullptr && + if (use_flash_attention && parameters.qkv_format == AttentionQkvFormat::QKV_BS3NH && parameters.sequence_length < kernel_options_->MinSeqLenForFlashAttentionPackedQkv()) { use_flash_attention = false; } @@ -162,19 +182,21 @@ Status MultiHeadAttention::ComputeInternal(OpKernelContext* context) const { auto out_accum_buffer = GetScratchBuffer(0, context->GetComputeStream()); // nullptr #endif - bool use_fused_cross_attention = !use_flash_attention && - !disable_fused_cross_attention_ && - nullptr == key_padding_mask && - nullptr == relative_position_bias && - (nullptr == past_key && nullptr == past_value && !parameters.pass_past_in_kv) && - key != nullptr && - (value != nullptr || bias == nullptr) && // TODO: new kernel for adding bias to packed KV - parameters.hidden_size == parameters.v_hidden_size && - has_fused_cross_attention_kernel(sm, parameters.head_size, - parameters.kv_sequence_length); + bool use_fused_cross_attention = + !use_flash_attention && + !disable_fused_cross_attention_ && + nullptr == key_padding_mask && + nullptr == relative_position_bias && + nullptr == past_key && nullptr == present_key && + (parameters.qkv_format == Q_K_V_BSNH || (parameters.qkv_format == Q_KV_BSNH_BSN2H && bias == nullptr)) && + parameters.hidden_size == parameters.v_hidden_size && + has_fused_cross_attention_kernel(sm, parameters.head_size, + parameters.kv_sequence_length); if (use_fused_cross_attention) { if (fused_fp16_cross_attention_kernel_ == nullptr) { - fused_fp16_cross_attention_kernel_ = get_fused_cross_attention_kernels(sm); + std::call_once(fused_cross_init_once_flag_, [&]() { + fused_fp16_cross_attention_kernel_ = get_fused_cross_attention_kernels(sm); + }); } // In case some kernel not loaded due to shared memory limit, we need to double check here. @@ -184,17 +206,18 @@ Status MultiHeadAttention::ComputeInternal(OpKernelContext* context) const { } } - bool use_fused_runner = !use_flash_attention && - !disable_fused_self_attention_ && - fused_cross_attention_kernel == nullptr && - nullptr == relative_position_bias && - (value != nullptr || key == nullptr) && - (nullptr == past_key && nullptr == past_value && !parameters.pass_past_in_kv) && - (nullptr == key_padding_mask || is_mask_1d_seq_len) && - parameters.hidden_size == parameters.v_hidden_size && - parameters.sequence_length == parameters.kv_sequence_length && - FusedMHARunnerFP16v2::IsSupported(sm, parameters.head_size, sequence_length, - enable_trt_flash_attention_, false); + bool use_fused_runner = + !use_flash_attention && + !disable_fused_self_attention_ && + fused_cross_attention_kernel == nullptr && + nullptr == relative_position_bias && + (parameters.qkv_format == Q_K_V_BSNH || parameters.qkv_format == QKV_BSN3H) && + nullptr == past_key && nullptr == present_key && + (nullptr == key_padding_mask || AttentionMaskType::MASK_1D_KEY_SEQ_LEN) && + parameters.hidden_size == parameters.v_hidden_size && + parameters.sequence_length == parameters.kv_sequence_length && // self attention only for fused runner + FusedMHARunnerFP16v2::IsSupported(sm, parameters.head_size, sequence_length, + enable_trt_flash_attention_, false); if (use_fused_runner) { // Here we assume that num_heads and head_size does not change for a MultiHeadAttention node. if (nullptr == fused_fp16_runner_.get()) { @@ -214,10 +237,11 @@ Status MultiHeadAttention::ComputeInternal(OpKernelContext* context) const { #if USE_MEMORY_EFFICIENT_ATTENTION int length_threshold = this->kernel_options_->MinSeqLenForEfficientAttentionFp32(); - bool is_long_sequence = sizeof(T) == 2 || // sequence length threshold is 0 for FP16 + bool is_long_sequence = std::is_same::value || // sequence length threshold is 0 for FP16 parameters.sequence_length >= length_threshold || parameters.kv_sequence_length >= length_threshold; + // Check whether the relative position bias alignment is good for memory efficient attention. bool is_good_for_rpb = relative_position_bias != nullptr && parameters.sequence_length % (4 * sizeof(T)) == 0; bool use_memory_efficient_attention = @@ -226,82 +250,25 @@ Status MultiHeadAttention::ComputeInternal(OpKernelContext* context) const { fused_cross_attention_kernel == nullptr && !disable_memory_efficient_attention_ && is_long_sequence && - !past_no_bias && (relative_position_bias == nullptr || is_good_for_rpb) && (nullptr == key_padding_mask || parameters.mask_type == AttentionMaskType::MASK_1D_KEY_SEQ_LEN_START) && - has_memory_efficient_attention(sm, sizeof(T) == 2, parameters.head_size, parameters.v_head_size); + has_memory_efficient_attention(sm, std::is_same::value, + parameters.head_size, parameters.v_head_size); #else constexpr bool use_memory_efficient_attention = false; #endif - if (kernel_options_->AllowDebugInfo()) { - AttentionKernelDebugInfo debug_info; - debug_info.use_flash_attention = use_flash_attention; - debug_info.use_trt_cross_attention = fused_cross_attention_kernel != nullptr; - debug_info.use_efficient_attention = use_memory_efficient_attention; - if (fused_fp16_runner_ != nullptr) { - debug_info.SetTrtFusedKernel(is_unidirectional_, enable_trt_flash_attention_, sequence_length); - } - - debug_info.Print("MultiHeadAttention", - this->Node().Name(), - std::is_same::value, - std::is_same::value); - } - - // When packed kv or packed qkv is used, there is no needed for add bias transpose thus no qkv workspace. - // TODO(tianleiwu): flash attention or memory efficient attention might not need qkv workspace sometime. - bool no_qkv_workspace = nullptr == value && - (use_fused_cross_attention || (nullptr != fused_runner && nullptr == key)) && - nullptr == key_padding_mask && - nullptr == bias; - - size_t workspace_bytes; - constexpr size_t element_size = sizeof(T); - if (no_qkv_workspace) { - workspace_bytes = (parameters.batch_size > kCumulatedSequenceLengthCacheMaxBatchSize) ? 2 * GetSequenceOffsetSize(parameters.batch_size, true) : 0; - } else { - workspace_bytes = GetAttentionWorkspaceSize(element_size, - parameters.batch_size, - parameters.num_heads, - parameters.head_size, - parameters.v_head_size, - parameters.sequence_length, - parameters.kv_sequence_length, - parameters.total_sequence_length, - fused_runner, - use_flash_attention, - use_fused_cross_attention, - use_memory_efficient_attention); - } - - auto work_space = GetScratchBuffer(workspace_bytes, context->GetComputeStream()); - - const size_t past_k_bytes = element_size * parameters.batch_size * parameters.kv_sequence_length * parameters.num_heads * parameters.head_size; - const size_t past_v_bytes = element_size * parameters.batch_size * parameters.kv_sequence_length * parameters.num_heads * parameters.v_head_size; - const bool use_temp_k_v_workspace = parameters.pass_past_in_kv || use_memory_efficient_attention || use_flash_attention; - auto temp_k_work_space = use_temp_k_v_workspace ? GetScratchBuffer(past_k_bytes, context->GetComputeStream()) : nullptr; - auto temp_v_work_space = use_temp_k_v_workspace ? GetScratchBuffer(past_v_bytes, context->GetComputeStream()) : nullptr; - typedef typename ToCudaType::MappedType CudaT; AttentionData data; data.bias = (nullptr == bias) ? nullptr : reinterpret_cast(bias->Data()); data.query = reinterpret_cast(query->Data()); - data.key = (nullptr == key || parameters.pass_past_in_kv) ? nullptr : reinterpret_cast(key->Data()); - data.value = (nullptr == value || parameters.pass_past_in_kv) ? nullptr : reinterpret_cast(value->Data()); + data.key = (nullptr == key) ? nullptr : reinterpret_cast(key->Data()); + data.value = (nullptr == value) ? nullptr : reinterpret_cast(value->Data()); data.mask_index = (nullptr == key_padding_mask) ? nullptr : key_padding_mask->Data(); data.mask_index_dims = (nullptr == key_padding_mask) ? gsl::span() : key_padding_mask->Shape().GetDims(); - data.past_key = pass_key_value_as_past ? reinterpret_cast(key->Data()) - : (nullptr == past_key) ? nullptr - : reinterpret_cast(past_key->Data()); - data.past_value = pass_key_value_as_past ? reinterpret_cast(value->Data()) - : (nullptr == past_value) ? nullptr - : reinterpret_cast(past_value->Data()); + data.past_key = (nullptr == past_key) ? nullptr : reinterpret_cast(past_key->Data()); + data.past_value = (nullptr == past_value) ? nullptr : reinterpret_cast(past_value->Data()); data.relative_position_bias = (nullptr == relative_position_bias) ? nullptr : reinterpret_cast(relative_position_bias->Data()); - data.has_qkv_workspace = !no_qkv_workspace; - data.workspace = reinterpret_cast(work_space.get()); - data.temp_k_workspace = use_temp_k_v_workspace ? reinterpret_cast(temp_k_work_space.get()) : nullptr; - data.temp_v_workspace = use_temp_k_v_workspace ? reinterpret_cast(temp_v_work_space.get()) : nullptr; data.output = reinterpret_cast(output->MutableData()); data.present_key = (nullptr == present_key) ? nullptr : reinterpret_cast(present_key->MutableData()); data.present_value = (nullptr == present_value) ? nullptr : reinterpret_cast(present_value->MutableData()); @@ -309,8 +276,41 @@ Status MultiHeadAttention::ComputeInternal(OpKernelContext* context) const { data.fused_cross_attention_kernel = fused_cross_attention_kernel; data.use_flash_attention = use_flash_attention; data.use_memory_efficient_attention = use_memory_efficient_attention; - data.cumulated_sequence_length_q_cache = &(this->cumulated_sequence_length_q_cache_); - data.cumulated_sequence_length_kv_cache = &(this->cumulated_sequence_length_kv_cache_); + + // Cache of cumulated sequence length that could help when sequence length does not change (for example, image model). + // The cache will be initialized only once, and become readonly after that. + if ((data.fused_cross_attention_kernel != nullptr || data.fused_runner != nullptr) && data.mask_index == nullptr) { + cudaStream_t stream = Stream(context); + data.cumulated_sequence_length_q_cache = this->cumulated_sequence_length_q_cache_.TryGet( + parameters.batch_size, parameters.sequence_length, stream); + + if (data.fused_cross_attention_kernel != nullptr) { + data.cumulated_sequence_length_kv_cache = this->cumulated_sequence_length_kv_cache_.TryGet( + parameters.batch_size, parameters.kv_sequence_length, stream); + } + } + + const bool no_qkv_workspace = NoQkvWorkspace(parameters, data); + size_t workspace_bytes = GetAttentionWorkspaceSize(sizeof(T), + parameters.batch_size, + parameters.num_heads, + parameters.head_size, + parameters.v_head_size, + parameters.sequence_length, + parameters.kv_sequence_length, + parameters.total_sequence_length, + fused_runner, + use_flash_attention, + use_fused_cross_attention, + use_memory_efficient_attention, + no_qkv_workspace); + auto work_space = GetScratchBuffer(workspace_bytes, context->GetComputeStream()); + + data.has_qkv_workspace = !no_qkv_workspace; + data.workspace = reinterpret_cast(work_space.get()); + data.workspace_bytes = workspace_bytes; + + data.allow_debug_info = kernel_options_->AllowDebugInfo(); if (softmax_lse_accum_buffer != nullptr) { data.softmax_lse_accum = reinterpret_cast(softmax_lse_accum_buffer.get()); } @@ -318,8 +318,23 @@ Status MultiHeadAttention::ComputeInternal(OpKernelContext* context) const { data.out_accum = reinterpret_cast(out_accum_buffer.get()); } - cublasHandle_t cublas = GetCublasHandle(context); + if (data.allow_debug_info) { + AttentionKernelDebugInfo debug_info; + debug_info.use_flash_attention = use_flash_attention; + debug_info.use_trt_cross_attention = fused_cross_attention_kernel != nullptr; + debug_info.use_efficient_attention = use_memory_efficient_attention; + if (fused_fp16_runner_ != nullptr) { + debug_info.SetTrtFusedKernel(is_unidirectional_, enable_trt_flash_attention_, sequence_length); + } + debug_info.Print("MultiHeadAttention", + this->Node().Name(), + std::is_same::value, + std::is_same::value); + + data.PrintDebugInfo(); + } + cublasHandle_t cublas = GetCublasHandle(context); return QkvToContext( device_prop, cublas, context->GetComputeStream(), parameters, data); } diff --git a/onnxruntime/contrib_ops/cuda/bert/multihead_attention.h b/onnxruntime/contrib_ops/cuda/bert/multihead_attention.h index 26e38dbad9fd7..68fd0c9943fca 100644 --- a/onnxruntime/contrib_ops/cuda/bert/multihead_attention.h +++ b/onnxruntime/contrib_ops/cuda/bert/multihead_attention.h @@ -4,6 +4,7 @@ #pragma once #include +#include #include "core/providers/cuda/cuda_kernel.h" #include "contrib_ops/cuda/bert/tensorrt_fused_multihead_attention/mha_runner.h" #include "contrib_ops/cuda/bert/tensorrt_fused_multihead_attention/cross_attention/fmha_cross_attention.h" @@ -32,11 +33,16 @@ class MultiHeadAttention final : public CudaKernel { bool disable_fused_cross_attention_; bool disable_flash_attention_; bool disable_memory_efficient_attention_; + + // These mutable members are readonly after they are initialized so that they can be shared among multiple threads. + // Initialization are done only once by the first thread using the resource, so use once_flag to guard each resource. mutable std::unique_ptr fused_fp16_runner_; mutable std::once_flag fused_fp16_runner_created_; mutable const FusedMultiHeadCrossAttentionKernel* fused_fp16_cross_attention_kernel_; + mutable std::once_flag fused_cross_init_once_flag_; mutable CumulatedSequenceLengthCache cumulated_sequence_length_q_cache_; mutable CumulatedSequenceLengthCache cumulated_sequence_length_kv_cache_; + const AttentionKernelOptions* kernel_options_; }; diff --git a/onnxruntime/contrib_ops/cuda/bert/packed_attention_impl.cu b/onnxruntime/contrib_ops/cuda/bert/packed_attention_impl.cu index ac2cb5165a94c..2521cd49b5482 100644 --- a/onnxruntime/contrib_ops/cuda/bert/packed_attention_impl.cu +++ b/onnxruntime/contrib_ops/cuda/bert/packed_attention_impl.cu @@ -297,7 +297,7 @@ struct T2 { }; template -void LaunchAddBiasTranspose( +void AddBiasTransposePacked( const T* input, const T* biases, T* output, const int batch_size, const int sequence_length, const int num_heads, const int qk_head_size, const int v_head_size, @@ -452,7 +452,7 @@ Status FusedScaledDotProductAttention( void* fused_runner = data.fused_runner; ORT_RETURN_IF_NOT(nullptr != fused_runner, "fused_runner cannot be NULL"); - LaunchAddBiasTranspose(data.gemm_buffer, data.bias, data.workspace, + AddBiasTransposePacked(data.gemm_buffer, data.bias, data.workspace, batch_size, sequence_length, num_heads, qk_head_size, v_head_size, AttentionQkvFormat::QKV_BSN3H, data.token_offset, @@ -477,7 +477,7 @@ Status FusedScaledDotProductAttentionCutlass( const int num_heads = parameters.num_heads; const int qk_head_size = parameters.head_size; const int v_head_size = parameters.v_head_size; - LaunchAddBiasTranspose(data.gemm_buffer, data.bias, data.workspace, + AddBiasTransposePacked(data.gemm_buffer, data.bias, data.workspace, batch_size, sequence_length, num_heads, qk_head_size, v_head_size, AttentionQkvFormat::Q_K_V_BSNH, data.token_offset, @@ -564,7 +564,7 @@ Status UnfusedScaledDotProductAttention( T* k = q + elements_q; T* v = k + elements_k; - LaunchAddBiasTranspose(data.gemm_buffer, data.bias, data.workspace, + AddBiasTransposePacked(data.gemm_buffer, data.bias, data.workspace, batch_size, sequence_length, num_heads, qk_head_size, v_head_size, AttentionQkvFormat::Q_K_V_BNSH, data.token_offset, @@ -657,6 +657,20 @@ Status QkvToContext( return UnfusedScaledDotProductAttention(device_prop, cublas, stream, parameters, data); } +template void AddBiasTransposePacked( + const float* input, const float* biases, float* output, + const int batch_size, const int sequence_length, + const int num_heads, const int qk_head_size, const int v_head_size, + AttentionQkvFormat format, const int32_t* token_offset, int32_t token_count, + cudaStream_t stream); + +template void AddBiasTransposePacked( + const half* input, const half* biases, half* output, + const int batch_size, const int sequence_length, + const int num_heads, const int qk_head_size, const int v_head_size, + AttentionQkvFormat format, const int32_t* token_offset, int32_t token_count, + cudaStream_t stream); + template Status QkvToContext( const cudaDeviceProp& device_prop, cublasHandle_t& cublas, diff --git a/onnxruntime/contrib_ops/cuda/bert/packed_multihead_attention_impl.cu b/onnxruntime/contrib_ops/cuda/bert/packed_multihead_attention_impl.cu index b4ca0194b08bc..e5a4c54f48903 100644 --- a/onnxruntime/contrib_ops/cuda/bert/packed_multihead_attention_impl.cu +++ b/onnxruntime/contrib_ops/cuda/bert/packed_multihead_attention_impl.cu @@ -502,7 +502,7 @@ struct T2 { }; template -void LaunchTranspose( +void AddBiasTransposePacked( const T* query, const T* key, const T* value, const T* bias, T* output, const int batch_size, const int sequence_length, const int num_heads, const int qk_head_size, const int v_head_size, @@ -566,11 +566,11 @@ Status FusedAttentionTrt( // When packed QKV is used, we can directly pass it to fused runner. Otherwise, we need transpose to BSN3H format. const T* qkv = data.query; if (!data.no_qkv_workspace) { - LaunchTranspose(data.query, data.key, data.value, data.bias, data.workspace, - batch_size, sequence_length, - num_heads, qk_head_size, v_head_size, - data.source_qkv_format, AttentionQkvFormat::QKV_TN3H, - data.token_offset, parameters.token_count, stream); + AddBiasTransposePacked(data.query, data.key, data.value, data.bias, data.workspace, + batch_size, sequence_length, + num_heads, qk_head_size, v_head_size, + data.source_qkv_format, AttentionQkvFormat::QKV_TN3H, + data.token_offset, parameters.token_count, stream); qkv = data.workspace; } @@ -601,11 +601,11 @@ Status FlashAttention( // When separated Q, K, V is used, we can directly use them in Cutlass FMHA. Otherwise, transpose BSN3H to 3BSNH if (!data.no_qkv_workspace) { - LaunchTranspose(data.query, data.key, data.value, data.bias, data.workspace, - batch_size, sequence_length, - num_heads, qk_head_size, v_head_size, - data.source_qkv_format, AttentionQkvFormat::Q_K_V_TNH, - data.token_offset, parameters.token_count, stream); + AddBiasTransposePacked(data.query, data.key, data.value, data.bias, data.workspace, + batch_size, sequence_length, + num_heads, qk_head_size, v_head_size, + data.source_qkv_format, AttentionQkvFormat::Q_K_V_TNH, + data.token_offset, parameters.token_count, stream); } float scale = parameters.scale == 0.0f ? 1.f / sqrt(static_cast(qk_head_size)) @@ -675,11 +675,11 @@ Status FusedAttentionCutlass( // When separated Q, K, V is used, we can directly use them in Cutlass FMHA. Otherwise, transpose BSN3H to 3BSNH if (!data.no_qkv_workspace) { - LaunchTranspose(data.query, data.key, data.value, data.bias, data.workspace, - batch_size, sequence_length, - num_heads, qk_head_size, v_head_size, - data.source_qkv_format, AttentionQkvFormat::Q_K_V_TNH, - data.token_offset, parameters.token_count, stream); + AddBiasTransposePacked(data.query, data.key, data.value, data.bias, data.workspace, + batch_size, sequence_length, + num_heads, qk_head_size, v_head_size, + data.source_qkv_format, AttentionQkvFormat::Q_K_V_TNH, + data.token_offset, parameters.token_count, stream); } MemoryEfficientAttentionParams p; @@ -746,11 +746,11 @@ Status UnfusedAttention( const size_t elements_v = static_cast(batches) * static_cast(size_per_batch_v); // Q, K and V pointers when fused attention is not used - LaunchTranspose(data.query, data.key, data.value, data.bias, data.workspace, - batch_size, sequence_length, - num_heads, qk_head_size, v_head_size, - data.source_qkv_format, AttentionQkvFormat::Q_K_V_BNSH, - data.token_offset, parameters.token_count, stream); + AddBiasTransposePacked(data.query, data.key, data.value, data.bias, data.workspace, + batch_size, sequence_length, + num_heads, qk_head_size, v_head_size, + data.source_qkv_format, AttentionQkvFormat::Q_K_V_BNSH, + data.token_offset, parameters.token_count, stream); T* qkv = data.workspace; T* q = qkv; @@ -848,6 +848,22 @@ Status QkvToContext( return UnfusedAttention(device_prop, cublas, stream, parameters, data); } +template void AddBiasTransposePacked( + const half* query, const half* key, const half* value, const half* bias, half* output, + const int batch_size, const int sequence_length, + const int num_heads, const int qk_head_size, const int v_head_size, + AttentionQkvFormat source_format, AttentionQkvFormat target_format, + const int32_t* token_offset, int32_t token_count, + cudaStream_t stream); + +template void AddBiasTransposePacked( + const float* query, const float* key, const float* value, const float* bias, float* output, + const int batch_size, const int sequence_length, + const int num_heads, const int qk_head_size, const int v_head_size, + AttentionQkvFormat source_format, AttentionQkvFormat target_format, + const int32_t* token_offset, int32_t token_count, + cudaStream_t stream); + template Status QkvToContext( const cudaDeviceProp& device_prop, cublasHandle_t& cublas, diff --git a/onnxruntime/contrib_ops/cuda/quantization/attention_quantization.cc b/onnxruntime/contrib_ops/cuda/quantization/attention_quantization.cc index 168c69c69f003..b62e566d43f89 100644 --- a/onnxruntime/contrib_ops/cuda/quantization/attention_quantization.cc +++ b/onnxruntime/contrib_ops/cuda/quantization/attention_quantization.cc @@ -190,7 +190,8 @@ Status QAttention::ComputeInternal(OpKernelContext* context) const { fused_runner, use_flash_attention, use_fused_cross_attention, - use_memory_efficient_attention); + use_memory_efficient_attention, + true); auto work_space = GetScratchBuffer(workSpaceSize, context->GetComputeStream()); @@ -208,6 +209,7 @@ Status QAttention::ComputeInternal(OpKernelContext* context) const { data.has_qkv_workspace = true; data.workspace = reinterpret_cast(work_space.get()); + data.workspace_bytes = workSpaceSize; data.output = reinterpret_cast(output->MutableData()); if (nullptr != present) { data.present = reinterpret_cast(present->MutableData()); diff --git a/onnxruntime/contrib_ops/cuda/utils/dump_cuda_tensor.cc b/onnxruntime/contrib_ops/cuda/utils/dump_cuda_tensor.cc index e10c2ec63fd51..6d52ff7282799 100644 --- a/onnxruntime/contrib_ops/cuda/utils/dump_cuda_tensor.cc +++ b/onnxruntime/contrib_ops/cuda/utils/dump_cuda_tensor.cc @@ -13,6 +13,9 @@ namespace cuda { #if DUMP_TENSOR_LEVEL > 0 +// Environment variable to enable/disable GPU Tensor dumping +constexpr const char* kEnableGpuTensorDumper = "ORT_ENABLE_GPU_DUMP"; + // Total number of elements which trigger snippet rather than full dump (default 200). Value 0 disables snippet. constexpr const char* kTensorSnippetThreshold = "ORT_TENSOR_SNIPPET_THRESHOLD"; @@ -202,6 +205,10 @@ void DumpGpuTensor(const char* name, const Tensor& tensor) { DumpGpuTensor(nullptr, tensor, static_cast(num_rows), static_cast(row_size)); } +CudaTensorConsoleDumper::CudaTensorConsoleDumper() { + is_enabled_ = ParseEnvironmentVariableWithDefault(kEnableGpuTensorDumper, 1) != 0; +} + void CudaTensorConsoleDumper::Print(const std::string& value) const { std::cout << value << std::endl; } @@ -329,6 +336,8 @@ void CudaTensorConsoleDumper::Print(const char* name, const std::string& value, } #else +CudaTensorConsoleDumper::CudaTensorConsoleDumper() { +} void CudaTensorConsoleDumper::Print(const std::string&) const { } diff --git a/onnxruntime/contrib_ops/cuda/utils/dump_cuda_tensor.h b/onnxruntime/contrib_ops/cuda/utils/dump_cuda_tensor.h index 6ad0ad9a67b75..4f41161cd4a31 100644 --- a/onnxruntime/contrib_ops/cuda/utils/dump_cuda_tensor.h +++ b/onnxruntime/contrib_ops/cuda/utils/dump_cuda_tensor.h @@ -13,7 +13,7 @@ namespace cuda { class CudaTensorConsoleDumper : public onnxruntime::contrib::IConsoleDumper { public: - CudaTensorConsoleDumper() = default; + CudaTensorConsoleDumper(); virtual ~CudaTensorConsoleDumper() {} void Print(const char* name, const size_t* tensor, int dim0, int dim1) const override; diff --git a/onnxruntime/contrib_ops/rocm/bert/attention_impl.cu b/onnxruntime/contrib_ops/rocm/bert/attention_impl.cu index b0ed3ff82226a..b94971ffd44d5 100644 --- a/onnxruntime/contrib_ops/rocm/bert/attention_impl.cu +++ b/onnxruntime/contrib_ops/rocm/bert/attention_impl.cu @@ -119,7 +119,7 @@ Status ClassifyAttentionMode( if (attn->qkv_format == Q_K_V_BSNH) { attn->mode = BSNH_BLNH_BLNH_NONE_NONE_NONE_NONE; return Status::OK(); - } else if (attn->pass_past_in_kv) { + } else if (attn->qkv_format == Q_K_V_BSNH_BNSH_BNSH) { attn->mode = BSNH_BNLH_BNLH_NONE_NONE_NONE_NONE; return Status::OK(); } @@ -128,7 +128,7 @@ Status ClassifyAttentionMode( if (attn->qkv_format == Q_K_V_BSNH) { attn->mode = BSNH_BLNH_BLNH_NONE_NONE_BNTH_BNTH; return Status::OK(); - } else if (attn->pass_past_in_kv) { + } else if (attn->qkv_format == Q_K_V_BSNH_BNSH_BNSH) { attn->mode = BSNH_BNLH_BNLH_NONE_NONE_BNTH_BNTH; return Status::OK(); } @@ -136,7 +136,7 @@ Status ClassifyAttentionMode( if (attn->qkv_format == Q_K_V_BSNH) { attn->mode = BSNH_BLNH_BLNH_NONE_NONE_BNMH_BNMH; return Status::OK(); - } else if (attn->pass_past_in_kv) { + } else if (attn->qkv_format == Q_K_V_BSNH_BNSH_BNSH) { attn->mode = BSNH_BNLH_BNLH_NONE_NONE_BNMH_BNMH; return Status::OK(); } @@ -146,7 +146,7 @@ Status ClassifyAttentionMode( if (attn->qkv_format == Q_K_V_BSNH) { attn->mode = BSNH_BLNH_BLNH_BNPH_BNPH_BNTH_BNTH; return Status::OK(); - } else if (attn->pass_past_in_kv) { + } else if (attn->qkv_format == Q_K_V_BSNH_BNSH_BNSH) { attn->mode = BSNH_BNLH_BNLH_BNPH_BNPH_BNTH_BNTH; return Status::OK(); } @@ -154,7 +154,7 @@ Status ClassifyAttentionMode( if (attn->qkv_format == Q_K_V_BSNH) { attn->mode = BSNH_BLNH_BLNH_BNMH_BNMH_BNMH_BNMH; return Status::OK(); - } else if (attn->pass_past_in_kv) { + } else if (attn->qkv_format == Q_K_V_BSNH_BNSH_BNSH) { attn->mode = BSNH_BNLH_BNLH_BNMH_BNMH_BNMH_BNMH; return Status::OK(); } diff --git a/onnxruntime/contrib_ops/rocm/bert/attention_impl.h b/onnxruntime/contrib_ops/rocm/bert/attention_impl.h index 349df045becf2..d593bc0012826 100644 --- a/onnxruntime/contrib_ops/rocm/bert/attention_impl.h +++ b/onnxruntime/contrib_ops/rocm/bert/attention_impl.h @@ -132,12 +132,6 @@ class CompatRocblasMathModeSetter { } }; -enum AttentionType { - kAttention, - kMultiHeadAttention, - kDecoderMaskedMultiHeadAttention, -}; - enum AttentionMode { // Q,K,V,PastK,PastV,PresentK,PresentV QFMT_KFMT_VFMT_NONE_NONE_NONE_NONE, diff --git a/onnxruntime/contrib_ops/rocm/bert/multihead_attention.cu b/onnxruntime/contrib_ops/rocm/bert/multihead_attention.cu index 09e7d61b71db9..5997daaca6e8a 100644 --- a/onnxruntime/contrib_ops/rocm/bert/multihead_attention.cu +++ b/onnxruntime/contrib_ops/rocm/bert/multihead_attention.cu @@ -122,9 +122,11 @@ Status MultiHeadAttention::ComputeInternal(OpKernelContext* context) const { query, key, value, bias, key_padding_mask, relative_position_bias, past_key, past_value, past_seq_len, - &attn, num_heads_, - mask_filter_value_, scale_, false, /*is_unidirectional_*/ - past_present_share_buffer_, false, device_prop.maxThreadsPerBlock)); + &attn, num_heads_, + mask_filter_value_, scale_, false, /*is_unidirectional_*/ + past_present_share_buffer_, + attn_type_, + device_prop.maxThreadsPerBlock)); if (attn_type_ == kDecoderMaskedMultiHeadAttention && attn.sequence_length != 1) { return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, diff --git a/onnxruntime/python/tools/transformers/fusion_attention.py b/onnxruntime/python/tools/transformers/fusion_attention.py index dc2b38f3928ac..a9ff623fb6967 100644 --- a/onnxruntime/python/tools/transformers/fusion_attention.py +++ b/onnxruntime/python/tools/transformers/fusion_attention.py @@ -691,6 +691,9 @@ def create_multihead_attention_node( return None # Add bias to inputs for MHA + # Bias for cross attention is not fully supported in DMMHA and cpu MHA kernels since they assume + # bias has been added to key and value when they are in BNSH format, so only bias for query is used. + # Need add checks if we found such assumption is not true. if not self.disable_multi_head_attention_bias: bias_name = self.create_combined_qkv_bias(q_add, k_add, v_add, mha_node_name) mha_inputs.append(bias_name) diff --git a/onnxruntime/test/python/transformers/benchmark_mha.py b/onnxruntime/test/python/transformers/benchmark_mha.py index 715a92431e6bf..ec350874af32c 100644 --- a/onnxruntime/test/python/transformers/benchmark_mha.py +++ b/onnxruntime/test/python/transformers/benchmark_mha.py @@ -88,9 +88,11 @@ def __init__( enable_cuda_graph: bool = False, dtype=torch.float, use_kv_cache: bool = False, + has_past_input: bool = False, share_past_present_buffer: bool = False, input_format: int = InputFormats.Q_K_V_BSNH_BSNH_BSNH, verbose: bool = False, + has_bias: bool = False, ): self.operator = "MultiHeadAttention" self.batch_size = batch_size @@ -103,15 +105,25 @@ def __init__( self.causal = causal self.softmax_scale = softmax_scale or (1.0 / (head_size**0.5)) + # Support the case that there is no past but need present output (for prompt case). + self.has_past_input = has_past_input + if has_past_input: + assert use_kv_cache + else: # no past input + assert past_sequence_length == 0 + + self.has_present_output = use_kv_cache + self.use_kv_cache = use_kv_cache if not use_kv_cache: assert past_sequence_length == 0 else: assert self.kv_sequence_length == self.sequence_length - if input_format == InputFormats.Q_K_V_BSNH_BNSH_BNSH: - # cross attention does not have past state - assert not use_kv_cache + # Only BSNH input format supports past state. + if input_format != InputFormats.Q_K_V_BSNH_BSNH_BSNH: + assert not self.has_past_input + assert not self.has_present_output # Derived values self.total_sequence_length = self.kv_sequence_length + past_sequence_length @@ -130,6 +142,7 @@ def __init__( self.is_packed_qkv = input_format == InputFormats.QKV_BSN3H self.is_packed_kv = input_format == InputFormats.Q_KV_BSNH_BSN2H self.verbose = verbose + self.has_bias = has_bias def __repr__(self): return ( @@ -140,7 +153,8 @@ def __repr__(self): f"causal={self.causal}), softmax_scale={self.softmax_scale}, use_kv_cache={self.use_kv_cache}, " f"share_past_present_buffer={self.share_past_present_buffer}, " f"provider={self.provider}, device={self.device}, enable_cuda_graph={self.enable_cuda_graph}, " - f"dtype={self.dtype}, input_format={InputFormats.input_format_str(self.input_format)}" + f"dtype={self.dtype}, input_format={InputFormats.input_format_str(self.input_format)}, " + f"has_bias={self.has_bias}" ) def shape_dict(self, input_format=None): @@ -176,16 +190,23 @@ def shape_dict(self, input_format=None): "value": (self.batch_size, self.num_heads, self.sequence_length, self.head_size), } - if self.use_kv_cache: - assert input_format != InputFormats.Q_K_V_BSNH_BNSH_BNSH, "cross attention shall not have past state" + if self.has_past_input: shapes = { **shapes, "past_key": (self.batch_size, self.num_heads, self.past_buffer_length, self.head_size), "past_value": (self.batch_size, self.num_heads, self.past_buffer_length, self.head_size), + } + + if self.has_present_output: + shapes = { + **shapes, "present_key": (self.batch_size, self.num_heads, self.present_buffer_length, self.head_size), "present_value": (self.batch_size, self.num_heads, self.present_buffer_length, self.head_size), } + if self.has_bias: + shapes["bias"] = (3 * self.num_heads * self.head_size,) + return shapes def symbolic_shape_dict(self, input_format=None): @@ -221,19 +242,26 @@ def symbolic_shape_dict(self, input_format=None): "value": ("batch_size", self.num_heads, "sequence_length", self.head_size), } - if self.use_kv_cache: - assert input_format != InputFormats.Q_K_V_BSNH_BNSH_BNSH, "cross attention shall not have past state" + if self.has_past_input: shapes = { **shapes, "past_key": ("batch_size", self.num_heads, "past_buffer_length", self.head_size), "past_value": ("batch_size", self.num_heads, "past_buffer_length", self.head_size), + } + + if self.has_present_output: + shapes = { + **shapes, "present_key": ("batch_size", self.num_heads, "present_buffer_length", self.head_size), "present_value": ("batch_size", self.num_heads, "present_buffer_length", self.head_size), } + if self.has_bias: + shapes["bias"] = (3 * self.num_heads * self.head_size,) + return shapes - def random_inputs(self, seed: int = 123): + def random_inputs(self, seed: int = 123, no_bias_k_v: bool = False): device = self.device dtype = self.dtype @@ -246,6 +274,14 @@ def random_inputs(self, seed: int = 123): q = torch.empty(shape, device=device, dtype=dtype).normal_(mean=0, std=0.1) k = torch.empty(shape, device=device, dtype=dtype).normal_(mean=0, std=0.1) v = torch.empty(shape, device=device, dtype=dtype).normal_(mean=0, std=0.1) + + bias_q = torch.empty((self.num_heads * self.head_size,), device=device, dtype=dtype).normal_(mean=0, std=0.1) + bias_k = torch.empty((self.num_heads * self.head_size,), device=device, dtype=dtype).normal_(mean=0, std=0.1) + bias_v = torch.empty((self.num_heads * self.head_size,), device=device, dtype=dtype).normal_(mean=0, std=0.1) + if no_bias_k_v: + bias_k = torch.zeros_like(bias_k) + bias_v = torch.zeros_like(bias_v) + k_bnsh = k.transpose(1, 2) v_bnsh = v.transpose(1, 2) @@ -277,7 +313,7 @@ def random_inputs(self, seed: int = 123): "value": v_bnsh.contiguous(), } - if self.use_kv_cache: + if self.has_past_input: feeds = { **feeds, "past_key": torch.empty(shape_dict["past_key"], device=device, dtype=dtype).normal_(mean=0, std=0.1), @@ -286,6 +322,9 @@ def random_inputs(self, seed: int = 123): ), } + if self.has_bias: + feeds["bias"] = torch.concat([bias_q, bias_k, bias_v], dim=0).reshape(shape_dict["bias"]).contiguous() + return feeds def get_input_output_names(self): @@ -299,15 +338,29 @@ def get_input_output_names(self): else: inputs, outputs = ["query", "key", "value"], ["output"] - if self.use_kv_cache: - return [*inputs, "past_key", "past_value"], [*outputs, "present_key", "present_value"] - else: - return inputs, outputs + if self.has_bias: + inputs = [*inputs, "bias"] + + if self.has_past_input: + inputs = [*inputs, "past_key", "past_value"] + + if self.has_present_output: + outputs = [*outputs, "present_key", "present_value"] + + return inputs, outputs def fill_optional_mha_inputs(input_names): inputs = ["query", "key", "value", "bias", "key_padding_mask", "relative_position_bias", "past_key", "past_value"] - return input_names[:-2] + [""] * (len(inputs) - len(input_names)) + input_names[-2:] + + # Remove optional inputs that are not in input_names with empty string + inputs_with_optional = [input if input in input_names else "" for input in inputs] + + # Remove empty string at the end of the list. + while inputs_with_optional[-1] == "": + inputs_with_optional.pop(-1) + + return inputs_with_optional def create_multi_head_attention_onnx_model(config: MultiHeadAttentionConfig, use_symbolic_shape=False): @@ -317,7 +370,7 @@ def create_multi_head_attention_onnx_model(config: MultiHeadAttentionConfig, use nodes = [ helper.make_node( "MultiHeadAttention", - fill_optional_mha_inputs(input_names) if config.use_kv_cache else input_names, + fill_optional_mha_inputs(input_names), output_names, "MultiHeadAttention_0", num_heads=config.num_heads, @@ -331,11 +384,13 @@ def create_multi_head_attention_onnx_model(config: MultiHeadAttentionConfig, use inputs = [ helper.make_tensor_value_info(input_name, float_type, list(shape_dict[input_name])) for input_name in input_names + if input_name ] outputs = [ helper.make_tensor_value_info(output_name, float_type, list(shape_dict[output_name])) for output_name in output_names + if output_name ] graph = helper.make_graph( @@ -355,6 +410,7 @@ def create_ort_session( session_options=None, attention_kernel=SdpaKernel.DEFAULT, use_symbolic_shape: bool = True, + use_tf32: bool = True, ) -> CudaSession: if config.verbose: print(f"create session for {vars(config)}") @@ -364,6 +420,7 @@ def create_ort_session( device_id = torch.cuda.current_device() if isinstance(config.device, str) else config.device.index provider_options = CudaSession.get_cuda_provider_options(device_id, config.enable_cuda_graph) provider_options["sdpa_kernel"] = int(attention_kernel) + provider_options["use_tf32"] = int(use_tf32) providers = [(config.provider, provider_options), "CPUExecutionProvider"] else: providers = ["CPUExecutionProvider"] @@ -373,9 +430,11 @@ def create_ort_session( def create_session( - config: MultiHeadAttentionConfig, session_options=None, attention_kernel=SdpaKernel.DEFAULT + config: MultiHeadAttentionConfig, session_options=None, attention_kernel=SdpaKernel.DEFAULT, use_tf32: bool = True ) -> CudaSession: - ort_session = create_ort_session(config, session_options, attention_kernel, use_symbolic_shape=False) + ort_session = create_ort_session( + config, session_options, attention_kernel, use_symbolic_shape=False, use_tf32=use_tf32 + ) cuda_session = CudaSession(ort_session, config.device, config.enable_cuda_graph) shape_dict = config.shape_dict() cuda_session.allocate_buffers(shape_dict) @@ -385,8 +444,8 @@ def create_session( class OrtMultiHeadAttention: """A wrapper of ORT MultiHeadAttention to test relevance and performance.""" - def __init__(self, config: MultiHeadAttentionConfig, session_options=None): - self.ort_session = create_session(config, session_options) + def __init__(self, config: MultiHeadAttentionConfig, session_options=None, use_tf32: bool = True): + self.ort_session = create_session(config, session_options, use_tf32=use_tf32) self.feed_dict = config.random_inputs() def infer(self): diff --git a/onnxruntime/test/python/transformers/test_mha.py b/onnxruntime/test/python/transformers/test_mha.py index 0fcbd889847e9..a35d02b0b9d52 100644 --- a/onnxruntime/test/python/transformers/test_mha.py +++ b/onnxruntime/test/python/transformers/test_mha.py @@ -21,6 +21,47 @@ import onnxruntime +def get_provider_support_info(provider: str, use_kv_cache: bool): + if provider == "CUDAExecutionProvider": + if not use_kv_cache: + formats = [ + InputFormats.Q_K_V_BSNH_BSNH_BSNH, + InputFormats.Q_KV_BSNH_BSN2H, + InputFormats.QKV_BSN3H, + InputFormats.Q_K_V_BSNH_BNSH_BNSH, + ] + else: + formats = [InputFormats.Q_K_V_BSNH_BSNH_BSNH] + + device_id = torch.cuda.current_device() + device = torch.device("cuda", device_id) + dtype = torch.float16 + else: + assert provider == "CPUExecutionProvider" + formats = [InputFormats.Q_K_V_BSNH_BSNH_BSNH] + if not use_kv_cache: + formats.append(InputFormats.Q_K_V_BSNH_BNSH_BNSH) + device = torch.device("cpu") + dtype = torch.float + return device, dtype, formats + + +def get_bias_support(format: InputFormats): + if format == InputFormats.Q_K_V_BSNH_BSNH_BSNH: + return [True, False] + + if format == InputFormats.Q_K_V_BSNH_BNSH_BNSH: + return [True, False] + + if format == InputFormats.Q_KV_BSNH_BSN2H: + return [False] + + if format == InputFormats.QKV_BSN3H: + return [True, False] + + raise RuntimeError(f"Unknown format: {format}") + + def attention_reference( head_size: int, query: torch.Tensor, @@ -84,8 +125,8 @@ def attention_reference( def mha_with_past_reference( config: MultiHeadAttentionConfig, - past_k: torch.Tensor, - past_v: torch.Tensor, + past_k: Optional[torch.Tensor], + past_v: Optional[torch.Tensor], q: torch.Tensor, k: torch.Tensor, v: torch.Tensor, @@ -94,41 +135,23 @@ def mha_with_past_reference( ): assert config.kv_sequence_length == config.sequence_length assert config.use_kv_cache - assert past_k.dim() == 4 and k.dim() == 4 and past_k.size(1) == k.size(1) # both BNSH format - assert past_v.dim() == 4 and v.dim() == 4 and past_v.size(1) == v.size(1) # both BNSH format - - present_k = torch.cat((past_k, k), dim=2) - present_v = torch.cat((past_v, v), dim=2) + if past_k is not None: + assert ( + past_k.dim() == 4 and k.dim() == 4 and past_k.size(1) == k.size(1) + ), f"expect BNSH format: {past_k.shape=} {k.shape=}" + + if past_v is not None: + assert ( + past_v.dim() == 4 and v.dim() == 4 and past_v.size(1) == v.size(1) + ), f"expect BNSH format: {past_v.shape=} {v.shape=}" + + present_k = torch.cat((past_k, k), dim=2) if past_k is not None else k + present_v = torch.cat((past_v, v), dim=2) if past_v is not None else v out = attention_reference(config.head_size, q, present_k, present_v, scale=scale, mask=mask) return out, present_k, present_v -def get_provider_support_info(provider: str, use_kv_cache: bool): - if provider == "CUDAExecutionProvider": - if not use_kv_cache: - formats = [ - InputFormats.Q_K_V_BSNH_BSNH_BSNH, - InputFormats.Q_KV_BSNH_BSN2H, - InputFormats.QKV_BSN3H, - InputFormats.Q_K_V_BSNH_BNSH_BNSH, - ] - else: - formats = [InputFormats.Q_K_V_BSNH_BSNH_BSNH] - - device_id = torch.cuda.current_device() - device = torch.device("cuda", device_id) - dtype = torch.float16 - else: - assert provider == "CPUExecutionProvider" - formats = [InputFormats.Q_K_V_BSNH_BSNH_BSNH] - if not use_kv_cache: - formats.append(InputFormats.Q_K_V_BSNH_BSNH_BSNH) - device = torch.device("cpu") - dtype = torch.float - return device, dtype, formats - - def get_compute_capability(): if torch.cuda.is_available() and "CUDAExecutionProvider" in onnxruntime.get_available_providers(): major, minor = torch.cuda.get_device_capability() @@ -143,35 +166,38 @@ def no_kv_cache_test_cases(provider: str, comprehensive: bool): yield batch_sizes = [1, 2, 3] - sequence_lengths = [1, 16, 127, 128, 255, 256, 383, 384, 2048] + sequence_lengths = [1, 16, 127, 128, 255, 256, 383, 384, 512] heads = [1, 3, 4, 16] head_sizes = [8, 16, 32, 40, 64, 80, 96, 128, 160, 192, 224, 256] device, dtype, formats = get_provider_support_info(provider, False) if comprehensive: + sequence_lengths = [*sequence_lengths, 2048] # Large sequence length is slow and need a lot of memory for batch_size in batch_sizes: for sequence_length in sequence_lengths: for num_heads in heads: for head_size in head_sizes: for format in formats: for causal in [True, False]: - config = MultiHeadAttentionConfig( - batch_size=batch_size, - sequence_length=sequence_length, - num_heads=num_heads, - head_size=head_size, - causal=causal, - past_sequence_length=0, - kv_sequence_length=sequence_length, - max_cache_sequence_length=None, - provider=provider, - device=device, - dtype=dtype, - use_kv_cache=False, - share_past_present_buffer=False, - input_format=format, - ) - yield config + for has_bias in get_bias_support(format): + config = MultiHeadAttentionConfig( + batch_size=batch_size, + sequence_length=sequence_length, + num_heads=num_heads, + head_size=head_size, + causal=causal, + past_sequence_length=0, + kv_sequence_length=sequence_length, + max_cache_sequence_length=None, + provider=provider, + device=device, + dtype=dtype, + use_kv_cache=False, + share_past_present_buffer=False, + input_format=format, + has_bias=has_bias, + ) + yield config else: test_cases = max(len(batch_sizes), len(sequence_lengths), len(heads), len(head_sizes)) for i in range(test_cases): @@ -179,25 +205,27 @@ def no_kv_cache_test_cases(provider: str, comprehensive: bool): sequence_length = sequence_lengths[i % len(sequence_lengths)] num_heads = heads[i % len(heads)] head_size = head_sizes[i % len(head_sizes)] - format = formats[i % len(formats)] for causal in [True, False]: - config = MultiHeadAttentionConfig( - batch_size=batch_size, - sequence_length=sequence_length, - num_heads=num_heads, - head_size=head_size, - causal=causal, - past_sequence_length=0, - kv_sequence_length=sequence_length, - max_cache_sequence_length=None, - provider=provider, - device=device, - dtype=dtype, - use_kv_cache=False, - share_past_present_buffer=False, - input_format=format, - ) - yield config + for format in formats: + for has_bias in get_bias_support(format): + config = MultiHeadAttentionConfig( + batch_size=batch_size, + sequence_length=sequence_length, + num_heads=num_heads, + head_size=head_size, + causal=causal, + past_sequence_length=0, + kv_sequence_length=sequence_length, + max_cache_sequence_length=None, + provider=provider, + device=device, + dtype=dtype, + use_kv_cache=False, + share_past_present_buffer=False, + input_format=format, + has_bias=has_bias, + ) + yield config def kv_cache_test_cases(provider: str, comprehensive: bool): @@ -206,37 +234,42 @@ def kv_cache_test_cases(provider: str, comprehensive: bool): yield batch_sizes = [1, 2, 3] - sequence_lengths = [1, 15, 16, 255, 256, 2048] + sequence_lengths = [1, 15, 16, 255, 256, 512] heads = [1, 3, 4, 16] head_sizes = [8, 16, 32, 40, 64, 80, 96, 128, 160, 192, 224, 256] - - sequence_length = 1 device, dtype, formats = get_provider_support_info(provider, True) if comprehensive: + sequence_lengths = [*sequence_lengths, 2048] # Large sequence length is slow and need a lot of memory for batch_size in batch_sizes: for past_sequence_length in sequence_lengths: for num_heads in heads: for head_size in head_sizes: for format in formats: for causal in [True, False]: - config = MultiHeadAttentionConfig( - batch_size=batch_size, - sequence_length=sequence_length, - num_heads=num_heads, - head_size=head_size, - causal=causal, - past_sequence_length=past_sequence_length, - kv_sequence_length=sequence_length, - max_cache_sequence_length=None, - provider=provider, - device=device, - dtype=dtype, - use_kv_cache=True, - share_past_present_buffer=False, - input_format=format, - ) - yield config + for has_past_input in [True, False]: + for has_bias in get_bias_support(format): + sequence_length = 1 if has_past_input else past_sequence_length + past_seq_len = past_sequence_length if has_past_input else 0 + config = MultiHeadAttentionConfig( + batch_size=batch_size, + sequence_length=sequence_length, + num_heads=num_heads, + head_size=head_size, + causal=causal, + past_sequence_length=past_seq_len, + kv_sequence_length=sequence_length, + max_cache_sequence_length=None, + provider=provider, + device=device, + dtype=dtype, + use_kv_cache=True, + has_past_input=has_past_input, + share_past_present_buffer=False, + input_format=format, + has_bias=has_bias, + ) + yield config else: test_cases = max(len(batch_sizes), len(sequence_lengths), len(heads), len(head_sizes)) for i in range(test_cases): @@ -244,31 +277,31 @@ def kv_cache_test_cases(provider: str, comprehensive: bool): past_sequence_length = sequence_lengths[i % len(sequence_lengths)] num_heads = heads[i % len(heads)] head_size = head_sizes[i % len(head_sizes)] - format = formats[i % len(formats)] for causal in [True, False]: - config = MultiHeadAttentionConfig( - batch_size=batch_size, - sequence_length=sequence_length, - num_heads=num_heads, - head_size=head_size, - causal=causal, - past_sequence_length=past_sequence_length, - kv_sequence_length=sequence_length, - max_cache_sequence_length=None, - provider=provider, - device=device, - dtype=dtype, - use_kv_cache=True, - share_past_present_buffer=False, - input_format=format, - ) - yield config - - -def mha_test_cases(provider: str, comprehensive: bool): - return itertools.chain( - no_kv_cache_test_cases(provider, comprehensive), kv_cache_test_cases(provider, comprehensive) - ) + for format in formats: + for has_past_input in [True, False]: + for has_bias in get_bias_support(format): + sequence_length = 1 if has_past_input else past_sequence_length + past_seq_len = past_sequence_length if has_past_input else 0 + config = MultiHeadAttentionConfig( + batch_size=batch_size, + sequence_length=sequence_length, + num_heads=num_heads, + head_size=head_size, + causal=causal, + past_sequence_length=past_seq_len, + kv_sequence_length=sequence_length, + max_cache_sequence_length=None, + provider=provider, + device=device, + dtype=dtype, + use_kv_cache=True, + has_past_input=has_past_input, + share_past_present_buffer=False, + input_format=format, + has_bias=has_bias, + ) + yield config def no_kv_cache_multi_thread_test_cases(provider: str, comprehensive: bool): @@ -343,6 +376,7 @@ def kv_cache_multi_thread_test_cases(provider: str, comprehensive: bool): device=device, dtype=dtype, use_kv_cache=True, + has_past_input=True, share_past_present_buffer=False, input_format=format, ) @@ -350,13 +384,6 @@ def kv_cache_multi_thread_test_cases(provider: str, comprehensive: bool): yield configs -def multi_thread_test_cases(provider: str, comprehensive: bool): - return itertools.chain( - no_kv_cache_multi_thread_test_cases(provider, comprehensive), - kv_cache_multi_thread_test_cases(provider, comprehensive), - ) - - def causal_mask(seqlen_q, seqlen_k, query_padding_mask=None, key_padding_mask=None, device=None): row_idx = rearrange(torch.arange(seqlen_q, device=device, dtype=torch.long), "s -> s 1") col_idx = torch.arange(seqlen_k, device=device, dtype=torch.long) @@ -374,28 +401,31 @@ def parity_check_mha( if config.causal and config.provider == "CUDAExecutionProvider": return - ort_mha = OrtMultiHeadAttention(config) + ort_mha = OrtMultiHeadAttention(config, use_tf32=False) ort_outputs = ort_mha.infer() out = ort_outputs["output"] out = torch.reshape(out, (config.batch_size, config.sequence_length, config.num_heads, config.head_size)) + no_bias_k_v = config.input_format == InputFormats.Q_K_V_BSNH_BNSH_BNSH config.input_format = InputFormats.Q_K_V_BSNH_BSNH_BSNH - ref_inputs = config.random_inputs() - q = ( - ref_inputs["query"] - .reshape((config.batch_size, config.sequence_length, config.num_heads, config.head_size)) - .transpose(1, 2) - ) - k = ( - ref_inputs["key"] - .reshape((config.batch_size, config.kv_sequence_length, config.num_heads, config.head_size)) - .transpose(1, 2) - ) - v = ( - ref_inputs["value"] - .reshape((config.batch_size, config.kv_sequence_length, config.num_heads, config.head_size)) - .transpose(1, 2) - ) + ref_inputs = config.random_inputs(no_bias_k_v=no_bias_k_v) + q = ref_inputs["query"].reshape((config.batch_size, config.sequence_length, config.num_heads, config.head_size)) + k = ref_inputs["key"].reshape((config.batch_size, config.kv_sequence_length, config.num_heads, config.head_size)) + v = ref_inputs["value"].reshape((config.batch_size, config.kv_sequence_length, config.num_heads, config.head_size)) + + if "bias" in ref_inputs: + bias = ref_inputs["bias"] + bias = bias.reshape((3, config.num_heads, config.head_size)) + bias_q = bias[0, :, :].reshape(1, 1, config.num_heads, config.head_size) + bias_k = bias[1, :, :].reshape(1, 1, config.num_heads, config.head_size) + bias_v = bias[2, :, :].reshape(1, 1, config.num_heads, config.head_size) + q = q + bias_q + k = k + bias_k + v = v + bias_v + + q = q.transpose(1, 2) + k = k.transpose(1, 2) + v = v.transpose(1, 2) mask = None if config.causal: @@ -404,8 +434,8 @@ def parity_check_mha( k_cache = None v_cache = None if config.use_kv_cache: - past_k = ref_inputs["past_key"] - past_v = ref_inputs["past_value"] + past_k = ref_inputs.get("past_key", None) + past_v = ref_inputs.get("past_value", None) out_ref, k_cache, v_cache = mha_with_past_reference(config, past_k, past_v, q, k, v, mask=mask) else: out_ref = attention_reference(config.head_size, q, k, v, mask=mask) @@ -445,7 +475,7 @@ def parity_check_mha_multi_threading( test_inputs: List[Dict], rtol: float = 1e-3, atol: float = 1e-3, - attention_kernel: int = SdpaKernel.DEFAULT, + attention_kernel=SdpaKernel.DEFAULT, max_threads: int = 5, verbose: bool = False, ): @@ -454,6 +484,7 @@ def parity_check_mha_multi_threading( # For now, MHA CUDA kernel does not support causal so skip such test cases. if config.causal and config.provider == "CUDAExecutionProvider": return None + # Some kernel does not support certain input format. if attention_kernel not in [ SdpaKernel.DEFAULT, @@ -462,7 +493,7 @@ def parity_check_mha_multi_threading( ] and config.input_format in [InputFormats.Q_KV_BSNH_BSN2H]: return None - ort_session = create_ort_session(config, attention_kernel=attention_kernel, use_symbolic_shape=True) + ort_session = create_ort_session(config, attention_kernel=attention_kernel, use_symbolic_shape=True, use_tf32=False) def convert_to_ort_inputs(feed_dict): ort_inputs = {} @@ -572,18 +603,32 @@ def check_parity_with_config(i: int): return None -# Do not run too many tests in CI pipeline. Change it to True to run all combinations in dev machine. +def mha_test_cases(provider: str, comprehensive: bool): + return itertools.chain( + no_kv_cache_test_cases(provider, comprehensive), + kv_cache_test_cases(provider, comprehensive), + ) + + +def multi_thread_test_cases(provider: str, comprehensive: bool): + return itertools.chain( + no_kv_cache_multi_thread_test_cases(provider, comprehensive), + kv_cache_multi_thread_test_cases(provider, comprehensive), + ) + + +# Off by default so that we do not run too many tests in CI pipeline. comprehensive_mode = False class TestMultiHeadAttention(unittest.TestCase): @parameterized.expand(mha_test_cases("CUDAExecutionProvider", comprehensive_mode), skip_on_empty=True) def test_mha_cuda(self, config): - parity_check_mha(config) + parity_check_mha(config, rtol=5e-3, atol=5e-3) @parameterized.expand(mha_test_cases("CPUExecutionProvider", comprehensive_mode), skip_on_empty=True) def test_mha_cpu(self, config): - parity_check_mha(config) + parity_check_mha(config, rtol=5e-3, atol=5e-3) def run_mha_cuda_multi_threading(self, attention_kernel): for configs in multi_thread_test_cases("CUDAExecutionProvider", comprehensive_mode): @@ -604,19 +649,24 @@ def run_mha_cuda_multi_threading(self, attention_kernel): assert exception is None, f"{attention_kernel=}, {vars(configs[0])}, {exception}" def test_mha_cuda_multi_threading(self): - self.run_mha_cuda_multi_threading(SdpaKernel.DEFAULT) + if get_compute_capability() >= 60: + self.run_mha_cuda_multi_threading(SdpaKernel.DEFAULT) def test_mha_cuda_multi_threading_efficient(self): - self.run_mha_cuda_multi_threading(SdpaKernel.EFFICIENT_ATTENTION) + if comprehensive_mode and get_compute_capability() >= 60: + self.run_mha_cuda_multi_threading(SdpaKernel.EFFICIENT_ATTENTION) + + def test_mha_cuda_multi_threading_math(self): + if comprehensive_mode and get_compute_capability() >= 60: + self.run_mha_cuda_multi_threading(SdpaKernel.MATH) def test_mha_cuda_multi_threading_trt(self): - sm = get_compute_capability() - if sm in [75, 80, 86, 89]: + if get_compute_capability() in [75, 80, 86, 89]: self.run_mha_cuda_multi_threading( SdpaKernel.TRT_FUSED_ATTENTION | SdpaKernel.TRT_FLASH_ATTENTION - | SdpaKernel.TRT_CROSS_ATTENTION | SdpaKernel.TRT_CAUSAL_ATTENTION + | SdpaKernel.TRT_CROSS_ATTENTION )