Skip to content

Commit

Permalink
DH-4696/adding the API key as database connection inputs (#173)
Browse files Browse the repository at this point in the history
  • Loading branch information
MohammadrezaPourreza authored Sep 22, 2023
1 parent 5937b35 commit 2d9e873
Show file tree
Hide file tree
Showing 18 changed files with 218 additions and 32 deletions.
2 changes: 2 additions & 0 deletions dataherald/api/fastapi.py
Original file line number Diff line number Diff line change
Expand Up @@ -137,6 +137,7 @@ def create_database_connection(
alias=database_connection_request.alias,
uri=database_connection_request.connection_uri,
path_to_credentials_file=database_connection_request.path_to_credentials_file,
llm_credentials=database_connection_request.llm_credentials,
use_ssh=database_connection_request.use_ssh,
ssh_settings=database_connection_request.ssh_settings,
)
Expand Down Expand Up @@ -170,6 +171,7 @@ def update_database_connection(
alias=database_connection_request.alias,
uri=database_connection_request.connection_uri,
path_to_credentials_file=database_connection_request.path_to_credentials_file,
llm_credentials=database_connection_request.llm_credentials,
use_ssh=database_connection_request.use_ssh,
ssh_settings=database_connection_request.ssh_settings,
)
Expand Down
3 changes: 1 addition & 2 deletions dataherald/eval/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,8 +23,7 @@ class Evaluator(Component, ABC):

def __init__(self, system: System):
self.system = system
model = ChatModel(self.system)
self.llm = model.get_model(temperature=0)
self.model = ChatModel(self.system)

def get_confidence_score(
self,
Expand Down
4 changes: 4 additions & 0 deletions dataherald/eval/eval_agent.py
Original file line number Diff line number Diff line change
Expand Up @@ -168,6 +168,7 @@ def get_tools(self) -> List[BaseTool]:

class EvaluationAgent(Evaluator):
sample_rows: int = 10
llm: Any = None

def __init__(self, system: System):
super().__init__(system)
Expand Down Expand Up @@ -246,6 +247,9 @@ def evaluate(
logger.info(
f"Generating score for the question/sql pair: {str(question.question)}/ {str(generated_answer.sql_query)}"
)
self.llm = self.model.get_model(
database_connection=database_connection, temperature=0
)
database = SQLDatabase.get_sql_engine(database_connection)
user_question = question.question
sql = generated_answer.sql_query
Expand Down
6 changes: 6 additions & 0 deletions dataherald/eval/simple_evaluator.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
import logging
import re
import time
from typing import Any

from langchain.chains import LLMChain
from langchain.prompts.chat import (
Expand Down Expand Up @@ -52,6 +53,8 @@


class SimpleEvaluator(Evaluator):
llm: Any = None

def __init__(self, system: System):
super().__init__(system)
self.system = system
Expand Down Expand Up @@ -85,6 +88,9 @@ def evaluate(
logger.info(
f"(Simple evaluator) Generating score for the question/sql pair: {str(question.question)}/ {str(generated_answer.sql_query)}"
)
self.llm = self.model.get_model(
database_connection=database_connection, temperature=0
)
start_time = time.time()
system_message_prompt = SystemMessagePromptTemplate.from_template(
SYSTEM_TEMPLATE
Expand Down
8 changes: 7 additions & 1 deletion dataherald/model/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
from typing import Any

from dataherald.config import Component, System
from dataherald.sql_database.models.types import DatabaseConnection


class LLMModel(Component, ABC):
Expand All @@ -12,5 +13,10 @@ def __init__(self, system: System):
self.system = system

@abstractmethod
def get_model(self, **kwargs: Any) -> Any:
def get_model(
self,
database_connection: DatabaseConnection,
model_family="openai",
**kwargs: Any
) -> Any:
pass
20 changes: 19 additions & 1 deletion dataherald/model/base_models.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,8 @@
from overrides import override

from dataherald.model import LLMModel
from dataherald.sql_database.models.types import DatabaseConnection
from dataherald.utils.encrypt import FernetEncrypt


class BaseModel(LLMModel):
Expand All @@ -17,7 +19,23 @@ def __init__(self, system):
self.cohere_api_key = os.environ.get("COHERE_API_KEY")

@override
def get_model(self, **kwargs: Any) -> Any:
def get_model(
self,
database_connection: DatabaseConnection,
model_family="openai",
**kwargs: Any
) -> Any:
if database_connection.llm_credentials is not None:
fernet_encrypt = FernetEncrypt()
api_key = fernet_encrypt.decrypt(
database_connection.llm_credentials.api_key
)
if model_family == "openai":
self.openai_api_key = api_key
elif model_family == "anthropic":
self.anthropic_api_key = api_key
elif model_family == "google":
self.google_api_key = api_key
if self.openai_api_key:
self.model = OpenAI(model_name=self.model_name, **kwargs)
elif self.aleph_alpha_api_key:
Expand Down
32 changes: 28 additions & 4 deletions dataherald/model/chat_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,8 @@
from overrides import override

from dataherald.model import LLMModel
from dataherald.sql_database.models.types import DatabaseConnection
from dataherald.utils.encrypt import FernetEncrypt


class ChatModel(LLMModel):
Expand All @@ -16,13 +18,35 @@ def __init__(self, system):
self.google_api_key = os.environ.get("GOOGLE_API_KEY")

@override
def get_model(self, **kwargs: Any) -> Any:
def get_model(
self,
database_connection: DatabaseConnection,
model_family="openai",
**kwargs: Any
) -> Any:
if database_connection.llm_credentials is not None:
fernet_encrypt = FernetEncrypt()
api_key = fernet_encrypt.decrypt(
database_connection.llm_credentials.api_key
)
if model_family == "openai":
self.openai_api_key = api_key
elif model_family == "anthropic":
self.anthropic_api_key = api_key
elif model_family == "google":
self.google_api_key = api_key
if self.openai_api_key:
self.model = ChatOpenAI(model_name=self.model_name, **kwargs)
self.model = ChatOpenAI(
model_name=self.model_name, openai_api_key=self.openai_api_key, **kwargs
)
elif self.anthropic_api_key:
self.model = ChatAnthropic(model=self.model, **kwargs)
self.model = ChatAnthropic(
model=self.model, anthropic_api_key=self.anthropic_api_key, **kwargs
)
elif self.google_api_key:
self.model = ChatGooglePalm(model_name=self.model_name, **kwargs)
self.model = ChatGooglePalm(
model_name=self.model_name, google_api_key=self.google_api_key, **kwargs
)
else:
raise ValueError("No valid API key environment variable found")
return self.model
18 changes: 18 additions & 0 deletions dataherald/sql_database/models/types.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,23 @@
from dataherald.utils.encrypt import FernetEncrypt


class LLMCredentials(BaseSettings):
organization_id: str | None
api_key: str | None

@validator("api_key", "organization_id", pre=True, always=True)
def encrypt(cls, value: str):
fernet_encrypt = FernetEncrypt()
try:
fernet_encrypt.decrypt(value)
return value
except Exception:
return fernet_encrypt.encrypt(value)

def __getitem__(self, key: str) -> Any:
return getattr(self, key)


class SSHSettings(BaseSettings):
db_name: str | None
host: str | None
Expand Down Expand Up @@ -39,6 +56,7 @@ class DatabaseConnection(BaseModel):
use_ssh: bool = False
uri: str | None
path_to_credentials_file: str | None
llm_credentials: LLMCredentials | None = None
ssh_settings: SSHSettings | None = None

@validator("uri", pre=True, always=True)
Expand Down
3 changes: 1 addition & 2 deletions dataherald/sql_generator/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,8 +22,7 @@ class SQLGenerator(Component, ABC):

def __init__(self, system: System): # noqa: ARG002
self.system = system
model = ChatModel(self.system)
self.llm = model.get_model(temperature=0)
self.model = ChatModel(self.system)

def create_sql_query_status(
self, db: SQLDatabase, query: str, response: NLQueryResponse
Expand Down
5 changes: 5 additions & 0 deletions dataherald/sql_generator/dataherald_sqlagent.py
Original file line number Diff line number Diff line change
Expand Up @@ -502,6 +502,7 @@ class DataheraldSQLAgent(SQLGenerator):
"""Dataherald SQL agent"""

max_number_of_examples: int = 100 # maximum number of question/SQL pairs
llm: Any = None

def remove_duplicate_examples(self, fewshot_exmaples: List[dict]) -> List[dict]:
returned_result = []
Expand Down Expand Up @@ -569,6 +570,10 @@ def generate_response(
start_time = time.time()
context_store = self.system.instance(ContextStore)
storage = self.system.instance(DB)
self.llm = self.model.get_model(
database_connection=database_connection,
temperature=0,
)
repository = DBScannerRepository(storage)
db_scan = repository.get_all_tables_by_db(
db_connection_id=database_connection.id
Expand Down
7 changes: 5 additions & 2 deletions dataherald/sql_generator/generates_nl_answer.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,8 +27,7 @@ class GeneratesNlAnswer:
def __init__(self, system, storage):
self.system = system
self.storage = storage
model = ChatModel(self.system)
self.llm = model.get_model(temperature=0)
self.model = ChatModel(self.system)

def execute(self, nl_query_response: NLQueryResponse) -> NLQueryResponse:
nl_question_repository = NLQuestionRepository(self.storage)
Expand All @@ -40,6 +39,10 @@ def execute(self, nl_query_response: NLQueryResponse) -> NLQueryResponse:
database_connection = db_connection_repository.find_by_id(
nl_question.db_connection_id
)
self.llm = self.model.get_model(
database_connection=database_connection,
temperature=0,
)
database = SQLDatabase.get_sql_engine(database_connection)
nl_query_response = create_sql_query_status(
database, nl_query_response.sql_query, nl_query_response
Expand Down
7 changes: 6 additions & 1 deletion dataherald/sql_generator/langchain_sqlagent.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@

import logging
import time
from typing import List
from typing import Any, List

from langchain.agents import initialize_agent
from langchain.agents.agent_toolkits import SQLDatabaseToolkit
Expand All @@ -20,6 +20,8 @@


class LangChainSQLAgentSQLGenerator(SQLGenerator):
llm: Any | None = None

@override
def generate_response(
self,
Expand All @@ -28,6 +30,9 @@ def generate_response(
context: List[dict] = None,
) -> NLQueryResponse: # type: ignore
logger.info(f"Generating SQL response to question: {str(user_question.dict())}")
self.llm = self.model.get_model(
database_connection=database_connection, temperature=0
)
self.database = SQLDatabase.get_sql_engine(database_connection)
tools = SQLDatabaseToolkit(db=self.database, llm=self.llm).get_tools()
start_time = time.time()
Expand Down
7 changes: 6 additions & 1 deletion dataherald/sql_generator/langchain_sqlchain.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@

import logging
import time
from typing import List
from typing import Any, List

from langchain import SQLDatabaseChain
from langchain.callbacks import get_openai_callback
Expand Down Expand Up @@ -38,6 +38,8 @@


class LangChainSQLChainSQLGenerator(SQLGenerator):
llm: Any | None = None

@override
def generate_response(
self,
Expand All @@ -46,6 +48,9 @@ def generate_response(
context: List[dict] = None,
) -> NLQueryResponse:
start_time = time.time()
self.llm = self.model.get_model(
database_connection=database_connection, temperature=0
)
self.database = SQLDatabase.get_sql_engine(database_connection)
logger.info(
f"Generating SQL response to question: {str(user_question.dict())} with passed context {context}"
Expand Down
7 changes: 6 additions & 1 deletion dataherald/sql_generator/llamaindex.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@

import logging
import time
from typing import List
from typing import Any, List

import tiktoken
from langchain.callbacks.openai_info import MODEL_COST_PER_1K_TOKENS
Expand All @@ -26,6 +26,8 @@


class LlamaIndexSQLGenerator(SQLGenerator):
llm: Any | None = None

@override
def generate_response(
self,
Expand All @@ -35,6 +37,9 @@ def generate_response(
) -> NLQueryResponse:
start_time = time.time()
logger.info(f"Generating SQL response to question: {str(user_question.dict())}")
self.llm = self.model.get_model(
database_connection=database_connection, temperature=0
)
token_counter = TokenCountingHandler(
tokenizer=tiktoken.encoding_for_model(self.llm.model_name).encode,
verbose=False, # set to true to see usage printed to the console
Expand Down
3 changes: 2 additions & 1 deletion dataherald/types.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
from bson.objectid import ObjectId
from pydantic import BaseModel, validator

from dataherald.sql_database.models.types import SSHSettings
from dataherald.sql_database.models.types import LLMCredentials, SSHSettings


class DBConnectionValidation(BaseModel):
Expand Down Expand Up @@ -95,6 +95,7 @@ class DatabaseConnectionRequest(BaseModel):
use_ssh: bool = False
connection_uri: str | None
path_to_credentials_file: str | None
llm_credentials: LLMCredentials | None
ssh_settings: SSHSettings | None


Expand Down
Loading

0 comments on commit 2d9e873

Please sign in to comment.