Skip to content

Commit

Permalink
Support cache chain
Browse files Browse the repository at this point in the history
Signed-off-by: SimFG <bang.fu@zilliz.com>
  • Loading branch information
SimFG committed Mar 29, 2023
1 parent 6554e39 commit c5f6ff9
Show file tree
Hide file tree
Showing 5 changed files with 36 additions and 7 deletions.
1 change: 1 addition & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -130,6 +130,7 @@ dmypy.json

.idea
**/data_map.txt
**/data_map**.txt
**/faiss.index
**/sqlite.db
**/example.py
Expand Down
15 changes: 12 additions & 3 deletions example/map/map_manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,24 +2,33 @@

from gpt_cache.cache.factory import get_data_manager
from gpt_cache.view import openai
from gpt_cache.core import cache
from gpt_cache.core import cache, Cache


def run():
dirname, _ = os.path.split(os.path.abspath(__file__))
bak_cache = Cache()
bak_cache.init(data_manager=get_data_manager("map",
data_path=dirname + "/data_map_bak.txt",
max_size=10))
cache.init(data_manager=get_data_manager("map",
data_path=dirname + "/data_map.txt",
max_size=10))
max_size=10),
next_cache=bak_cache)
mock_messages = [
{"role": "system", "content": "You are a helpful assistant."},
{"role": "user", "content": "foo5"}
{"role": "user", "content": "foo15"}
]

# you should CLOSE it if you SECONDLY run it
for i in range(10):
question = f"foo{i}"
answer = f"receiver the foo {i}"
cache.data_manager.save(question, answer, cache.embedding_func(question))
for i in range(10, 20):
question = f"foo{i}"
answer = f"receiver the foo {i}"
bak_cache.data_manager.save(question, answer, bak_cache.embedding_func(question))

answer = openai.ChatCompletion.create(
model="gpt-3.5-turbo",
Expand Down
5 changes: 4 additions & 1 deletion gpt_cache/cache/data_manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -52,7 +52,10 @@ def get_scalar_data(self, vector_data, **kwargs):
return vector_data

def search(self, embedding_data, **kwargs):
return [self.data[embedding_data]]
try:
return [self.data[embedding_data]]
except KeyError:
return []

def close(self):
try:
Expand Down
10 changes: 9 additions & 1 deletion gpt_cache/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -77,6 +77,7 @@ def __init__(self):
self.similarity_positive = True
self.config = None
self.report = Report()
self.next_cache = None

def init(self,
cache_enable_func=cache_all,
Expand All @@ -88,6 +89,7 @@ def init(self,
similarity_threshold=0.5,
similarity_positive=True,
config=Config(),
next_cache=None,
**kwargs
):
self.cache_enable_func = cache_enable_func
Expand All @@ -100,6 +102,12 @@ def init(self,
self.similarity_positive = similarity_positive
self.data_manager.init(**kwargs)
self.config = config
self.next_cache = next_cache

def close(self):
self.data_manager.close()
if self.next_cache:
self.next_cache.close()

@staticmethod
def set_openai_key():
Expand All @@ -112,6 +120,6 @@ def set_openai_key():
@atexit.register
def cache_close():
try:
cache.data_manager.close()
cache.close()
except Exception as e:
print(e)
12 changes: 10 additions & 2 deletions gpt_cache/view/openai.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,8 @@ def create(cls, *args, **kwargs):
func_name="search",
report_func=chat_cache.report.search,
)(embedding_data, extra_param=context.get('search', None))
if cache_data_list is None:
cache_data_list = []
cache_answers = []
for cache_data in cache_data_list:
cache_question, cache_answer = chat_cache.data_manager.get_scalar_data(
Expand All @@ -43,9 +45,15 @@ def create(cls, *args, **kwargs):
chat_cache.report.hint_cache()
return construct_resp_from_cache(return_message)

# TODO cache poster -> can chain
# TODO support stream data
openai_data = openai.ChatCompletion.create(*args, **kwargs)
next_cache = chat_cache.next_cache
if next_cache:
print("next_cache")
kwargs["cache_obj"] = next_cache
openai_data = ChatCompletion.create(*args, **kwargs)
else:
openai_data = openai.ChatCompletion.create(*args, **kwargs)

if cache_enable:
chat_cache.data_manager.save(pre_embedding_data,
get_message_from_openai_answer(openai_data),
Expand Down

0 comments on commit c5f6ff9

Please sign in to comment.