Skip to content

Commit

Permalink
Disable the FA backend for SDPA on AMD GPUs (huggingface#30850)
Browse files Browse the repository at this point in the history
* disable fa

* disable fa

* update warning

* update warning
  • Loading branch information
mht-sharma authored May 16, 2024
1 parent 9d889f8 commit 0753134
Showing 1 changed file with 10 additions and 0 deletions.
10 changes: 10 additions & 0 deletions src/transformers/modeling_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -1479,6 +1479,16 @@ def _autoset_attn_implementation(
config,
hard_check_only=False if requested_attn_implementation is None else True,
)

if (
torch.version.hip is not None
and config._attn_implementation == "sdpa"
and torch.cuda.device_count() > 1
):
logger.warning_once(
"Using the `SDPA` attention implementation on multi-gpu setup with ROCM may lead to performance issues due to the FA backend. Disabling it to use alternative backends."
)
torch.backends.cuda.enable_flash_sdp(False)
else:
config._attn_implementation = "eager"

Expand Down

0 comments on commit 0753134

Please sign in to comment.