Skip to content

Commit

Permalink
Fix flaky test_custom_4d_attention_mask (#35606)
Browse files Browse the repository at this point in the history
* fix

* fix

---------

Co-authored-by: ydshieh <ydshieh@users.noreply.github.com>
  • Loading branch information
ydshieh and ydshieh authored Jan 10, 2025
1 parent f63829c commit bbc0004
Show file tree
Hide file tree
Showing 2 changed files with 12 additions and 2 deletions.
10 changes: 8 additions & 2 deletions src/transformers/testing_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -1431,14 +1431,20 @@ def set_model_tester_for_less_flaky_test(test_case):
and target_num_hidden_layers is not None
):
test_case.model_tester.vision_config = copy.deepcopy(test_case.model_tester.vision_config)
test_case.model_tester.vision_config["num_hidden_layers"] = target_num_hidden_layers
if isinstance(test_case.model_tester.vision_config, dict):
test_case.model_tester.vision_config["num_hidden_layers"] = 1
else:
test_case.model_tester.vision_config.num_hidden_layers = 1
if (
hasattr(test_case.model_tester, "text_config")
and "num_hidden_layers" in test_case.model_tester.text_config
and target_num_hidden_layers is not None
):
test_case.model_tester.text_config = copy.deepcopy(test_case.model_tester.text_config)
test_case.model_tester.text_config["num_hidden_layers"] = target_num_hidden_layers
if isinstance(test_case.model_tester.text_config, dict):
test_case.model_tester.text_config["num_hidden_layers"] = 1
else:
test_case.model_tester.text_config.num_hidden_layers = 1

# A few model class specific handling

Expand Down
4 changes: 4 additions & 0 deletions tests/test_modeling_common.py
Original file line number Diff line number Diff line change
Expand Up @@ -4707,13 +4707,17 @@ def test_custom_4d_attention_mask(self):
reason="Model architecture has no generative classes, and thus not necessarily supporting 4D masks"
)

set_model_tester_for_less_flaky_test(self)

for model_class in self.all_generative_model_classes:
if not model_class._supports_static_cache:
self.skipTest(f"{model_class.__name__} is not guaranteed to work with custom 4D attention masks")
config, _ = self.model_tester.prepare_config_and_inputs_for_common()
set_config_for_less_flaky_test(config)
if getattr(config, "sliding_window", 0) is not None and getattr(config, "sliding_window", 0) > 0:
self.skipTest(f"{model_class.__name__} with sliding window attention is not supported by this test")
model = model_class(config).to(device=torch_device, dtype=torch.float32)
set_model_for_less_flaky_test(model)

(
input_ids,
Expand Down

0 comments on commit bbc0004

Please sign in to comment.