Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Support for weaviate vector store #474

Closed
Prev Previous commit
Next Next commit
added support for weaviate vector-store
Signed-off-by: sunilkumardash9 <sunilkumardash9@gmail.com>
  • Loading branch information
sunilkumardash9 committed Jul 7, 2023
commit 90d49916a2e778c713828bb717443597f4160ae8
26 changes: 26 additions & 0 deletions gptcache/manager/vector_data/manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -257,6 +257,32 @@ def get(name, **kwargs):
flush_interval_sec=flush_interval_sec,
index_params=index_params,
)
elif name == "weaviate":
from .. vector_data.weaviate import Weaviate
url = kwargs.get("url", None)
auth_client_secret = kwargs.get('auth_client_secret', None),
timeout_config = kwargs.get("timeout_config", (10, 60))
proxies = kwargs.get("proxies", None)
trust_env = kwargs.get("trust_env", False)
additional_headers = kwargs.get("additional_headers", None)
startup_period = kwargs.get("startup_period", 5)
embedded_options = kwargs.get("embedded_options", None)
additional_config = kwargs.get("additional_config", None)
class_name = kwargs.get("class_name", "Gptcache")
top_k = kwargs.get("top_k", 1)
vector_base = Weaviate(
url= url,
auth_client_secret = auth_client_secret,
timeout_config = timeout_config,
proxies = proxies,
trust_env = trust_env,
additional_headers = additional_headers,
startup_period = startup_period,
embedded_options = embedded_options,
additional_config = additional_config,
class_name = class_name,
top_k = top_k,
)
else:
raise NotFoundError("vector store", name)
return vector_base
126 changes: 126 additions & 0 deletions gptcache/manager/vector_data/weaviate.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,126 @@
from typing import List

import numpy as np

from gptcache.manager.vector_data.base import VectorBase, VectorData
from ... utils import import_weaviate
from gptcache.utils.log import gptcache_log

from weaviate import Client, EmbeddedOptions, Config

import_weaviate()

class Weaviate(VectorBase):
"""Weaviate Vector store"""
def __init__(self,
url: str | None = None,
auth_client_secret: None = None,
timeout_config = (10, 60),
proxies: dict | str | None = None,
trust_env: bool = False,
additional_headers: dict | None = None,
startup_period: int | None = 5,
embedded_options: None = None,
additional_config: None = None,
top_k: int = 1,
distance: str = "cosine",
collection_name: str = "Gptcache",
):
self.class_name = collection_name
self.top_k = top_k
self.distance = distance
if embedded_options:
self.client = Client(embedded_options = EmbeddedOptions(),
startup_period = startup_period,
timeout_config = timeout_config,
additional_config=additional_config)
else:
self.client = Client(url,
auth_client_secret,
timeout_config,
proxies,
trust_env,
additional_headers,
startup_period,
embedded_options,
additional_config,
)

def _create_collection(self, class_name: str):
if not class_name:
class_name = self.class_name
if self.client.schema.exists(class_name):
gptcache_log.info(
"The %s already exists, and it will be used directly", class_name
)
else:
gptcache_class_schema = {
"class": class_name,
"description": "caching LLM responses",
"properties": [
{
"name": "id_",
"dataType": ["int"],
}
],
'vectorIndexConfig':
{
"distance": self.distance
}
}
self.client.schema.create_class(gptcache_class_schema)

def mul_add(self, datas: List[VectorData]):
with self.client.batch(
batch_size=len(datas)
) as batch:
# Batch import
for data in datas:
properties = {
"id_": data.id,
}
self.client.batch.add_data_object(
properties,
self.class_name,
vector = data.data.tolist()
)

def search(self, data: np.ndarray, top_k: int = -1):
if not self.client.schema.exists(self.class_name):
self._create_collection(self.class_name)
if top_k==-1:
top_k = self.top_k
result = self.client.query.get(class_name = self.class_name, properties = ['id_']).\
with_near_vector(content={"vector": data.tolist()}).\
with_additional(['distance']).\
with_limit(top_k).do()
return list(map(lambda x: (x['_additional']['distance'], x['id_']), result['data']['Get'][self.class_name]))

def get_uuids(self, ids: List[str]):
uuid_list = []
for id_ in ids:
res = self.client.query.get(class_name=self.class_name, properties=['id_']).\
with_where({"path": ["id_"], "operator":"Equal", "valueNumber":id_}).\
with_additional(["id"]).do()
uuid_list.append(res['data']['Get'][self.class_name][0]['_additional']['id'])
return uuid_list

def delete(self, ids: List[str]):
uuids = self.get_uuids(ids)
for uuid_ in uuids:
self.client.data_object.delete(class_name='example', uuid=uuid_)

def rebuild(self, ids=None) :
return

def flush(self):
return True

def close(self):
pass






5 changes: 5 additions & 0 deletions gptcache/utils/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,6 +41,7 @@
"import_fastapi",
"import_redis",
"import_qdrant",
"import_weaviate"
]

import importlib.util
Expand Down Expand Up @@ -257,3 +258,7 @@ def import_fastapi():

def import_redis():
_check_library("redis")


def import_weaviate():
_check_library("weaviate-client")
30 changes: 30 additions & 0 deletions tests/unit_tests/manager/test_weaviate.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,30 @@
import unittest

import numpy as np

from gptcache.manager.vector_data import VectorBase
from gptcache.manager.vector_data.base import VectorData


class TestUSearchDB(unittest.TestCase):
def test_normal(self):
size = 1000
dim = 512
top_k = 10
weaviate = VectorBase(
"weaviate",
top_k = top_k
)
data = np.random.randn(size, dim).astype(np.float32)
weaviate.mul_add([VectorData(id=i, data=v) for v, i in zip(data, range(size))])
search_result = weaviate.search(data[0], top_k)
self.assertEqual(len(search_result), top_k)
weaviate.mul_add([VectorData(id=size, data=data[0])])
ret = weaviate.search(data[0])
self.assertIn(ret[0][1], [0, size])
self.assertIn(ret[1][1], [0, size])
weaviate.delete([0, 1, 2, 3, 4, 5, size])
ret = weaviate.search(data[0])
self.assertNotIn(ret[0][1], [0, size])
weaviate.rebuild()
weaviate.close()