diff --git a/deeplake/constants.py b/deeplake/constants.py index cd5922fd9d..8d2fc900b4 100644 --- a/deeplake/constants.py +++ b/deeplake/constants.py @@ -335,3 +335,10 @@ "REMOVE": 2, "UPDATE": 3, } + + +DEFAULT_RATE_LIMITER_KEY_TO_VALUE = { + "enabled": False, + "bytes_per_minute": MAX_BYTES_PER_MINUTE, + "batch_byte_size": TARGET_BYTE_SIZE, +} diff --git a/deeplake/core/vectorstore/deep_memory.py b/deeplake/core/vectorstore/deep_memory.py index 012186939c..a56b666c4f 100644 --- a/deeplake/core/vectorstore/deep_memory.py +++ b/deeplake/core/vectorstore/deep_memory.py @@ -1,3 +1,4 @@ +import logging import uuid from collections import defaultdict from typing import Any, Dict, Optional, List, Union, Callable, Tuple @@ -32,6 +33,7 @@ def __init__( self, dataset: Dataset, client: DeepMemoryBackendClient, + logger: logging.Logger, embedding_function: Optional[Any] = None, token: Optional[str] = None, creds: Optional[Dict[str, Any]] = None, @@ -41,6 +43,7 @@ def __init__( Args: dataset (Dataset): deeplake dataset object. client (DeepMemoryBackendClient): Client to interact with the DeepMemory managed service. Defaults to None. + logger (logging.Logger): Logger object. embedding_function (Optional[Any], optional): Embedding funtion class used to convert queries/documents to embeddings. Defaults to None. token (Optional[str], optional): API token for the DeepMemory managed service. Defaults to None. creds (Optional[Dict[str, Any]], optional): Credentials to access the dataset. Defaults to None. @@ -63,6 +66,7 @@ def __init__( self.embedding_function = embedding_function self.client = client self.creds = creds or {} + self.logger = logger def train( self, @@ -94,6 +98,7 @@ def train( Raises: ValueError: if embedding_function is not specified either during initialization or during training. """ + self.logger.info("Starting DeepMemory training job") feature_report_path( path=self.dataset.path, feature_name="dm.train", @@ -127,8 +132,10 @@ def train( runtime=runtime, token=token or self.token, creds=self.creds, + verbose=False, ) + self.logger.info("Preparing training data for deepmemory:") queries_vs.add( text=[query for query in queries], metadata=[ @@ -144,7 +151,9 @@ def train( queries_path=queries_path, ) - print(f"DeepMemory training job started. Job ID: {response['job_id']}") + self.logger.info( + f"DeepMemory training job started. Job ID: {response['job_id']}" + ) return response["job_id"] def cancel(self, job_id: str): @@ -305,75 +314,67 @@ def evaluate( top_k: List[int] = [1, 3, 5, 10, 50, 100], qvs_params: Optional[Dict[str, Any]] = None, ) -> Dict[str, Dict[str, float]]: - """Evaluate a model on DeepMemory managed service. + """ + Evaluate a model using the DeepMemory managed service. Examples: - >>> #1. Evaluate a model with embedding function - >>> relevance: List[List[Tuple[str, int]]] = [[("doc_id_1", 1), ("doc_id_2", 1)], [("doc_id_3", 1)]] - >>> # doc_id_1, doc_id_2, doc_id_3 are the ids of the documents in the corpus dataset that is relevant to the queries. It is stored in the `id` tensor of the corpus dataset. - >>> queries: List[str] = ["What is the capital of India?", "What is the capital of France?"] - >>> embedding_function: Callable[..., List[np.ndarray] = openai_embedding.embed_documents - >>> vectorstore.deep_memory.evaluate( - ... relevance=relevance, - ... queries=queries, - ... embedding_function=embedding_function, - ... ) - >>> #2. Evaluate a model with precomputed embeddings - >>> relevance: List[List[Tuple[str, int]]] = [[("doc_id_1", 1), ("doc_id_2", 1)], [("doc_id_3", 1)]] - >>> # doc_id_1, doc_id_2, doc_id_3 are the ids of the documents in the corpus dataset that is relevant to the queries. It is stored in the `id` tensor of the corpus dataset. - >>> queries: List[str] = ["What is the capital of India?", "What is the capital of France?"] - >>> embedding: Union[List[np.ndarray[Any, Any]], List[List[float]] = [[-1.2, 12, ...], ...] - >>> vectorstore.deep_memory.evaluate( - ... relevance=relevance, - ... queries=queries, - ... embedding=embedding, - ... ) - >>> #3. Evaluate a model with precomputed embeddings and log queries - >>> relevance: List[List[Tuple[str, int]]] = [[("doc_id_1", 1), ("doc_id_2", 1)], [("doc_id_3", 1)]] - >>> # doc_id_1, doc_id_2, doc_id_3 are the ids of the documents in the corpus dataset that is relevant to the queries. It is stored in the `id` tensor of the corpus dataset. - >>> queries: List[str] = ["What is the capital of India?", "What is the capital of France?"] - >>> embedding: Union[List[np.ndarray[Any, Any]], List[List[float]] = [[-1.2, 12, ...], ...] - >>> vectorstore.deep_memory.evaluate( - ... relevance=relevance, - ... queries=queries, - ... embedding=embedding, - ... qvs_params={ - ... "log_queries": True, - ... } - ... ) - >>> #4. Evaluate a model with precomputed embeddings and log queries, and custom branch - >>> relevance: List[List[Tuple[str, int]]] = [[("doc_id_1", 1), ("doc_id_2", 1)], [("doc_id_3", 1)]] - >>> # doc_id_1, doc_id_2, doc_id_3 are the ids of the documents in the corpus dataset that is relevant to the queries. It is stored in the `id` tensor of the corpus dataset. - >>> queries: List[str] = ["What is the capital of India?", "What is the capital of France?"] - >>> embedding: Union[List[np.ndarray[Any, Any]], List[List[float]] = [[-1.2, 12, ...], ...] - >>> vectorstore.deep_memory.evaluate( - ... relevance=relevance, - ... queries=queries, - ... embedding=embedding, - ... qvs_params={ - ... "log_queries": True, - ... "branch": "queries", - ... } - ... ) + # 1. Evaluate a model using an embedding function: + relevance = [[("doc_id_1", 1), ("doc_id_2", 1)], [("doc_id_3", 1)]] + queries = ["What is the capital of India?", "What is the capital of France?"] + embedding_function = openai_embedding.embed_documents + vectorstore.deep_memory.evaluate( + relevance=relevance, + queries=queries, + embedding_function=embedding_function, + ) + + # 2. Evaluate a model with precomputed embeddings: + embeddings = [[-1.2, 12, ...], ...] + vectorstore.deep_memory.evaluate( + relevance=relevance, + queries=queries, + embedding=embeddings, + ) + + # 3. Evaluate a model with precomputed embeddings and log queries: + vectorstore.deep_memory.evaluate( + relevance=relevance, + queries=queries, + embedding=embeddings, + qvs_params={"log_queries": True}, + ) + + # 4. Evaluate with precomputed embeddings, log queries, and a custom branch: + vectorstore.deep_memory.evaluate( + relevance=relevance, + queries=queries, + embedding=embeddings, + qvs_params={ + "log_queries": True, + "branch": "queries", + } + ) Args: - queries (List[str]): List of queries to evaluate the model on. - relevance (List[List[Tuple[str, int]]]): List of relevant documents for each query with their respective relevance score. - The outer list corresponds to the queries and the inner list corresponds to the doc_id, relevence_score pair for each query. - doc_id is the document id in the corpus dataset. It is stored in the `id` tensor of the corpus dataset. - relevence_score is the relevance score of the document for the query. The range is between 0 and 1, where 0 stands for not relevant and 1 stands for relevant. - embedding (Optional[np.ndarray], optional): Embedding of the queries. Defaults to None. - embedding_function (Optional[Callable[..., List[np.ndarray]]], optional): Embedding funtion used to convert queries to embeddings. Defaults to None. - top_k (List[int], optional): List of top_k values to evaluate the model on. Defaults to [1, 3, 5, 10, 50, 100]. - qvs_params (Optional[Dict], optional): Parameters to initialize the queries vectorstore. Defaults to None. + queries (List[str]): Queries for model evaluation. + relevance (List[List[Tuple[str, int]]]): Relevant documents and scores for each query. + - Outer list: matches the queries. + - Inner list: pairs of doc_id and relevance score. + - doc_id: Document ID from the corpus dataset, found in the `id` tensor. + - relevance_score: Between 0 (not relevant) and 1 (relevant). + embedding (Optional[np.ndarray], optional): Query embeddings. Defaults to None. + embedding_function (Optional[Callable[..., List[np.ndarray]]], optional): Function to convert queries into embeddings. Defaults to None. + top_k (List[int], optional): Ranks for model evaluation. Defaults to [1, 3, 5, 10, 50, 100]. + qvs_params (Optional[Dict], optional): Parameters to initialize the queries vectorstore. When specified, creates a new vectorstore to track evaluation queries, the Deep Memory response, and the naive vector search results. Defaults to None. Returns: - Dict[str, Dict[str, float]]: Dictionary of recalls for each top_k value. + Dict[str, Dict[str, float]]: Recalls for each rank. Raises: - ImportError: if indra is not installed - ValueError: if embedding_function is not specified either during initialization or during evaluation. + ImportError: If `indra` is not installed. + ValueError: If no embedding_function is provided either during initialization or evaluation. """ + feature_report_path( path=self.dataset.path, feature_name="dm.evaluate", @@ -440,7 +441,7 @@ def evaluate( (True, "deepmemory_distance"), ]: eval_type = "with" if use_model else "without" - print(f"---- Evaluating {eval_type} model ---- ") + print(f"---- Evaluating {eval_type} Deep Memory ---- ") avg_recalls, queries_dict = recall_at_k( indra_dataset, relevance, diff --git a/deeplake/core/vectorstore/deeplake_vectorstore.py b/deeplake/core/vectorstore/deeplake_vectorstore.py index 16775855b3..2a591194f7 100644 --- a/deeplake/core/vectorstore/deeplake_vectorstore.py +++ b/deeplake/core/vectorstore/deeplake_vectorstore.py @@ -8,7 +8,6 @@ import deeplake from deeplake.core import index_maintenance from deeplake.core.distance_type import DistanceType -from deeplake.util.dataset import try_flushing from deeplake.util.exceptions import DeepMemoryWaitingListError from deeplake.util.path import convert_pathlib_to_string_if_needed @@ -340,10 +339,9 @@ def add( embedding_data=embedding_data, embedding_tensor=embedding_tensor, rate_limiter=rate_limiter, + logger=self.logger, ) - try_flushing(self.dataset) - if self.verbose: self.dataset.summary() @@ -448,8 +446,6 @@ def search( username=self.username, ) - try_flushing(self.dataset) - if exec_option is None and self.exec_option != "python" and callable(filter): self.logger.warning( 'Switching exec_option to "python" (runs on client) because filter is specified as a function. ' @@ -603,8 +599,6 @@ def delete( self.dataset.pop_multiple(row_ids) - try_flushing(self.dataset) - return True def update_embedding( @@ -677,8 +671,6 @@ def update_embedding( username=self.username, ) - try_flushing(self.dataset) - ( embedding_function, embedding_source_tensor, @@ -711,8 +703,6 @@ def update_embedding( self.dataset[row_ids].update(embedding_tensor_data) - try_flushing(self.dataset) - @staticmethod def delete_by_path( path: Union[str, pathlib.Path], diff --git a/deeplake/core/vectorstore/deepmemory_vectorstore.py b/deeplake/core/vectorstore/deepmemory_vectorstore.py index adcc303710..fd5351e50e 100644 --- a/deeplake/core/vectorstore/deepmemory_vectorstore.py +++ b/deeplake/core/vectorstore/deepmemory_vectorstore.py @@ -17,6 +17,7 @@ def __init__(self, client, *arg, **kwargs): embedding_function=self.embedding_function, client=client, creds=self.creds, + logger=self.logger, ) def search( diff --git a/deeplake/core/vectorstore/vector_search/dataset/__init__.py b/deeplake/core/vectorstore/vector_search/dataset/__init__.py index 73cbf86705..8efeae2625 100644 --- a/deeplake/core/vectorstore/vector_search/dataset/__init__.py +++ b/deeplake/core/vectorstore/vector_search/dataset/__init__.py @@ -11,4 +11,5 @@ convert_id_to_row_id, search_row_ids, extend, + populate_rate_limiter, ) diff --git a/deeplake/core/vectorstore/vector_search/dataset/dataset.py b/deeplake/core/vectorstore/vector_search/dataset/dataset.py index c7a67b3388..15ef5c9986 100644 --- a/deeplake/core/vectorstore/vector_search/dataset/dataset.py +++ b/deeplake/core/vectorstore/vector_search/dataset/dataset.py @@ -18,6 +18,7 @@ MAX_BYTES_PER_MINUTE, TARGET_BYTE_SIZE, VECTORSTORE_EXTEND_BATCH_SIZE, + DEFAULT_RATE_LIMITER_KEY_TO_VALUE, ) from deeplake.util.exceptions import IncorrectEmbeddingShapeError @@ -460,6 +461,7 @@ def extend( dataset: deeplake.core.dataset.Dataset, rate_limiter: Dict, _extend_batch_size: int = VECTORSTORE_EXTEND_BATCH_SIZE, + logger=None, ): """ Function to extend the dataset with new data. @@ -468,8 +470,15 @@ def extend( embedding_data = [embedding_data] if embedding_function: + number_of_batches = ceil(len(embedding_data[0]) / _extend_batch_size) + progressbar_str = ( + f"Creating {len(embedding_data[0])} embeddings in " + f"{number_of_batches} batches of size {min(_extend_batch_size, len(embedding_data[0]))}:" + ) + for idx in tqdm( - range(0, len(embedding_data[0]), _extend_batch_size), "creating embeddings" + range(0, len(embedding_data[0]), _extend_batch_size), + progressbar_str, ): batch_start, batch_end = idx, idx + _extend_batch_size @@ -488,9 +497,32 @@ def extend( batched_processed_tensors = {**batched_embeddings, **batched_tensors} - dataset.extend(batched_processed_tensors) + dataset.extend(batched_processed_tensors, progressbar=False) + else: + logger.info("Uploading data to deeplake dataset.") + dataset.extend(processed_tensors, progressbar=True) + + +def populate_rate_limiter(rate_limiter): + if rate_limiter is None or rate_limiter == {}: + return { + "enabled": False, + "bytes_per_minute": MAX_BYTES_PER_MINUTE, + "batch_byte_size": TARGET_BYTE_SIZE, + } else: - dataset.extend(processed_tensors) + rate_limiter_keys = ["enabled", "bytes_per_minute", "batch_byte_size"] + + for key in rate_limiter_keys: + if key not in rate_limiter: + rate_limiter[key] = DEFAULT_RATE_LIMITER_KEY_TO_VALUE[key] + + for item in rate_limiter: + if item not in rate_limiter_keys: + raise ValueError( + f"Invalid rate_limiter key: {item}. Valid keys are: 'enabled', 'bytes_per_minute', 'batch_byte_size'." + ) + return rate_limiter def extend_or_ingest_dataset( @@ -500,7 +532,9 @@ def extend_or_ingest_dataset( embedding_tensor, embedding_data, rate_limiter, + logger, ): + rate_limiter = populate_rate_limiter(rate_limiter) # TODO: Add back the old logic with checkpointing after indexing is fixed extend( embedding_function, @@ -509,6 +543,7 @@ def extend_or_ingest_dataset( processed_tensors, dataset, rate_limiter, + logger=logger, ) diff --git a/deeplake/core/vectorstore/vector_search/dataset/test_dataset.py b/deeplake/core/vectorstore/vector_search/dataset/test_dataset.py index da6d753c83..d123fe1073 100644 --- a/deeplake/core/vectorstore/vector_search/dataset/test_dataset.py +++ b/deeplake/core/vectorstore/vector_search/dataset/test_dataset.py @@ -428,3 +428,23 @@ def mock_embedding_function(text): assert ( abs(elapsed_minutes - expected_time) <= tolerance ), "Rate limiting did not work as expected!" + + +def test_populate_rate_limiter(): + rate_limiter = { + "enabled": True, + } + + rate_limiter_parsed = dataset_utils.populate_rate_limiter(rate_limiter) + assert rate_limiter_parsed == { + "enabled": True, + "bytes_per_minute": MAX_BYTES_PER_MINUTE, + "batch_byte_size": TARGET_BYTE_SIZE, + } + + rate_limiter = { + "enabled": True, + "bytes_per_second": 1000, + } + with pytest.raises(ValueError): + rate_limiter_parsed = dataset_utils.populate_rate_limiter(rate_limiter) diff --git a/deeplake/core/vectorstore/vector_search/indra/search_algorithm.py b/deeplake/core/vectorstore/vector_search/indra/search_algorithm.py index 3e64b88c57..e8835830ce 100644 --- a/deeplake/core/vectorstore/vector_search/indra/search_algorithm.py +++ b/deeplake/core/vectorstore/vector_search/indra/search_algorithm.py @@ -68,15 +68,7 @@ def search( return_tensors, ) - if runtime: - view, data = deeplake_dataset.query( - tql_query, runtime=runtime, return_data=True - ) - if return_view: - return view - - return_data = data - elif deep_memory: + if deep_memory: if not INDRA_INSTALLED: raise raise_indra_installation_error(indra_import_error=None) @@ -103,6 +95,15 @@ def search( for tensor in view.tensors: return_data[tensor] = utils.parse_tensor_return(view[tensor]) + elif runtime: + view, data = deeplake_dataset.query( + tql_query, runtime=runtime, return_data=True + ) + if return_view: + return view + + return_data = data + else: if not INDRA_INSTALLED: raise raise_indra_installation_error( diff --git a/deeplake/requirements/plugins.txt b/deeplake/requirements/plugins.txt index 4d1a43d8c5..4ba1ebe9b3 100644 --- a/deeplake/requirements/plugins.txt +++ b/deeplake/requirements/plugins.txt @@ -4,7 +4,7 @@ torchvision tensorflow tensorflow_datasets pickle5>=0.0.11; python_version < "3.8" and python_version >= "3.6" -ray==2.3.0 +ray==2.7.1 datasets~=1.17 mmcv-full==1.7.1; platform_system == "Linux" and python_version >= "3.7" mmdet==2.28.1; platform_system == "Linux" and python_version >= "3.7" diff --git a/docs/source/Deep-Memory.rst b/docs/source/Deep-Memory.rst index 56784d242b..a434927e05 100644 --- a/docs/source/Deep-Memory.rst +++ b/docs/source/Deep-Memory.rst @@ -1,8 +1,17 @@ .. _deep_memory: -Deep Memory -===================== +Deep Memory API +=============== +.. currentmodule:: deeplake.core.vectorstore.deep_memory +.. autoclass:: DeepMemory() + :members: + + .. automethod:: __init__ + + +Syntax +~~~~~~ .. role:: sql(code) :language: sql @@ -10,11 +19,8 @@ This page describes :meth:`ds.query `. De to improve the search results, by aligning queries with the corpus dataset. It gives up to +22% of recall improvement on an eval dataset. To use deep_memory, please subscribe to our waitlist. -Syntax -~~~~~~~~ - Training ------- +-------- To start training you should first create a vectostore object, and then preprocess the data and use deep memory with it: @@ -36,6 +42,8 @@ and the relevance score (range is 0-1, where 0 represents unrelated document and ... embedding_function = embedding_function, # function that takes converts texts into embeddings, it is optional and can be skipped if provided during initialization ... ) +Tracking the training progress +------------------------------ ``job_id`` is string, which can be used to track the training progress. You can use ``db.deep_memory.status(job_id)`` to get the status of the job. when the model is still in pending state (not started yet) you will see the following: @@ -74,7 +82,8 @@ ID STATUS RESULTS PROGRESS 651a4d41d05a21a5a6a15f67 completed recall@10: 0.62% (+0.62%) eta: 2.5 seconds recall@10: 0.62% (+0.62%) - +Deep Memory Evaluation +---------------------- Once the training is completed, you can use ``db.deep_memory.evaluate`` to evaluate the model performance on the custom dataset. Once again you would need to preprocess the dataset so that, ``corpus``, will become a list of list of tuples, where outer list corresponds to the query and inner list to the relevant documents. Each tuple should contain the document id (``id`` tensor from the corpus dataset) @@ -84,12 +93,17 @@ and the relevance score (range is 0-1, where 0 represents unrelated document and ... corpus: List[List[Tuple[str, float]]] = corpus, ... queries: List[str] = queries, ... embedding_function = embedding_function, # function that takes converts texts into embeddings, it is optional and can be skipped if provided during initialization +... qvs_params = {"enbabled": True} ... ) ``recalls`` is a dictionary with the following keys: ``with_model`` contains a dictionary with recall metrics for the naive vector search on the custom dataset for different k values ``without_model`` contains a dictionary with recall metrics for the naive vector search on the custom dataset for different k values +``qvs_params`` when specified creates a separate vectorstore that tracks all evaluation queries and documents, so that you can use it to compare the performance of +deep_memory to naive vector search. By default, it is turned off. If enabled the dataset will be created at ``hub://{$ORG_ID}/{$DATASET_ID}_eval_queries`` +Deep Memory Search +------------------ After the model is trained you also can search using it: >>> results = db.search( diff --git a/docs/source/Vector-Store.rst b/docs/source/Vector-Store.rst index 969460350a..fc839efabf 100644 --- a/docs/source/Vector-Store.rst +++ b/docs/source/Vector-Store.rst @@ -34,3 +34,40 @@ Vector Store Properties VectorStore.summary VectorStore.tensors VectorStore.__len__ + +VectorStore.DeepMemory +====================== + +Creating a Deep Memory +~~~~~~~~~~~~~~~~~~~~~~ + +if Deep Memory is available on your plan, it will be automatically initialized when you create a Vector Store. + +.. currentmodule:: deeplake.core.vectorstore.deep_memory +.. autosummary:: + :toctree: + :nosignatures: + + DeepMemory.__init__ + +Deep Memory Operations +~~~~~~~~~~~~~~~~~~~~~~~ + +.. autosummary:: + :toctree: + :nosignatures: + + DeepMemory.train + DeepMemory.cancel + DeepMemory.delete + +Deep Memory Properties +~~~~~~~~~~~~~~~~~~~~~~~ + +.. autosummary:: + :toctree: + :nosignatures: + + DeepMemory.status + DeepMemory.list_jobs + DeepMemory.__len__ \ No newline at end of file diff --git a/docs/source/deeplake.VectorStore.rst b/docs/source/deeplake.VectorStore.rst index e12aa2fb65..9eb9953fb7 100644 --- a/docs/source/deeplake.VectorStore.rst +++ b/docs/source/deeplake.VectorStore.rst @@ -2,5 +2,6 @@ deeplake.VectorStore -------------------- .. autoclass:: deeplake.core.vectorstore.deeplake_vectorstore.VectorStore - :members: - :show-inheritance: \ No newline at end of file + :members: + :show-inheritance: + :special-members: __init__ \ No newline at end of file