Skip to content

Commit

Permalink
[DH-5344] Redesigning table-description endpoints (#377)
Browse files Browse the repository at this point in the history
  • Loading branch information
jcjc712 authored Feb 5, 2024
1 parent 48f50c4 commit 28b8130
Show file tree
Hide file tree
Showing 7 changed files with 133 additions and 28 deletions.
7 changes: 7 additions & 0 deletions dataherald/api/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,7 @@
GoldenSQL,
GoldenSQLRequest,
InstructionRequest,
RefreshTableDescriptionRequest,
ScannerRequest,
TableDescriptionRequest,
UpdateInstruction,
Expand All @@ -52,6 +53,12 @@ def scan_db(
) -> list[TableDescriptionResponse]:
pass

@abstractmethod
def refresh_table_description(
self, refresh_table_description: RefreshTableDescriptionRequest
) -> list[TableDescriptionResponse]:
pass

@abstractmethod
def create_database_connection(
self, database_connection_request: DatabaseConnectionRequest
Expand Down
60 changes: 32 additions & 28 deletions dataherald/api/fastapi.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,8 +37,6 @@
from dataherald.db_scanner import Scanner
from dataherald.db_scanner.models.types import (
QueryHistory,
TableDescription,
TableDescriptionStatus,
)
from dataherald.db_scanner.repository.base import (
InvalidColumnNameError,
Expand Down Expand Up @@ -80,6 +78,7 @@
GoldenSQLRequest,
Instruction,
InstructionRequest,
RefreshTableDescriptionRequest,
ScannerRequest,
TableDescriptionRequest,
UpdateInstruction,
Expand Down Expand Up @@ -180,7 +179,6 @@ def create_database_connection(
metadata=database_connection_request.metadata,
)

SQLDatabase.get_sql_engine(db_connection, True)
except ValueError as e:
raise HTTPException(status_code=400, detail=str(e)) # noqa: B904
except InvalidDBConnectionError as e:
Expand All @@ -191,8 +189,32 @@ def create_database_connection(

db_connection_repository = DatabaseConnectionRepository(self.storage)
db_connection = db_connection_repository.insert(db_connection)

# Get tables and views and create table-descriptions as NOT_SCANNED
sql_database = SQLDatabase.get_sql_engine(db_connection, True)
scanner_repository = TableDescriptionRepository(self.storage)
scanner = self.system.instance(Scanner)
scanner.create_tables(sql_database, str(db_connection.id), scanner_repository)

return DatabaseConnectionResponse(**db_connection.dict())

@override
def refresh_table_description(
self, refresh_table_description: RefreshTableDescriptionRequest
) -> list[TableDescriptionResponse]:
db_connection_repository = DatabaseConnectionRepository(self.storage)
db_connection = db_connection_repository.find_by_id(
refresh_table_description.db_connection_id
)

sql_database = SQLDatabase.get_sql_engine(db_connection, True)
# Get tables and views and create missing table-descriptions as NOT_SCANNED and update DEPRECATED
scanner_repository = TableDescriptionRepository(self.storage)
scanner = self.system.instance(Scanner)
return scanner.refresh_tables(
sql_database, str(db_connection.id), scanner_repository
)

@override
def list_database_connections(self) -> list[DatabaseConnectionResponse]:
db_connection_repository = DatabaseConnectionRepository(self.storage)
Expand Down Expand Up @@ -221,7 +243,7 @@ def update_database_connection(
metadata=database_connection_request.metadata,
)

SQLDatabase.get_sql_engine(db_connection, True)
sql_database = SQLDatabase.get_sql_engine(db_connection, True)
except ValueError as e:
raise HTTPException(status_code=400, detail=str(e)) # noqa: B904
except InvalidDBConnectionError as e:
Expand All @@ -231,6 +253,12 @@ def update_database_connection(
)
db_connection_repository = DatabaseConnectionRepository(self.storage)
db_connection = db_connection_repository.update(db_connection)

# Get tables and views and create missing table-descriptions as NOT_SCANNED and update DEPRECATED
scanner_repository = TableDescriptionRepository(self.storage)
scanner = self.system.instance(Scanner)
scanner.refresh_tables(sql_database, str(db_connection.id), scanner_repository)

return DatabaseConnectionResponse(**db_connection.dict())

@override
Expand Down Expand Up @@ -267,30 +295,6 @@ def list_table_descriptions(
{"db_connection_id": str(db_connection_id), "table_name": table_name}
)

if db_connection_id:
db_connection_repository = DatabaseConnectionRepository(self.storage)
db_connection = db_connection_repository.find_by_id(db_connection_id)
database = SQLDatabase.get_sql_engine(db_connection)
all_tables = database.get_tables_and_views()

if table_name:
all_tables = [table for table in all_tables if table == table_name]

for table_description in table_descriptions:
if table_description.table_name not in all_tables:
table_description.status = TableDescriptionStatus.DEPRECATED.value
else:
all_tables.remove(table_description.table_name)
for table in all_tables:
table_descriptions.append(
TableDescription(
table_name=table,
status=TableDescriptionStatus.NOT_SCANNED.value,
db_connection_id=db_connection_id,
columns=[],
)
)

return [
TableDescriptionResponse(**table_description.dict())
for table_description in table_descriptions
Expand Down
2 changes: 2 additions & 0 deletions dataherald/db/mongo.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,6 +39,8 @@ def rename_field(
def update_or_create(self, collection: str, query: dict, obj: dict) -> int:
row = self.find_one(collection, query)
if row:
if "created_at" in obj:
del obj["created_at"]
self._data_store[collection].update_one(query, {"$set": obj})
return row["_id"]
return self.insert_one(collection, obj)
Expand Down
20 changes: 20 additions & 0 deletions dataherald/db_scanner/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,3 +28,23 @@ def synchronizing(
repository: TableDescriptionRepository,
) -> list[TableDescription]:
""" "Update table_description status"""

@abstractmethod
def create_tables(
self,
sql_database: SQLDatabase,
db_connection_id: str,
repository: TableDescriptionRepository,
metadata: dict = None,
) -> None:
""" "Create tables"""

@abstractmethod
def refresh_tables(
self,
sql_database: SQLDatabase,
db_connection_id: str,
repository: TableDescriptionRepository,
metadata: dict = None,
) -> list[TableDescription]:
""" "Refresh tables"""
54 changes: 54 additions & 0 deletions dataherald/db_scanner/sqlalchemy.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,60 @@ def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
self.scanner_service: AbstractScanner = None

@override
def create_tables(
self,
sql_database: SQLDatabase,
db_connection_id: str,
repository: TableDescriptionRepository,
metadata: dict = None,
) -> None:
tables = sql_database.get_tables_and_views()
for table in tables:
repository.save_table_info(
TableDescription(
db_connection_id=db_connection_id,
table_name=table,
status=TableDescriptionStatus.NOT_SCANNED.value,
metadata=metadata,
)
)

@override
def refresh_tables(
self,
sql_database: SQLDatabase,
db_connection_id: str,
repository: TableDescriptionRepository,
metadata: dict = None,
) -> list[TableDescription]:
stored_tables = repository.find_by({"db_connection_id": str(db_connection_id)})
stored_tables_list = [table.table_name for table in stored_tables]

source_tables = sql_database.get_tables_and_views()

rows = []
for table_description in stored_tables:
if table_description.table_name not in source_tables:
table_description.status = TableDescriptionStatus.DEPRECATED.value
rows.append(repository.save_table_info(table_description))
else:
rows.append(TableDescription(**table_description.dict()))

for table in source_tables:
if table not in stored_tables_list:
rows.append(
repository.save_table_info(
TableDescription(
db_connection_id=db_connection_id,
table_name=table,
status=TableDescriptionStatus.NOT_SCANNED.value,
metadata=metadata,
)
)
)
return rows

@override
def synchronizing(
self,
Expand Down
14 changes: 14 additions & 0 deletions dataherald/server/fastapi/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,7 @@
GoldenSQL,
GoldenSQLRequest,
InstructionRequest,
RefreshTableDescriptionRequest,
ScannerRequest,
TableDescriptionRequest,
UpdateInstruction,
Expand Down Expand Up @@ -93,6 +94,14 @@ def __init__(self, settings: Settings):
tags=["Table descriptions"],
)

self.router.add_api_route(
"/api/v1/table-descriptions/refresh",
self.refresh_table_description,
methods=["POST"],
status_code=201,
tags=["Table descriptions"],
)

self.router.add_api_route(
"/api/v1/table-descriptions/{table_description_id}",
self.update_table_description,
Expand Down Expand Up @@ -362,6 +371,11 @@ def scan_db(
) -> list[TableDescriptionResponse]:
return self._api.scan_db(scanner_request, background_tasks)

def refresh_table_description(
self, refresh_table_description: RefreshTableDescriptionRequest
):
return self._api.refresh_table_description(refresh_table_description)

def create_prompt(self, prompt_request: PromptRequest) -> PromptResponse:
return self._api.create_prompt(prompt_request)

Expand Down
4 changes: 4 additions & 0 deletions dataherald/types.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,10 @@ class InstructionRequest(DBConnectionValidation):
metadata: dict | None


class RefreshTableDescriptionRequest(DBConnectionValidation):
pass


class Instruction(BaseModel):
id: str | None = None
instruction: str
Expand Down

0 comments on commit 28b8130

Please sign in to comment.