Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Support for weaviate vector store #474

Closed
Prev Previous commit
Next Next commit
Merge branch 'dev' into support-for-Weaviate
  • Loading branch information
sunilkumardash9 authored Jul 18, 2023
commit 70d0fd068fcd904f35ad7cb56f5c026151829640
10 changes: 10 additions & 0 deletions docs/release_note.md
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,16 @@ To read the following content, you need to understand the basic use of GPTCache,
- [Readme doc](https://github.com/zilliztech/GPTCache)
- [Usage doc](https://github.com/zilliztech/GPTCache/blob/main/docs/usage.md)

## v0.1.36 (2023.7.14)

1. Fix the connection error of the remote redis cache store
2. Add the openai proxy for the chat complete api

## v0.1.35 (2023.7.7)

1. Support the redis as the cache store, usage example: [redis+onnx](https://github.com/zilliztech/GPTCache/blob/main/tests/integration_tests/test_redis_onnx.py)
2. Add report table for easy analysis of cache data

## v0.1.34 (2023.6.30)

1. Add support for Qdrant Vector Store
Expand Down
2 changes: 1 addition & 1 deletion gptcache/__init__.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
"""gptcache version"""
__version__ = "0.1.34"
__version__ = "0.1.36"

from gptcache.config import Config
from gptcache.core import Cache
Expand Down
9 changes: 7 additions & 2 deletions gptcache/adapter/adapter.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@ def adapt(llm_handler, cache_data_convert, update_cache_callback, *args, **kwarg
:param kwargs: llm kwargs
:return: llm result
"""
start_time = time.time()
search_only_flag = kwargs.pop("search_only", False)
user_temperature = "temperature" in kwargs
user_top_k = "top_k" in kwargs
Expand Down Expand Up @@ -170,7 +171,9 @@ def adapt(llm_handler, cache_data_convert, update_cache_callback, *args, **kwarg
rank,
)
if rank_threshold <= rank:
cache_answers.append((float(rank), cache_data.answers[0].answer, search_data))
cache_answers.append(
(float(rank), cache_data.answers[0].answer, search_data, cache_data)
)
chat_cache.data_manager.hit_cache_callback(search_data)
cache_answers = sorted(cache_answers, key=lambda x: x[0], reverse=True)
answers_dict = dict((d[1], d) for d in cache_answers)
Expand Down Expand Up @@ -424,7 +427,9 @@ async def aadapt(
rank,
)
if rank_threshold <= rank:
cache_answers.append((float(rank), cache_data.answers[0].answer, search_data))
cache_answers.append(
(float(rank), cache_data.answers[0].answer, search_data, cache_data)
)
chat_cache.data_manager.hit_cache_callback(search_data)
cache_answers = sorted(cache_answers, key=lambda x: x[0], reverse=True)
answers_dict = dict((d[1], d) for d in cache_answers)
Expand Down
9 changes: 9 additions & 0 deletions gptcache/manager/scalar_data/manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -93,6 +93,15 @@ def get(name, **kwargs):
username=kwargs.get("username"),
password=kwargs.get("password")
)
elif name == "redis":
from gptcache.manager.scalar_data.redis_storage import RedisCacheStorage

return RedisCacheStorage(
host=kwargs.pop("redis_host", "localhost"),
port=kwargs.pop("redis_port", 6379),
global_key_prefix=kwargs.pop("global_key_prefix", TABLE_NAME),
**kwargs
)
else:
raise NotFoundError("cache store", name)
return cache_base
175 changes: 117 additions & 58 deletions gptcache/manager/scalar_data/mongo.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,27 +3,29 @@

import numpy as np

from gptcache.manager.scalar_data.base import CacheStorage, CacheData, Question, QuestionDep
from gptcache.manager.scalar_data.base import (
CacheStorage,
CacheData,
Question,
QuestionDep,
)
from gptcache.utils import import_mongodb

import_mongodb()

from mongoengine import Document # pylint: disable=wrong-import-position
from mongoengine import fields # pylint: disable=wrong-import-position
import mongoengine as me # pylint: disable=wrong-import-position
# pylint: disable=C0413
from mongoengine import Document
from mongoengine import fields
import mongoengine as me


def get_models():
class Questions(Document):
"""
questions collection
"""
meta = {
"collection": "questions",
"indexes": [
"deleted"
]
}

meta = {"collection": "questions", "indexes": ["deleted"]}
_id = fields.SequenceField()
question = fields.StringField()
create_on = fields.DateTimeField(default=datetime.now())
Expand All @@ -39,13 +41,9 @@ class Answers(Document):
"""
answer collection
"""

_id = fields.SequenceField()
meta = {
"collection": "answers",
"indexes": [
"question_id"
]
}
meta = {"collection": "answers", "indexes": ["question_id"]}
answer = fields.StringField()
answer_type = fields.IntField()
question_id = fields.IntField()
Expand All @@ -58,13 +56,8 @@ class Sessions(Document):
"""
session collection
"""
meta = {
"collection": "sessions",
"indexes": [
"question_id"
]

}
meta = {"collection": "sessions", "indexes": ["question_id"]}
_id = fields.SequenceField()
session_id = fields.StringField()
session_question = fields.StringField()
Expand All @@ -78,12 +71,8 @@ class QuestionDeps(Document):
"""
Question Dep collection
"""
meta = {
"collection": "question_deps",
"indexes": [
"question_id"
]
}

meta = {"collection": "question_deps", "indexes": ["question_id"]}
_id = fields.SequenceField()
question_id = fields.IntField()
dep_name = fields.StringField()
Expand All @@ -94,7 +83,30 @@ class QuestionDeps(Document):
def oid(self):
return self._id

return Questions, Answers, QuestionDeps, Sessions
class Report(Document):
"""
Report
"""

meta = {
"collection": "report",
"indexes": ["cache_question_id", "similarity", "cache_delta_time"],
}
_id = fields.SequenceField()
user_question = fields.StringField()
cache_question_id = fields.IntField()
cache_question = fields.StringField()
cache_answer = fields.StringField()
similarity = fields.FloatField()
cache_delta_time = fields.FloatField()
cache_time = fields.DateTimeField(default=datetime.now())
extra = fields.StringField()

@property
def oid(self):
return self._id

return Questions, Answers, QuestionDeps, Sessions, Report


class MongoStorage(CacheStorage):
Expand All @@ -108,38 +120,66 @@ class MongoStorage(CacheStorage):

:param host: mongodb host, default value 'localhost'
:type host: str

:param port: mongodb port, default value 27017
:type host: int

:param dbname: Mongo database name, default value 'gptcache'
:param dbname: database name, default value 'gptcache'
:type host: str

:param : Mongo database name, default value 'gptcache'
:type host: str

:param username: username for authentication, default value None
:type host: str

:param password: password for authentication, default value None
:type host: str

Example:
.. code-block:: python

from gptcache.manager import CacheBase, manager_factory

cache_store = CacheBase('mongo',
mongo_host="localhost",
mongo_port=27017,
dbname="gptcache",
username=None,
password=None,
)
# or
data_manager = manager_factory("mongo,faiss", data_dir="./workspace",
scalar_params={
"mongo_host": "localhost",
"mongo_port": 27017,
"dbname"="gptcache",
"username"="",
"password"="",
},
vector_params={"dimension": 128},
)
"""

def __init__(
self,
host: str = "localhost",
port: int = 27017,
dbname: str = "gptcache",
username: str = None,
password: str = None,
**kwargs):
self.con = me.connect(host=host,
port=port,
db=dbname,
username=username,
password=password,
**kwargs)
self._ques, self._answer, self._ques_dep, self._session = get_models()
self,
host: str = "localhost",
port: int = 27017,
dbname: str = "gptcache",
username: str = None,
password: str = None,
**kwargs
):
self.con = me.connect(
host=host,
port=port,
db=dbname,
username=username,
password=password,
**kwargs
)
(
self._ques,
self._answer,
self._ques_dep,
self._session,
self._report,
) = get_models()

def create(self):
pass
Expand All @@ -151,7 +191,7 @@ def _insert(self, data: CacheData):
else data.question.content,
embedding_data=data.embedding_data.tobytes()
if data.embedding_data is not None
else None
else None,
)
ques_data.save()
if isinstance(data.question, Question) and data.question.deps is not None:
Expand Down Expand Up @@ -209,8 +249,7 @@ def get_data_by_id(self, key) -> Optional[CacheData]:

res_ans = [(item.answer, item.answer_type) for item in answers]
res_deps = [
QuestionDep(item.dep_name, item.dep_data, item.dep_type)
for item in deps
QuestionDep(item.dep_name, item.dep_data, item.dep_type) for item in deps
]
return CacheData(
question=qs.question if not deps else Question(qs.question, res_deps),
Expand Down Expand Up @@ -243,17 +282,18 @@ def count(self, state: int = 0, is_all: bool = False):
return self._ques.objects(deleted=state).count()

def add_session(self, question_id, session_id, session_question):
self._session(question_id=question_id,
session_id=session_id,
session_question=session_question
).save()
self._session(
question_id=question_id,
session_id=session_id,
session_question=session_question,
).save()

def list_sessions(self, session_id=None, key=None):
query = {}
if session_id:
query["session_id"] = session_id
if key:
query["_id"] = key
query["question_id"] = key

return self._session.objects(__raw__=query)

Expand All @@ -263,6 +303,25 @@ def delete_session(self, keys):
def count_answers(self):
return self._answer.objects.count()

def report_cache(
self,
user_question,
cache_question,
cache_question_id,
cache_answer,
similarity_value,
cache_delta_time,
):
report_data = self._report(
user_question=user_question,
cache_question=cache_question,
cache_question_id=cache_question_id,
cache_answer=cache_answer,
similarity=similarity_value,
cache_delta_time=cache_delta_time,
)
report_data.save()

def close(self):
me.disconnect()
self.con.close()
self.con.close()
8 changes: 7 additions & 1 deletion gptcache/utils/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -258,7 +258,13 @@ def import_fastapi():

def import_redis():
_check_library("redis")
_check_library("redis_om")


def import_weaviate():
_check_library("weaviate-client")
_check_library("weaviate-client")


def import_starlette():
_check_library("starlette")

5 changes: 4 additions & 1 deletion tests/unit_tests/manager/test_mongo.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,13 @@
import time

import numpy as np
from mongoengine import connect, disconnect

from gptcache.manager.scalar_data.base import CacheData, Question
from gptcache.manager.scalar_data.mongo import MongoStorage
from gptcache.utils import import_mongodb

import_mongodb()
from mongoengine import connect, disconnect


def test_mongo():
Expand Down
You are viewing a condensed version of this merge commit. You can view the full changes here.