Skip to content

Commit

Permalink
clean up tokenization - fix python 2 tests
Browse files Browse the repository at this point in the history
  • Loading branch information
thomwolf committed Feb 18, 2019
1 parent d44db11 commit b450a7f
Showing 1 changed file with 17 additions and 5 deletions.
22 changes: 17 additions & 5 deletions pytorch_pretrained_bert/tokenization_gpt2.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,14 +20,19 @@
import logging
import os
import regex as re
import sys
from io import open
from functools import lru_cache

from tqdm import tqdm
try:
from functools import lru_cache
except ImportError:
# Just a dummy decorator to get the checks to run on python2
# because honestly I don't want to support a byte-level unicode BPE tokenizer on python 2 right now.
def lru_cache(func):
def func_wrapper(*inputs, **args):
return func(inputs, args)
return func_wrapper

from .file_utils import cached_path
from .tokenization import BasicTokenizer

logger = logging.getLogger(__name__)

Expand Down Expand Up @@ -125,7 +130,8 @@ def from_pretrained(cls, pretrained_model_name_or_path, cache_dir=None, *inputs,
tokenizer = cls(resolved_vocab_file, resolved_merges_file, *inputs, **kwargs)
return tokenizer

def __init__(self, vocab_file, merges_file, errors='replace'):
def __init__(self, vocab_file, merges_file, errors='replace', max_len=None):
self.max_len = max_len if max_len is not None else int(1e12)
self.encoder = json.load(open(vocab_file))
self.decoder = {v:k for k,v in self.encoder.items()}
self.errors = errors # how to handle errors in decoding
Expand Down Expand Up @@ -188,6 +194,12 @@ def encode(self, text):
for token in re.findall(self.pat, text):
token = ''.join(self.byte_encoder[b] for b in token.encode('utf-8'))
bpe_tokens.extend(self.encoder[bpe_token] for bpe_token in self.bpe(token).split(' '))
if len(bpe_tokens) > self.max_len:
raise ValueError(
"Token indices sequence length is longer than the specified maximum "
" sequence length for this OpenAI GPT-2 model ({} > {}). Running this"
" sequence through the model will result in indexing errors".format(len(bpe_tokens), self.max_len)
)
return bpe_tokens

def decode(self, tokens):
Expand Down

0 comments on commit b450a7f

Please sign in to comment.