Skip to content

Commit

Permalink
Merge pull request TransformerOptimus#1061 from TransformerOptimus/we…
Browse files Browse the repository at this point in the history
…aviate-backend-main

Weaviate Backend
  • Loading branch information
Tarraann authored Aug 16, 2023
2 parents fa17e09 + 5c980c9 commit 6dbd7da
Show file tree
Hide file tree
Showing 13 changed files with 245 additions and 156 deletions.
11 changes: 9 additions & 2 deletions superagi/controllers/vector_db_indices.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@ def get_marketplace_valid_indices(knowledge_name: str, organisation = Depends(ge
knowledge_with_config = KnowledgeConfigs.fetch_knowledge_config_details_marketplace(knowledge['id'])
pinecone = []
qdrant = []
weaviate = []
for vector_db in vector_dbs:
indices = VectordbIndices.get_vector_indices_from_vectordb(db.session, vector_db.id)
for index in indices:
Expand All @@ -26,13 +27,17 @@ def get_marketplace_valid_indices(knowledge_name: str, organisation = Depends(ge
pinecone.append(data)
if vector_db.db_type == "Qdrant":
qdrant.append(data)
return {"pinecone": pinecone, "qdrant": qdrant}
if vector_db.db_type == "Weaviate":
data["is_valid_dimension"] = True
weaviate.append(data)
return {"pinecone": pinecone, "qdrant": qdrant, "weaviate": weaviate}

@router.get("/user/valid_indices")
def get_user_valid_indices(organisation = Depends(get_user_organisation)):
vector_dbs = Vectordbs.get_vector_db_from_organisation(db.session, organisation)
pinecone = []
qdrant = []
weaviate = []
for vector_db in vector_dbs:
indices = VectordbIndices.get_vector_indices_from_vectordb(db.session, vector_db.id)
for index in indices:
Expand All @@ -42,4 +47,6 @@ def get_user_valid_indices(organisation = Depends(get_user_organisation)):
pinecone.append(data)
if vector_db.db_type == "Qdrant":
qdrant.append(data)
return {"pinecone": pinecone, "qdrant": qdrant}
if vector_db.db_type == "Weaviate":
weaviate.append(data)
return {"pinecone": pinecone, "qdrant": qdrant, "weaviate": weaviate}
27 changes: 24 additions & 3 deletions superagi/controllers/vector_dbs.py
Original file line number Diff line number Diff line change
Expand Up @@ -77,7 +77,7 @@ def connect_pinecone_vector_db(data: dict, organisation = Depends(get_user_organ
pinecone_db = Vectordbs.add_vector_db(db.session, data["name"], "Pinecone", organisation)
VectordbConfigs.add_vector_db_config(db.session, pinecone_db.id, db_creds)
for collection in data["collections"]:
VectordbIndices.add_vector_index(db.session, collection, pinecone_db.id, db_connect_for_index["dimensions"], index_state)
VectordbIndices.add_vector_index(db.session, collection, pinecone_db.id, index_state, db_connect_for_index["dimensions"])
return {"id": pinecone_db.id, "name": pinecone_db.name}

@router.post("/connect/qdrant")
Expand All @@ -97,10 +97,30 @@ def connect_qdrant_vector_db(data: dict, organisation = Depends(get_user_organis
qdrant_db = Vectordbs.add_vector_db(db.session, data["name"], "Qdrant", organisation)
VectordbConfigs.add_vector_db_config(db.session, qdrant_db.id, db_creds)
for collection in data["collections"]:
VectordbIndices.add_vector_index(db.session, collection, qdrant_db.id, db_connect_for_index["dimensions"], index_state)
VectordbIndices.add_vector_index(db.session, collection, qdrant_db.id, index_state, db_connect_for_index["dimensions"])

return {"id": qdrant_db.id, "name": qdrant_db.name}

@router.post("/connect/weaviate")
def connect_weaviate_vector_db(data: dict, organisation = Depends(get_user_organisation)):
db_creds = {
"api_key": data["api_key"],
"url": data["url"]
}
for collection in data["collections"]:
try:
vector_db_storage = VectorFactory.build_vector_storage("weaviate", collection, **db_creds)
db_connect_for_index = vector_db_storage.get_index_stats()
index_state = "Custom" if db_connect_for_index["vector_count"] > 0 else "None"
except:
raise HTTPException(status_code=400, detail="Unable to connect Weaviate")
weaviate_db = Vectordbs.add_vector_db(db.session, data["name"], "Weaviate", organisation)
VectordbConfigs.add_vector_db_config(db.session, weaviate_db.id, db_creds)
for collection in data["collections"]:
VectordbIndices.add_vector_index(db.session, collection, weaviate_db.id, index_state)

return {"id": weaviate_db.id, "name": weaviate_db.name}

@router.put("/update/vector_db/{vector_db_id}")
def update_vector_db(new_indices: list, vector_db_id: int):
vector_db = Vectordbs.get_vector_db_from_id(db.session, vector_db_id)
Expand All @@ -119,9 +139,10 @@ def update_vector_db(new_indices: list, vector_db_id: int):
vector_db_storage = VectorFactory.build_vector_storage(vector_db.db_type, index, **db_creds)
vector_db_index_stats = vector_db_storage.get_index_stats()
index_state = "Custom" if vector_db_index_stats["vector_count"] > 0 else "None"
dimensions = vector_db_index_stats["dimensions"] if 'dimensions' in vector_db_index_stats else None
except:
raise HTTPException(status_code=400, detail="Unable to update vector db")
VectordbIndices.add_vector_index(db.session, index, vector_db_id, vector_db_index_stats["dimensions"], index_state)
VectordbIndices.add_vector_index(db.session, index, vector_db_id, index_state, dimensions)



Expand Down
2 changes: 1 addition & 1 deletion superagi/models/vector_db_indices.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,7 +47,7 @@ def delete_vector_db_index(cls, session, vector_index_id):
session.commit()

@classmethod
def add_vector_index(cls, session, index_name, vector_db_id, dimensions, state):
def add_vector_index(cls, session, index_name, vector_db_id, state, dimensions = None): #will be none only in the case of weaviate
vector_index = VectordbIndices(name=index_name, vector_db_id=vector_db_id, dimensions=dimensions, state=state)
session.add(vector_index)
session.commit()
Expand Down
2 changes: 1 addition & 1 deletion superagi/tools/knowledge_search/knowledge_search.py
Original file line number Diff line number Diff line change
Expand Up @@ -49,7 +49,7 @@ def _execute(self, query: str):
embedding_model = AgentExecutor.get_embedding(model_source, model_api_key)
try:
if vector_db_index.state == "Custom":
filters = {}
filters = None
if vector_db_index.state == "Marketplace":
filters = {"knowledge_name": knowledge.name}
vector_db_storage = VectorFactory.build_vector_storage(vector_db.db_type, vector_db_index.name, embedding_model, **db_creds)
Expand Down
6 changes: 5 additions & 1 deletion superagi/vector_embeddings/vector_embedding_factory.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
from pinecone import UnauthorizedException
from superagi.vector_embeddings.pinecone import Pinecone
from superagi.vector_embeddings.qdrant import Qdrant
from superagi.vector_embeddings.weaviate import Weaviate
from superagi.types.vector_store_types import VectorStoreType

class VectorEmbeddingFactory:
Expand Down Expand Up @@ -40,4 +41,7 @@ def build_vector_storage(cls, vector_store: VectorStoreType, chunk_json: Optiona
return Pinecone(uuid, embeds, metadata)

if vector_store == VectorStoreType.QDRANT:
return Qdrant(uuid, embeds, metadata)
return Qdrant(uuid, embeds, metadata)

if vector_store == VectorStoreType.WEAVIATE:
return Weaviate(uuid, embeds, metadata)
14 changes: 14 additions & 0 deletions superagi/vector_embeddings/weaviate.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,14 @@
from typing import Any
from superagi.vector_embeddings.base import VectorEmbeddings

class Weaviate(VectorEmbeddings):

def __init__(self, uuid, embeds, metadata):
self.uuid = uuid
self.embeds = embeds
self.metadata = metadata

def get_vector_embeddings_from_chunks(self):
""" Returns embeddings for vector dbs from final chunks"""

return {'ids': self.uuid, 'data_object': self.metadata, 'vectors': self.embeds}
10 changes: 9 additions & 1 deletion superagi/vector_store/vector_factory.py
Original file line number Diff line number Diff line change
Expand Up @@ -93,4 +93,12 @@ def build_vector_storage(cls, vector_store: VectorStoreType, index_name, embeddi
client = qdrant.create_qdrant_client(creds["api_key"], creds["url"], creds["port"])
return qdrant.Qdrant(client, embedding_model, index_name)
except:
raise ValueError("Qdrant API key not found")
raise ValueError("Qdrant API key not found")

if vector_store == VectorStoreType.WEAVIATE:
try:
client = weaviate.create_weaviate_client(creds["url"], creds["api_key"])
return weaviate.Weaviate(client, embedding_model, index_name)
except:
raise ValueError("Weaviate API key not found")

126 changes: 79 additions & 47 deletions superagi/vector_store/weaviate.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,13 +4,12 @@
from typing import Any, Dict, Iterable, List, Optional, Tuple

import weaviate

from uuid import uuid4
from superagi.vector_store.base import VectorStore
from superagi.vector_store.document import Document


def create_weaviate_client(
use_embedded: bool = True,
url: Optional[str] = None,
api_key: Optional[str] = None,
) -> weaviate.Client:
Expand All @@ -28,9 +27,7 @@ def create_weaviate_client(
Raises:
ValueError: If invalid argument combination are passed.
"""
if use_embedded:
client = weaviate.Client(embedded_options=weaviate.embedded.EmbeddedOptions())
elif url:
if url:
if api_key:
auth_config = weaviate.AuthApiKey(api_key=api_key)
else:
Expand All @@ -45,9 +42,9 @@ def create_weaviate_client(

class Weaviate(VectorStore):
def __init__(
self, client: weaviate.Client, embedding_model: Any, index: str, text_field: str
self, client: weaviate.Client, embedding_model: Any, class_name: str, text_field: str = "text"
):
self.index = index
self.class_name = class_name
self.embedding_model = embedding_model
self.text_field = text_field

Expand All @@ -56,48 +53,47 @@ def __init__(
def add_texts(
self, texts: Iterable[str], metadatas: List[dict] | None = None, **kwargs: Any
) -> List[str]:
result = []
with self.client.batch as batch:
for i, text in enumerate(texts):
metadata = metadatas[i] if metadatas else {}
data_object = metadata.copy()
data_object[self.text_field] = text
vector = self.embedding_model.get_embedding(text)

batch.add_data_object(data_object, class_name=self.index, vector=vector)

object = batch.create_objects()[0]
result.append(object["id"])
return result
result = {}
collected_ids = []
for i, text in enumerate(texts):
metadata = metadatas[i] if metadatas else {}
data_object = metadata.copy()
data_object[self.text_field] = text
vector = self.embedding_model.get_embedding(text)
id = str(uuid4())
result = {"ids": id, "data_object": data_object, "vectors": vector}
collected_ids.append(id)
self.add_embeddings_to_vector_db(result)
return collected_ids

def get_matching_text(
self, query: str, top_k: int = 5, **kwargs: Any
self, query: str, top_k: int = 5, metadata: dict = None, **kwargs: Any
) -> List[Document]:
alpha = kwargs.get("alpha", 0.5)
metadata_fields = self._get_metadata_fields()
query_vector = self.embedding_model.get_embedding(query)

results = (
self.client.query.get(self.index, metadata_fields + [self.text_field])
.with_hybrid(query, vector=query_vector, alpha=alpha)
.with_limit(top_k)
.do()
)

results_data = results["data"]["Get"][self.index]
documents = []
for result in results_data:
text_content = result[self.text_field]
metadata = {}
for field in metadata_fields:
metadata[field] = result[field]
document = Document(text_content=text_content, metadata=metadata)
documents.append(document)

return documents

if metadata is not None:
for key, value in metadata.items():
filters = {
"path": [key],
"operator": "Equal",
"valueString": value
}

results = self.client.query.get(
self.class_name,
metadata_fields + [self.text_field],
).with_near_vector(
{"vector": query_vector, "certainty": 0.7}
).with_where(filters).with_limit(top_k).do()

results_data = results["data"]["Get"][self.class_name]
search_res = self._get_search_res(results_data, query)
documents = self._build_documents(results_data, metadata_fields)

return {"search_res": search_res, "documents": documents}

def _get_metadata_fields(self) -> List[str]:
schema = self.client.schema.get(self.index)
schema = self.client.schema.get(self.class_name)
property_names = []
for property_schema in schema["properties"]:
property_names.append(property_schema["name"])
Expand All @@ -106,10 +102,46 @@ def _get_metadata_fields(self) -> List[str]:
return property_names

def get_index_stats(self) -> dict:
pass
result = self.client.query.aggregate(self.class_name).with_meta_count().do()
vector_count = result['data']['Aggregate'][self.class_name][0]['meta']['count']
return {'vector_count': vector_count}

def add_embeddings_to_vector_db(self, embeddings: dict) -> None:
pass

try:
with self.client.batch as batch:
for i in range(len(embeddings['ids'])):
data_object = {key: value for key, value in embeddings['data_object'][i].items()}
batch.add_data_object(data_object, class_name=self.class_name, uuid=embeddings['ids'][i], vector=embeddings['vectors'][i])
except Exception as err:
raise err

def delete_embeddings_from_vector_db(self, ids: List[str]) -> None:
pass
try:
for id in ids:
self.client.data_object.delete(
uuid = id,
class_name = self.class_name
)
except Exception as err:
raise err

def _build_documents(self, results_data, metadata_fields) -> List[Document]:
documents = []
for result in results_data:
text_content = result[self.text_field]
metadata = {}
for field in metadata_fields:
metadata[field] = result[field]
document = Document(text_content=text_content, metadata=metadata)
documents.append(document)

return documents

def _get_search_res(self, results, query):
text = [item['text'] for item in results]
search_res = f"Query: {query}\n"
i = 0
for context in text:
search_res += f"Chunk{i}: \n{context}\n"
i += 1
return search_res
Empty file.
25 changes: 25 additions & 0 deletions tests/integration_tests/vector_embeddings/test_weaviate.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,25 @@
import unittest
from superagi.vector_embeddings.base import VectorEmbeddings
from superagi.vector_embeddings.weaviate import Weaviate

class TestWeaviate(unittest.TestCase):

def setUp(self):
self.weaviate = Weaviate(uuid="1234", embeds=[0.1, 0.2, 0.3, 0.4], metadata={"info": "sample data"})

def test_init(self):
self.assertEqual(self.weaviate.uuid, "1234")
self.assertEqual(self.weaviate.embeds, [0.1, 0.2, 0.3, 0.4])
self.assertEqual(self.weaviate.metadata, {"info": "sample data"})

def test_get_vector_embeddings_from_chunks(self):
expected_result = {
"ids": "1234",
"data_object": {"info": "sample data"},
"vectors": [0.1, 0.2, 0.3, 0.4]
}
self.assertEqual(self.weaviate.get_vector_embeddings_from_chunks(), expected_result)


if __name__ == '__main__':
unittest.main()
Loading

0 comments on commit 6dbd7da

Please sign in to comment.