Skip to content

Commit

Permalink
Support the openai moderation adapter (#376)
Browse files Browse the repository at this point in the history
Signed-off-by: SimFG <bang.fu@zilliz.com>
  • Loading branch information
SimFG authored May 22, 2023
1 parent d95be09 commit a2b7466
Show file tree
Hide file tree
Showing 9 changed files with 304 additions and 42 deletions.
7 changes: 3 additions & 4 deletions examples/processor/temperature_example.py
Original file line number Diff line number Diff line change
@@ -1,12 +1,11 @@
import os
import time

from gptcache import cache, Config
from gptcache.manager import manager_factory
from gptcache import cache
from gptcache.adapter import openai
from gptcache.embedding import Onnx
from gptcache.manager import manager_factory
from gptcache.processor.post import temperature_softmax
from gptcache.similarity_evaluation.distance import SearchDistanceEvaluation
from gptcache.adapter import openai

cache.set_openai_key()

Expand Down
36 changes: 27 additions & 9 deletions gptcache/adapter/adapter.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
import logging
from gptcache import cache
from gptcache.processor.post import temperature_softmax
from gptcache.utils.error import NotInitError
from gptcache.utils.log import gptcache_log
from gptcache.utils.time import time_cal


Expand All @@ -23,8 +23,14 @@ def adapt(llm_handler, cache_data_convert, update_cache_callback, *args, **kwarg
if 0 < temperature < 2:
cache_skip_options = [True, False]
prob_cache_skip = [0, 1]
cache_skip = kwargs.pop("cache_skip", temperature_softmax(
messages=cache_skip_options, scores = prob_cache_skip, temperature=temperature))
cache_skip = kwargs.pop(
"cache_skip",
temperature_softmax(
messages=cache_skip_options,
scores=prob_cache_skip,
temperature=temperature,
),
)
elif temperature >= 2:
cache_skip = kwargs.pop("cache_skip", True)
else: # temperature <= 0
Expand Down Expand Up @@ -56,7 +62,9 @@ def adapt(llm_handler, cache_data_convert, update_cache_callback, *args, **kwarg
)(
embedding_data,
extra_param=context.get("search_func", None),
top_k=kwargs.pop("top_k", 5) if (user_temperature and not user_top_k) else kwargs.pop("top_k", -1),
top_k=kwargs.pop("top_k", 5)
if (user_temperature and not user_top_k)
else kwargs.pop("top_k", -1),
)
if cache_data_list is None:
cache_data_list = []
Expand All @@ -83,7 +91,7 @@ def adapt(llm_handler, cache_data_convert, update_cache_callback, *args, **kwarg
if "deps" in context and hasattr(ret.question, "deps"):
eval_query_data = {
"question": context["deps"][0]["data"],
"embedding": None
"embedding": None,
}
eval_cache_data = {
"question": ret.question.deps[0].data,
Expand All @@ -108,6 +116,12 @@ def adapt(llm_handler, cache_data_convert, update_cache_callback, *args, **kwarg
eval_cache_data,
extra_param=context.get("evaluation_func", None),
)
gptcache_log.debug(
"similarity: [user question] %s, [cache question] %s, [value] %f",
pre_store_data,
ret.question,
rank,
)
if rank_threshold <= rank:
cache_answers.append((rank, ret.answers[0].answer, cache_data))
chat_cache.data_manager.hit_cache_callback(cache_data)
Expand All @@ -118,15 +132,17 @@ def adapt(llm_handler, cache_data_convert, update_cache_callback, *args, **kwarg
return_message = chat_cache.post_process_messages_func(
messages=[t[1] for t in cache_answers],
scores=[t[0] for t in cache_answers],
temperature=temperature
temperature=temperature,
)
else:
return_message = chat_cache.post_process_messages_func(
[t[1] for t in cache_answers]
)
chat_cache.report.hint_cache()
if session:
chat_cache.data_manager.add_session(answers_dict[return_message], session.name, pre_embedding_data)
chat_cache.data_manager.add_session(
answers_dict[return_message], session.name, pre_embedding_data
)
return cache_data_convert(return_message)

next_cache = chat_cache.next_cache
Expand Down Expand Up @@ -157,7 +173,9 @@ def update_cache_func(handled_llm_data, question=None):
session=session,
)

llm_data = update_cache_callback(llm_data, update_cache_func, *args, **kwargs)
llm_data = update_cache_callback(
llm_data, update_cache_func, *args, **kwargs
)
except Exception as e: # pylint: disable=W0703
logging.warning("failed to save the data to cache, error: %s", e)
gptcache_log.warning("failed to save the data to cache, error: %s", e)
return llm_data
144 changes: 120 additions & 24 deletions gptcache/adapter/openai.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,9 @@
import base64
import json
import os
import time
from io import BytesIO
from typing import Iterator, Any
from typing import Iterator, Any, List

from gptcache.adapter.adapter import adapt
from gptcache.manager.scalar_data.base import Answer, DataType
Expand Down Expand Up @@ -55,9 +56,13 @@ def llm_handler(cls, *llm_args, **llm_kwargs):
raise CacheError("openai error") from e

@staticmethod
def update_cache_callback(llm_data, update_cache_func, *args, **kwargs): # pylint: disable=unused-argument
def update_cache_callback(
llm_data, update_cache_func, *args, **kwargs
): # pylint: disable=unused-argument
if not isinstance(llm_data, Iterator):
update_cache_func(Answer(get_message_from_openai_answer(llm_data), DataType.STR))
update_cache_func(
Answer(get_message_from_openai_answer(llm_data), DataType.STR)
)
return llm_data
else:

Expand All @@ -82,7 +87,7 @@ def cache_data_convert(cache_data):
cache_data_convert,
cls.update_cache_callback,
*args,
**kwargs
**kwargs,
)


Expand Down Expand Up @@ -114,7 +119,9 @@ def cache_data_convert(cache_data):
return construct_text_from_cache(cache_data)

@staticmethod
def update_cache_callback(llm_data, update_cache_func, *args, **kwargs): # pylint: disable=unused-argument
def update_cache_callback(
llm_data, update_cache_func, *args, **kwargs
): # pylint: disable=unused-argument
update_cache_func(Answer(get_text_from_openai_answer(llm_data), DataType.STR))
return llm_data

Expand All @@ -125,7 +132,7 @@ def create(cls, *args, **kwargs):
cls.cache_data_convert,
cls.update_cache_callback,
*args,
**kwargs
**kwargs,
)


Expand All @@ -150,6 +157,7 @@ class Audio(openai.Audio):
audio_file= open("/path/to/audio.mp3", "rb")
transcript = openai.Audio.translate("whisper-1", audio_file)
"""

@classmethod
def transcribe(cls, model: str, file: Any, *args, **kwargs):
def llm_handler(*llm_args, **llm_kwargs):
Expand All @@ -161,12 +169,22 @@ def llm_handler(*llm_args, **llm_kwargs):
def cache_data_convert(cache_data):
return construct_audio_text_from_cache(cache_data)

def update_cache_callback(llm_data, update_cache_func, *args, **kwargs): # pylint: disable=unused-argument
update_cache_func(Answer(get_audio_text_from_openai_answer(llm_data), DataType.STR))
def update_cache_callback(
llm_data, update_cache_func, *args, **kwargs
): # pylint: disable=unused-argument
update_cache_func(
Answer(get_audio_text_from_openai_answer(llm_data), DataType.STR)
)
return llm_data

return adapt(
llm_handler, cache_data_convert, update_cache_callback, model=model, file=file, *args, **kwargs
llm_handler,
cache_data_convert,
update_cache_callback,
model=model,
file=file,
*args,
**kwargs,
)

@classmethod
Expand All @@ -180,12 +198,22 @@ def llm_handler(*llm_args, **llm_kwargs):
def cache_data_convert(cache_data):
return construct_audio_text_from_cache(cache_data)

def update_cache_callback(llm_data, update_cache_func, *args, **kwargs): # pylint: disable=unused-argument
update_cache_func(Answer(get_audio_text_from_openai_answer(llm_data), DataType.STR))
def update_cache_callback(
llm_data, update_cache_func, *args, **kwargs
): # pylint: disable=unused-argument
update_cache_func(
Answer(get_audio_text_from_openai_answer(llm_data), DataType.STR)
)
return llm_data

return adapt(
llm_handler, cache_data_convert, update_cache_callback, model=model, file=file, *args, **kwargs
llm_handler,
cache_data_convert,
update_cache_callback,
model=model,
file=file,
*args,
**kwargs,
)


Expand Down Expand Up @@ -224,25 +252,93 @@ def llm_handler(*llm_args, **llm_kwargs):

def cache_data_convert(cache_data):
return construct_image_create_resp_from_cache(
image_data=cache_data,
response_format=response_format,
size=size
)
image_data=cache_data, response_format=response_format, size=size
)

def update_cache_callback(llm_data, update_cache_func, *args, **kwargs): # pylint: disable=unused-argument
def update_cache_callback(
llm_data, update_cache_func, *args, **kwargs
): # pylint: disable=unused-argument
if response_format == "b64_json":
img_b64 = get_image_from_openai_b64(llm_data)
if isinstance(img_b64, str):
img_b64 = img_b64.encode("ascii")
update_cache_func(Answer(img_b64, DataType.IMAGE_BASE64))
elif response_format == "url":
update_cache_func(Answer(get_image_from_openai_url(llm_data), DataType.IMAGE_URL))
update_cache_func(
Answer(get_image_from_openai_url(llm_data), DataType.IMAGE_URL)
)
return llm_data

return adapt(
llm_handler, cache_data_convert, update_cache_callback, response_format=response_format, size=size, *args, **kwargs
llm_handler,
cache_data_convert,
update_cache_callback,
response_format=response_format,
size=size,
*args,
**kwargs,
)


class Moderation(openai.Moderation):
"""Openai Moderation Wrapper
Example:
.. code-block:: python
from gptcache.adapter import openai
from gptcache.adapter.api import init_similar_cache
from gptcache.processor.pre import get_openai_moderation_input
init_similar_cache(pre_func=get_openai_moderation_input)
openai.Moderation.create(
input="I want to kill them.",
)
"""

@classmethod
def llm_handler(cls, *llm_args, **llm_kwargs):
try:
return super().create(*llm_args, **llm_kwargs)
except openai.error.OpenAIError as e:
raise CacheError("openai error") from e

@classmethod
def cache_data_convert(cls, cache_data):
return json.loads(cache_data)

@classmethod
def update_cache_callback(
cls, llm_data, update_cache_func, *args, **kwargs
): # pylint: disable=unused-argument
update_cache_func(Answer(json.dumps(llm_data, indent=4), DataType.STR))
return llm_data

@classmethod
def create(cls, *args, **kwargs):
res = adapt(
cls.llm_handler,
cls.cache_data_convert,
cls.update_cache_callback,
*args,
**kwargs,
)

input_request_param = kwargs.get("input")
expect_res_len = 1
if isinstance(input_request_param, List):
expect_res_len = len(input_request_param)
if len(res.get("results")) != expect_res_len:
kwargs["cache_skip"] = True
res = adapt(
cls.llm_handler,
cls.cache_data_convert,
cls.update_cache_callback,
*args,
**kwargs,
)
return res


def construct_resp_from_cache(return_message):
return {
Expand Down Expand Up @@ -329,15 +425,15 @@ def construct_image_create_resp_from_cache(image_data, response_format, size):
elif response_format == "b64_json":
image_data = base64.b64encode(buffered.getvalue()).decode("ascii")
else:
raise AttributeError(f"Invalid response_format: {response_format} is not one of ['url', 'b64_json']")
raise AttributeError(
f"Invalid response_format: {response_format} is not one of ['url', 'b64_json']"
)

return {
"gptcache": True,
"created": int(time.time()),
"data": [
{response_format: image_data}
]
}
"data": [{response_format: image_data}],
}


def construct_audio_text_from_cache(return_text):
Expand Down
4 changes: 2 additions & 2 deletions gptcache/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,8 +8,8 @@ class Config:
:param log_time_func: optional, customized log time function
:type log_time_func: Optional[Callable[[str, float], None]]
:param similarity_threshold: a threshold ranged from 0 to 1 to filter search results with similarity score higher than the threshold.
When it is 0, there is no hits. When it is 1, all search results will be returned as hits.
:param similarity_threshold: a threshold ranged from 0 to 1 to filter search results with similarity score higher \
than the threshold. When it is 0, there is no hits. When it is 1, all search results will be returned as hits.
:type similarity_threshold: float
:param prompts: optional, if the request content will remove the prompt string when the request contains the prompt list
:type prompts: Optional[List[str]]
Expand Down
2 changes: 1 addition & 1 deletion gptcache/embedding/huggingface.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,7 @@ class Huggingface(BaseEmbedding):
test_sentence = '什么是Github'
huggingface = Huggingface(model='uer/albert-base-chinese-cluecorpussmall')
embed = encoder.to_embeddings(test_sentence)
embed = huggingface.to_embeddings(test_sentence)
"""

def __init__(self, model: str = "distilbert-base-uncased"):
Expand Down
18 changes: 18 additions & 0 deletions gptcache/processor/pre.py
Original file line number Diff line number Diff line change
Expand Up @@ -224,3 +224,21 @@ def get_inputs(data: Dict[str, Any], **_: Dict[str, Any]):
# "hello"
"""
return data.get("inputs")


def get_openai_moderation_input(data: Dict[str, Any], **_: Dict[str, Any]) -> str:
"""get the input param of the openai moderation request params
:param data: the user openai moderation request data
:type data: Dict[str, Any]
Example:
.. code-block:: python
from gptcache.processor.pre import get_openai_moderation_input
content = get_openai_moderation_input({"input": ["hello", "world"]})
# "['hello', 'world']"
"""

return str(data.get("input"))
Loading

0 comments on commit a2b7466

Please sign in to comment.