Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[CUDA] Fix MultiHeadAttention thread safe and bias support #21498

Merged
merged 18 commits into from
Jul 31, 2024

Conversation

tianleiwu
Copy link
Contributor

@tianleiwu tianleiwu commented Jul 25, 2024

Description

Issues Fixed

(1) TRT cross attention not thread safe. Core changes like this are used to make it thread-safe:

  • Add an once_flag to CumulatedSequenceLengthCache to make sure it is only initialized once; and change the cache to be read only after initialization. Previously, the content is not read-only so it might be changed by other thread and potentially cause buffer overrun.
  • The kernel initialization is not guarded (Although the factory of kernel loading has static mutex to guard multiple threading), so the mutable variable might be set by two different threads at the same time. Add an once_flag to avoid that.

This requires need some workspace computation change as well. So I did not create a separated pull request.

(2) Bias for cross attention

That scenario has assumption that only query has bias, but not for key and value. However, such assumption is not verified in runtime and there was no comment of assumption, and there was no test case so the support of scenario was disabled by mistake. Actually, the scenario is used in whisper model (TODO: we shall add tests for whisper to CI pipeline, and also update fusion script to verify such assumptions if needed.)

CUDA/CPU kernels supports bias for cross attention as long as bias is zero for key and value. I updated the check to support the scenario and added comments wherever there is such assumption.

(3) Fallback support

Previously, unfused kernel did not support packed qkv and packed kv formats. That means some case might fail since there is no fallback. Example error message are like packed QKV format is not implemented for current GPU. Please disable it in fusion options. or packed KV format is not implemented for current GPU. Please disable packed kv in fusion options..

I added new AddBiasTranpose cuda kernels for them to support fallback, so that all supported cases will not fail.

Improvements

(4) QKV workspace size.

The logic for no_qkv_workspace could be easily out of sync since related code are scattered in different source files. I refactor the code to move all related code to one file (attention_prepare_qkv.cu) and add asserts, so that the logic can be in sync.

(5) Remove confusing concept of pass past in kv

parameters.pass_past_in_kv is confusing since the k/v in cross attention is not past state. Remove it and use parameters.qkv_format == Q_K_V_BSNH_BNSH_BNSH instead.

New code does not use past_key/past_value for cross attention, so the logic is more clear.

(6) More coverage and less workspace and less transpose of flash and efficient attention
Previously, there is one condition does not run flash or efficient attention:

 bool past_no_bias = (pass_key_value_as_past || past_key != nullptr || present_key != nullptr) && bias == nullptr;

After this change, we can use flash and efficient attention for the case, and also less workspace.

For example, cross attention with bias, the original code uses two additional workspaces:

  transpose: past_key (BxNxSxH) => temp_k_workspace (BxSxNxH), past_value (BxNxSxH_v) => temp_v_workspace (BxSxNxH_v)
  add bias: query => q,   temp_k_workspace => k,   temp_v_workspace => v

New logic is like

   if (has bias)
      Add bias to query, key, value, and store in q, k, v workspace
   else
      Use query, key and value directly as q, k and v in kernel

We can see that, we do not need allocate temp_k_workspace and temp_v_workspace so use less memory. New code saved two transposes in this case.

Flash and efficient attention supports BSNH or BNSH formats for k and v. In old code, k/v are also converted to BSNH format. Some is not necessary. I do some change to convert k/v to BSNH or BNSH case by case. So that there are more cases can be covered by flash or efficient attention to improve performance.

(6) Debugging support
Previously, there is less debug info. In this change, I add a flag for debug info in the AttentionData. So that we can output debug info during the processing.

Also add functions to consolidate the dumping of inputs, QKV processing and outputs; Add an environment variable ORT_ENABLE_GPU_DUMP to allow disable dumping from cuda kernel.

Summary of changes

(1) Refactoring the CheckInputs, and pass in operator type.
(2) Refactoring the PrepareQKV to support fallback for packed qkv or packed kv inputs.
(3) Change a few case of PrepareQKV to allow more case covered by flash and efficient attention.
(4) use parameters.qkv_format == Q_K_V_BSNH_BNSH_BNSH to replace parameters.pass_past_in_kv
(5) Allow bias input for Q_K_V_BSNH_BNSH_BNSH, and add comments of assumption that key/value has no bias in this case.
(6) Fix thread-safe issue in CumulatedSequenceLengthCache handling.
(7) Add test cases to cover all supported scenarios.

Current support scenarios for MultiHeadAttention for CUDA/CPU:

