Skip to content

Commit

Permalink
[DH-5261] Improve table description performance (#370)
Browse files Browse the repository at this point in the history
  • Loading branch information
jcjc712 authored Jan 31, 2024
1 parent 7943992 commit 435884e
Show file tree
Hide file tree
Showing 4 changed files with 21 additions and 64 deletions.
7 changes: 3 additions & 4 deletions dataherald/api/fastapi.py
Original file line number Diff line number Diff line change
Expand Up @@ -142,9 +142,8 @@ def scan_db(
status_code=400,
detail=f"Unable to connect to db: {scanner_request.db_connection_id}, {e}",
)
all_tables = database.get_tables_and_views()

scanner = self.system.instance(Scanner)
all_tables = scanner.get_all_tables_and_views(database)
if scanner_request.table_names:
for table in scanner_request.table_names:
if table not in all_tables:
Expand All @@ -154,6 +153,7 @@ def scan_db(
else:
scanner_request.table_names = all_tables

scanner = self.system.instance(Scanner)
rows = scanner.synchronizing(
scanner_request,
TableDescriptionRepository(self.storage),
Expand Down Expand Up @@ -271,9 +271,8 @@ def list_table_descriptions(
db_connection_repository = DatabaseConnectionRepository(self.storage)
db_connection = db_connection_repository.find_by_id(db_connection_id)
database = SQLDatabase.get_sql_engine(db_connection)
all_tables = database.get_tables_and_views()

scanner = self.system.instance(Scanner)
all_tables = scanner.get_all_tables_and_views(database)
if table_name:
all_tables = [table for table in all_tables if table == table_name]

Expand Down
4 changes: 0 additions & 4 deletions dataherald/db_scanner/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,3 @@ def synchronizing(
repository: TableDescriptionRepository,
) -> list[TableDescription]:
""" "Update table_description status"""

@abstractmethod
def get_all_tables_and_views(self, database: SQLDatabase) -> list[str]:
""" "Retrieve all tables and views"""
7 changes: 0 additions & 7 deletions dataherald/db_scanner/sqlalchemy.py
Original file line number Diff line number Diff line change
Expand Up @@ -57,13 +57,6 @@ def synchronizing(
)
return rows

@override
def get_all_tables_and_views(self, database: SQLDatabase) -> list[str]:
inspector = inspect(database.engine)
meta = MetaData(bind=database.engine)
MetaData.reflect(meta, views=True)
return inspector.get_table_names() + inspector.get_view_names()

def get_table_examples(
self, meta: MetaData, db_engine: SQLDatabase, table: str, rows_number: int = 3
) -> List[Any]:
Expand Down
67 changes: 18 additions & 49 deletions dataherald/sql_database/base.py
Original file line number Diff line number Diff line change
@@ -1,12 +1,11 @@
"""SQL wrapper around SQLDatabase in langchain."""
import logging
import re
from typing import Any, List
from typing import List
from urllib.parse import unquote

import sqlparse
from langchain.sql_database import SQLDatabase as LangchainSQLDatabase
from sqlalchemy import MetaData, create_engine, text
from sqlalchemy import MetaData, create_engine, inspect, text
from sqlalchemy.engine import Engine
from sshtunnel import SSHTunnelForwarder

Expand Down Expand Up @@ -34,40 +33,27 @@ def add(uri, engine):
DBConnections.db_connections[uri] = engine


class SQLDatabase(LangchainSQLDatabase):
"""SQL Database.
Wrapper around SQLDatabase object from langchain. Offers
some helper utilities for insertion and querying.
See `langchain documentation <https://tinyurl.com/4we5ku8j>`_ for more details:
Args:
*args: Arguments to pass to langchain SQLDatabase.
**kwargs: Keyword arguments to pass to langchain SQLDatabase.
"""
class SQLDatabase:
def __init__(self, engine: Engine):
"""Create engine from database URI."""
self._engine = engine

@property
def engine(self) -> Engine:
"""Return SQL Alchemy engine."""
return self._engine

@property
def metadata_obj(self) -> MetaData:
"""Return SQL Alchemy metadata."""
return self._metadata

@classmethod
def from_uri(
cls, database_uri: str, engine_args: dict | None = None, **kwargs: Any
cls, database_uri: str, engine_args: dict | None = None
) -> "SQLDatabase":
"""Construct a SQLAlchemy engine from URI."""
_engine_args = engine_args or {}
if database_uri.lower().startswith("duckdb"):
config = {"autoload_known_extensions": False}
_engine_args["connect_args"] = {"config": config}
engine = create_engine(database_uri, **_engine_args)
return cls(engine, **kwargs)
return cls(engine)

@classmethod
def get_sql_engine(
Expand Down Expand Up @@ -208,30 +194,13 @@ def run_sql(self, command: str, top_k: int = None) -> tuple[str, dict]:
return str(result), {"result": result}
return "", {}

# from llama-index's sql-wrapper
def get_table_columns(self, table_name: str) -> List[Any]:
"""Get table columns."""
return self._inspector.get_columns(table_name)

# from llama-index's sql-wrapper
def get_single_table_info(self, table_name: str) -> str:
"""Get table info for a single table."""
# same logic as table_info, but with specific table names
template = (
"Table '{table_name}' has columns: {columns} "
"and foreign keys: {foreign_keys}."
)
columns = []
for column in self._inspector.get_columns(table_name):
columns.append(f"{column['name']} ({str(column['type'])})")
column_str = ", ".join(columns)
foreign_keys = []
for foreign_key in self._inspector.get_foreign_keys(table_name):
foreign_keys.append(
f"{foreign_key['constrained_columns']} -> "
f"{foreign_key['referred_table']}.{foreign_key['referred_columns']}"
)
foreign_key_str = ", ".join(foreign_keys)
return template.format(
table_name=table_name, columns=column_str, foreign_keys=foreign_key_str
)
def get_tables_and_views(self) -> List[str]:
inspector = inspect(self._engine)
meta = MetaData(bind=self._engine)
MetaData.reflect(meta, views=True)
return inspector.get_table_names() + inspector.get_view_names()

@property
def dialect(self) -> str:
"""Return string representation of dialect to use."""
return self._engine.dialect.name

0 comments on commit 435884e

Please sign in to comment.