Skip to content

Commit

Permalink
[DH-5444] Implement Error Codes on Engine (#408)
Browse files Browse the repository at this point in the history
* [DH-5444] Implement Error Codes on Engine

* Move error code logic into utils package
  • Loading branch information
jcjc712 authored Feb 27, 2024
1 parent 4e09e63 commit 2c70f16
Show file tree
Hide file tree
Showing 7 changed files with 264 additions and 290 deletions.
485 changes: 196 additions & 289 deletions dataherald/api/fastapi.py

Large diffs are not rendered by default.

12 changes: 12 additions & 0 deletions dataherald/context_store/default.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,10 @@

from dataherald.config import System
from dataherald.context_store import ContextStore
from dataherald.repositories.database_connections import (
DatabaseConnectionNotFoundError,
DatabaseConnectionRepository,
)
from dataherald.repositories.golden_sqls import GoldenSQLRepository
from dataherald.repositories.instructions import InstructionRepository
from dataherald.types import GoldenSQL, GoldenSQLRequest, Prompt
Expand Down Expand Up @@ -66,6 +70,7 @@ def retrieve_context_for_question(
def add_golden_sqls(self, golden_sqls: List[GoldenSQLRequest]) -> List[GoldenSQL]:
"""Creates embeddings of the questions and adds them to the VectorDB. Also adds the golden sqls to the DB"""
golden_sqls_repository = GoldenSQLRepository(self.db)
db_connection_repository = DatabaseConnectionRepository(self.db)
stored_golden_sqls = []
for record in golden_sqls:
try:
Expand All @@ -74,6 +79,13 @@ def add_golden_sqls(self, golden_sqls: List[GoldenSQLRequest]) -> List[GoldenSQL
raise MalformedGoldenSQLError(
f"SQL {record.sql} is malformed. Please check the syntax."
) from e

db_connection = db_connection_repository.find_by_id(record.db_connection_id)
if not db_connection:
raise DatabaseConnectionNotFoundError(
f"Database connection not found, {record.db_connection_id}"
)

prompt_text = record.prompt_text
golden_sql = GoldenSQL(
prompt_text=prompt_text,
Expand Down
4 changes: 4 additions & 0 deletions dataherald/repositories/golden_sqls.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,10 @@
DB_COLLECTION = "golden_sqls"


class GoldenSQLNotFoundError(Exception):
pass


class GoldenSQLRepository:
def __init__(self, storage):
self.storage = storage
Expand Down
9 changes: 9 additions & 0 deletions dataherald/sql_database/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,10 @@ class EmptyDBError(Exception):
pass


class SSHInvalidDatabaseConnectionError(Exception):
pass


class DBConnections:
db_connections = {}

Expand Down Expand Up @@ -83,6 +87,11 @@ def get_sql_engine(
engine = cls.from_uri_ssh(database_info)
DBConnections.add(database_info.id, engine)
return engine
except Exception as e:
raise SSHInvalidDatabaseConnectionError(
f"Invalid SSH connection, {e}"
) from e
try:
db_uri = unquote(fernet_encrypt.decrypt(database_info.connection_uri))

file_path = database_info.path_to_credentials_file
Expand Down
6 changes: 5 additions & 1 deletion dataherald/sql_database/models/types.py
Original file line number Diff line number Diff line change
Expand Up @@ -71,6 +71,10 @@ def __getitem__(self, key: str) -> Any:
return getattr(self, key)


class InvalidURIFormatError(Exception):
pass


class DatabaseConnection(BaseModel):
id: str | None
alias: str
Expand All @@ -88,7 +92,7 @@ def validate_uri(cls, input_string):
pattern = r"([^:/]+)://([^/]+)/([^/]+)"
match = re.match(pattern, input_string)
if not match:
raise ValueError(f"Invalid URI format: {input_string}")
raise InvalidURIFormatError(f"Invalid URI format: {input_string}")

@validator("connection_uri", pre=True, always=True)
def connection_uri_format(cls, value: str):
Expand Down
7 changes: 7 additions & 0 deletions dataherald/types.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
from pydantic import BaseModel, Field, validator

from dataherald.sql_database.models.types import FileStorage, SSHSettings
from dataherald.utils.models_context_window import OPENAI_FINETUNING_MODELS_WINDOW_SIZES


class DBConnectionValidation(BaseModel):
Expand Down Expand Up @@ -128,6 +129,12 @@ class BaseLLM(BaseModel):
model_name: str | None = None
model_parameters: dict[str, str] | None = None

@validator("model_name")
def validate_model_name(cls, v: str | None):
if v and v not in OPENAI_FINETUNING_MODELS_WINDOW_SIZES:
raise ValueError(f"Model {v} not supported") # noqa: B904
return v


class Finetuning(BaseModel):
id: str | None = None
Expand Down
31 changes: 31 additions & 0 deletions dataherald/utils/error_codes.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,31 @@
from fastapi.responses import JSONResponse

ERROR_MAPPING = {
"InvalidId": "invalid_object_id",
"InvalidDBConnectionError": "invalid_database_connection",
"InvalidURIFormatError": "invalid_database_uri_format",
"SSHInvalidDatabaseConnectionError": "ssh_invalid_database_connection",
"EmptyDBError": "empty_database",
"DatabaseConnectionNotFoundError": "database_connection_not_found",
"GoldenSQLNotFoundError": "golden_sql_not_found",
"LLMNotSupportedError": "llm_model_not_supported",
"PromptNotFoundError": "prompt_not_found",
"SQLGenerationError": "sql_generation_not_created",
"SQLInjectionError": "sql_injection",
"SQLGenerationNotFoundError": "sql_generation_not_found",
"NLGenerationError": "nl_generation_not_created",
"MalformedGoldenSQLError": "invalid_golden_sql",
}


def error_response(error, detail: dict, default_error_code=""):
return JSONResponse(
status_code=400,
content={
"error_code": ERROR_MAPPING.get(
error.__class__.__name__, default_error_code
),
"message": str(error),
"detail": detail,
},
)

0 comments on commit 2c70f16

Please sign in to comment.