Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Generate: Export TF generate with a TF tokenizer #22310

Merged
merged 2 commits into from
Mar 22, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
84 changes: 31 additions & 53 deletions src/transformers/generation/tf_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -1725,14 +1725,13 @@ def greedy_search_body_fn(generated, finished_sequences, cur_len, model_kwargs):

# 2-to-n generation steps can then be run in autoregressive fashion
# only in case 1st generation step does NOT yield EOS token though
if greedy_search_cond_fn(generated, finished_sequences, cur_len, model_kwargs):
maximum_iterations = max_length - cur_len
generated, _, cur_len, _ = tf.while_loop(
greedy_search_cond_fn,
greedy_search_body_fn,
(generated, finished_sequences, cur_len, model_kwargs),
maximum_iterations=maximum_iterations,
)
maximum_iterations = max_length - cur_len
generated, _, cur_len, _ = tf.while_loop(
greedy_search_cond_fn,
greedy_search_body_fn,
(generated, finished_sequences, cur_len, model_kwargs),
maximum_iterations=maximum_iterations,
)

# 6. prepare outputs
if not use_xla:
Expand Down Expand Up @@ -2016,14 +2015,13 @@ def sample_body_fn(generated, finished_sequences, cur_len, model_kwargs):

# 2-to-n generation steps can then be run in autoregressive fashion
# only in case 1st generation step does NOT yield EOS token though
if sample_cond_fn(generated, finished_sequences, cur_len, model_kwargs):
maximum_iterations = max_length - cur_len
generated, _, cur_len, _ = tf.while_loop(
sample_cond_fn,
sample_body_fn,
(generated, finished_sequences, cur_len, model_kwargs),
maximum_iterations=maximum_iterations,
)
maximum_iterations = max_length - cur_len
generated, _, cur_len, _ = tf.while_loop(
sample_cond_fn,
sample_body_fn,
(generated, finished_sequences, cur_len, model_kwargs),
maximum_iterations=maximum_iterations,
)

# 6. prepare outputs
if not use_xla:
Expand Down Expand Up @@ -2565,7 +2563,8 @@ def beam_search_body_fn(

# 2-to-n generation steps can then be run in autoregressive fashion (only in case 1st generation step does
# NOT yield EOS token though)
if beam_search_cond_fn(
maximum_iterations = max_length - cur_len
(
cur_len,
running_sequences,
running_scores,
Expand All @@ -2574,9 +2573,10 @@ def beam_search_body_fn(
scores,
beam_indices,
is_sent_finished,
model_kwargs,
):
maximum_iterations = max_length - cur_len
_,
) = tf.while_loop(
beam_search_cond_fn,
beam_search_body_fn,
(
cur_len,
running_sequences,
Expand All @@ -2586,23 +2586,10 @@ def beam_search_body_fn(
scores,
beam_indices,
is_sent_finished,
_,
) = tf.while_loop(
beam_search_cond_fn,
beam_search_body_fn,
(
cur_len,
running_sequences,
running_scores,
running_beam_indices,
sequences,
scores,
beam_indices,
is_sent_finished,
model_kwargs,
),
maximum_iterations=maximum_iterations,
)
model_kwargs,
),
maximum_iterations=maximum_iterations,
)

# 6. prepare outputs
# Account for the edge-case where there are no finished sequences for a particular batch item. If so, return
Expand Down Expand Up @@ -3019,22 +3006,13 @@ def contrastive_search_body_fn(

# 2-to-n generation steps can then be run in autoregressive fashion
# only in case 1st generation step does NOT yield EOS token though
if contrastive_search_cond_fn(
generated, finished_sequences, cur_len, model_kwargs, next_step_cached_variables
):
maximum_iterations = max_length - cur_len
(
generated,
_,
cur_len,
_,
_,
) = tf.while_loop(
contrastive_search_cond_fn,
contrastive_search_body_fn,
(generated, finished_sequences, cur_len, model_kwargs, next_step_cached_variables),
maximum_iterations=maximum_iterations,
)
maximum_iterations = max_length - cur_len
generated, _, cur_len, _, _ = tf.while_loop(
contrastive_search_cond_fn,
contrastive_search_body_fn,
(generated, finished_sequences, cur_len, model_kwargs, next_step_cached_variables),
maximum_iterations=maximum_iterations,
)

# 6. prepare outputs
if not use_xla:
Expand Down
39 changes: 37 additions & 2 deletions tests/generation/test_tf_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,13 +13,15 @@
# See the License for the specific language governing permissions and
# limitations under the License.

import os
import tempfile
import unittest

import numpy as np
from huggingface_hub import hf_hub_download

from transformers import is_tf_available
from transformers.testing_utils import require_tf, slow
from transformers import is_tensorflow_text_available, is_tf_available
from transformers.testing_utils import require_tensorflow_text, require_tf, slow

from ..test_modeling_tf_common import floats_tensor
from .test_framework_agnostic import GenerationIntegrationTestsMixin
Expand All @@ -40,6 +42,9 @@
tf_top_k_top_p_filtering,
)

if is_tensorflow_text_available():
import tensorflow_text as text


@require_tf
class UtilsFunctionsTest(unittest.TestCase):
Expand Down Expand Up @@ -239,6 +244,36 @@ def serving(self, input_ids, attention_mask):
tf_model_outputs = test_model.generate(**inputs, max_new_tokens=max_new_tokens)
tf.debugging.assert_equal(tf_func_outputs, tf_model_outputs)

@slow
@require_tensorflow_text
def test_generate_tf_function_export_with_tf_tokenizer(self):
# TF-only test: tf.saved_model export
with tempfile.TemporaryDirectory() as tmp_dir:
# file needed to load the TF tokenizer
hf_hub_download(repo_id="google/flan-t5-small", filename="spiece.model", local_dir=tmp_dir)

class CompleteSentenceTransformer(tf.keras.layers.Layer):
def __init__(self):
super().__init__()
self.tokenizer = text.SentencepieceTokenizer(
model=tf.io.gfile.GFile(os.path.join(tmp_dir, "spiece.model"), "rb").read()
)
self.model = TFAutoModelForSeq2SeqLM.from_pretrained("hf-internal-testing/tiny-random-t5")

def call(self, inputs, *args, **kwargs):
tokens = self.tokenizer.tokenize(inputs)
input_ids, attention_mask = text.pad_model_inputs(
tokens, max_seq_length=64, pad_value=self.model.config.pad_token_id
)
outputs = self.model.generate(input_ids=input_ids, attention_mask=attention_mask)
return self.tokenizer.detokenize(outputs)

complete_model = CompleteSentenceTransformer()
inputs = tf.keras.layers.Input(shape=(1,), dtype=tf.string, name="inputs")
outputs = complete_model(inputs)
keras_model = tf.keras.Model(inputs, outputs)
keras_model.save(tmp_dir)

def test_eos_token_id_int_and_list_top_k_top_sampling(self):
# Has PT equivalent: this test relies on random sampling
generation_kwargs = {
Expand Down