-
Notifications
You must be signed in to change notification settings - Fork 27.4k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
is_pretokenized seems to work incorrectly #6046
Comments
We face a similar issue with the distilbert tokenizer.
It seems that subword tokens (here ##n and ##ota) get split into further tokens even though we set |
As I mentioned before we used
this returns Hope this will help someone 🤗 |
I have the same issue as @tholor - there seem to be some nasty differences between slow and fast tokenizer implementations. |
Just got the same issue with from transformers import AutoTokenizer
tokenizer = AutoTokenizer.from_pretrained("bert-base-uncased")
text = "huggingface transformers"
tok = tokenizer.tokenize(text)
print(tok)
# ['hugging', '##face', 'transformers']
output = tokenizer.encode_plus(tok, is_pretokenized=True)
tokenizer.convert_ids_to_tokens(output["input_ids"])
# ['[CLS]', 'hugging', '#', '#', 'face', 'transformers', '[SEP]'] when output2 = tokenizer.encode_plus(tok, is_pretokenized=False)
tokenizer.convert_ids_to_tokens(output2["input_ids"])
# ['[CLS]', 'hugging', '##face', 'transformers', '[SEP]'] |
I believe that this issue can be closed because of explanation in #6575 stating that |
🐛 Bug
Information
Model I am using (Bert, XLNet ...): roberta
Language I am using the model on (English, Chinese ...): English
The problem arises when using:
I use
RobertaTokenizerFast
on pretokenized text, but problem arises when I switch to slow version tooThe tasks I am working on is:
I am trying to implement sliding window for roberta
To reproduce
I use
tokenizer.tokenize(text)
method to tokenize whole text (1-3 sentences), when I divide tokens into chunks and try to use__call__
method (I also triedencode
) withis_pretokenized=True
argument, but this creates additional tokens (like 3 times more then should be). I worked this around by usingtokenize
->convert_tokens_to_ids
->prepare_for_model
->pad
pipeline, but I believe that batch methods should be faster and more memory efficientSteps to reproduce the behavior:
tokenizer = AutoTokenizer.from_pretrained('roberta-base', add_prefix_space=True, use_fast=True)
ex_text = 'long text'
tokens = tokenizer.tokenize(ex_text)
examples = [tokens[i:i+126] for i in range(0, len(tokens), 100)]
print(len(tokenizer(examples, is_pretokenized=True)['input_ids'][0])) # this prints more than 128
Expected behavior
I would expect to get result similar to result I get when I use
Am I doing something wrong or it's unexpected behaviour?
Environment info
transformers
version: 3.0.2EDIT:
I see that when I use
__call__
it actually treatĠ
as 2 tokens:tokenizer(tokenizer.tokenize('How'), is_pretokenized=True)['input_ids']
out: [0, 4236, 21402, 6179, 2]
where 4236, 21402 isĠ
The text was updated successfully, but these errors were encountered: