Skip to content

Commit

Permalink
DH-5567/adding intermediate steps as the response (#440)
Browse files Browse the repository at this point in the history
* DH-5567/adding intermediate steps as the response

* DH-5557add truncation

* Dh-5567/reformat
  • Loading branch information
MohammadrezaPourreza authored Mar 22, 2024
1 parent 88ee8fa commit 3dbd483
Show file tree
Hide file tree
Showing 12 changed files with 159 additions and 28 deletions.
1 change: 0 additions & 1 deletion dataherald/api/fastapi.py
Original file line number Diff line number Diff line change
Expand Up @@ -731,7 +731,6 @@ def create_prompt_and_sql_generation(
return error_response(
e, prompt_sql_generation_request.dict(), "sql_generation_not_created"
)

return SQLGenerationResponse(**sql_generation.dict())

@override
Expand Down
3 changes: 2 additions & 1 deletion dataherald/api/types/responses.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@

from dataherald.db_scanner.models.types import TableDescription
from dataherald.sql_database.models.types import DatabaseConnection
from dataherald.types import GoldenSQL, LLMConfig
from dataherald.types import GoldenSQL, IntermediateStep, LLMConfig


class BaseResponse(BaseModel):
Expand Down Expand Up @@ -33,6 +33,7 @@ class SQLGenerationResponse(BaseResponse):
status: str
completed_at: str | None
llm_config: LLMConfig | None
intermediate_steps: list[IntermediateStep] | None
sql: str | None
tokens_used: int | None
confidence_score: float | None
Expand Down
20 changes: 14 additions & 6 deletions dataherald/services/sql_generations.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,6 +48,17 @@ def generate_response_with_timeout(self, sql_generator, user_prompt, db_connecti
user_prompt=user_prompt, database_connection=db_connection
)

def update_the_initial_sql_generation(
self, initial_sql_generation: SQLGeneration, sql_generation: SQLGeneration
):
initial_sql_generation.sql = sql_generation.sql
initial_sql_generation.tokens_used = sql_generation.tokens_used
initial_sql_generation.completed_at = datetime.now()
initial_sql_generation.status = sql_generation.status
initial_sql_generation.error = sql_generation.error
initial_sql_generation.intermediate_steps = sql_generation.intermediate_steps
return self.sql_generation_repository.update(initial_sql_generation)

def create(
self, prompt_id: str, sql_generation_request: SQLGenerationRequest
) -> SQLGeneration:
Expand Down Expand Up @@ -153,12 +164,9 @@ def create(
)
initial_sql_generation.evaluate = sql_generation_request.evaluate
initial_sql_generation.confidence_score = confidence_score
initial_sql_generation.sql = sql_generation.sql
initial_sql_generation.tokens_used = sql_generation.tokens_used
initial_sql_generation.completed_at = datetime.now()
initial_sql_generation.status = sql_generation.status
initial_sql_generation.error = sql_generation.error
return self.sql_generation_repository.update(initial_sql_generation)
return self.update_the_initial_sql_generation(
initial_sql_generation, sql_generation
)

def start_streaming(
self, prompt_id: str, sql_generation_request: SQLGenerationRequest, queue: Queue
Expand Down
59 changes: 42 additions & 17 deletions dataherald/sql_generator/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@
from dataherald.sql_database.base import SQLDatabase, SQLInjectionError
from dataherald.sql_database.models.types import DatabaseConnection
from dataherald.sql_generator.create_sql_query_status import create_sql_query_status
from dataherald.types import LLMConfig, Prompt, SQLGeneration
from dataherald.types import IntermediateStep, LLMConfig, Prompt, SQLGeneration
from dataherald.utils.strings import contains_line_breaks


Expand Down Expand Up @@ -74,20 +74,6 @@ def create_sql_query_status(
) -> SQLGeneration:
return create_sql_query_status(db, query, sql_generation)

def format_intermediate_representations(
self, intermediate_representation: List[Tuple[AgentAction, str]]
) -> List[str]:
"""Formats the intermediate representation into a string."""
formatted_intermediate_representation = []
for item in intermediate_representation:
formatted_intermediate_representation.append(
f"Thought: '{str(item[0].log).split('Action:')[0]}'\n"
f"Action: '{item[0].tool}'\n"
f"Action Input: '{item[0].tool_input}'\n"
f"Observation: '{item[1]}'"
)
return formatted_intermediate_representation

def format_sql_query(self, sql_query: str) -> str:
comments = [
match.group() for match in re.finditer(r"--.*$", sql_query, re.MULTILINE)
Expand All @@ -110,14 +96,53 @@ def extract_query_from_intermediate_steps(
action = step[0]
if type(action) == AgentAction and action.tool == "SqlDbQuery":
if "SELECT" in self.format_sql_query(action.tool_input).upper():
sql_query = self.remove_markdown(sql_query)
sql_query = self.remove_markdown(action.tool_input)
if sql_query == "":
for step in intermediate_steps:
action = step[0]
if "SELECT" in action.tool_input.upper():
sql_query = self.remove_markdown(sql_query)
sql_query = self.remove_markdown(action.tool_input)
if not sql_query.upper().strip().startswith("SELECT"):
sql_query = ""
return sql_query

def construct_intermediate_steps(
self, intermediate_steps: List[Tuple[AgentAction, str]], suffix: str = ""
) -> List[IntermediateStep]:
"""Constructs the intermediate steps."""
formatted_intermediate_steps = []
for step in intermediate_steps:
if step[0].tool == "SqlDbQuery":
formatted_intermediate_steps.append(
IntermediateStep(
thought=str(step[0].log).split("Action:")[0],
action=step[0].tool,
action_input=step[0].tool_input,
observation="QUERY RESULTS ARE NOT STORED FOR PRIVACY REASONS.",
)
)
else:
formatted_intermediate_steps.append(
IntermediateStep(
thought=str(step[0].log).split("Action:")[0],
action=step[0].tool,
action_input=step[0].tool_input,
observation=self.truncate_observations(step[1]),
)
)
formatted_intermediate_steps[0].thought = suffix.split("Thought: ")[1].split(
"{agent_scratchpad}"
)[0]
return formatted_intermediate_steps

def truncate_observations(self, obervarion: str, max_length: int = 2000) -> str:
"""Truncate the tool input."""
return (
obervarion[:max_length] + "... (truncated)"
if len(obervarion) > max_length
else obervarion
)

@abstractmethod
def generate_response(
self,
Expand Down
5 changes: 4 additions & 1 deletion dataherald/sql_generator/dataherald_finetuning_agent.py
Original file line number Diff line number Diff line change
Expand Up @@ -228,7 +228,7 @@ async def _arun(
class QuerySQLDataBaseTool(BaseSQLDatabaseTool, BaseTool):
"""Tool for querying a SQL database."""

name = "ExecuteQuery"
name = "SqlDbQuery"
description = """
Input: SQL query.
Output: Result from the database or an error message if the query is incorrect.
Expand Down Expand Up @@ -591,6 +591,9 @@ def generate_response(
response.sql = replace_unprocessable_characters(sql_query)
response.tokens_used = cb.total_tokens
response.completed_at = datetime.datetime.now()
response.intermediate_steps = self.construct_intermediate_steps(
result["intermediate_steps"], FINETUNING_AGENT_SUFFIX
)
return self.create_sql_query_status(
self.database,
response.sql,
Expand Down
7 changes: 7 additions & 0 deletions dataherald/sql_generator/dataherald_sqlagent.py
Original file line number Diff line number Diff line change
Expand Up @@ -736,6 +736,13 @@ def generate_response(
response.sql = replace_unprocessable_characters(sql_query)
response.tokens_used = cb.total_tokens
response.completed_at = datetime.datetime.now()
if number_of_samples > 0:
suffix = SUFFIX_WITH_FEW_SHOT_SAMPLES
else:
suffix = SUFFIX_WITHOUT_FEW_SHOT_SAMPLES
response.intermediate_steps = self.construct_intermediate_steps(
result["intermediate_steps"], suffix=suffix
)
return self.create_sql_query_status(
self.database,
response.sql,
Expand Down
8 changes: 8 additions & 0 deletions dataherald/types.py
Original file line number Diff line number Diff line change
Expand Up @@ -177,13 +177,21 @@ class LLMConfig(BaseModel):
api_base: str | None = None


class IntermediateStep(BaseModel):
thought: str
action: str
action_input: str
observation: str


class SQLGeneration(BaseModel):
id: str | None = None
prompt_id: str
finetuning_id: str | None
low_latency_mode: bool = False
llm_config: LLMConfig | None
evaluate: bool = False
intermediate_steps: list[IntermediateStep] | None
sql: str | None
status: str = "INVALID"
completed_at: datetime | None
Expand Down
4 changes: 2 additions & 2 deletions dataherald/utils/agent_prompts.py
Original file line number Diff line number Diff line change
Expand Up @@ -121,7 +121,7 @@
If SQL results has None or NULL values, handle them by adding a WHERE clause to filter them out.
If SQL query doesn't follow the instructions or return incorrect results modify the SQL query to fit the instructions and fix the errors.
Only make minor modifications to the SQL query, do not change the SQL query completely.
You MUST always use the ExecuteQuery tool to make sure the SQL query is correct before returning it.
You MUST always use the SqlDbQuery tool to make sure the SQL query is correct before returning it.
### Instructions from the database administrator:
{admin_instructions}
Expand All @@ -134,7 +134,7 @@
#
Here is the plan you have to follow:
1) Use the `GenerateSql` tool to generate a SQL query for the given question.
2) Always Use the `ExecuteQuery` tool to execute the SQL query on the database to check if the results are correct.
2) Always Use the `SqlDbQuery` tool to execute the SQL query on the database to check if the results are correct.
#
### Instructions from the database administrator:
Expand Down
16 changes: 16 additions & 0 deletions docs/api.create_prompt_sql_generation.rst
Original file line number Diff line number Diff line change
Expand Up @@ -67,6 +67,13 @@ HTTP 201 code response
"llm_name": "gpt-4-turbo-preview",
"api_base": "string"
},
"intermediate_steps": [
{
"action": "string",
"action_input": "string",
"observation": "string"
}
],
"sql": "string",
"tokens_used": 0,
"confidence_score": 0,
Expand Down Expand Up @@ -113,6 +120,15 @@ HTTP 201 code response
"llm_name": "mistralai/Mixtral-8x7B-Instruct-v0.1",
"api_base": "https://tt5h145hsc119q-8000.proxy.runpod.net/v1"
},
intermediate_steps": [
{
"thought": "I should Collect examples of Question/SQL pairs to check if there is a similar question among the examples.\n",
"action": "FewshotExamplesRetriever",
"action_input": "5",
"observation": "samples ... "
},
...
],
"sql": "SELECT metric_value \nFROM renthub_median_rent \nWHERE period_type = 'monthly' \nAND geo_type = 'city' \nAND location_name = 'Miami' \nAND property_type = 'All Residential' \nAND period_end = (SELECT DATE_TRUNC('MONTH', CURRENT_DATE()) - INTERVAL '1 day')\nLIMIT 10",
"tokens_used": 18115,
"confidence_score": 0.95,
Expand Down
16 changes: 16 additions & 0 deletions docs/api.create_sql_generation.rst
Original file line number Diff line number Diff line change
Expand Up @@ -61,6 +61,13 @@ HTTP 201 code response
"llm_name": "gpt-4-turbo-preview",
"api_base": "string"
},
"intermediate_steps": [
{
"action": "string",
"action_input": "string",
"observation": "string"
}
],
"sql": "string",
"tokens_used": 0,
"confidence_score": 0,
Expand Down Expand Up @@ -102,6 +109,15 @@ HTTP 201 code response
"llm_name": "mistralai/Mixtral-8x7B-Instruct-v0.1",
"api_base": "https://tt5h145hsc119q-8000.proxy.runpod.net/v1"
},
intermediate_steps": [
{
"thought": "I should Collect examples of Question/SQL pairs to check if there is a similar question among the examples.\n",
"action": "FewshotExamplesRetriever",
"action_input": "5",
"observation": "samples ... "
},
...
],
"sql": "SELECT metric_value \nFROM renthub_median_rent \nWHERE period_type = 'monthly' \nAND geo_type = 'city' \nAND location_name = 'Miami' \nAND property_type = 'All Residential' \nAND period_end = (SELECT DATE_TRUNC('MONTH', CURRENT_DATE()) - INTERVAL '1 day')\nLIMIT 10",
"tokens_used": 18115,
"confidence_score": null,
Expand Down
24 changes: 24 additions & 0 deletions docs/api.get_sql_generation.rst
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,17 @@ HTTP 200 code response
"finetuning_id": "string",
"status": "string",
"completed_at": "string",
"llm_config": {
"llm_name": "gpt-4-turbo-preview",
"api_base": "string"
},
"intermediate_steps": [
{
"action": "string",
"action_input": "string",
"observation": "string"
}
],
"sql": "string",
"tokens_used": 0,
"confidence_score": 0,
Expand All @@ -54,6 +65,19 @@ HTTP 200 code response
"finetuning_id": null,
"status": "VALID",
"completed_at": "2024-01-04 21:11:27.235000+00:00",
"llm_config": {
"llm_name": "gpt-4-turbo-preview",
"api_base": null
},
"intermediate_steps": [
{
"thought": "I should Collect examples of Question/SQL pairs to check if there is a similar question among the examples.\n",
"action": "FewshotExamplesRetriever",
"action_input": "5",
"observation": "Found 5 examples of similar questions."
},
...
],
"sql": "\nSELECT dh_zip_code, MAX(metric_value) as highest_rent -- Select the zip code and the maximum rent value\nFROM renthub_average_rent\nWHERE dh_county_name = 'Los Angeles' -- Filter for Los Angeles county\nAND period_start <= '2022-05-01' -- Filter for the period that starts on or before May 1st, 2022\nAND period_end >= '2022-05-31' -- Filter for the period that ends on or after May 31st, 2022\nGROUP BY dh_zip_code -- Group by zip code to aggregate rent values\nORDER BY highest_rent DESC -- Order by the highest rent in descending order\nLIMIT 1; -- Limit to the top result\n",
"tokens_used": 12185,
"confidence_score": null,
Expand Down
24 changes: 24 additions & 0 deletions docs/api.list_sql_generations.rst
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,17 @@ HTTP 200 code response
"finetuning_id": "string",
"status": "string",
"completed_at": "string",
"llm_config": {
"llm_name": "gpt-4-turbo-preview",
"api_base": "string"
},
"intermediate_steps": [
{
"action": "string",
"action_input": "string",
"observation": "string"
}
],
"sql": "string",
"tokens_used": 0,
"confidence_score": 0,
Expand Down Expand Up @@ -58,6 +69,19 @@ HTTP 200 code response
"finetuning_id": null,
"status": "VALID",
"completed_at": "2024-01-03 18:54:55.091000+00:00",
"llm_config": {
"llm_name": "gpt-4-turbo-preview",
"api_base": null
},
"intermediate_steps": [
{
"thought": "I should Collect examples of Question/SQL pairs to check if there is a similar question among the examples.\n",
"action": "FewshotExamplesRetriever",
"action_input": "5",
"observation": "Found 5 examples of similar questions."
},
...
],
"sql": "\nSELECT metric_value -- Rent price\nFROM renthub_median_rent\nWHERE geo_type='city' -- Focusing on city-level data\n AND dh_state_name = 'California' -- State is California\n AND dh_place_name = 'Los Angeles' -- City is Los Angeles\n AND period_start = '2023-06-01' -- Most recent data available\nORDER BY metric_value DESC -- In case there are multiple entries, order by price descending\nLIMIT 1; -- Only need the top result\n",
"tokens_used": 9491,
"confidence_score": null,
Expand Down

0 comments on commit 3dbd483

Please sign in to comment.