Skip to content

Commit

Permalink
Deep memory support in VectorStore (activeloopai#2594)
Browse files Browse the repository at this point in the history
  • Loading branch information
adolkhan authored Oct 3, 2023
1 parent 881f4a2 commit 178e28e
Show file tree
Hide file tree
Showing 37 changed files with 2,221 additions and 113 deletions.
2 changes: 2 additions & 0 deletions deeplake/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,8 @@
import sys
from deeplake.util.check_latest_version import warn_if_update_required

from deeplake.core.vectorstore import VectorStore


if sys.platform == "darwin":
multiprocessing.set_start_method("fork", force=True)
Expand Down
2 changes: 1 addition & 1 deletion deeplake/api/tests/test_api.py
Original file line number Diff line number Diff line change
Expand Up @@ -2363,7 +2363,7 @@ def dataset_handler_error_check(runner, username, password):

@pytest.mark.slow
def test_hub_related_permission_exceptions(
hub_cloud_dev_credentials, hub_cloud_dev_token, hub_dev_token
hub_cloud_dev_credentials,
):
username, password = hub_cloud_dev_credentials
runner = CliRunner()
Expand Down
113 changes: 112 additions & 1 deletion deeplake/client/client.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
import deeplake
import requests
from typing import Any, Optional, Dict
import textwrap
from typing import Any, Optional, Dict, List, Union
from deeplake.util.exceptions import (
AgreementNotAcceptedError,
AuthorizationException,
Expand All @@ -18,6 +19,7 @@
write_token,
read_token,
remove_token,
JobResponseStatusSchema,
)
from deeplake.client.config import (
ACCEPT_AGREEMENTS_SUFFIX,
Expand Down Expand Up @@ -527,3 +529,112 @@ def has_indra_org_permission(self, org_id: str) -> Dict[str, Any]:
).json()

return response


class DeepMemoryBackendClient(DeepLakeBackendClient):
def __init__(self, token: Optional[str] = None):
super().__init__(token=token)

def deepmemory_is_available(self, org_id: str):
"""Checks if DeepMemory is available for the user.
Args:
org_id (str): The name of the user/organization to which the dataset belongs.
Returns:
bool: True if DeepMemory is available, False otherwise.
"""
response = self.request(
"GET",
f"/api/organizations/{org_id}/features/deepmemory",
endpoint=self.endpoint(),
)
return response.json()["available"]

def start_taining(
self,
corpus_path: str,
queries_path: str,
) -> Dict[str, Any]:
"""Starts training of DeepMemory model.
Args:
corpus_path (str): The path to the corpus dataset.
queries_path (str): The path to the queries dataset.
Returns:
Dict[str, Any]: The json response containing job_id.
"""
response = self.request(
method="POST",
relative_url="/api/deepmemory/v1/train",
json={"corpus_dataset": corpus_path, "query_dataset": queries_path},
)
check_response_status(response)
return response.json()

def cancel_job(self, job_id: str):
"""Cancels a job with job_id.
Args:
job_id (str): The job_id of the job to be cancelled.
Returns:
bool: True if job was cancelled successfully, False otherwise.
"""
try:
response = self.request(
method="POST",
relative_url=f"/api/deepmemory/v1/jobs/{job_id}/cancel",
)
check_response_status(response)
except Exception as e:
print(f"Job with job_id='{job_id}' was not cancelled!\n Error: {e}")
return False
print("Job cancelled successfully")
return True

def check_status(self, job_id: str, recall: str, improvement: str):
"""Checks status of a job with job_id.
Args:
job_id (str): The job_id of the job to be checked.
recall (str): Current best top 10 recall
importvement (str): Current best improvement over baseline
Returns:
Dict[str, Any]: The json response containing job status.
"""
response = self.request(
method="GET",
relative_url=f"/api/deepmemory/v1/jobs/{job_id}/status",
)
check_response_status(response)
response_status_schema = JobResponseStatusSchema(response=response.json())
response_status_schema.print_status(job_id, recall, improvement)
return response.json()

def list_jobs(self, dataset_path: str):
"""Lists all jobs for a dataset.
Args:
dataset_path (str): The path to the dataset.
Returns:
Dict[str, Any]: The json response containing list of jobs.
"""
dataset_id = dataset_path[6:]
response = self.request(
method="GET",
relative_url=f"/api/deepmemory/v1/{dataset_id}/jobs",
)
check_response_status(response)
return response.json()

def delete_job(self, job_id: str):
"""Deletes a job with job_id.
Args:
job_id (str): The job_id of the job to be deleted.
Returns:
bool: True if job was deleted successfully, False otherwise.
"""
try:
response = self.request(
method="DELETE",
relative_url=f"/api/deepmemory/v1/jobs/{job_id}",
)
check_response_status(response)
return True
except Exception as e:
print(f"Job with job_id='{job_id}' was not deleted!\n Error: {e}")
return False
234 changes: 233 additions & 1 deletion deeplake/client/test_client.py
Original file line number Diff line number Diff line change
@@ -1,13 +1,19 @@
import pytest
from deeplake.cli.commands import login, logout
from click.testing import CliRunner
from deeplake.client.client import DeepLakeBackendClient
from deeplake.client.client import (
DeepLakeBackendClient,
DeepMemoryBackendClient,
JobResponseStatusSchema,
)
from deeplake.client.utils import (
write_token,
read_token,
remove_token,
)

from time import sleep


@pytest.mark.slow
def test_client_requests(hub_cloud_dev_credentials):
Expand Down Expand Up @@ -52,3 +58,229 @@ def test_client_workspace_organizations(
assert username in deeplake_client.get_user_organizations()

runner.invoke(logout)


def create_response(
job_id="6508464cd80cab681bfcfff3",
dataset_id="some_dataset_id",
organization_id="some_organization_id",
status="training",
progress={
"eta": 100.34,
"last_update_at": "2021-08-31T15:00:00.000000",
"error": None,
"train_recall@10": "87.8%",
"best_recall@10": "85.5% (+2.6)%",
"epoch": 0,
"base_val_recall@10": 0.8292181491851807,
"val_recall@10": "85.5%",
"dataset": "query",
"split": 0,
"loss": -0.05437087118625641,
"delta": 2.572011947631836,
},
):
return {
"id": job_id,
"dataset_id": dataset_id,
"organization_id": organization_id,
"status": status,
"progress": progress,
}


class Status:
pending = (
"--------------------------------------------------------------\n"
"| 1238464cd80cab681bfcfff3 |\n"
"--------------------------------------------------------------\n"
"| status | pending |\n"
"--------------------------------------------------------------\n"
"| progress | None |\n"
"--------------------------------------------------------------\n"
"| results | not available yet |\n"
"--------------------------------------------------------------\n\n\n"
)

training = (
"--------------------------------------------------------------\n"
"| 3218464cd80cab681bfcfff3 |\n"
"--------------------------------------------------------------\n"
"| status | training |\n"
"--------------------------------------------------------------\n"
"| progress | eta: 100.3 seconds |\n"
"| | recall@10: 85.5% (+2.6%) |\n"
"--------------------------------------------------------------\n"
"| results | not available yet |\n"
"--------------------------------------------------------------\n\n\n"
)

completed = (
"--------------------------------------------------------------\n"
"| 2138464cd80cab681bfcfff3 |\n"
"--------------------------------------------------------------\n"
"| status | completed |\n"
"--------------------------------------------------------------\n"
"| progress | eta: 100.3 seconds |\n"
"| | recall@10: 85.5% (+2.6%) |\n"
"--------------------------------------------------------------\n"
"| results | recall@10: 85.5% (+2.6%) |\n"
"--------------------------------------------------------------\n\n\n"
)

failed = (
"--------------------------------------------------------------\n"
"| 1338464cd80cab681bfcfff3 |\n"
"--------------------------------------------------------------\n"
"| status | failed |\n"
"--------------------------------------------------------------\n"
"| progress | eta: None seconds |\n"
"| | error: list indices must be |\n"
"| | integers or slices, |\n"
"| | not str |\n"
"--------------------------------------------------------------\n"
"| results | not available yet |\n"
"--------------------------------------------------------------\n\n\n"
)


def test_deepmemory_response_without_job_id():
response = create_response()

del response["dataset_id"]
del response["id"]

with pytest.raises(ValueError):
response_schema = JobResponseStatusSchema(response=response)

response["dataset_id"] = "some id"

with pytest.raises(ValueError):
response_schema = JobResponseStatusSchema(response=response)


def test_deepmemory_print_status_and_list_jobs(capsys, precomputed_jobs_list):
# for training that is just started
job_id = "1238464cd80cab681bfcfff3"
pending_response = create_response(
job_id=job_id,
status="pending",
progress=None,
)
response_schema = JobResponseStatusSchema(response=pending_response)
response_schema.print_status(job_id)
captured = capsys.readouterr()
assert captured.out == Status.pending

# for training that is in progress
job_id = "3218464cd80cab681bfcfff3"
training_response = create_response(job_id=job_id)
response_schema = JobResponseStatusSchema(response=training_response)
response_schema.print_status(job_id, recall="85.5", importvement="2.6")
captured = capsys.readouterr()
assert captured.out == Status.training

# for training jobs that are finished
job_id = "2138464cd80cab681bfcfff3"
completed_response = create_response(
job_id=job_id,
status="completed",
)
response_schema = JobResponseStatusSchema(response=completed_response)
response_schema.print_status(job_id, recall="85.5", importvement="2.6")
captured = capsys.readouterr()
assert captured.out == Status.completed

# for jobs that failed
job_id = "1338464cd80cab681bfcfff3"
failed_response = create_response(
job_id=job_id,
status="failed",
progress={
"eta": None,
"last_update_at": "2021-08-31T15:00:00.000000",
"error": "list indices must be integers or slices, not str",
"dataset": "query",
},
)
response_schema = JobResponseStatusSchema(response=failed_response)
response_schema.print_status(job_id)
captured = capsys.readouterr()
assert captured.out == Status.failed

responses = [
pending_response,
training_response,
completed_response,
failed_response,
]
recalls = {
"1238464cd80cab681bfcfff3": None,
"3218464cd80cab681bfcfff3": "85.5",
"2138464cd80cab681bfcfff3": "85.5",
"1338464cd80cab681bfcfff3": None,
}
improvements = {
"1238464cd80cab681bfcfff3": None,
"3218464cd80cab681bfcfff3": "2.6",
"2138464cd80cab681bfcfff3": "2.6",
"1338464cd80cab681bfcfff3": None,
}
response_schema = JobResponseStatusSchema(response=responses)
output_str = response_schema.print_jobs(
debug=True,
recalls=recalls,
improvements=improvements,
)
assert output_str == precomputed_jobs_list


@pytest.mark.slow
def test_deepmemory_train_and_cancel(job_id, capsys, hub_cloud_dev_token):
client = DeepMemoryBackendClient(hub_cloud_dev_token)

canceled = client.cancel_job(job_id="non-existent-job-id")
captured = capsys.readouterr()
expected = "Job with job_id='non-existent-job-id' was not cancelled!\n Error: Entity non-existent-job-id does not exist.\n"
assert canceled == False
assert captured.out == expected

canceled = client.cancel_job(job_id=job_id)
captured = capsys.readouterr()
expected = (
f"Job with job_id='{job_id}' was not cancelled!\n"
f" Error: Job {job_id} is not in pending state, skipping cancellation.\n"
)
assert canceled == False
assert expected in captured.out == expected


@pytest.mark.slow
def test_deepmemory_delete(
capsys,
hub_cloud_dev_credentials,
corpus_query_relevances_copy,
hub_cloud_dev_token,
):
(
corpus_path,
_,
_,
_,
) = corpus_query_relevances_copy

username, _ = hub_cloud_dev_credentials
query_path = f"hub://{username}/deepmemory_test_queries_managed"
client = DeepMemoryBackendClient(hub_cloud_dev_token)
job = client.start_taining(
corpus_path=corpus_path,
queries_path=query_path,
)
client.cancel_job(job_id=job["job_id"])
client.delete_job(job_id=job["job_id"])

deleted = client.delete_job(job_id="non-existent-job-id")
output_str = capsys.readouterr()
expected = "Job with job_id='non-existent-job-id' was not deleted!\n Error: Entity non-existent-job-id does not exist.\n"
assert deleted == False
assert expected in output_str.out
Loading

0 comments on commit 178e28e

Please sign in to comment.