Q K V pastK pastV presentK presentV Bias Op desc
BSNH BLNH BLNH - - - - QKV not packed
BLN3H - - - - - - QKV qkv packed
not support in CPU
BSNH BLN2H - - - - - --- kv packed
not support in CPU
BSNH BNLH BNLH - - - - Q-- cross attention
bias for Q only
BSNH BLNH BLNH - - BNTH BNTH QKV no past
only present
BSNH BLNH BLNH BNPH BNPH BNTH BNTH QKV past and present
(not share buffer)

Motivation and Context

#18854

@tianleiwu tianleiwu marked this pull request as draft July 25, 2024 07:22
@tianleiwu tianleiwu marked this pull request as ready for review July 25, 2024 16:51
@tianleiwu tianleiwu marked this pull request as draft July 26, 2024 00:43
@tianleiwu tianleiwu force-pushed the tlwu/fix_dmmha_input_check branch from 98dea5e to 9d34245 Compare July 26, 2024 06:38
@tianleiwu tianleiwu force-pushed the tlwu/fix_dmmha_input_check branch from d726987 to f4b85fe Compare July 26, 2024 07:38
@tianleiwu tianleiwu marked this pull request as ready for review July 26, 2024 19:40
@tianleiwu tianleiwu marked this pull request as draft July 26, 2024 22:10
@tianleiwu tianleiwu force-pushed the tlwu/fix_dmmha_input_check branch from 322f6ec to bcb25bf Compare July 29, 2024 21:15
@tianleiwu tianleiwu changed the title [CUDA] Fix DecoderMaskedMultiHeadAttention bias input check [CUDA] Fix MultiHeadAttention thread safe and bias support Jul 29, 2024
@tianleiwu tianleiwu force-pushed the tlwu/fix_dmmha_input_check branch from bcb25bf to cc97579 Compare July 30, 2024 00:02
@tianleiwu tianleiwu marked this pull request as ready for review July 30, 2024 06:47
@tianleiwu tianleiwu requested review from cloudhan and wangyems July 30, 2024 06:47
wangyems
wangyems previously approved these changes Jul 30, 2024
@tianleiwu tianleiwu merged commit c5f8389 into main Jul 31, 2024
92 of 95 checks passed
@tianleiwu tianleiwu deleted the tlwu/fix_dmmha_input_check branch July 31, 2024 16:01
@tianleiwu tianleiwu added the release:1.19.0 Cherry pick to ORT 1.19 label Aug 1, 2024
prathikr pushed a commit that referenced this pull request Aug 3, 2024
### Description

#### Issues Fixed
(1) **TRT cross attention not thread safe**. [Core changes like
this](6fd7aba)
are used to make it thread-safe:
* Add an once_flag to CumulatedSequenceLengthCache to make sure it is
only initialized once; and change the cache to be read only after
initialization. Previously, the content is not read-only so it might be
changed by other thread and potentially cause buffer overrun.
* The kernel initialization is not guarded (Although the factory of
kernel loading has static mutex to guard multiple threading), so the
mutable variable might be set by two different threads at the same time.
Add an once_flag to avoid that.

This requires need some workspace computation change as well. So I did
not create a separated pull request.

(2) **Bias for cross attention**

That scenario has assumption that only query has bias, but not for key
and value. However, such assumption is not verified in runtime and there
was no comment of assumption, and there was no test case so the support
of scenario was disabled by mistake. Actually, the scenario is used in
whisper model (TODO: we shall add tests for whisper to CI pipeline, and
also update fusion script to verify such assumptions if needed.)

CUDA/CPU kernels supports bias for cross attention as long as bias is
zero for key and value. I updated the check to support the scenario and
added comments wherever there is such assumption.

(3) **Fallback support**

Previously, unfused kernel did not support packed qkv and packed kv
formats. That means some case might fail since there is no fallback. I
added new AddBiasTranpose cuda kernels for them to support fallback, so
that all supported cases will not fail.

#### Improvements

(4) **QKV workspace size**.

The logic for no_qkv_workspace could be easily out of sync since related
code are scattered in different source files. I refactor the code to
move all related code to one file (attention_prepare_qkv.cu) and add
asserts, so that the logic can be in sync.

(5) **Remove confusing concept of pass past in kv**

parameters.pass_past_in_kv is confusing since the k/v in cross attention
is not past state. Remove it and use parameters.qkv_format ==
Q_K_V_BSNH_BNSH_BNSH instead.

