-
Notifications
You must be signed in to change notification settings - Fork 27.4k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Support mixed-language batches in WhisperGenerationMixin
#29688
Conversation
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This looks great to me! Thanks for the great work here!
I've left a few comments here and there. Could you also make sure that all slow tests still pass ?
Also requesting a review from @sanchit-gandhi as the changes are pretty heavy!
) | ||
else: | ||
language = [language] * batch_size |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This is a nit, but I wonder if you could keep a list of length 1 here and repeat init_tokens
batch_size times at the end of the method instead, WDYT ?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Sounds like a good idea to me.
The docs for this PR live here. All of your documentation changes will be reflected on that endpoint. The docs are available until 30 days after the last update. |
@ylacombe I ran the slow tests in
|
if task in TASK_IDS: | ||
init_tokens.append(generation_config.task_to_id[generation_config.task]) | ||
task_id = generation_config.task_to_id[generation_config.task] | ||
for i in range(len(init_tokens)): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
init_tokens
is only 1 long in the case of language=None
at this point (because when you do the deepcopy on init_tokens, it assumes that a list of languages has been passsed in). We also need to set language for a later check in the function, since the dim of language is used in a check
To fix this, we can do:
id_to_lang = {v: k for k, v in generation_config.lang_to_id.items()}
language = [id_to_lang[lang_id.item()] for lang_id in lang_ids]
init_tokens = [copy.deepcopy(init_tokens[0]) for _ in lang_ids]
(the init tokens is a bit janky above, it would be cleaner if we didn't do the deepcopy above since then it is a list of lists).
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
you can validate the current failure with the test of multiple languages, but passing None
as the language
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Not sure to follow the issue here, could you expand on this ?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@ylacombe 0b424f7 was based on your suggestion, but it turns out it broke language detection. So maybe we can revert that unless you think it would introduce significant overhead.
Another option would be to do something like this:
if isinstance(language, (list, tuple)):
...
elif language is None:
# Language will be detected for each item in batch
language = [None] * batch_size
else:
language = [language] # Use a length-1 list now, broadcast later
# Separate init_tokens for each language
init_tokens = [copy.deepcopy(init_tokens) for _ in language]
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Actually @naveen-corpusant, I don't think you're right on your second point. That later check is not checking the detected language, it's there to check if language was passed to generate()
and if so, set the task to transcription. I'll introduce a variable languages
in order to keep the original language
parameter intact.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I went with my suggestion above, so language
will be a length-1 list if a single language was passed, otherwise it will be a length-batch_size
list of either languages or None
s (if the languages need to be detected).
As long as this PR doesn't introduce new failing tests, this should be okay! |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thank you for your contribution @cifkao! The PR looks nearly ready to go, with just some small refactor suggestions below. Since this PR touches some core generation code in Whisper, could you please run the slow tests for the model and confirm that they pass before merging? No worries if you don't have the compute to do this, we can also run it on a machine and copy back the results (although it'll certainly be faster for you to iterate if you have a machine you can use!). Also cc @kamilakesbi for visibility
@@ -1067,7 +1068,10 @@ def _set_language_and_task(language, task, is_multilingual, generation_config): | |||
"to `generate`. Either set the language using the `forced_decoder_ids` in the model config, " | |||
"or update the generation config as per the instructions https://github.com/huggingface/transformers/issues/25084#issuecomment-1664398224" | |||
) | |||
language = language.lower() | |||
if isinstance(language, str): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I would be in favour of moving this normalisation of the language
argument to the _retrieve_init_tokens
method. It can be executed just before we map the language
string argument to the langauge id token:
if language is not None: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@sanchit-gandhi But this is already the _retrieve_init_tokens
function. But yes, I'll move it here to normalize the list of languages once it's already been ensured that it's a list.
# Separate init_tokens for each language | ||
init_tokens = [copy.deepcopy(init_tokens) for _ in language] | ||
|
||
# from v4.39 the forced decoder ids are always None in favour of decoder input ids |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Ah ok it's added back here. Could we move it back up so it sits next to the forced decoder id logic?
language = [language] # Use a length-1 list now, broadcast later | ||
|
||
# Separate init_tokens for each language | ||
init_tokens = [copy.deepcopy(init_tokens) for _ in language] |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Is there a reason we need a deepcopy
here?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Not really, will change it to copy
then.
if task in TASK_IDS: | ||
init_tokens.append(generation_config.task_to_id[generation_config.task]) | ||
task_id = generation_config.task_to_id[generation_config.task] | ||
for i in range(len(init_tokens)): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The rest code assumes that we're using the same init tokens across the entire batch, but in the proposed changes we're looping over each batch item and computing the init tokens for each. This is redundant since the init tokens will be the same for each element of the batch. In practice, we should only need to compute them once!
Instead of doing this looping, can we keep the existing code and compute the init tokens just once? We can then copy them for each element in the batch, in a similar way to how you did previously:
# Separate init_tokens for each language
init_tokens = [copy.deepcopy(init_tokens) for _ in language]
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Actually, init_tokens
will either be length 1 (if a single language was passed), or have possibly different values for each element in the batch (if a batch of languages was passed or the languages were detected). So there is no redundancy.
One could maybe handle the task tokens first and only then expand init_tokens
to the size of the batch (if needed), but with the way it's written, this is not easily possible (because the language token comes before the task token).
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Sure - but the rest of the init tokens (start token id, task and timestamps) are fixed across the batch? So we only need to do a batched computation for the language token id (multiple ids), and for the rest we only need do it once for a single item, and then copy it for the rest of the batch elements?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The current version of the code makes it clear why we need to do this - happy with how it is! #29688 (comment)
|
||
# Both languages in the same batch |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Nice test! Would you mind defining a new test function for this test and moving the new checks into here? (e.g. test_large_generation_batched_languages
) Keeping the tests small and modular helps identify which parts of the code are potentially broken with new PRs
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Done!
Co-authored-by: Sanchit Gandhi <93869735+sanchit-gandhi@users.noreply.github.com>
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@ylacombe @sanchit-gandhi The code was quite hard to follow, so I refactored it a bit.
I re-ran the slow tests and got the same output – a few failing tests, but last time I checked, they were failing on main
as well:
========================================================================================================================== short test summary info ===========================================================================================================================
FAILED tests/models/whisper/test_modeling_whisper.py::WhisperModelIntegrationTests::test_whisper_longform_multi_batch - assert [' While Porashaggy sits there, a cooing dove. He has gone, and gone for good," answered Polychrom, who had managed to squeeze into the room beside the dragon, and had witnessed the occurrences with much interest. I have remained a prisoner only because I wi...
FAILED tests/models/whisper/test_modeling_whisper.py::WhisperModelIntegrationTests::test_whisper_longform_multi_batch_prev_cond - assert [" While poor shaggy sits there, a cooling dove. He has gone and gone for good, answered polychrome, who had managed to squeeze into the room beside the dragon, and had witnessed the occurrences with much interest. I have remained a prisoner only because I w...
FAILED tests/models/whisper/test_modeling_whisper.py::WhisperModelIntegrationTests::test_whisper_longform_prompt_ids - AssertionError: assert 'quilter' in 'because you were sleeping instead of conquering the lovely rose princess has become a fiddle without a bow while poor shaggy sits there a cooing dove'
FAILED tests/models/whisper/test_modeling_whisper.py::WhisperModelIntegrationTests::test_whisper_longform_single_batch - assert [" Because you were sleeping instead of conquering, the lovely rose princess has become a fiddle without a bow, all poor ashaggy sits there, accooing dove. He has gone and gone for good. Antered polychrome who had managed to squeeze into the room beside the ...
FAILED tests/models/whisper/test_modeling_whisper.py::WhisperModelIntegrationTests::test_whisper_longform_single_batch_prev_cond - assert [" Because you were sleeping instead of conquering, the lovely rose princess has become a fiddle without a bow, all poor ashaggy sits there, accooing dove. He has gone and gone for good. Antered polychrome who had managed to squeeze into the room beside the ...
==================================================================================================== 5 failed, 293 passed, 106 skipped, 121 warnings in 259.68s (0:04:19) ====================================================================================================
|
||
# if language is defined it'll overwrite language ids that might have already been defined via the generation_config | ||
replace_or_add(init_tokens, lang_id, generation_config.lang_to_id.values()) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I'm not sure why this function was used here. The language token is always going to be at init_tokens[1]
, so it's doing the same thing as this bit in the elif
branch (only slower):
transformers/src/transformers/models/whisper/generation_whisper.py
Lines 1177 to 1181 in 56b64bf
# append or replace lang_id to init_tokens | |
if len(init_tokens) > 1: | |
init_tokens[1] = lang_id | |
else: | |
init_tokens.append(lang_id) |
So I unified the two branches.
|
||
# Both languages in the same batch |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Done!
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Very clean PR - thanks for iterating @cifkao!
del languages | ||
|
||
# Update init_tokens with task | ||
for i in range(len(init_tokens)): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This looks much cleaner and I understand the motivation for looping over the init_tokens
!
Failing slow tests will be addressed by #30152 |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Hi @cifkao,
Thanks for working on this!
LGTM
Reviewing in a bit to finish it! Would be nice to rebase as it's been a while! 🤗 |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Looks good to me!
This is a nice feature, thanks for adding support for batches!
can you merge with main to make sure all the tests pass? 🤗 |
@ArthurZucker Done! |
Congrats for this PR! 🤗 |
* Add support for mixing languages in a single batch * Update docstring * Enable different detected languages in batch * Do not require input_features * Test list of languages * Fix comment * Make init_tokens length-1 if possible, broadcast at the end * Test for ValueError with language list of incorrect length * Slow test for batched multilingual transcription * fixup * Apply suggestions from code review Co-authored-by: Sanchit Gandhi <93869735+sanchit-gandhi@users.noreply.github.com> * Address review, refactor * Second attempt to move this line where it was originally * Split test, fix a bug --------- Co-authored-by: Sanchit Gandhi <93869735+sanchit-gandhi@users.noreply.github.com>
Hi @ArthurZucker @sanchit-gandhi @cifkao, model.generate(
batch.input_features,
language=batch.language,
) We observe the following warning:
Can it be ignored? Since we are in generate mode, the pad token does not affect the process. |
What does this PR do?
Fixes #29685
Before submitting
Who can review?
@ylacombe @sanchit-gandhi