Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

is_pretokenized seems to work incorrectly #6046

Closed
2 tasks done
Zhylkaaa opened this issue Jul 26, 2020 · 5 comments
Closed
2 tasks done

is_pretokenized seems to work incorrectly #6046

Zhylkaaa opened this issue Jul 26, 2020 · 5 comments

Comments

@Zhylkaaa
Copy link
Contributor

Zhylkaaa commented Jul 26, 2020

🐛 Bug

Information

Model I am using (Bert, XLNet ...): roberta

Language I am using the model on (English, Chinese ...): English

The problem arises when using:

  • the official example scripts: (give details below)
  • my own modified scripts: (give details below)

I use RobertaTokenizerFast on pretokenized text, but problem arises when I switch to slow version too

The tasks I am working on is:

  • an official GLUE/SQUaD task: (give the name)
  • my own task or dataset: (give details below)

I am trying to implement sliding window for roberta

To reproduce

I use tokenizer.tokenize(text) method to tokenize whole text (1-3 sentences), when I divide tokens into chunks and try to use __call__ method (I also tried encode) with is_pretokenized=True argument, but this creates additional tokens (like 3 times more then should be). I worked this around by using tokenize -> convert_tokens_to_ids -> prepare_for_model -> pad pipeline, but I believe that batch methods should be faster and more memory efficient
Steps to reproduce the behavior:

  1. tokenizer = AutoTokenizer.from_pretrained('roberta-base', add_prefix_space=True, use_fast=True)
  2. ex_text = 'long text'
  3. tokens = tokenizer.tokenize(ex_text)
  4. examples = [tokens[i:i+126] for i in range(0, len(tokens), 100)]
  5. print(len(tokenizer(examples, is_pretokenized=True)['input_ids'][0])) # this prints more than 128

Expected behavior

I would expect to get result similar to result I get when I use

tokens = tokeniser.tokenize(ex_text)
inputs = tokenizer.convert_tokens_to_ids(tokens)
inputs = [inputs[i:i+126] for i in range(0, len(tokens), 100)]
inputs = [tokenizer.prepare_for_model(example) for example in inputs] 
inputs = tokenizer.pad(inputs, padding='longest')

Am I doing something wrong or it's unexpected behaviour?

Environment info

  • transformers version: 3.0.2
  • Platform: MacOs
  • Python version: 3.8.3
  • PyTorch version (GPU?): 1.5.1 (no GPU)
  • Tensorflow version (GPU?): NO
  • Using GPU in script?: NO
  • Using distributed or parallel set-up in script?: NO

EDIT:
I see that when I use __call__ it actually treat Ġ as 2 tokens:
tokenizer(tokenizer.tokenize('How'), is_pretokenized=True)['input_ids']
out: [0, 4236, 21402, 6179, 2] where 4236, 21402 is Ġ

@Zhylkaaa Zhylkaaa changed the title Is_pretokenized seems to not work is_pretokenized seems to work incorrectly Jul 27, 2020
@tholor
Copy link
Contributor

tholor commented Aug 6, 2020

We face a similar issue with the distilbert tokenizer.

from transformers import AutoTokenizer

tokenizer = AutoTokenizer.from_pretrained("distilbert-base-german-cased")
tokens = ['1980', 'kam', 'der', 'Crow', '##n', 'von', 'Toy', '##ota']
result = tokenizer.encode_plus(text=tokens,
                               text_pair=None,
                               add_special_tokens=True,
                               truncation=False,
                               return_special_tokens_mask=True,
                               return_token_type_ids=True,
                               is_pretokenized=True
                               )
result["input_ids"]
# returns:
[102,
 3827,
 1396,
 125,
 28177,
 1634,
 1634,
 151,
 195,
 25840,
 1634,
 1634,
 23957,
 30887,
 103]

tokenizer.decode(result["input_ids"])
# returns:
'[CLS] 1980 kam der Crow # # n von Toy # # ota [SEP]'

It seems that subword tokens (here ##n and ##ota) get split into further tokens even though we set is_pretokenized=True. This seems unexpected to me but maybe I am missing something?

@Zhylkaaa
Copy link
Contributor Author

Zhylkaaa commented Aug 7, 2020

As I mentioned before we used is_pretokenized to create sliding window, but recently discovered that this can be achieved using:

stride = max_seq_length - 2 - int(max_seq_length*stride)
tokenized_examples = tokenizer(examples, return_overflowing_tokens=True, 
                               max_length=max_seq_length, stride=stride, truncation=True)

this returns dict with input_ids, attention_mask and overflow_to_sample_mapping (this helps to map between windows and example, but you should check for its presence, if you pass 1 short example it might not be there).

Hope this will help someone 🤗

@PhilipMay
Copy link
Contributor

I have the same issue as @tholor - there seem to be some nasty differences between slow and fast tokenizer implementations.

@chrk623
Copy link

chrk623 commented Aug 18, 2020

Just got the same issue with bert-base-uncased, However if when is_pretokenized=False it seems to be OK. Is this expected behaviour?

from transformers import AutoTokenizer
tokenizer = AutoTokenizer.from_pretrained("bert-base-uncased")
text  = "huggingface transformers"
tok = tokenizer.tokenize(text)
print(tok)
# ['hugging', '##face', 'transformers']

output = tokenizer.encode_plus(tok, is_pretokenized=True)
tokenizer.convert_ids_to_tokens(output["input_ids"])
# ['[CLS]', 'hugging', '#', '#', 'face', 'transformers', '[SEP]']

when is_pretokenized=False

output2 = tokenizer.encode_plus(tok, is_pretokenized=False)
tokenizer.convert_ids_to_tokens(output2["input_ids"])
# ['[CLS]', 'hugging', '##face', 'transformers', '[SEP]']

@Zhylkaaa
Copy link
Contributor Author

I believe that this issue can be closed because of explanation in #6575 stating that is_pretokenized expect just list of words spited by white space not actual tokens. So this is "kind of expected" behaviour :)

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

4 participants