Skip to content

Commit

Permalink
Support caching of async completion and cache completion (#513)
Browse files Browse the repository at this point in the history
* Use the old version for the chromadb (#492)

Signed-off-by: SimFG <bang.fu@zilliz.com>
Signed-off-by: Reuben Thomas-Davis <reubenestd@gmail.com>

* added support for weaviate vector databse (#493)

* added support for weaviate vector databse

Signed-off-by: pranaychandekar <pranayc6@gmail.com>

* added support for in local db for weaviate vector store

Signed-off-by: pranaychandekar <pranayc6@gmail.com>

* added unit test case for weaviate vector store

Signed-off-by: pranaychandekar <pranayc6@gmail.com>

* resolved unit test case error for weaviate vector store

Signed-off-by: pranaychandekar <pranayc6@gmail.com>

* increased code coverage
resolved pylint issues

pylint: disabled C0413

Signed-off-by: pranaychandekar <pranayc6@gmail.com>

---------

Signed-off-by: pranaychandekar <pranayc6@gmail.com>
Signed-off-by: Reuben Thomas-Davis <reubenestd@gmail.com>

* Update the version to `0.1.37` (#494)

Signed-off-by: SimFG <bang.fu@zilliz.com>
Signed-off-by: Reuben Thomas-Davis <reubenestd@gmail.com>

* ✨ support caching of async completion and cache completion

Signed-off-by: Reuben Thomas-Davis <reubenestd@gmail.com>

* ✨ add streaming support for chatcompletion

Signed-off-by: Reuben Thomas-Davis <reubenestd@gmail.com>

* ✅ improve test coverage and formatting

Signed-off-by: Reuben Thomas-Davis <reubenestd@gmail.com>

* ✨ support caching of async completion and cache completion

Signed-off-by: Reuben Thomas-Davis <reubenestd@gmail.com>

* ✨ add streaming support for chatcompletion

Signed-off-by: Reuben Thomas-Davis <reubenestd@gmail.com>

* ✅ improve test coverage and formatting

Signed-off-by: Reuben Thomas-Davis <reubenestd@gmail.com>

* correct merge duplication

Signed-off-by: Reuben Thomas-Davis <reubenestd@gmail.com>

* correct update cache callback

Signed-off-by: Reuben Thomas-Davis <reubenestd@gmail.com>

* add additional tests for improved coverage

Signed-off-by: Reuben Thomas-Davis <reubenestd@gmail.com>

* remove redundant param in docstring

Signed-off-by: Reuben Thomas-Davis <reubenestd@gmail.com>

---------

Signed-off-by: SimFG <bang.fu@zilliz.com>
Signed-off-by: Reuben Thomas-Davis <reubenestd@gmail.com>
Signed-off-by: pranaychandekar <pranayc6@gmail.com>
Co-authored-by: SimFG <bang.fu@zilliz.com>
Co-authored-by: Pranay Chandekar <pranayc6@gmail.com>
  • Loading branch information
3 people authored Sep 5, 2023
1 parent bca8de9 commit d19bf74
Show file tree
Hide file tree
Showing 3 changed files with 403 additions and 34 deletions.
1 change: 0 additions & 1 deletion gptcache/adapter/adapter.py
Original file line number Diff line number Diff line change
Expand Up @@ -513,7 +513,6 @@ def update_cache_func(handled_llm_data, question=None):
== 0
):
chat_cache.flush()

llm_data = update_cache_callback(
llm_data, update_cache_func, *args, **kwargs
)
Expand Down
109 changes: 97 additions & 12 deletions gptcache/adapter/openai.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,21 +3,21 @@
import os
import time
from io import BytesIO
from typing import Iterator, Any, List
from typing import Any, AsyncGenerator, Iterator, List

from gptcache import cache
from gptcache.adapter.adapter import adapt
from gptcache.adapter.adapter import aadapt, adapt
from gptcache.adapter.base import BaseCacheLLM
from gptcache.manager.scalar_data.base import Answer, DataType
from gptcache.utils import import_openai, import_pillow
from gptcache.utils.error import wrap_error
from gptcache.utils.response import (
get_stream_message_from_openai_answer,
get_message_from_openai_answer,
get_text_from_openai_answer,
get_audio_text_from_openai_answer,
get_image_from_openai_b64,
get_image_from_openai_url,
get_audio_text_from_openai_answer,
get_message_from_openai_answer,
get_stream_message_from_openai_answer,
get_text_from_openai_answer,
)
from gptcache.utils.token import token_counter

Expand Down Expand Up @@ -56,15 +56,40 @@ class ChatCompletion(openai.ChatCompletion, BaseCacheLLM):
@classmethod
def _llm_handler(cls, *llm_args, **llm_kwargs):
try:
return super().create(*llm_args, **llm_kwargs) if cls.llm is None else cls.llm(*llm_args, **llm_kwargs)
return (
super().create(*llm_args, **llm_kwargs)
if cls.llm is None
else cls.llm(*llm_args, **llm_kwargs)
)
except openai.OpenAIError as e:
raise wrap_error(e) from e

@classmethod
async def _allm_handler(cls, *llm_args, **llm_kwargs):
try:
return (
(await super().acreate(*llm_args, **llm_kwargs))
if cls.llm is None
else await cls.llm(*llm_args, **llm_kwargs)
)
except openai.OpenAIError as e:
raise wrap_error(e) from e

@staticmethod
def _update_cache_callback(
llm_data, update_cache_func, *args, **kwargs
): # pylint: disable=unused-argument
if not isinstance(llm_data, Iterator):
if isinstance(llm_data, AsyncGenerator):

async def hook_openai_data(it):
total_answer = ""
async for item in it:
total_answer += get_stream_message_from_openai_answer(item)
yield item
update_cache_func(Answer(total_answer, DataType.STR))

return hook_openai_data(llm_data)
elif not isinstance(llm_data, Iterator):
update_cache_func(
Answer(get_message_from_openai_answer(llm_data), DataType.STR)
)
Expand Down Expand Up @@ -92,8 +117,6 @@ def cache_data_convert(cache_data):
saved_token = [input_token, output_token]
else:
saved_token = [0, 0]
if kwargs.get("stream", False):
return _construct_stream_resp_from_cache(cache_data, saved_token)
return _construct_resp_from_cache(cache_data, saved_token)

kwargs = cls.fill_base_args(**kwargs)
Expand All @@ -105,6 +128,38 @@ def cache_data_convert(cache_data):
**kwargs,
)

@classmethod
async def acreate(cls, *args, **kwargs):
chat_cache = kwargs.get("cache_obj", cache)
enable_token_counter = chat_cache.config.enable_token_counter

def cache_data_convert(cache_data):
if enable_token_counter:
input_token = _num_tokens_from_messages(kwargs.get("messages"))
output_token = token_counter(cache_data)
saved_token = [input_token, output_token]
else:
saved_token = [0, 0]
if kwargs.get("stream", False):
return async_iter(
_construct_stream_resp_from_cache(cache_data, saved_token)
)
return _construct_resp_from_cache(cache_data, saved_token)

kwargs = cls.fill_base_args(**kwargs)
return await aadapt(
cls._allm_handler,
cache_data_convert,
cls._update_cache_callback,
*args,
**kwargs,
)


async def async_iter(input_list):
for item in input_list:
yield item


class Completion(openai.Completion, BaseCacheLLM):
"""Openai Completion Wrapper
Expand All @@ -128,7 +183,22 @@ class Completion(openai.Completion, BaseCacheLLM):
@classmethod
def _llm_handler(cls, *llm_args, **llm_kwargs):
try:
return super().create(*llm_args, **llm_kwargs) if not cls.llm else cls.llm(*llm_args, **llm_kwargs)
return (
super().create(*llm_args, **llm_kwargs)
if not cls.llm
else cls.llm(*llm_args, **llm_kwargs)
)
except openai.OpenAIError as e:
raise wrap_error(e) from e

@classmethod
async def _allm_handler(cls, *llm_args, **llm_kwargs):
try:
return (
(await super().acreate(*llm_args, **llm_kwargs))
if cls.llm is None
else await cls.llm(*llm_args, **llm_kwargs)
)
except openai.OpenAIError as e:
raise wrap_error(e) from e

Expand All @@ -154,6 +224,17 @@ def create(cls, *args, **kwargs):
**kwargs,
)

@classmethod
async def acreate(cls, *args, **kwargs):
kwargs = cls.fill_base_args(**kwargs)
return await aadapt(
cls._allm_handler,
cls._cache_data_convert,
cls._update_cache_callback,
*args,
**kwargs,
)


class Audio(openai.Audio):
"""Openai Audio Wrapper
Expand Down Expand Up @@ -319,7 +400,11 @@ class Moderation(openai.Moderation, BaseCacheLLM):
@classmethod
def _llm_handler(cls, *llm_args, **llm_kwargs):
try:
return super().create(*llm_args, **llm_kwargs) if not cls.llm else cls.llm(*llm_args, **llm_kwargs)
return (
super().create(*llm_args, **llm_kwargs)
if not cls.llm
else cls.llm(*llm_args, **llm_kwargs)
)
except openai.OpenAIError as e:
raise wrap_error(e) from e

Expand Down
Loading

0 comments on commit d19bf74

Please sign in to comment.