From 0753134f4d2c723aa4460f0ed14f668c94d21050 Mon Sep 17 00:00:00 2001 From: Mohit Sharma Date: Thu, 16 May 2024 17:01:14 +0530 Subject: [PATCH] Disable the FA backend for SDPA on AMD GPUs (#30850) * disable fa * disable fa * update warning * update warning --- src/transformers/modeling_utils.py | 10 ++++++++++ 1 file changed, 10 insertions(+) diff --git a/src/transformers/modeling_utils.py b/src/transformers/modeling_utils.py index 599147e6ccb219..106f79ae8e3b58 100755 --- a/src/transformers/modeling_utils.py +++ b/src/transformers/modeling_utils.py @@ -1479,6 +1479,16 @@ def _autoset_attn_implementation( config, hard_check_only=False if requested_attn_implementation is None else True, ) + + if ( + torch.version.hip is not None + and config._attn_implementation == "sdpa" + and torch.cuda.device_count() > 1 + ): + logger.warning_once( + "Using the `SDPA` attention implementation on multi-gpu setup with ROCM may lead to performance issues due to the FA backend. Disabling it to use alternative backends." + ) + torch.backends.cuda.enable_flash_sdp(False) else: config._attn_implementation = "eager"