Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Support mixed-language batches in WhisperGenerationMixin #29688

Merged
merged 15 commits into from
May 15, 2024

Conversation

cifkao
Copy link
Contributor

@cifkao cifkao commented Mar 16, 2024

What does this PR do?

Fixes #29685

Before submitting

  • This PR fixes a typo or improves the docs (you can dismiss the other checks if that's the case).
  • Did you read the contributor guideline, Pull Request section?
  • Was this discussed/approved via a Github issue or the forum? Please add a link to it if that's the case.
  • Did you make sure to update the documentation with your changes? Here are the documentation guidelines, and here are tips on formatting docstrings.
  • Did you write any new necessary tests?

Who can review?

@ylacombe @sanchit-gandhi

Copy link
Contributor

@ylacombe ylacombe left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This looks great to me! Thanks for the great work here!

I've left a few comments here and there. Could you also make sure that all slow tests still pass ?

Also requesting a review from @sanchit-gandhi as the changes are pretty heavy!

tests/models/whisper/test_modeling_whisper.py Show resolved Hide resolved
)
else:
language = [language] * batch_size
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is a nit, but I wonder if you could keep a list of length 1 here and repeat init_tokens batch_size times at the end of the method instead, WDYT ?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Sounds like a good idea to me.

@ylacombe ylacombe requested a review from sanchit-gandhi March 18, 2024 11:43
@HuggingFaceDocBuilderDev

The docs for this PR live here. All of your documentation changes will be reflected on that endpoint. The docs are available until 30 days after the last update.

@cifkao
Copy link
Contributor Author

cifkao commented Mar 18, 2024

@ylacombe I ran the slow tests in test_modeling_whisper. I get some failing tests, but the same ones fail on main as well, with the same output:

