Skip to content

Commit

Permalink
Review the paddlenlp code
Browse files Browse the repository at this point in the history
Signed-off-by: SimFG <bang.fu@zilliz.com>
  • Loading branch information
SimFG committed May 22, 2023
1 parent dff1c77 commit 1825e90
Show file tree
Hide file tree
Showing 7 changed files with 40 additions and 18 deletions.
4 changes: 3 additions & 1 deletion .github/workflows/Nightly_CI_main.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -68,4 +68,6 @@ jobs:
shell: bash
working-directory: tests
run: |
IS_CI=true python3 -m pytest ./ --tags L2
export IS_CI=true
export PROTOCOL_BUFFERS_PYTHON_IMPLEMENTATION=python
python3 -m pytest ./ --tags L2
4 changes: 3 additions & 1 deletion .github/workflows/unit_test_main.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -73,7 +73,9 @@ jobs:
shell: bash
working-directory: tests
run: |
IS_CI=true python3 -m pytest ./
export IS_CI=true
export PROTOCOL_BUFFERS_PYTHON_IMPLEMENTATION=python
python3 -m pytest ./
- name: Generate coverage report
run: |
Expand Down
31 changes: 25 additions & 6 deletions gptcache/adapter/api.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,15 +5,30 @@
import gptcache.processor.pre
from gptcache import Cache, cache, Config
from gptcache.adapter.adapter import adapt
from gptcache.embedding import Onnx, Huggingface, SBERT, FastText, Data2VecAudio, Timm, ViT, OpenAI, Cohere, Rwkv
from gptcache.embedding import (
Onnx,
Huggingface,
SBERT,
FastText,
Data2VecAudio,
Timm,
ViT,
OpenAI,
Cohere,
Rwkv,
PaddleNLP,
)
from gptcache.embedding.base import BaseEmbedding
from gptcache.manager import manager_factory
from gptcache.manager.data_manager import DataManager
from gptcache.processor.post import first
from gptcache.processor.pre import get_prompt
from gptcache.similarity_evaluation import (
SearchDistanceEvaluation, NumpyNormEvaluation, OnnxModelEvaluation,
ExactMatchEvaluation, KReciprocalEvaluation
SearchDistanceEvaluation,
NumpyNormEvaluation,
OnnxModelEvaluation,
ExactMatchEvaluation,
KReciprocalEvaluation,
)
from gptcache.utils import import_ruamel

Expand Down Expand Up @@ -145,7 +160,9 @@ def init_similar_cache(
embedding = Onnx()
if not data_manager:
data_manager = manager_factory(
"sqlite,faiss", data_dir=data_dir, vector_params={"dimension": embedding.dimension}
"sqlite,faiss",
data_dir=data_dir,
vector_params={"dimension": embedding.dimension},
)
evaluation = SearchDistanceEvaluation()
cache_obj = cache_obj if cache_obj else cache
Expand Down Expand Up @@ -207,7 +224,7 @@ def init_similar_cache_from_config(config_dir: str, cache_obj: Optional[Cache] =
)


def _get_model(model_src, model_config = None):
def _get_model(model_src, model_config=None):
model_src = model_src.lower()
model_config = model_config or {}

Expand All @@ -231,9 +248,11 @@ def _get_model(model_src, model_config = None):
return Cohere(**model_config)
if model_src == "rwkv":
return Rwkv(**model_config)
if model_src == "paddlenlp":
return PaddleNLP(**model_config)


def _get_eval(strategy, kws = None):
def _get_eval(strategy, kws=None):
strategy = strategy.lower()
kws = kws or {}

Expand Down
7 changes: 3 additions & 4 deletions gptcache/adapter/stability_sdk.py
Original file line number Diff line number Diff line change
@@ -1,15 +1,15 @@
from io import BytesIO
import base64
import warnings
from dataclasses import dataclass
from io import BytesIO
from typing import List

from gptcache.adapter.adapter import adapt
from gptcache.manager.scalar_data.base import Answer, DataType
from gptcache.utils.error import CacheError
from gptcache.utils import (
import_stability, import_pillow
)
)
from gptcache.utils.error import CacheError

import_pillow()
import_stability()
Expand All @@ -19,7 +19,6 @@
import stability_sdk.interfaces.gooseai.generation.generation_pb2 as generation # pylint: disable=C0413



class StabilityInference(client.StabilityInference):
"""client.StabilityInference Wrapper
Expand Down
8 changes: 3 additions & 5 deletions gptcache/embedding/paddlenlp.py
Original file line number Diff line number Diff line change
@@ -1,14 +1,14 @@
import numpy as np

from gptcache.utils import import_paddlenlp,import_paddle
from gptcache.embedding.base import BaseEmbedding
from gptcache.utils import import_paddlenlp, import_paddle

import_paddle()
import_paddlenlp()


import paddle # pylint: disable=C0413
from paddlenlp.transformers import AutoModel,AutoTokenizer # pylint: disable=C0413
import paddle # pylint: disable=C0413
from paddlenlp.transformers import AutoModel, AutoTokenizer # pylint: disable=C0413

class PaddleNLP(BaseEmbedding):
"""Generate sentence embedding for given text using pretrained models from PaddleNLP transformers.
Expand All @@ -35,7 +35,6 @@ def __init__(self, model: str = "ernie-3.0-medium-zh"):
self.tokenizer.pad_token = "<pad>"
self.__dimension = None


def to_embeddings(self, data, **_):
"""Generate embedding given text input
Expand Down Expand Up @@ -63,7 +62,6 @@ def post_proc(self, token_embeddings, inputs):
) / paddle.clip(input_mask_expanded.sum(1), min=1e-9)
return sentence_embs


@property
def dimension(self):
"""Embedding dimension.
Expand Down
3 changes: 2 additions & 1 deletion gptcache/utils/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -214,7 +214,8 @@ def import_docarray():


def import_paddle():
_check_library("paddlepaddle", package="paddlepaddle==2.4.0")
prompt_install("protobuf==3.20.0")
_check_library("paddlepaddle")


def import_paddlenlp():
Expand Down
1 change: 1 addition & 0 deletions tests/requirements.txt
Original file line number Diff line number Diff line change
Expand Up @@ -20,3 +20,4 @@ mock
pexpect
spacy
safetensors
protobuf==3.20.0

0 comments on commit 1825e90

Please sign in to comment.