Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Second Attempt - Add concurrent insertion of vector rows in the Cassandra Vector Store #7017

Merged
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
16 changes: 13 additions & 3 deletions docs/extras/ecosystem/integrations/cassandra.mdx
Original file line number Diff line number Diff line change
@@ -1,10 +1,10 @@
# Cassandra

>[Apache Cassandra®](https://cassandra.apache.org/) is a free and open-source, distributed, wide-column
>[Apache Cassandra®](https://cassandra.apache.org/) is a free and open-source, distributed, wide-column
> store, NoSQL database management system designed to handle large amounts of data across many commodity servers,
> providing high availability with no single point of failure. Cassandra offers support for clusters spanning
> providing high availability with no single point of failure. Cassandra offers support for clusters spanning
> multiple datacenters, with asynchronous masterless replication allowing low latency operations for all clients.
> Cassandra was designed to implement a combination of _Amazon's Dynamo_ distributed storage and replication
> Cassandra was designed to implement a combination of _Amazon's Dynamo_ distributed storage and replication
> techniques combined with _Google's Bigtable_ data and storage engine model.

## Installation and Setup
Expand All @@ -16,6 +16,16 @@ pip install cassio



## Vector Store

See a [usage example](/docs/modules/data_connection/vectorstores/integrations/cassandra.html).

```python
from langchain.memory import CassandraChatMessageHistory
```



## Memory

See a [usage example](/docs/modules/memory/integrations/cassandra_chat_message_history.html).
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@
},
"outputs": [],
"source": [
"!pip install \"cassio>=0.0.5\""
"!pip install \"cassio>=0.0.7\""
]
},
{
Expand All @@ -44,14 +44,16 @@
"import os\n",
"import getpass\n",
"\n",
"database_mode = (input('\\n(L)ocal Cassandra or (A)stra DB? ')).upper()\n",
"database_mode = (input('\\n(C)assandra or (A)stra DB? ')).upper()\n",
"\n",
"keyspace_name = input('\\nKeyspace name? ')\n",
"\n",
"if database_mode == 'A':\n",
" ASTRA_DB_APPLICATION_TOKEN = getpass.getpass('\\nAstra DB Token (\"AstraCS:...\") ')\n",
" #\n",
" ASTRA_DB_SECURE_BUNDLE_PATH = input('Full path to your Secure Connect Bundle? ')"
" ASTRA_DB_SECURE_BUNDLE_PATH = input('Full path to your Secure Connect Bundle? ')\n",
"elif database_mode == 'C':\n",
" CASSANDRA_CONTACT_POINTS = input('Contact points? (comma-separated, empty for localhost) ').strip()"
]
},
{
Expand All @@ -72,8 +74,15 @@
"from cassandra.cluster import Cluster\n",
"from cassandra.auth import PlainTextAuthProvider\n",
"\n",
"if database_mode == 'L':\n",
" cluster = Cluster()\n",
"if database_mode == 'C':\n",
" if CASSANDRA_CONTACT_POINTS:\n",
" cluster = Cluster([\n",
" cp.strip()\n",
" for cp in CASSANDRA_CONTACT_POINTS.split(',')\n",
" if cp.strip()\n",
" ])\n",
" else:\n",
" cluster = Cluster()\n",
" session = cluster.connect()\n",
"elif database_mode == 'A':\n",
" ASTRA_DB_CLIENT_ID = \"token\"\n",
Expand Down Expand Up @@ -261,7 +270,7 @@
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.8.10"
"version": "3.10.6"
}
},
"nbformat": 4,
Expand Down
77 changes: 44 additions & 33 deletions langchain/vectorstores/cassandra.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,8 @@
"""Wrapper around Cassandra vector-store capabilities, based on cassIO."""
from __future__ import annotations

import hashlib
import typing
import uuid
from typing import Any, Iterable, List, Optional, Tuple, Type, TypeVar

import numpy as np
Expand All @@ -17,14 +17,6 @@

CVST = TypeVar("CVST", bound="Cassandra")

# a positive number of seconds to expire entries, or None for no expiration.
CASSANDRA_VECTORSTORE_DEFAULT_TTL_SECONDS = None


def _hash(_input: str) -> str:
"""Use a deterministic hashing approach."""
return hashlib.md5(_input.encode()).hexdigest()