New code does not use past_key/past_value for cross attention, so the
logic is more clear.

(6) **More coverage and less workspace and less transpose of flash and
efficient attention**
Previously, there is one condition does not run flash or efficient
attention:
```
 bool past_no_bias = (pass_key_value_as_past || past_key != nullptr || present_key != nullptr) && bias == nullptr;
```
After this change, we can use flash and efficient attention for the
case, and also less workspace.

For example, cross attention with bias, the original code uses two
additional workspaces:
```
  transpose: past_key (BxNxSxH) => temp_k_workspace (BxSxNxH), past_value (BxNxSxH_v) => temp_v_workspace (BxSxNxH_v)
  add bias: query => q,   temp_k_workspace => k,   temp_v_workspace => v
```

New logic is like
```
   if (has bias)
      Add bias to query, key, value, and store in q, k, v workspace
   else
      Use query, key and value directly as q, k and v in kernel
```

We can see that, we do not need allocate temp_k_workspace and
temp_v_workspace so use less memory. New code saved two transposes in
this case.

Flash and efficient attention supports BSNH or BNSH formats for k and v.
In old code, k/v are also converted to BSNH format. Some is not
necessary. I do some change to convert k/v to BSNH or BNSH case by case.
So that there are more cases can be covered by flash or efficient
attention to improve performance.

(6) **Debugging support**
Previously, there is less debug info. In this change, I add a flag for
debug info in the AttentionData. So that we can output debug info during
the processing.

Also add functions to consolidate the dumping of inputs, QKV processing
and outputs; Add an environment variable `ORT_ENABLE_GPU_DUMP` to allow
disable dumping from cuda kernel.

#### Summary of changes
(1) Refactoring the CheckInputs, and pass in operator type.
(2) Refactoring the PrepareQKV to support fallback for packed qkv or
packed kv inputs.
(3) Change a few case of PrepareQKV to allow more case covered by flash
and efficient attention.
(4) use parameters.qkv_format == Q_K_V_BSNH_BNSH_BNSH to replace
parameters.pass_past_in_kv
(5) Allow bias input for Q_K_V_BSNH_BNSH_BNSH, and add comments of
assumption that key/value has no bias in this case.
(6) Fix thread-safe issue in CumulatedSequenceLengthCache handling.
(7) Add test cases to cover all supported scenarios.

Current support scenarios for MultiHeadAttention for CUDA/CPU:

| Q | K | V | pastK| pastV | presentK| presentV | Bias | Op desc
| ---- | ---- | ---- | ------ | ----- | --------- | -------- |
-----|---------
| BSNH | BLNH| BLNH| - | - | - | - | QKV | not packed
| BLN3H| - | - | - | - | - | - | QKV | qkv packed <br> not support in
CPU
| BSNH | BLN2H| - | - | - | - | - | --- | kv packed <br> not support in
CPU
| BSNH | BNLH| BNLH| - | - | - | - | Q-- | cross attention <br> bias for
Q only
| BSNH | BLNH | BLNH | - | - | BNTH | BNTH | QKV | no past <br> only
present
| BSNH | BLNH | BLNH | BNPH | BNPH | BNTH | BNTH | QKV | past and
present <br> (not share buffer)

### Motivation and Context
<!-- - Why is this change required? What problem does it solve?
- If it fixes an open issue, please link to the issue here. -->
#18854
prathikr pushed a commit that referenced this pull request Aug 5, 2024
### Description

#### Issues Fixed
(1) **TRT cross attention not thread safe**. [Core changes like
this](6fd7aba)
are used to make it thread-safe:
* Add an once_flag to CumulatedSequenceLengthCache to make sure it is
only initialized once; and change the cache to be read only after
initialization. Previously, the content is not read-only so it might be
changed by other thread and potentially cause buffer overrun.
* The kernel initialization is not guarded (Although the factory of
kernel loading has static mutex to guard multiple threading), so the
mutable variable might be set by two different threads at the same time.
Add an once_flag to avoid that.

This requires need some workspace computation change as well. So I did
not create a separated pull request.

(2) **Bias for cross attention**

That scenario has assumption that only query has bias, but not for key
and value. However, such assumption is not verified in runtime and there
was no comment of assumption, and there was no test case so the support
of scenario was disabled by mistake. Actually, the scenario is used in
whisper model (TODO: we shall add tests for whisper to CI pipeline, and
also update fusion script to verify such assumptions if needed.)

