Skip to content

Commit

Permalink
[CUDA] Fix MultiHeadAttention thread safe and bias support (#21498)
Browse files Browse the repository at this point in the history
### Description

#### Issues Fixed
(1) **TRT cross attention not thread safe**. [Core changes like
this](6fd7aba)
are used to make it thread-safe:
* Add an once_flag to CumulatedSequenceLengthCache to make sure it is
only initialized once; and change the cache to be read only after
initialization. Previously, the content is not read-only so it might be
changed by other thread and potentially cause buffer overrun.
* The kernel initialization is not guarded (Although the factory of
kernel loading has static mutex to guard multiple threading), so the
mutable variable might be set by two different threads at the same time.
Add an once_flag to avoid that.

This requires need some workspace computation change as well. So I did
not create a separated pull request.

(2) **Bias for cross attention**

That scenario has assumption that only query has bias, but not for key
and value. However, such assumption is not verified in runtime and there
was no comment of assumption, and there was no test case so the support
of scenario was disabled by mistake. Actually, the scenario is used in
whisper model (TODO: we shall add tests for whisper to CI pipeline, and
also update fusion script to verify such assumptions if needed.)

CUDA/CPU kernels supports bias for cross attention as long as bias is
zero for key and value. I updated the check to support the scenario and
added comments wherever there is such assumption.

(3) **Fallback support**

Previously, unfused kernel did not support packed qkv and packed kv
formats. That means some case might fail since there is no fallback. I
added new AddBiasTranpose cuda kernels for them to support fallback, so
that all supported cases will not fail.

#### Improvements

(4) **QKV workspace size**.

The logic for no_qkv_workspace could be easily out of sync since related
code are scattered in different source files. I refactor the code to
move all related code to one file (attention_prepare_qkv.cu) and add
asserts, so that the logic can be in sync.

(5) **Remove confusing concept of pass past in kv**

parameters.pass_past_in_kv is confusing since the k/v in cross attention
is not past state. Remove it and use parameters.qkv_format ==
Q_K_V_BSNH_BNSH_BNSH instead.

New code does not use past_key/past_value for cross attention, so the
logic is more clear.

(6) **More coverage and less workspace and less transpose of flash and
efficient attention**
Previously, there is one condition does not run flash or efficient
attention:
```
 bool past_no_bias = (pass_key_value_as_past || past_key != nullptr || present_key != nullptr) && bias == nullptr;
```
After this change, we can use flash and efficient attention for the
case, and also less workspace.

For example, cross attention with bias, the original code uses two
additional workspaces:
```
  transpose: past_key (BxNxSxH) => temp_k_workspace (BxSxNxH), past_value (BxNxSxH_v) => temp_v_workspace (BxSxNxH_v)
  add bias: query => q,   temp_k_workspace => k,   temp_v_workspace => v
```

New logic is like
```
   if (has bias)
      Add bias to query, key, value, and store in q, k, v workspace
   else
      Use query, key and value directly as q, k and v in kernel
```

We can see that, we do not need allocate temp_k_workspace and
temp_v_workspace so use less memory. New code saved two transposes in
this case.

Flash and efficient attention supports BSNH or BNSH formats for k and v.
In old code, k/v are also converted to BSNH format. Some is not
necessary. I do some change to convert k/v to BSNH or BNSH case by case.
So that there are more cases can be covered by flash or efficient
attention to improve performance.

(6) **Debugging support**
Previously, there is less debug info. In this change, I add a flag for
debug info in the AttentionData. So that we can output debug info during
the processing.

Also add functions to consolidate the dumping of inputs, QKV processing
and outputs; Add an environment variable `ORT_ENABLE_GPU_DUMP` to allow
disable dumping from cuda kernel.

#### Summary of changes
(1) Refactoring the CheckInputs, and pass in operator type.
(2) Refactoring the PrepareQKV to support fallback for packed qkv or
packed kv inputs.
(3) Change a few case of PrepareQKV to allow more case covered by flash
and efficient attention.
(4) use parameters.qkv_format == Q_K_V_BSNH_BNSH_BNSH to replace
parameters.pass_past_in_kv
(5) Allow bias input for Q_K_V_BSNH_BNSH_BNSH, and add comments of
assumption that key/value has no bias in this case.
(6) Fix thread-safe issue in CumulatedSequenceLengthCache handling.
(7) Add test cases to cover all supported scenarios.

Current support scenarios for MultiHeadAttention for CUDA/CPU:

| Q | K | V | pastK| pastV | presentK| presentV | Bias | Op desc
| ---- | ---- | ---- | ------ | ----- | --------- | -------- |
-----|---------
| BSNH | BLNH| BLNH| - | - | - | - | QKV | not packed
| BLN3H| - | - | - | - | - | - | QKV | qkv packed <br> not support in
CPU
| BSNH | BLN2H| - | - | - | - | - | --- | kv packed <br> not support in
CPU
| BSNH | BNLH| BNLH| - | - | - | - | Q-- | cross attention <br> bias for
Q only
| BSNH | BLNH | BLNH | - | - | BNTH | BNTH | QKV | no past <br> only
present
| BSNH | BLNH | BLNH | BNPH | BNPH | BNTH | BNTH | QKV | past and
present <br> (not share buffer)

### Motivation and Context
<!-- - Why is this change required? What problem does it solve?
- If it fixes an open issue, please link to the issue here. -->
#18854
  • Loading branch information
tianleiwu authored and prathikr committed Aug 5, 2024
1 parent 188e453 commit 7cea166
Show file tree
Hide file tree
Showing 28 changed files with 1,729 additions and 1,057 deletions.
1 change: 0 additions & 1 deletion onnxruntime/contrib_ops/cpu/bert/attention_base.cc
Original file line number Diff line number Diff line change
Expand Up @@ -258,7 +258,6 @@ Status AttentionBase::CheckInputs(const TensorShape& input_shape,
output_parameters->scale = scale_;
output_parameters->mask_type = mask_type;
output_parameters->broadcast_res_pos_bias = broadcast_res_pos_bias;
output_parameters->pass_past_in_kv = false;
output_parameters->qkv_format = Q_K_V_BNSH;
}

Expand Down
13 changes: 10 additions & 3 deletions onnxruntime/contrib_ops/cpu/bert/attention_common.h
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,12 @@
namespace onnxruntime {
namespace contrib {

enum AttentionType {
kAttention,
kMultiHeadAttention,
kDecoderMaskedMultiHeadAttention,
};

enum AttentionMaskType {
MASK_NONE, // No mask
MASK_1D_KEY_SEQ_LEN, // [batch_size], key sequence length
Expand All @@ -24,10 +30,12 @@ enum AttentionQkvFormat {
UNKNOWN, // enum value not set, or depends on qkv projection implementation details
Q_K_V_BNSH, // for non-packed qkv, permuted
Q_K_V_BSNH, // for non-packed qkv, not permuted, used by memory efficient attention or MultiHeadAttention
QKV_BSN3H, // for TRT fused attention, qkv are packed
Q_K_V_BSNH_BNSH_BNSH, // for cross attention, k and v are permuted
Q_K_V_BNSH_QKV_BS3NH, // for TRT fused causal attention, data has two formats (qkv is 3BNSH, gemm_buffer is BS3NH)
Q_KV_BSNH_BSN2H, // for TRT fused cross attention, kv are packed
Q_K_V_TNH, // for memory efficient attention, qkv are not packed, and paddings are removed.
Q_KV_BSNH_BSN2H, // for TRT fused cross attention, kv are packed
QKV_BSN3H, // for TRT fused attention, qkv are packed
QKV_BS3NH, // for DecoderMaskedMultiHeadAttention, qkv are packed
QKV_TN3H, // for TRT fused attention, qkv are packed and paddings are removed
};

Expand Down Expand Up @@ -61,7 +69,6 @@ struct AttentionParameters {
bool past_present_share_buffer;
bool do_rotary;
bool broadcast_res_pos_bias;
bool pass_past_in_kv;
float mask_filter_value;
float scale;
bool use_tf32;
Expand Down
15 changes: 4 additions & 11 deletions onnxruntime/contrib_ops/cpu/bert/multihead_attention.cc
Original file line number Diff line number Diff line change
Expand Up @@ -85,7 +85,7 @@ Status MultiHeadAttention<T>::Compute(OpKernelContext* context) const {
scale_,
is_unidirectional_,
past_present_share_buffer,
false));
kMultiHeadAttention));

const int batch_size = parameters.batch_size;
const int q_sequence_length = parameters.sequence_length;
Expand Down Expand Up @@ -121,20 +121,13 @@ Status MultiHeadAttention<T>::Compute(OpKernelContext* context) const {
AllocatorPtr allocator;
ORT_RETURN_IF_ERROR(context->GetTempSpaceAllocator(&allocator));

// For each of Q/K/V, there are multiple scenarios:
// 1) Combined QKV bias is null
// a) Q/K/V is (B, S, D)
// b) Q/K/V is (B, S, N, H)
// 2) No packed QKV in Q
// a) Q/K/V has seq_len = 1
// b) Q/K/V has seq_len > 1

OrtValue Q;
ORT_RETURN_IF_ERROR(MaybeTransposeToBNSHAndAddBias<T>(
context, allocator, batch_size, num_heads_, q_sequence_length, qk_head_size, query, bias, q_bias_offset, Q));

if (parameters.pass_past_in_kv) { // key and value in BNSH format
assert(bias == nullptr);
if (parameters.qkv_format == Q_K_V_BSNH_BNSH_BNSH) {
// For cross attention with k and v in BNSH format, we assume that bias for key and value are zeros.
// So we don't need to add bias for key and value here.
assert(past_key == nullptr);
assert(past_value == nullptr);
return ApplyAttention(Q.GetMutable<Tensor>()->MutableData<T>(),
Expand Down
Loading

0 comments on commit 7cea166

Please sign in to comment.