Skip to content

Commit

Permalink
feat: Add support for tokenizer (baidubce#21)
Browse files Browse the repository at this point in the history
* support tokenizer

* update function name

* add tokenizer readme

* add notes for local simulation
  • Loading branch information
ZingLix authored Oct 27, 2023
1 parent 16e7691 commit 39169fe
Show file tree
Hide file tree
Showing 4 changed files with 156 additions and 0 deletions.
13 changes: 13 additions & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -371,6 +371,19 @@ import qianfan
chat_comp = qianfan.ChatCompletion(query_per_second=0.5)
```

### Tokenizer

对于大语言模型,一般有 token 长度的限制,我们提供了 `Tokenizer` 工具类,可以方便地对文本的 token 数量进行估算。

使用方法如下

```python
text = "这是待计算 token 数量的一段文本"
count = qianfan.Tokenizer.count_tokens(text)
```

目前仅支持 `ERNIE-Bot` 系列模型的估算,约为 `汉字数+单词数*1.3`

## License

Apache-2.0
2 changes: 2 additions & 0 deletions src/qianfan/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@
from qianfan.resources.llm.embedding import Embedding
from qianfan.resources.llm.plugin import Plugin
from qianfan.resources.tools.prompt import Prompt
from qianfan.resources.tools.tokenizer import Tokenizer
from qianfan.resources.typing import QfMessages, QfResponse, QfRole
from qianfan.utils import disable_log, enable_log
from qianfan.version import VERSION
Expand All @@ -42,6 +43,7 @@
"FineTune",
"Model",
"Service",
"Tokenizer",
"Prompt",
"AK",
"SK",
Expand Down
111 changes: 111 additions & 0 deletions src/qianfan/resources/tools/tokenizer.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,111 @@
# Copyright (c) 2023 Baidu, Inc. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

"""
Tokenizer
"""

import unicodedata
from typing import Any

from qianfan.errors import InternalError, InvalidArgumentError


class Tokenizer(object):
"""
Class for Tokenizer API
"""

@classmethod
def count_tokens(
cls, text: str, mode: str = "local", model: str = "ERNIE-Bot", **kwargs: Any
) -> int:
"""
Count the number of tokens in a given text.
Parameters:
text (str):
The input text for which tokens need to be counted.
mode (str, optional):
`local` (default):
local **SIMULATION** (Chinese characters count + English word count * 1.3)
model (str, optional):
The name of the model to be used for token counting, which
may influence the counting strategy. Default is 'ERNIE-Bot'.
kwargs (Any):
Additional keyword arguments that can be passed to customize the request.
"""
if mode not in ["local"]:
raise InvalidArgumentError(
f"Mode `{mode}` is not supported for count token, supported mode:"
" `local`"
)
if mode == "local":
return cls._local_count_tokens(text)

# unreachable
raise InternalError

@classmethod
def _local_count_tokens(cls, text: str, model: str = "ERNIE-Bot") -> int:
"""
Calculate the token count for a given text using a local simulation.
** THIS IS CALCULATED BY LOCAL SIMULATION, NOT REAL TOKEN COUNT **
The token count is computed as follows:
(Chinese characters count) + (English word count * 1.3)
"""
han_count = 0
text_only_word = ""
for ch in text:
if cls._is_cjk_character(ch):
han_count += 1
text_only_word += " "
elif cls._is_punctuation(ch) or cls._is_space(ch):
text_only_word += " "
else:
text_only_word += ch
word_count = len(list(filter(lambda x: x != "", text_only_word.split(" "))))
return han_count + int(word_count * 1.3)

@staticmethod
def _is_cjk_character(ch: str) -> bool:
"""
Check if the character is CJK character.
"""
code = ord(ch)
return 0x4E00 <= code <= 0x9FFF

@staticmethod
def _is_space(ch: str) -> bool:
"""
Check if the character is space.
"""
return ch in {" ", "\n", "\r", "\t"} or unicodedata.category(ch) == "Zs"

@staticmethod
def _is_punctuation(ch: str) -> bool:
"""
Check if the character is punctuation.
"""
code = ord(ch)
return (
33 <= code <= 47
or 58 <= code <= 64
or 91 <= code <= 96
or 123 <= code <= 126
or unicodedata.category(ch).startswith("P")
)
30 changes: 30 additions & 0 deletions src/qianfan/tests/tokenizer_test.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,30 @@
# Copyright (c) 2023 Baidu, Inc. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

"""
Unit test for Tokenizer
"""

from qianfan import Tokenizer


def test_count_token_local():
cases = [
("hello world 1 2 3", 6),
("你好呀", 3),
("你好hello哈哈world 1 2 3", 10),
(" 你好 123 哈哈1! \t 123 Hello world", 10),
]
for text, count in cases:
assert Tokenizer.count_tokens(text, mode="local") == count

0 comments on commit 39169fe

Please sign in to comment.