-
Notifications
You must be signed in to change notification settings - Fork 512
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Signed-off-by: vax521 <13263397018@163.com> Signed-off-by: feimeng <13263397018@163.com>
- Loading branch information
Showing
6 changed files
with
153 additions
and
2 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,32 @@ | ||
from gptcache.adapter import openai | ||
from gptcache import cache | ||
from gptcache.manager.factory import get_data_manager | ||
from gptcache.manager import get_data_manager, CacheBase, VectorBase | ||
from gptcache.similarity_evaluation.distance import SearchDistanceEvaluation | ||
from gptcache.embedding import PaddleNLP | ||
|
||
|
||
def run(): | ||
paddlenlp = PaddleNLP() | ||
|
||
cache_base = CacheBase('sqlite') | ||
vector_base = VectorBase('faiss', dimension=paddlenlp.dimension) | ||
data_manager = get_data_manager(cache_base, vector_base) | ||
|
||
cache.init(embedding_func=paddlenlp.to_embeddings, | ||
data_manager=data_manager, | ||
similarity_evaluation=SearchDistanceEvaluation(), | ||
) | ||
cache.set_openai_key() | ||
|
||
answer = openai.ChatCompletion.create( | ||
model='gpt-3.5-turbo', | ||
messages=[ | ||
{'role': 'user', 'content': 'what is chatgpt'} | ||
], | ||
) | ||
print(answer) | ||
|
||
|
||
if __name__ == '__main__': | ||
run() |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,76 @@ | ||
import numpy as np | ||
|
||
from gptcache.utils import import_paddlenlp,import_paddle | ||
from gptcache.embedding.base import BaseEmbedding | ||
|
||
import_paddle() | ||
import_paddlenlp() | ||
|
||
|
||
import paddle # pylint: disable=C0413 | ||
from paddlenlp.transformers import AutoModel,AutoTokenizer # pylint: disable=C0413 | ||
|
||
class PaddleNLP(BaseEmbedding): | ||
"""Generate sentence embedding for given text using pretrained models from PaddleNLP transformers. | ||
:param model: model name, defaults to 'ernie-3.0-medium-zh'. | ||
:type model: str | ||
Example: | ||
.. code-block:: python | ||
from gptcache.embedding import PaddleNLP | ||
test_sentence = 'Hello, world.' | ||
encoder = PaddleNLP(model='ernie-3.0-medium-zh') | ||
embed = encoder.to_embeddings(test_sentence) | ||
""" | ||
|
||
def __init__(self, model: str = "ernie-3.0-medium-zh"): | ||
self.model = AutoModel.from_pretrained(model) | ||
self.model.eval() | ||
|
||
self.tokenizer = AutoTokenizer.from_pretrained(model) | ||
if not self.tokenizer.pad_token: | ||
self.tokenizer.pad_token = "<pad>" | ||
self.__dimension = None | ||
|
||
|
||
def to_embeddings(self, data, **_): | ||
"""Generate embedding given text input | ||
:param data: text in string. | ||
:type data: str | ||
:return: a text embedding in shape of (dim,). | ||
""" | ||
if not isinstance(data, list): | ||
data = [data] | ||
inputs = self.tokenizer( | ||
data, padding=True, truncation=True, return_tensors="pd" | ||
) | ||
outs = self.model(**inputs)[0] | ||
emb = self.post_proc(outs, inputs).squeeze(0).detach().numpy() | ||
return np.array(emb).astype("float32") | ||
|
||
def post_proc(self, token_embeddings, inputs): | ||
attention_mask = paddle.ones(inputs["token_type_ids"].shape) | ||
input_mask_expanded = ( | ||
attention_mask.unsqueeze(-1).expand(token_embeddings.shape).astype("float32") | ||
) | ||
sentence_embs = paddle.sum( | ||
token_embeddings * input_mask_expanded, 1 | ||
) / paddle.clip(input_mask_expanded.sum(1), min=1e-9) | ||
return sentence_embs | ||
|
||
|
||
@property | ||
def dimension(self): | ||
"""Embedding dimension. | ||
:return: embedding dimension | ||
""" | ||
if not self.__dimension: | ||
self.__dimension = len(self.to_embeddings("foo")) | ||
return self.__dimension | ||
|
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,14 @@ | ||
from gptcache.embedding import PaddleNLP | ||
from gptcache.adapter.api import _get_model | ||
|
||
|
||
def test_paddlenlp(): | ||
t = PaddleNLP("ernie-3.0-nano-zh") | ||
dimension = t.dimension | ||
data = t.to_embeddings("中国") | ||
assert len(data) == dimension, f"{len(data)}, {t.dimension}" | ||
|
||
t = _get_model(model_src="paddlenlp", model_config={"model": "ernie-3.0-nano-zh"}) | ||
dimension = t.dimension | ||
data = t.to_embeddings("中国") | ||
assert len(data) == dimension, f"{len(data)}, {t.dimension}" |