Skip to content

Commit

Permalink
DH-5324/update the engine with latest models (#360)
Browse files Browse the repository at this point in the history
  • Loading branch information
MohammadrezaPourreza authored Jan 26, 2024
1 parent 26cb3d8 commit dd440f2
Show file tree
Hide file tree
Showing 9 changed files with 10 additions and 9 deletions.
2 changes: 1 addition & 1 deletion .env.example
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
# Openai info. All these fields are required for the engine to work.
OPENAI_API_KEY = #This field is required for the engine to work.
ORG_ID =
LLM_MODEL = 'gpt-4-1106-preview' #the openAI llm model that you want to use. possible values: gpt-4-1106-preview.
LLM_MODEL = 'gpt-4-turbo-preview' #the openAI llm model that you want to use. possible values: gpt-4-turbo-preview.

# All of our SQL generation agents are using different tools to generate SQL queries, in order to limit the number of times that agents can
# use different tools you can set the "AGENT_MAX_ITERATIONS" env variable. By default it is set to 20 iterations.
Expand Down
2 changes: 1 addition & 1 deletion README.md
Original file line number Diff line number Diff line change
Expand Up @@ -68,7 +68,7 @@ cp .env.example .env

Specifically the following 5 fields must be manually set before the engine is started.

LLM_MODEL is employed by the engine to generate SQL from natural language. You can use the default model (gpt-4-1106-preview) or use your own.
LLM_MODEL is employed by the engine to generate SQL from natural language. You can use the default model (gpt-4-turbo-preview) or use your own.

```
#OpenAI credentials and model
Expand Down
2 changes: 1 addition & 1 deletion dataherald/sql_generator/dataherald_finetuning_agent.py
Original file line number Diff line number Diff line change
Expand Up @@ -399,7 +399,7 @@ def generate_response(
self.llm = self.model.get_model(
database_connection=database_connection,
temperature=0,
model_name=os.getenv("LLM_MODEL", "gpt-4-1106-preview"),
model_name=os.getenv("LLM_MODEL", "gpt-4-turbo-preview"),
)
repository = TableDescriptionRepository(storage)
db_scan = repository.get_all_tables_by_db(
Expand Down
4 changes: 2 additions & 2 deletions dataherald/sql_generator/dataherald_sqlagent.py
Original file line number Diff line number Diff line change
Expand Up @@ -54,7 +54,7 @@


TOP_K = int(os.getenv("UPPER_LIMIT_QUERY_RETURN_ROWS", "50"))
EMBEDDING_MODEL = "text-embedding-ada-002"
EMBEDDING_MODEL = "text-embedding-3-large"


def catch_exceptions(): # noqa: C901
Expand Down Expand Up @@ -613,7 +613,7 @@ def generate_response(
self.llm = self.model.get_model(
database_connection=database_connection,
temperature=0,
model_name=os.getenv("LLM_MODEL", "gpt-4-1106-preview"),
model_name=os.getenv("LLM_MODEL", "gpt-4-turbo-preview"),
)
repository = TableDescriptionRepository(storage)
db_scan = repository.get_all_tables_by_db(
Expand Down
2 changes: 1 addition & 1 deletion dataherald/sql_generator/langchain_sqlagent.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,7 +38,7 @@ def generate_response(
self.llm = self.model.get_model(
database_connection=database_connection,
temperature=0,
model_name=os.getenv("LLM_MODEL", "gpt-4-1106-preview"),
model_name=os.getenv("LLM_MODEL", "gpt-4-turbo-preview"),
)
self.database = SQLDatabase.get_sql_engine(database_connection)
tools = SQLDatabaseToolkit(db=self.database, llm=self.llm).get_tools()
Expand Down
2 changes: 1 addition & 1 deletion dataherald/sql_generator/langchain_sqlchain.py
Original file line number Diff line number Diff line change
Expand Up @@ -55,7 +55,7 @@ def generate_response(
self.llm = self.model.get_model(
database_connection=database_connection,
temperature=0,
model_name=os.getenv("LLM_MODEL", "gpt-4-1106-preview"),
model_name=os.getenv("LLM_MODEL", "gpt-4-turbo-preview"),
)
self.database = SQLDatabase.get_sql_engine(database_connection)
logger.info(
Expand Down
2 changes: 1 addition & 1 deletion dataherald/sql_generator/llamaindex.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,7 +43,7 @@ def generate_response(
self.llm = self.model.get_model(
database_connection=database_connection,
temperature=0,
model_name=os.getenv("LLM_MODEL", "gpt-4-1106-preview"),
model_name=os.getenv("LLM_MODEL", "gpt-4-turbo-preview"),
)
token_counter = TokenCountingHandler(
tokenizer=tiktoken.encoding_for_model(self.llm.model_name).encode,
Expand Down
1 change: 1 addition & 0 deletions dataherald/utils/models_context_window.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,5 +12,6 @@
"gpt-3.5-turbo-16k-0613": 16000,
"gpt-3.5-turbo-0301": 4000,
"gpt-4-1106-preview": 128000,
"gpt-4-turbo-preview": 128000,
"gpt-3.5-turbo-1106": 16000,
}
2 changes: 1 addition & 1 deletion dataherald/vector_store/pinecone.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@
from dataherald.types import GoldenSQL
from dataherald.vector_store import VectorStore

EMBEDDING_MODEL = "text-embedding-ada-002"
EMBEDDING_MODEL = "text-embedding-3-small"


class Pinecone(VectorStore):
Expand Down

0 comments on commit dd440f2

Please sign in to comment.