========================================================================================================================== short test summary info ===========================================================================================================================
FAILED tests/models/whisper/test_modeling_whisper.py::WhisperModelIntegrationTests::test_whisper_longform_multi_batch - assert [' While Porashaggy sits there, a cooing dove. He has gone, and gone for good," answered Polychrom, who had managed to squeeze into the room beside the dragon, and had witnessed the occurrences with much interest. I have remained a prisoner only because I wi...
FAILED tests/models/whisper/test_modeling_whisper.py::WhisperModelIntegrationTests::test_whisper_longform_multi_batch_prev_cond - assert [" While poor shaggy sits there, a cooling dove. He has gone and gone for good, answered polychrome, who had managed to squeeze into the room beside the dragon, and had witnessed the occurrences with much interest. I have remained a prisoner only because I w...
FAILED tests/models/whisper/test_modeling_whisper.py::WhisperModelIntegrationTests::test_whisper_longform_prompt_ids - AssertionError: assert 'quilter' in 'because you were sleeping instead of conquering the lovely rose princess has become a fiddle without a bow while poor shaggy sits there a cooing dove'
FAILED tests/models/whisper/test_modeling_whisper.py::WhisperModelIntegrationTests::test_whisper_longform_single_batch - assert [" Because you were sleeping instead of conquering, the lovely rose princess has become a fiddle without a bow, all poor ashaggy sits there, accooing dove. He has gone and gone for good. Antered polychrome who had managed to squeeze into the room beside the ...
FAILED tests/models/whisper/test_modeling_whisper.py::WhisperModelIntegrationTests::test_whisper_longform_single_batch_prev_cond - assert [" Because you were sleeping instead of conquering, the lovely rose princess has become a fiddle without a bow, all poor ashaggy sits there, accooing dove. He has gone and gone for good. Antered polychrome who had managed to squeeze into the room beside the ...
==================================================================================================== 5 failed, 292 passed, 106 skipped, 120 warnings in 249.92s (0:04:09) ====================================================================================================

if task in TASK_IDS:
init_tokens.append(generation_config.task_to_id[generation_config.task])
task_id = generation_config.task_to_id[generation_config.task]
for i in range(len(init_tokens)):

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

init_tokens is only 1 long in the case of language=None at this point (because when you do the deepcopy on init_tokens, it assumes that a list of languages has been passsed in). We also need to set language for a later check in the function, since the dim of language is used in a check

To fix this, we can do:

id_to_lang = {v: k for k, v in generation_config.lang_to_id.items()}
language = [id_to_lang[lang_id.item()] for lang_id in lang_ids]
init_tokens = [copy.deepcopy(init_tokens[0]) for _ in lang_ids]

(the init tokens is a bit janky above, it would be cleaner if we didn't do the deepcopy above since then it is a list of lists).

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

you can validate the current failure with the test of multiple languages, but passing None as the language

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

True! But it's becoming a bit convoluted then. It's probably cleaner to revert 0b424f7 so that init_tokens and language are always of length batch_size. @ylacombe?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Not sure to follow the issue here, could you expand on this ?

Copy link
Contributor Author

@cifkao cifkao Apr 1, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@ylacombe 0b424f7 was based on your suggestion, but it turns out it broke language detection. So maybe we can revert that unless you think it would introduce significant overhead.

Another option would be to do something like this:

        if isinstance(language, (list, tuple)):
            ...
        elif language is None:
            # Language will be detected for each item in batch
            language = [None] * batch_size
        else:
            language = [language]  # Use a length-1 list now, broadcast later

        # Separate init_tokens for each language
        init_tokens = [copy.deepcopy(init_tokens) for _ in language]

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actually @naveen-corpusant, I don't think you're right on your second point. That later check is not checking the detected language, it's there to check if language was passed to generate() and if so, set the task to transcription. I'll introduce a variable languages in order to keep the original language parameter intact.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I went with my suggestion above, so language will be a length-1 list if a single language was passed, otherwise it will be a length-batch_size list of either languages or Nones (if the languages need to be detected).

@ylacombe
Copy link
Contributor

ylacombe commented Apr 1, 2024

@ylacombe I ran the slow tests in test_modeling_whisper. I get some failing tests, but the same ones fail on main as well, with the same output:

As long as this PR doesn't introduce new failing tests, this should be okay!

Copy link
Contributor

@sanchit-gandhi sanchit-gandhi left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thank you for your contribution @cifkao! The PR looks nearly ready to go, with just some small refactor suggestions below. Since this PR touches some core generation code in Whisper, could you please run the slow tests for the model and confirm that they pass before merging? No worries if you don't have the compute to do this, we can also run it on a machine and copy back the results (although it'll certainly be faster for you to iterate if you have a machine you can use!). Also cc @kamilakesbi for visibility

@@ -1067,7 +1068,10 @@ def _set_language_and_task(language, task, is_multilingual, generation_config):
"to `generate`. Either set the language using the `forced_decoder_ids` in the model config, "
"or update the generation config as per the instructions https://github.com/huggingface/transformers/issues/25084#issuecomment-1664398224"
)
language = language.lower()
if isinstance(language, str):
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I would be in favour of moving this normalisation of the language argument to the _retrieve_init_tokens method. It can be executed just before we map the language string argument to the langauge id token:

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@sanchit-gandhi But this is already the _retrieve_init_tokens function. But yes, I'll move it here to normalize the list of languages once it's already been ensured that it's a list.

src/transformers/models/whisper/generation_whisper.py Outdated Show resolved Hide resolved
src/transformers/models/whisper/generation_whisper.py Outdated Show resolved Hide resolved
# Separate init_tokens for each language
init_tokens = [copy.deepcopy(init_tokens) for _ in language]

# from v4.39 the forced decoder ids are always None in favour of decoder input ids
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ah ok it's added back here. Could we move it back up so it sits next to the forced decoder id logic?

language = [language] # Use a length-1 list now, broadcast later

# Separate init_tokens for each language
init_tokens = [copy.deepcopy(init_tokens) for _ in language]
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is there a reason we need a deepcopy here?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Not really, will change it to copy then.

if task in TASK_IDS:
init_tokens.append(generation_config.task_to_id[generation_config.task])
task_id = generation_config.task_to_id[generation_config.task]
for i in range(len(init_tokens)):
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The rest code assumes that we're using the same init tokens across the entire batch, but in the proposed changes we're looping over each batch item and computing the init tokens for each. This is redundant since the init tokens will be the same for each element of the batch. In practice, we should only need to compute them once!

Instead of doing this looping, can we keep the existing code and compute the init tokens just once? We can then copy them for each element in the batch, in a similar way to how you did previously:

        # Separate init_tokens for each language
        init_tokens = [copy.deepcopy(init_tokens) for _ in language]

Copy link
Contributor Author

@cifkao cifkao Apr 5, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actually, init_tokens will either be length 1 (if a single language was passed), or have possibly different values for each element in the batch (if a batch of languages was passed or the languages were detected). So there is no redundancy.

One could maybe handle the task tokens first and only then expand init_tokens to the size of the batch (if needed), but with the way it's written, this is not easily possible (because the language token comes before the task token).

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Sure - but the rest of the init tokens (start token id, task and timestamps) are fixed across the batch? So we only need to do a batched computation for the language token id (multiple ids), and for the rest we only need do it once for a single item, and then copy it for the rest of the batch elements?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The current version of the code makes it clear why we need to do this - happy with how it is! #29688 (comment)


# Both languages in the same batch
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Nice test! Would you mind defining a new test function for this test and moving the new checks into here? (e.g. test_large_generation_batched_languages) Keeping the tests small and modular helps identify which parts of the code are potentially broken with new PRs

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Done!

Copy link
Contributor Author

@cifkao cifkao left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@ylacombe @sanchit-gandhi The code was quite hard to follow, so I refactored it a bit.

I re-ran the slow tests and got the same output – a few failing tests, but last time I checked, they were failing on main as well:

========================================================================================================================== short test summary info ===========================================================================================================================
FAILED tests/models/whisper/test_modeling_whisper.py::WhisperModelIntegrationTests::test_whisper_longform_multi_batch - assert [' While Porashaggy sits there, a cooing dove. He has gone, and gone for good," answered Polychrom, who had managed to squeeze into the room beside the dragon, and had witnessed the occurrences with much interest. I have remained a prisoner only because I wi...
FAILED tests/models/whisper/test_modeling_whisper.py::WhisperModelIntegrationTests::test_whisper_longform_multi_batch_prev_cond - assert [" While poor shaggy sits there, a cooling dove. He has gone and gone for good, answered polychrome, who had managed to squeeze into the room beside the dragon, and had witnessed the occurrences with much interest. I have remained a prisoner only because I w...
FAILED tests/models/whisper/test_modeling_whisper.py::WhisperModelIntegrationTests::test_whisper_longform_prompt_ids - AssertionError: assert 'quilter' in 'because you were sleeping instead of conquering the lovely rose princess has become a fiddle without a bow while poor shaggy sits there a cooing dove'
FAILED tests/models/whisper/test_modeling_whisper.py::WhisperModelIntegrationTests::test_whisper_longform_single_batch - assert [" Because you were sleeping instead of conquering, the lovely rose princess has become a fiddle without a bow, all poor ashaggy sits there, accooing dove. He has gone and gone for good. Antered polychrome who had managed to squeeze into the room beside the ...
FAILED tests/models/whisper/test_modeling_whisper.py::WhisperModelIntegrationTests::test_whisper_longform_single_batch_prev_cond - assert [" Because you were sleeping instead of conquering, the lovely rose princess has become a fiddle without a bow, all poor ashaggy sits there, accooing dove. He has gone and gone for good. Antered polychrome who had managed to squeeze into the room beside the ...
==================================================================================================== 5 failed, 293 passed, 106 skipped, 121 warnings in 259.68s (0:04:19) ====================================================================================================


# if language is defined it'll overwrite language ids that might have already been defined via the generation_config
replace_or_add(init_tokens, lang_id, generation_config.lang_to_id.values())
Copy link
Contributor Author

@cifkao cifkao Apr 5, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'm not sure why this function was used here. The language token is always going to be at init_tokens[1], so it's doing the same thing as this bit in the elif branch (only slower):

# append or replace lang_id to init_tokens
if len(init_tokens) > 1:
init_tokens[1] = lang_id
else:
init_tokens.append(lang_id)

So I unified the two branches.


# Both languages in the same batch
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Done!

Copy link
Contributor

@sanchit-gandhi sanchit-gandhi left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Very clean PR - thanks for iterating @cifkao!

del languages

# Update init_tokens with task
for i in range(len(init_tokens)):
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This looks much cleaner and I understand the motivation for looping over the init_tokens!

@sanchit-gandhi
Copy link
Contributor

Failing slow tests will be addressed by #30152

Copy link
Contributor

@kamilakesbi kamilakesbi left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Hi @cifkao,

Thanks for working on this!
LGTM

@ArthurZucker
Copy link
Collaborator

Reviewing in a bit to finish it! Would be nice to rebase as it's been a while! 🤗

Copy link
Collaborator

@ArthurZucker ArthurZucker left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Looks good to me!
This is a nice feature, thanks for adding support for batches!

@ArthurZucker
Copy link
Collaborator

can you merge with main to make sure all the tests pass? 🤗

@cifkao
Copy link
Contributor Author

cifkao commented May 13, 2024

@ArthurZucker Done!

@ArthurZucker ArthurZucker merged commit be3aa43 into huggingface:main May 15, 2024
21 checks passed
@ArthurZucker
Copy link
Collaborator

Congrats for this PR! 🤗

itazap pushed a commit that referenced this pull request May 24, 2024
* Add support for mixing languages in a single batch

* Update docstring

* Enable different detected languages in batch

* Do not require input_features

* Test list of languages

* Fix comment

* Make init_tokens length-1 if possible, broadcast at the end

* Test for ValueError with language list of incorrect length

* Slow test for batched multilingual transcription

* fixup

* Apply suggestions from code review

Co-authored-by: Sanchit Gandhi <93869735+sanchit-gandhi@users.noreply.github.com>

* Address review, refactor

* Second attempt to move this line where it was originally

* Split test, fix a bug

---------

Co-authored-by: Sanchit Gandhi <93869735+sanchit-gandhi@users.noreply.github.com>
@AvivSham
Copy link
Contributor

Hi @ArthurZucker @sanchit-gandhi @cifkao,
Thank you for this PR much needed.
When calling generate:

model.generate(
                batch.input_features,
                language=batch.language,
            )

We observe the following warning:

You have passed language=['en', 'fr'], but also have set `forced_decoder_ids` to [[1, None], [2, 50359]] which creates a conflict. `forced_decoder_ids` will be ignored in favor of language=['en', 'fr'].
The attention mask is not set and cannot be inferred from input because pad token is same as eos token.As a consequence, you may observe unexpected behavior. Please pass your input's `attention_mask` to obtain reliable results.

Can it be ignored? Since we are in generate mode, the pad token does not affect the process.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

Support mixed-language batches in WhisperGenerationMixin
8 participants