Skip to content

Commit

Permalink
Add the BaseCacheLLM abstract class to wrap the llm (#394)
Browse files Browse the repository at this point in the history
Signed-off-by: SimFG <bang.fu@zilliz.com>
  • Loading branch information
SimFG authored May 29, 2023
1 parent 0c67584 commit 873fca7
Show file tree
Hide file tree
Showing 19 changed files with 400 additions and 241 deletions.
2 changes: 1 addition & 1 deletion gptcache/adapter/api.py
Original file line number Diff line number Diff line change
Expand Up @@ -148,7 +148,7 @@ def init_similar_cache(
:param post_func: post-processing of the cached result list, the most similar result is taken by default
:type post_func: Callable[[List[Any]], Any]
:param config: cache configuration, the core is similar threshold
:type config: gptcache.Config
:type config: Config
:return: None
Example:
Expand Down
70 changes: 70 additions & 0 deletions gptcache/adapter/base.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,70 @@
from abc import ABCMeta
from typing import Any, Dict, Callable, Optional


class BaseCacheLLM(metaclass=ABCMeta):
"""Base LLM, When you have enhanced llm without using the original llm api,
you can use this class as a proxy to use the ability of the cache.
NOTE: Please make sure that the custom llm returns the same value as the original llm.
For example, if you use the openai proxy, you perform delay statistics before sending the openai request,
and then you package this part of the function, so you may have a separate package, which is different from openai.
If the api request parameters and return results you wrap are the same as the original ones,
then you can use this class to obtain cache-related capabilities.
Example:
.. code-block:: python
import time
import openai
from gptcache import Cache
from gptcache.adapter import openai as cache_openai
def proxy_openai_chat_complete(*args, **kwargs):
start_time = time.time()
res = openai.ChatCompletion.create(*args, **kwargs)
print("Consume Time Spent =", round((time.time() - start_time), 2))
return res
llm_cache = Cache()
cache_openai.ChatCompletion.llm = proxy_openai_chat_complete
cache_openai.ChatCompletion.cache_args = {"cache_obj": llm_cache}
cache_openai.ChatCompletion.create(
model="gpt-3.5-turbo",
messages=[
{
"role": "user",
"content": "What's GitHub?",
}
],
)
"""

llm: Optional[Callable] = None
"""
On a cache miss, if that variable is set, it will be called;
if not, it will call the original llm.
"""

cache_args: Dict[str, Any] = {}
"""
It can be used to set some cache-related public parameters.
If you don't want to set the same parameters every time when using cache, say cache_obj, you can use it.
"""

@classmethod
def fill_base_args(cls, **kwargs):
""" Fill the base args to the cache args
"""
for key, value in cls.cache_args.items():
if key not in kwargs:
kwargs[key] = value

return kwargs
12 changes: 6 additions & 6 deletions gptcache/adapter/diffusers.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,10 +3,10 @@

from gptcache.adapter.adapter import adapt
from gptcache.manager.scalar_data.base import Answer, DataType
from gptcache.utils.error import CacheError
from gptcache.utils import (
import_pillow, import_diffusers, import_huggingface
)
)
from gptcache.utils.error import CacheError

import_pillow()
import_huggingface()
Expand Down Expand Up @@ -41,15 +41,15 @@ class StableDiffusionPipeline(diffusers.StableDiffusionPipeline):
image = pipe(prompt=prompt).images[0]
"""

def llm_handler(self, *llm_args, **llm_kwargs):
def _llm_handler(self, *llm_args, **llm_kwargs):
try:
return super().__call__(*llm_args, **llm_kwargs)
except Exception as e:
raise CacheError("diffuser error") from e

def __call__(self, *args, **kwargs):
def cache_data_convert(cache_data):
return construct_resp_from_cache(cache_data)
return _construct_resp_from_cache(cache_data)

def update_cache_callback(llm_data, update_cache_func, *args, **kwargs): # pylint: disable=unused-argument
img = llm_data["images"][0]
Expand All @@ -60,11 +60,11 @@ def update_cache_callback(llm_data, update_cache_func, *args, **kwargs): # pyli
return llm_data

return adapt(
self.llm_handler, cache_data_convert, update_cache_callback, *args, **kwargs
self._llm_handler, cache_data_convert, update_cache_callback, *args, **kwargs
)


def construct_resp_from_cache(img_64):
def _construct_resp_from_cache(img_64):
im_bytes = base64.b64decode(img_64) # im_bytes is a binary image
im_file = BytesIO(im_bytes) # convert image to file-like object
img = Image.open(im_file)
Expand Down
10 changes: 5 additions & 5 deletions gptcache/adapter/dolly.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,8 @@
from typing import Any

from gptcache.adapter.adapter import adapt
from gptcache.utils import import_huggingface, import_torch
from gptcache.manager.scalar_data.base import Answer, DataType
from gptcache.utils import import_huggingface, import_torch

import_torch()
import_huggingface()
Expand Down Expand Up @@ -52,17 +52,17 @@ def from_model(cls, model: str, **kwargs):
def __call__(self, prompt: str, **kwargs):
return adapt(
self._dolly_pipeline,
cache_data_convert,
update_cache_callback,
_cache_data_convert,
_update_cache_callback,
inputs=prompt,
**kwargs
)


def cache_data_convert(cache_data):
def _cache_data_convert(cache_data):
return [{"generated_text": cache_data, "gptcache": True}]


def update_cache_callback(llm_data, update_cache_func, *args, **kwargs): # pylint: disable=unused-argument
def _update_cache_callback(llm_data, update_cache_func, *args, **kwargs): # pylint: disable=unused-argument
update_cache_func(Answer(llm_data[0]["generated_text"], DataType.STR))
return llm_data
20 changes: 10 additions & 10 deletions gptcache/adapter/langchain_models.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,8 +48,8 @@ def _call(self, prompt: str, stop: Optional[List[str]] = None, **kwargs) -> str:
session = self.session if "session" not in kwargs else kwargs.pop("session")
return adapt(
self.llm,
cache_data_convert,
update_cache_callback,
_cache_data_convert,
_update_cache_callback,
prompt=prompt,
stop=stop,
session=session,
Expand Down Expand Up @@ -93,8 +93,8 @@ def _generate(self, messages: Any, stop: Optional[List[str]] = None, **kwargs):
session = self.session if "session" not in kwargs else kwargs.pop("session")
return adapt(
self.chat._generate,
cache_msg_data_convert,
update_cache_msg_callback,
_cache_msg_data_convert,
_update_cache_msg_callback,
messages=messages,
stop=stop,
session=session,
Expand All @@ -105,8 +105,8 @@ async def _agenerate(self, messages: List[List[BaseMessage]], stop: Optional[Lis
session = self.session if "session" not in kwargs else kwargs.pop("session")
return adapt(
self.chat._agenerate,
cache_msg_data_convert,
update_cache_msg_callback,
_cache_msg_data_convert,
_update_cache_msg_callback,
messages=messages,
stop=stop,
session=session,
Expand All @@ -118,23 +118,23 @@ def __call__(self, messages: Any, stop: Optional[List[str]] = None, **kwargs):
return res.generations[0].message


def cache_data_convert(cache_data):
def _cache_data_convert(cache_data):
return cache_data


def update_cache_callback(llm_data, update_cache_func, *args, **kwargs): # pylint: disable=unused-argument
def _update_cache_callback(llm_data, update_cache_func, *args, **kwargs): # pylint: disable=unused-argument
update_cache_func(Answer(llm_data, DataType.STR))
return llm_data


def cache_msg_data_convert(cache_data):
def _cache_msg_data_convert(cache_data):
llm_res = ChatResult(generations=[ChatGeneration(text="",
generation_info=None,
message=AIMessage(content=cache_data, additional_kwargs={}))],
llm_output=None)
return llm_res


def update_cache_msg_callback(llm_data, update_cache_func, *args, **kwargs): # pylint: disable=unused-argument
def _update_cache_msg_callback(llm_data, update_cache_func, *args, **kwargs): # pylint: disable=unused-argument
update_cache_func(llm_data.generations[0].text)
return llm_data
32 changes: 15 additions & 17 deletions gptcache/adapter/llama_cpp.py
Original file line number Diff line number Diff line change
@@ -1,12 +1,10 @@
from typing import Iterator
import time
from typing import Iterator

from gptcache.adapter.adapter import adapt
from gptcache.manager.scalar_data.base import DataType, Answer

from gptcache.utils import import_llama_cpp_python


import_llama_cpp_python()

import llama_cpp # pylint: disable=wrong-import-position
Expand All @@ -21,16 +19,16 @@ class Llama(llama_cpp.Llama):
Example:
.. code-block:: python
onnx = Onnx()
m = manager_factory('sqlite,faiss,local', data_dir=root, vector_params={"dimension": onnx.dimension})
llm_cache = Cache()
llm_cache.init(
pre_embedding_func=get_prompt,
data_manager=m,
embedding_func=onnx.to_embeddings
)
llm = Llama('./models/7B/ggml-model.bin')
answer = llm(prompt=question, cache_obj=llm_cache)
onnx = Onnx()
m = manager_factory('sqlite,faiss,local', data_dir=root, vector_params={"dimension": onnx.dimension})
llm_cache = Cache()
llm_cache.init(
pre_embedding_func=get_prompt,
data_manager=m,
embedding_func=onnx.to_embeddings
)
llm = Llama('./models/7B/ggml-model.bin')
answer = llm(prompt=question, cache_obj=llm_cache)
"""
def __call__(
self,
Expand All @@ -54,8 +52,8 @@ def stream_answer(it):

def cache_data_convert(cache_data):
if kwargs.get("stream", False):
return construct_stream_resp_from_cache(cache_data)
return construct_resp_from_cache(cache_data)
return _construct_stream_resp_from_cache(cache_data)
return _construct_resp_from_cache(cache_data)

return adapt(
self.create_completion,
Expand All @@ -66,7 +64,7 @@ def cache_data_convert(cache_data):
)


def construct_resp_from_cache(return_message):
def _construct_resp_from_cache(return_message):
return {
"gptcache": True,
"choices": [
Expand All @@ -82,7 +80,7 @@ def construct_resp_from_cache(return_message):
}


def construct_stream_resp_from_cache(return_message):
def _construct_stream_resp_from_cache(return_message):
return [
{
"gptcache": True,
Expand Down
14 changes: 7 additions & 7 deletions gptcache/adapter/minigpt4.py
Original file line number Diff line number Diff line change
@@ -1,13 +1,8 @@
from gptcache.adapter.adapter import adapt
from gptcache.utils.error import CacheError
from gptcache.manager.scalar_data.base import DataType, Question, Answer

from argparse import Namespace

from minigpt4.common.config import Config
from minigpt4.common.registry import registry
from minigpt4.conversation.conversation import Chat, CONV_VISION

# pylint: disable=wildcard-import
# imports modules for registration
from minigpt4.datasets.builders import *
Expand All @@ -16,12 +11,17 @@
from minigpt4.runners import *
from minigpt4.tasks import *

from gptcache.adapter.adapter import adapt
from gptcache.manager.scalar_data.base import DataType, Question, Answer
from gptcache.utils.error import CacheError


class MiniGPT4: # pragma: no cover
"""MiniGPT4 Wrapper
Example:
.. code-block:: python
from gptcache import cache
from gptcache.processor.pre import get_image_question
from gptcache.adapter.minigpt4 import MiniGPT4
Expand Down Expand Up @@ -53,7 +53,7 @@ def from_pretrained(cls, cfg_path, gpu_id=0, options=None, return_hit=False):
chat = Chat(model, vis_processor, device="cuda:{}".format(args.gpu_id))
return cls(chat, return_hit)

def llm_handler(self, image, question):
def _llm_handler(self, image, question):
chat_state = CONV_VISION.copy()
img_list = []
try:
Expand Down Expand Up @@ -86,5 +86,5 @@ def update_cache_callback(llm_data, update_cache_func, *args, **kwargs): # pyli
return llm_data

return adapt(
self.llm_handler, cache_data_convert, update_cache_callback, image=image, question=question, cache_context=cache_context, *args, **kwargs
self._llm_handler, cache_data_convert, update_cache_callback, image=image, question=question, cache_context=cache_context, *args, **kwargs
)
Loading

0 comments on commit 873fca7

Please sign in to comment.