Skip to content

Commit

Permalink
Fix seamless TTS generate (#34968)
Browse files Browse the repository at this point in the history
* fix seamless tts generate

* apply same fix for v2

* [run-slow] seamless_m4t, seamless_m4t_v2

* remove TODO

* [run-slow] seamless_m4t, seamless_m4t_v2

* [run-slow] seamless_m4t, seamless_m4t_v2

* ignore failing test on multigpus

* [run-slow] seamless_m4t, seamless_m4t_v2

* [run-slow] seamless_m4t, seamless_m4t_v2
  • Loading branch information
ylacombe authored Dec 11, 2024
1 parent 33c12e4 commit 6181c6b
Show file tree
Hide file tree
Showing 4 changed files with 9 additions and 3 deletions.
2 changes: 2 additions & 0 deletions src/transformers/models/seamless_m4t/modeling_seamless_m4t.py
Original file line number Diff line number Diff line change
Expand Up @@ -293,6 +293,8 @@ def format_speech_generation_kwargs(kwargs):
elif key.startswith("speech_"):
key = key[len("speech_") :]
kwargs_speech[key] = value
elif key == "generation_config":
kwargs_text[key] = value
else:
# If the key is already in a specific config, then it's been set with a
# submodules specific value and we don't override
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -421,6 +421,8 @@ def format_speech_generation_kwargs(kwargs):
elif key.startswith("speech_"):
key = key[len("speech_") :]
kwargs_speech[key] = value
elif key == "generation_config":
kwargs_text[key] = value
else:
# If the key is already in a specific config, then it's been set with a
# submodules specific value and we don't override
Expand Down
5 changes: 5 additions & 0 deletions tests/models/seamless_m4t_v2/test_modeling_seamless_m4t_v2.py
Original file line number Diff line number Diff line change
Expand Up @@ -589,6 +589,11 @@ def test_attention_outputs(self):
[self.model_tester.num_attention_heads, encoder_seq_length, encoder_key_length],
)

# TODO: @ydshieh: refer to #34968
@unittest.skip(reason="Failing on multi-gpu runner")
def test_retain_grad_hidden_states_attentions(self):
pass


@require_torch
class SeamlessM4Tv2ModelWithTextInputTest(ModelTesterMixin, GenerationTesterMixin, unittest.TestCase):
Expand Down
3 changes: 0 additions & 3 deletions tests/pipelines/test_pipelines_text_to_audio.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,6 @@
require_torch,
require_torch_accelerator,
require_torch_or_tf,
run_test_using_subprocess,
slow,
torch_device,
)
Expand Down Expand Up @@ -67,10 +66,8 @@ def test_small_musicgen_pt(self):
audio = [output["audio"] for output in outputs]
self.assertEqual([ANY(np.ndarray), ANY(np.ndarray)], audio)

# TODO: @ylacombe: `SeamlessM4TForTextToSpeech.generate` has issue with `generation_config`. See issue #34811
@slow
@require_torch
@run_test_using_subprocess
def test_medium_seamless_m4t_pt(self):
speech_generator = pipeline(task="text-to-audio", model="facebook/hf-seamless-m4t-medium", framework="pt")

Expand Down

0 comments on commit 6181c6b

Please sign in to comment.