Skip to content

Commit

Permalink
WIP
Browse files Browse the repository at this point in the history
  • Loading branch information
adolkhan committed May 12, 2023
1 parent 5524f7f commit 6b35c09
Show file tree
Hide file tree
Showing 6 changed files with 163 additions and 49 deletions.
1 change: 1 addition & 0 deletions deeplake/constants.py
Original file line number Diff line number Diff line change
Expand Up @@ -191,3 +191,4 @@
TRANSFORM_CHUNK_CACHE_SIZE = 64 * MB

DEFAULT_DEEPLAKE_PATH = "./deeplake_vector_store"
MAX_RETRY_ATTEMPTS = 5
6 changes: 4 additions & 2 deletions deeplake/core/vectorstore/deeplake_vectorstore.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
from deeplake.core.vectorstore.vector_search import filter as filter_utils
from deeplake.constants import DEFAULT_DEEPLAKE_PATH
from deeplake.core.vectorstore.vector_search import vector_search
from deeplake.core.vectorstore.vector_search.ingestion import data_ingestion
from deeplake.core.vectorstore.vector_search.ingestion import ingest_data

try:
from indra import api
Expand Down Expand Up @@ -67,6 +67,7 @@ def add(
metadatas: Optional[List[dict]] = None,
ids: Optional[List[str]] = None,
embeddings: Optional[np.ndarray] = None,
total_samples_processed: Optional[Any] = None,
) -> List[str]:
"""Adding elements to deeplake vector store
Expand All @@ -79,12 +80,13 @@ def add(
ids (List[str]): List of document IDs
"""
elements = dataset_utils.create_elements(ids, texts, metadatas, embeddings)
data_ingestion.run_data_ingestion(
ingest_data.run_data_ingestion(
elements=elements,
dataset=self.dataset,
embedding_function=self.embedding_function,
ingestion_batch_size=self.ingestion_batch_size,
num_workers=self.num_workers,
total_samples_processed=total_samples_processed,
)
self.dataset.commit(allow_empty=True)
if self.verbose:
Expand Down
10 changes: 7 additions & 3 deletions deeplake/core/vectorstore/vector_search/dataset/dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,7 @@ def create_or_load_dataset(
if dataset_exists(dataset_path, token, creds, **kwargs):
return load_dataset(dataset_path, token, creds, logger, read_only, **kwargs)

return create_dataset(dataset_path, token, **kwargs)
return create_dataset(dataset_path, token, exec_option, **kwargs)


def dataset_exists(dataset_path, token, creds, **kwargs):
Expand All @@ -60,8 +60,12 @@ def load_dataset(dataset_path, token, creds, logger, read_only, **kwargs):
return dataset


def create_dataset(dataset_path, token, **kwargs):
dataset = deeplake.empty(dataset_path, token=token, **kwargs)
def create_dataset(dataset_path, token, exec_option, **kwargs):
runtime = None
if exec_option == "db_engite":
runtime = {"db_engite": True}

dataset = deeplake.empty(dataset_path, token=token, runtime=runtime, **kwargs)

with dataset:
dataset.create_tensor(
Expand Down
138 changes: 103 additions & 35 deletions deeplake/core/vectorstore/vector_search/ingestion/data_ingestion.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,44 +6,112 @@
from deeplake.core.dataset import Dataset as DeepLakeDataset
from deeplake.core.vectorstore.vector_search import utils
from deeplake.util.exceptions import TransformError
from deeplake.constants import MAX_RETRY_ATTEMPTS


def run_data_ingestion(
elements: List[Dict[str, Any]],
dataset: DeepLakeDataset,
embedding_function: Callable,
ingestion_batch_size: int,
num_workers: int,
):
"""Running data ingestion into deeplake dataset.
Args:
elements (List[Dict[str, Any]]): List of dictionaries. Each dictionary contains mapping of
names of 4 tensors (i.e. "embedding", "metadata", "ids", "text") to their corresponding values.
dataset (DeepLakeDataset): deeplake dataset object.
embedding_function (Callable): function used to convert query into an embedding.
ingestion_batch_size (int): The batch size to use during ingestion.
num_workers (int): The number of workers to use for ingesting data in parallel.
"""
batch_size = min(ingestion_batch_size, len(elements))
if batch_size == 0:
raise ValueError("batch_size must be a positive number greater than zero.")

batched = [
elements[i : i + batch_size] for i in range(0, len(elements), batch_size)
]

num_workers = min(num_workers, len(batched) // max(num_workers, 1))
checkpoint_interval = int(
(0.1 * len(batched) // max(num_workers, 1)) * max(num_workers, 1)
)

ingest(embedding_function=embedding_function).eval(
batched,
class DataIngestion:
def __init__(
self,
elements,
dataset,
num_workers=num_workers,
checkpoint_interval=checkpoint_interval,
)
embedding_function: Callable,
ingestion_batch_size: int,
num_workers: int,
retry_attempt: int,
total_samples_processed=None,
):
self.elements = elements
self.dataset = dataset
self.embedding_function = embedding_function
self.ingestion_batch_size = ingestion_batch_size
self.num_workers = num_workers
self.retry_attempt = retry_attempt
self.total_samples_processed = total_samples_processed

def collect_batched_data(self):
batch_size = min(self.ingestion_batch_size, len(self.elements))
if batch_size == 0:
raise ValueError("batch_size must be a positive number greater than zero.")

if self.total_samples_processed:
if self.total_samples_processed * batch_size >= len(self.elements):
return []

elements = self.elements[self.total_samples_processed * batch_size :]

batched = [
elements[i : i + batch_size] for i in range(0, len(elements), batch_size)
]
return batched

def get_num_workers(self, batched):
return min(self.num_workers, len(batched) // max(self.num_workers, 1))

def get_checkpoint_interval(self, batched):
checkpoint_interval = max(
int(
(0.1 * len(batched) // max(self.num_workers, 1))
* max(self.num_workers, 1),
),
self.num_workers,
1,
)
return checkpoint_interval

def run(self):
batched_data = self.collect_batched_data()
num_workers = self.get_num_workers(batched_data)
checkpoint_interval = self.get_checkpoint_interval(batched_data)

self._ingest(
batched=batched_data,
num_workers=num_workers,
checkpoint_interval=checkpoint_interval,
)

def _ingest(
self,
batched,
num_workers,
checkpoint_interval,
):
try:
ingest(embedding_function=self.embedding_function).eval(
batched,
self.dataset,
num_workers=num_workers,
checkpoint_interval=checkpoint_interval,
)
except Exception as e:
self.retry_attempt += 1
if self.retry_attempt > MAX_RETRY_ATTEMPTS:
raise Exception(
f"""Maximum retry attempts exceeded. You can resume ingestion, from the latest saved checkpoint.
To do that you should run:
```
deeplake_vector_store.add(
texts=texts,
metadatas=metadatas,
ids=ids,
embeddings=embeddings,
total_samples_processed={self.total_samples_processed},
)
```
"""
)
last_checkpoint = self.dataset.version_state["commit_node"].parent
self.total_samples_processed += last_checkpoint.total_samples_processed

data_ingestion = DataIngestion(
elements=self.elements,
dataset=self.dataset,
embedding_function=self.embedding_function,
ingestion_batch_size=self.ingestion_batch_size,
num_workers=num_workers,
retry_attempt=self.retry_attempt,
total_samples_processed=self.total_samples_processed,
)
data_ingestion.run()


@deeplake.compute
Expand Down
39 changes: 39 additions & 0 deletions deeplake/core/vectorstore/vector_search/ingestion/ingest_data.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,39 @@
from typing import Dict, List, Any, Callable

from deeplake.core.dataset import Dataset as DeepLakeDataset
from deeplake.core.vectorstore.vector_search.ingestion.data_ingestion import (
DataIngestion,
)


def run_data_ingestion(
elements: List[Dict[str, Any]],
dataset: DeepLakeDataset,
embedding_function: Callable,
ingestion_batch_size: int,
num_workers: int,
retry_attempt: int = 0,
total_samples_processed=None,
):
"""Running data ingestion into deeplake dataset.
Args:
elements (List[Dict[str, Any]]): List of dictionaries. Each dictionary contains mapping of
names of 4 tensors (i.e. "embedding", "metadata", "ids", "text") to their corresponding values.
dataset (DeepLakeDataset): deeplake dataset object.
embedding_function (Callable): function used to convert query into an embedding.
ingestion_batch_size (int): The batch size to use during ingestion.
num_workers (int): The number of workers to use for ingesting data in parallel.
"""

data_ingestion = DataIngestion(
elements=elements,
dataset=dataset,
embedding_function=embedding_function,
ingestion_batch_size=ingestion_batch_size,
num_workers=num_workers,
retry_attempt=retry_attempt,
total_samples_processed=total_samples_processed,
)

data_ingestion.run()
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@

import deeplake
from deeplake.constants import MB
from deeplake.core.vectorstore.vector_search.ingestion import data_ingestion
from deeplake.core.vectorstore.vector_search.ingestion import ingest_data

random.seed(1)

Expand All @@ -14,10 +14,10 @@ def corrupted_embedding_function(emb):
p = random.uniform(0, 1)
if p > 0.9:
raise Exception("CorruptedEmbeddingFunction")
return np.zeros((1, 1536), dtype=np.float32)
return np.zeros((len(emb), 1536), dtype=np.float32)


def test_data_ingestion():
def test_ingest_data():
data = [
{
"text": "a",
Expand Down Expand Up @@ -45,7 +45,7 @@ def test_data_ingestion():
},
]

dataset = deeplake.empty("mem://xyz")
dataset = deeplake.empty("./xyzabc", overwrite=True)
dataset.create_tensor(
"text",
htype="text",
Expand Down Expand Up @@ -80,7 +80,7 @@ def test_data_ingestion():
chunk_compression="lz4",
)

data_ingestion.run_data_ingestion(
ingest_data.run_data_ingestion(
dataset=dataset,
elements=data,
embedding_function=None,
Expand All @@ -89,18 +89,18 @@ def test_data_ingestion():
)

assert len(dataset) == 4
extended_data = data * 10
extended_data = data * 10000
with pytest.raises(Exception):
data_ingestion.run_data_ingestion(
ingest_data.run_data_ingestion(
dataset=dataset,
elements=extended_data,
embedding_function=corrupted_embedding_function,
ingestion_batch_size=1,
ingestion_batch_size=1024,
num_workers=2,
)

with pytest.raises(ValueError):
data_ingestion.run_data_ingestion(
ingest_data.run_data_ingestion(
dataset=dataset,
elements=extended_data,
embedding_function=corrupted_embedding_function,
Expand Down

0 comments on commit 6b35c09

Please sign in to comment.