class Cassandra(VectorStore):
"""Wrapper around Cassandra embeddings platform.
Expand All @@ -46,7 +38,7 @@ class Cassandra(VectorStore):

_embedding_dimension: int | None

def _getEmbeddingDimension(self) -> int:
def _get_embedding_dimension(self) -> int:
if self._embedding_dimension is None:
self._embedding_dimension = len(
self.embedding.embed_query("This is a sample sentence.")
Expand All @@ -59,7 +51,7 @@ def __init__(
session: Session,
keyspace: str,
table_name: str,
ttl_seconds: int | None = CASSANDRA_VECTORSTORE_DEFAULT_TTL_SECONDS,
ttl_seconds: Optional[int] = None,
) -> None:
try:
from cassio.vector import VectorTable
Expand All @@ -81,8 +73,8 @@ def __init__(
session=session,
keyspace=keyspace,
table=table_name,
embedding_dimension=self._getEmbeddingDimension(),
auto_id=False, # the `add_texts` contract admits user-provided ids
embedding_dimension=self._get_embedding_dimension(),
primary_key_type="TEXT",
)

def delete_collection(self) -> None:
Expand All @@ -99,11 +91,27 @@ def clear(self) -> None:
def delete_by_document_id(self, document_id: str) -> None:
return self.table.delete(document_id)

def delete(self, ids: List[str]) -> Optional[bool]:
"""Delete by vector ID.

Args:
ids: List of ids to delete.

Returns:
Optional[bool]: True if deletion is successful,
False otherwise, None if not implemented.
"""
for document_id in ids:
self.delete_by_document_id(document_id)
return True

def add_texts(
self,
texts: Iterable[str],
metadatas: Optional[List[dict]] = None,
ids: Optional[List[str]] = None,
batch_size: int = 16,
ttl_seconds: Optional[int] = None,
**kwargs: Any,
) -> List[str]:
"""Run more texts through the embeddings and add to the vectorstore.
Expand All @@ -112,33 +120,39 @@ def add_texts(
texts (Iterable[str]): Texts to add to the vectorstore.
metadatas (Optional[List[dict]], optional): Optional list of metadatas.
ids (Optional[List[str]], optional): Optional list of IDs.
batch_size (int): Number of concurrent requests to send to the server.
ttl_seconds (Optional[int], optional): Optional time-to-live
for the added texts.

Returns:
List[str]: List of IDs of the added texts.
"""
_texts = list(texts) # lest it be a generator or something
if ids is None:
# unless otherwise specified, we have deterministic IDs:
# re-inserting an existing document will not create a duplicate.
# (and effectively update the metadata)
ids = [_hash(text) for text in _texts]
ids = [uuid.uuid4().hex for _ in _texts]
if metadatas is None:
metadatas = [{} for _ in _texts]
#
ttl_seconds = kwargs.get("ttl_seconds", self.ttl_seconds)
ttl_seconds = ttl_seconds or self.ttl_seconds
#
embedding_vectors = self.embedding.embed_documents(_texts)
for text, embedding_vector, text_id, metadata in zip(
_texts, embedding_vectors, ids, metadatas
):
self.table.put(
document=text,
embedding_vector=embedding_vector,
document_id=text_id,
metadata=metadata,
ttl_seconds=ttl_seconds,
)
#
for i in range(0, len(_texts), batch_size):
batch_texts = _texts[i : i + batch_size]
batch_embedding_vectors = embedding_vectors[i : i + batch_size]
batch_ids = ids[i : i + batch_size]
batch_metadatas = metadatas[i : i + batch_size]

futures = [
self.table.put_async(
text, embedding_vector, text_id, metadata, ttl_seconds
)
for text, embedding_vector, text_id, metadata in zip(
batch_texts, batch_embedding_vectors, batch_ids, batch_metadatas
)
]
for future in futures:
future.result()
return ids

# id-returning search facilities
Expand Down Expand Up @@ -181,7 +195,6 @@ def similarity_search_with_score_id(
self,
query: str,
k: int = 4,
**kwargs: Any,
) -> List[Tuple[Document, float, str]]:
embedding_vector = self.embedding.embed_query(query)
return self.similarity_search_with_score_id_by_vector(
Expand Down Expand Up @@ -219,12 +232,10 @@ def similarity_search(
k: int = 4,
**kwargs: Any,
) -> List[Document]:
#
embedding_vector = self.embedding.embed_query(query)
return self.similarity_search_by_vector(
embedding_vector,
k,
**kwargs,
)

def similarity_search_by_vector(
Expand All @@ -245,7 +256,6 @@ def similarity_search_with_score(
self,
query: str,
k: int = 4,
**kwargs: Any,
) -> List[Tuple[Document, float]]:
embedding_vector = self.embedding.embed_query(query)
return self.similarity_search_with_score_by_vector(
Expand All @@ -266,7 +276,6 @@ def _similarity_search_with_relevance_scores(
return self.similarity_search_with_score(
query,
k,
**kwargs,
)

def max_marginal_relevance_search_by_vector(
Expand Down Expand Up @@ -352,6 +361,7 @@ def from_texts(
texts: List[str],
embedding: Embeddings,
metadatas: Optional[List[dict]] = None,
batch_size: int = 16,
**kwargs: Any,
) -> CVST:
"""Create a Cassandra vectorstore from raw texts.
Expand All @@ -378,6 +388,7 @@ def from_documents(
cls: Type[CVST],
documents: List[Document],
embedding: Embeddings,
batch_size: int = 16,
**kwargs: Any,
) -> CVST:
"""Create a Cassandra vectorstore from a document list.
Expand Down
Loading