CUDA/CPU kernels supports bias for cross attention as long as bias is
zero for key and value. I updated the check to support the scenario and
added comments wherever there is such assumption.

(3) **Fallback support**

Previously, unfused kernel did not support packed qkv and packed kv
formats. That means some case might fail since there is no fallback. I
added new AddBiasTranpose cuda kernels for them to support fallback, so
that all supported cases will not fail.

#### Improvements

(4) **QKV workspace size**.

The logic for no_qkv_workspace could be easily out of sync since related
code are scattered in different source files. I refactor the code to
move all related code to one file (attention_prepare_qkv.cu) and add
asserts, so that the logic can be in sync.

(5) **Remove confusing concept of pass past in kv**

parameters.pass_past_in_kv is confusing since the k/v in cross attention
is not past state. Remove it and use parameters.qkv_format ==
Q_K_V_BSNH_BNSH_BNSH instead.

New code does not use past_key/past_value for cross attention, so the
logic is more clear.

(6) **More coverage and less workspace and less transpose of flash and
efficient attention**
Previously, there is one condition does not run flash or efficient
attention:
```
 bool past_no_bias = (pass_key_value_as_past || past_key != nullptr || present_key != nullptr) && bias == nullptr;
```
After this change, we can use flash and efficient attention for the
case, and also less workspace.

For example, cross attention with bias, the original code uses two
additional workspaces:
```
  transpose: past_key (BxNxSxH) => temp_k_workspace (BxSxNxH), past_value (BxNxSxH_v) => temp_v_workspace (BxSxNxH_v)
  add bias: query => q,   temp_k_workspace => k,   temp_v_workspace => v
```

New logic is like
```
   if (has bias)
      Add bias to query, key, value, and store in q, k, v workspace
   else
      Use query, key and value directly as q, k and v in kernel
```

We can see that, we do not need allocate temp_k_workspace and
temp_v_workspace so use less memory. New code saved two transposes in
this case.

Flash and efficient attention supports BSNH or BNSH formats for k and v.
In old code, k/v are also converted to BSNH format. Some is not
necessary. I do some change to convert k/v to BSNH or BNSH case by case.
So that there are more cases can be covered by flash or efficient
attention to improve performance.

(6) **Debugging support**
Previously, there is less debug info. In this change, I add a flag for
debug info in the AttentionData. So that we can output debug info during
the processing.

Also add functions to consolidate the dumping of inputs, QKV processing
and outputs; Add an environment variable `ORT_ENABLE_GPU_DUMP` to allow
disable dumping from cuda kernel.

#### Summary of changes
(1) Refactoring the CheckInputs, and pass in operator type.
(2) Refactoring the PrepareQKV to support fallback for packed qkv or
packed kv inputs.
(3) Change a few case of PrepareQKV to allow more case covered by flash
and efficient attention.
(4) use parameters.qkv_format == Q_K_V_BSNH_BNSH_BNSH to replace
parameters.pass_past_in_kv
(5) Allow bias input for Q_K_V_BSNH_BNSH_BNSH, and add comments of
assumption that key/value has no bias in this case.
(6) Fix thread-safe issue in CumulatedSequenceLengthCache handling.
(7) Add test cases to cover all supported scenarios.

Current support scenarios for MultiHeadAttention for CUDA/CPU:

| Q | K | V | pastK| pastV | presentK| presentV | Bias | Op desc
| ---- | ---- | ---- | ------ | ----- | --------- | -------- |
-----|---------
| BSNH | BLNH| BLNH| - | - | - | - | QKV | not packed
| BLN3H| - | - | - | - | - | - | QKV | qkv packed <br> not support in
CPU
| BSNH | BLN2H| - | - | - | - | - | --- | kv packed <br> not support in
CPU
| BSNH | BNLH| BNLH| - | - | - | - | Q-- | cross attention <br> bias for
Q only
| BSNH | BLNH | BLNH | - | - | BNTH | BNTH | QKV | no past <br> only
present
| BSNH | BLNH | BLNH | BNPH | BNPH | BNTH | BNTH | QKV | past and
present <br> (not share buffer)

### Motivation and Context
<!-- - Why is this change required? What problem does it solve?
- If it fixes an open issue, please link to the issue here. -->
#18854
@prathikr prathikr added the cherry-picked Cherry-picked for a cherrypicks branch label Aug 6, 2024
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
cherry-picked Cherry-picked for a cherrypicks branch release:1.19.0 Cherry pick to ORT 1.19
Projects
None yet
Development

Successfully merging this pull request may close these issues.

5 participants