diff --git a/comps/agent/langchain/README.md b/comps/agent/langchain/README.md index 585ff5d964..67bce6a636 100644 --- a/comps/agent/langchain/README.md +++ b/comps/agent/langchain/README.md @@ -6,15 +6,16 @@ This agent microservice is built on Langchain/Langgraph frameworks. Agents integ ### 1.1 Supported agent types -We currently support the following types of agents: +We currently support the following types of agents. Please refer to the example config yaml (links in the table in [Section 1.2](#12-llm-engine)) for each agent strategy to see what environment variables need to be set up. 1. ReAct: use `react_langchain` or `react_langgraph` or `react_llama` as strategy. First introduced in this seminal [paper](https://arxiv.org/abs/2210.03629). The ReAct agent engages in "reason-act-observe" cycles to solve problems. Please refer to this [doc](https://python.langchain.com/v0.2/docs/how_to/migrate_agent/) to understand the differences between the langchain and langgraph versions of react agents. See table below to understand the validated LLMs for each react strategy. 2. RAG agent: use `rag_agent` or `rag_agent_llama` strategy. This agent is specifically designed for improving RAG performance. It has the capability to rephrase query, check relevancy of retrieved context, and iterate if context is not relevant. See table below to understand the validated LLMs for each rag agent strategy. 3. Plan and execute: `plan_execute` strategy. This type of agent first makes a step-by-step plan given a user request, and then execute the plan sequentially (or in parallel, to be implemented in future). If the execution results can solve the problem, then the agent will output an answer; otherwise, it will replan and execute again. +4. SQL agent: use `sql_agent_llama` or `sql_agent` strategy. This agent is specifically designed and optimized for answering questions aabout data in SQL databases. Users need to specify `db_name` and `db_path` for the agent to access the SQL database. For more technical details read descriptions [here](src/strategy/sqlagent/README.md). **Note**: -1. Due to the limitations in support for tool calling by TGI and vllm, we have developed subcategories of agent strategies (`rag_agent_llama` and `react_llama`) specifically designed for open-source LLMs served with TGI and vllm. +1. Due to the limitations in support for tool calling by TGI and vllm, we have developed subcategories of agent strategies (`rag_agent_llama`, `react_llama` and `sql_agent_llama`) specifically designed for open-source LLMs served with TGI and vllm. 2. For advanced developers who want to implement their own agent strategies, please refer to [Section 5](#5-customize-agent-strategy) below. ### 1.2 LLM engine @@ -25,14 +26,16 @@ Agents use LLM for reasoning and planning. We support 3 options of LLM engine: 2. Open-source LLMs served with vllm. Follow the instructions in [Section 2.2.2](#222-start-agent-microservices-with-vllm). 3. OpenAI LLMs via API calls. To use OpenAI llms, specify `llm_engine=openai` and `export OPENAI_API_KEY=` -| Agent type | `strategy` arg | Validated LLMs (serving SW) | Notes | -| ---------------- | ----------------- | -------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------- | ------------------------------------------------------------------------------------------------------------------------------- | -| ReAct | `react_langchain` | [llama3.1-70B-Instruct](https://huggingface.co/meta-llama/Llama-3.1-70B-Instruct) (tgi-gaudi) | Only allows tools with one input variable | -| ReAct | `react_langgraph` | GPT-4o-mini, [Mistral-7B-Instruct-v0.3](https://huggingface.co/mistralai/Mistral-7B-Instruct-v0.3) (vllm-gaudi), | if using vllm, need to specify `--enable-auto-tool-choice --tool-call-parser ${model_parser}`, refer to vllm docs for more info | -| ReAct | `react_llama` | [llama3.1-70B-Instruct](https://huggingface.co/meta-llama/Llama-3.1-70B-Instruct) (tgi-gaudi) | Recommended for open-source LLMs | -| RAG agent | `rag_agent` | GPT-4o-mini | | -| RAG agent | `rag_agent_llama` | [llama3.1-70B-Instruct](https://huggingface.co/meta-llama/Llama-3.1-70B-Instruct) (tgi-gaudi) | Recommended for open-source LLMs, only allows 1 tool with input variable to be "query" | -| Plan and execute | `plan_execute` | GPT-4o-mini, [Mistral-7B-Instruct-v0.3](https://huggingface.co/mistralai/Mistral-7B-Instruct-v0.3) (vllm-gaudi), [Llama-3.1-8B-Instruct](https://huggingface.co/meta-llama/Llama-3.1-8B-Instruct) (vllm-gaudi) | | +| Agent type | `strategy` arg | Validated LLMs (serving SW) | Notes | Example config yaml | +| ---------------- | ----------------- | ------------------------------------------------------------------------------------------------------------ | ------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------ | ----------------------------------------------------------------- | +| ReAct | `react_langchain` | [llama3.1-70B-Instruct](https://huggingface.co/meta-llama/Llama-3.1-70B-Instruct) (tgi-gaudi) (vllm-gaudi) | Only allows tools with one input variable | [react_langchain yaml](../../../tests/agent/react_langchain.yaml) | +| ReAct | `react_langgraph` | GPT-4o-mini, [llama3.1-70B-Instruct](https://huggingface.co/meta-llama/Llama-3.1-70B-Instruct) (vllm-gaudi), | if using vllm, need to specify `--enable-auto-tool-choice --tool-call-parser ${model_parser}`, refer to vllm docs for more info | [react_langgraph yaml](../../../tests/agent/react_vllm.yaml) | +| ReAct | `react_llama` | [llama3.1-70B-Instruct](https://huggingface.co/meta-llama/Llama-3.1-70B-Instruct) (tgi-gaudi) (vllm-gaudi) | Recommended for open-source LLMs, supports multiple tools and parallel tool calls. | [react_llama yaml](../../../tests/agent/reactllama.yaml) | +| RAG agent | `rag_agent` | GPT-4o-mini | | [rag_agent yaml](../../../tests/agent/ragagent_openai.yaml) | +| RAG agent | `rag_agent_llama` | [llama3.1-70B-Instruct](https://huggingface.co/meta-llama/Llama-3.1-70B-Instruct) (tgi-gaudi) (vllm-gaudi) | Recommended for open-source LLMs, only allows 1 tool with input variable to be "query" | [rag_agent_llama yaml](../../../tests/agent/ragagent.yaml) | +| Plan and execute | `plan_execute` | GPT-4o-mini, [llama3.1-70B-Instruct](https://huggingface.co/meta-llama/Llama-3.1-70B-Instruct) (vllm-gaudi) | use `--guided-decoding-backend lm-format-enforcer` when launching vllm. | [plan_execute yaml](../../../tests/agent/planexec_openai.yaml) | +| SQL agent | `sql_agent_llama` | [llama3.1-70B-Instruct](https://huggingface.co/meta-llama/Llama-3.1-70B-Instruct) (vllm-gaudi) | database query tool is natively integrated using Langchain's [QuerySQLDataBaseTool](https://python.langchain.com/api_reference/community/tools/langchain_community.tools.sql_database.tool.QuerySQLDatabaseTool.html). User can also register their own tools with this agent. | [sql_agent_llama yaml](../../../tests/agent/sql_agent_llama.yaml) | +| SQL agent | `sql_agent` | GPT-4o-mini | database query tool is natively integrated using Langchain's [QuerySQLDataBaseTool](https://python.langchain.com/api_reference/community/tools/langchain_community.tools.sql_database.tool.QuerySQLDatabaseTool.html). User can also register their own tools with this agent. | [sql_agent yaml](../../../tests/agent/sql_agent_openai.yaml) | ### 1.3 Tools @@ -57,25 +60,25 @@ Examples of how to register tools can be found in [Section 4](#-4-provide-your-o ```bash cd GenAIComps/ # back to GenAIComps/ folder -docker build -t opea/agent-langchain:latest -f comps/agent/langchain/Dockerfile . +docker build -t opea/agent-langchain:latest -f comps/agent/langchain/Dockerfile . --build-arg https_proxy=$https_proxy --build-arg http_proxy=$http_proxy ``` #### 2.2.1 Start Agent microservices with TGI ```bash export ip_address=$(hostname -I | awk '{print $1}') -export model=mistralai/Mistral-7B-Instruct-v0.3 +export model="meta-llama/Meta-Llama-3.1-70B-Instruct" export HUGGINGFACEHUB_API_TOKEN=${HUGGINGFACEHUB_API_TOKEN} export HF_TOKEN=${HUGGINGFACEHUB_API_TOKEN} -# TGI serving -docker run -d --runtime=habana --name "comps-tgi-gaudi-service" -p 8080:80 -v ./data:/data -e HF_TOKEN=$HF_TOKEN -e HABANA_VISIBLE_DEVICES=all -e OMPI_MCA_btl_vader_single_copy_mechanism=none --cap-add=sys_nice --ipc=host ghcr.io/huggingface/tgi-gaudi:latest --model-id $model --max-input-tokens 4096 --max-total-tokens 8092 +# TGI serving on 4 Gaudi2 cards +docker run -d --runtime=habana --name "comps-tgi-gaudi-service" -p 8080:80 -v ./data:/data -e HF_TOKEN=$HF_TOKEN -e HABANA_VISIBLE_DEVICES=all -e OMPI_MCA_btl_vader_single_copy_mechanism=none -e PT_HPU_ENABLE_LAZY_COLLECTIVES=true --cap-add=sys_nice --ipc=host ghcr.io/huggingface/tgi-gaudi:latest --model-id $model --max-input-tokens 8192 --max-total-tokens 16384 --sharded true --num-shard 4 # check status docker logs comps-tgi-gaudi-service -# Agent -docker run -d --runtime=runc --name="comps-langchain-agent-endpoint" -v $WORKPATH/comps/agent/langchain/tools:/home/user/comps/agent/langchain/tools -p 9090:9090 --ipc=host -e HUGGINGFACEHUB_API_TOKEN=${HUGGINGFACEHUB_API_TOKEN} -e model=${model} -e ip_address=${ip_address} -e strategy=react_langchain -e llm_endpoint_url=http://${ip_address}:8080 -e llm_engine=tgi -e recursion_limit=5 -e require_human_feedback=false -e tools=/home/user/comps/agent/langchain/tools/custom_tools.yaml opea/agent-langchain:latest +# Agent: react_llama strategy +docker run -d --runtime=runc --name="comps-langchain-agent-endpoint" -v $WORKPATH/comps/agent/langchain/tools:/home/user/comps/agent/langchain/tools -p 9090:9090 --ipc=host -e HUGGINGFACEHUB_API_TOKEN=${HUGGINGFACEHUB_API_TOKEN} -e model=${model} -e ip_address=${ip_address} -e strategy=react_llama -e llm_endpoint_url=http://${ip_address}:8080 -e llm_engine=tgi -e recursion_limit=15 -e require_human_feedback=false -e tools=/home/user/comps/agent/langchain/tools/custom_tools.yaml opea/agent-langchain:latest # check status docker logs comps-langchain-agent-endpoint @@ -85,24 +88,24 @@ docker logs comps-langchain-agent-endpoint ```bash export ip_address=$(hostname -I | awk '{print $1}') -export model=mistralai/Mistral-7B-Instruct-v0.3 +export model="meta-llama/Meta-Llama-3.1-70B-Instruct" export HUGGINGFACEHUB_API_TOKEN=${HUGGINGFACEHUB_API_TOKEN} export HF_TOKEN=${HUGGINGFACEHUB_API_TOKEN} export vllm_volume=${YOUR_LOCAL_DIR_FOR_MODELS} # build vLLM image -git clone https://github.com/HabanaAI/vllm-fork.git -cd ./vllm-fork +git clone https://github.com/vllm-project/vllm.git +cd ./vllm docker build -f Dockerfile.hpu -t opea/vllm-gaudi:latest --shm-size=128g . --build-arg https_proxy=$https_proxy --build-arg http_proxy=$http_proxy -# TGI serving -docker run -d --runtime=habana --rm --name "comps-vllm-gaudi-service" -p 8080:80 -v $vllm_volume:/data -e HF_TOKEN=$HF_TOKEN -e HF_HOME=/data -e OMPI_MCA_btl_vader_single_copy_mechanism=none -e PT_HPU_ENABLE_LAZY_COLLECTIVES=true -e http_proxy=$http_proxy -e https_proxy=$https_proxy -e VLLM_SKIP_WARMUP=true --cap-add=sys_nice --ipc=host opea/vllm-gaudi:latest --model ${model} --host 0.0.0.0 --port 80 --block-size 128 --max-num-seqs 4096 --max-seq_len-to-capture 8192 --enable-auto-tool-choice --tool-call-parser mistral +# vllm serving on 4 Gaudi2 cards +docker run -d --runtime=habana --rm --name "comps-vllm-gaudi-service" -p 8080:8000 -v $vllm_volume:/data -e HF_TOKEN=$HF_TOKEN -e HF_HOME=/data -e OMPI_MCA_btl_vader_single_copy_mechanism=none -e PT_HPU_ENABLE_LAZY_COLLECTIVES=true -e http_proxy=$http_proxy -e https_proxy=$https_proxy -e VLLM_SKIP_WARMUP=true --cap-add=sys_nice --ipc=host opea/vllm-gaudi:latest --model ${model} --max-seq-len-to-capture 16384 --enable-auto-tool-choice --tool-call-parser llama3_json --guided-decoding-backend lm-format-enforcer --tensor-parallel-size 4 # check status docker logs comps-vllm-gaudi-service # Agent -docker run -d --runtime=runc --name="comps-langchain-agent-endpoint" -v $WORKPATH/comps/agent/langchain/tools:/home/user/comps/agent/langchain/tools -p 9090:9090 --ipc=host -e HUGGINGFACEHUB_API_TOKEN=${HUGGINGFACEHUB_API_TOKEN} -e model=${model} -e ip_address=${ip_address} -e strategy=react_langgraph -e llm_endpoint_url=http://${ip_address}:8080 -e llm_engine=vllm -e recursion_limit=5 -e require_human_feedback=false -e tools=/home/user/comps/agent/langchain/tools/custom_tools.yaml opea/agent-langchain:latest +docker run -d --runtime=runc --name="comps-langchain-agent-endpoint" -v $WORKPATH/comps/agent/langchain/tools:/home/user/comps/agent/langchain/tools -p 9090:9090 --ipc=host -e HUGGINGFACEHUB_API_TOKEN=${HUGGINGFACEHUB_API_TOKEN} -e model=${model} -e ip_address=${ip_address} -e strategy=react_llama -e llm_endpoint_url=http://${ip_address}:8080 -e llm_engine=vllm -e recursion_limit=15 -e require_human_feedback=false -e tools=/home/user/comps/agent/langchain/tools/custom_tools.yaml opea/agent-langchain:latest # check status docker logs comps-langchain-agent-endpoint @@ -111,7 +114,7 @@ docker logs comps-langchain-agent-endpoint > debug mode > > ```bash -> docker run --rm --runtime=runc --name="comps-langchain-agent-endpoint" -v ./comps/agent/langchain/:/home/user/comps/agent/langchain/ -p 9090:9090 --ipc=host -e http_proxy=$http_proxy -e https_proxy=$https_proxy -e HUGGINGFACEHUB_API_TOKEN=${HUGGINGFACEHUB_API_TOKEN} -e model=${model} -e ip_address=${ip_address} -e strategy=react_langchain -e llm_endpoint_url=http://${ip_address}:8080 -e llm_engine=tgi -e recursion_limit=5 -e require_human_feedback=false -e tools=/home/user/comps/agent/langchain/tools/custom_tools.yaml opea/agent-langchain:latest +> docker run --rm --runtime=runc --name="comps-langchain-agent-endpoint" -v ./comps/agent/langchain/:/home/user/comps/agent/langchain/ -p 9090:9090 --ipc=host -e http_proxy=$http_proxy -e https_proxy=$https_proxy -e HUGGINGFACEHUB_API_TOKEN=${HUGGINGFACEHUB_API_TOKEN} -e model=${model} -e ip_address=${ip_address} -e strategy=react_llama -e llm_endpoint_url=http://${ip_address}:8080 -e llm_engine=vllm -e recursion_limit=15 -e require_human_feedback=false -e tools=/home/user/comps/agent/langchain/tools/custom_tools.yaml opea/agent-langchain:latest > ``` ## 🚀 3. Validate Microservice @@ -120,12 +123,12 @@ Once microservice starts, user can use below script to invoke. ```bash curl http://${ip_address}:9090/v1/chat/completions -X POST -H "Content-Type: application/json" -d '{ - "query": "What is the weather today in Austin?" + "query": "What is OPEA project?" }' # expected output -data: 'The temperature in Austin today is 78°F.' +data: 'The OPEA project is .....' # just showing partial example here. data: [DONE] @@ -186,7 +189,7 @@ def opea_rag_query(query): ```bash # Agent -docker run -d --runtime=runc --name="comps-langchain-agent-endpoint" -v my_tools:/home/user/comps/agent/langchain/tools -p 9090:9090 --ipc=host -e HUGGINGFACEHUB_API_TOKEN=${HUGGINGFACEHUB_API_TOKEN} -e model=${model} -e ip_address=${ip_address} -e strategy=react_langchain -e llm_endpoint_url=http://${ip_address}:8080 -e llm_engine=tgi -e recursive_limit=5 -e require_human_feedback=false -e tools=/home/user/comps/agent/langchain/tools/custom_tools.yaml opea/agent-langchain:latest +docker run -d --runtime=runc --name="comps-langchain-agent-endpoint" -v my_tools:/home/user/comps/agent/langchain/tools -p 9090:9090 --ipc=host -e HUGGINGFACEHUB_API_TOKEN=${HUGGINGFACEHUB_API_TOKEN} -e model=${model} -e ip_address=${ip_address} -e strategy=react_llama -e llm_endpoint_url=http://${ip_address}:8080 -e llm_engine=tgi -e recursive_limit=15 -e require_human_feedback=false -e tools=/home/user/comps/agent/langchain/tools/custom_tools.yaml opea/agent-langchain:latest ``` - validate with my_tools @@ -197,17 +200,10 @@ $ curl http://${ip_address}:9090/v1/chat/completions -X POST -H "Content-Type: a }' data: 'The Intel OPEA project is a initiative to incubate open source development of trusted, scalable open infrastructure for developer innovation and harness the potential value of generative AI. - - - - Thought: I now know the final answer. - - - - - - Thought: - - - -' -data: [DONE] - -$ curl http://${ip_address}:9090/v1/chat/completions -X POST -H "Content-Type: application/json" -d '{ - "query": "What is the weather today in Austin?" - }' -data: 'The weather information in Austin is not available from the Open Platform for Enterprise AI (OPEA). You may want to try checking another source such as a weather app or website. I apologize for not being able to find the information you were looking for. <|eot_id|>' - data: [DONE] ``` ## 5. Customize agent strategy For advanced developers who want to implement their own agent strategies, you can add a separate folder in `src\strategy`, implement your agent by inherit the `BaseAgent` class, and add your strategy into the `src\agent.py`. The architecture of this agent microservice is shown in the diagram below as a reference. -![Architecture Overview](agent_arch.jpg) +![Architecture Overview](assets/agent_arch.jpg) diff --git a/comps/agent/langchain/agent_arch.jpg b/comps/agent/langchain/assets/agent_arch.jpg similarity index 100% rename from comps/agent/langchain/agent_arch.jpg rename to comps/agent/langchain/assets/agent_arch.jpg diff --git a/comps/agent/langchain/assets/sql_agent.png b/comps/agent/langchain/assets/sql_agent.png new file mode 100644 index 0000000000..a4e2e3b33c Binary files /dev/null and b/comps/agent/langchain/assets/sql_agent.png differ diff --git a/comps/agent/langchain/assets/sql_agent_llama.png b/comps/agent/langchain/assets/sql_agent_llama.png new file mode 100644 index 0000000000..6d832d4d4f Binary files /dev/null and b/comps/agent/langchain/assets/sql_agent_llama.png differ diff --git a/comps/agent/langchain/requirements.txt b/comps/agent/langchain/requirements.txt index ab3ff0c6bc..431a5060a4 100644 --- a/comps/agent/langchain/requirements.txt +++ b/comps/agent/langchain/requirements.txt @@ -1,11 +1,11 @@ # used by microservice docarray[full] - -#used by tools -duckduckgo-search fastapi huggingface_hub langchain + +#used by tools +langchain-google-community langchain-huggingface langchain-openai langchain_community diff --git a/comps/agent/langchain/src/agent.py b/comps/agent/langchain/src/agent.py index 0533826c59..a7713a29bb 100644 --- a/comps/agent/langchain/src/agent.py +++ b/comps/agent/langchain/src/agent.py @@ -33,5 +33,15 @@ def instantiate_agent(args, strategy="react_langchain", with_memory=False): from .strategy.ragagent import RAGAgent return RAGAgent(args, with_memory, custom_prompt=custom_prompt) + elif strategy == "sql_agent_llama": + print("Initializing SQL Agent Llama") + from .strategy.sqlagent import SQLAgentLlama + + return SQLAgentLlama(args, with_memory, custom_prompt=custom_prompt) + elif strategy == "sql_agent": + print("Initializing SQL Agent") + from .strategy.sqlagent import SQLAgent + + return SQLAgent(args, with_memory, custom_prompt=custom_prompt) else: raise ValueError(f"Agent strategy: {strategy} not supported!") diff --git a/comps/agent/langchain/src/config.py b/comps/agent/langchain/src/config.py index 4178e2d9f9..6bb1b12dd2 100644 --- a/comps/agent/langchain/src/config.py +++ b/comps/agent/langchain/src/config.py @@ -72,3 +72,16 @@ if os.environ.get("timeout") is not None: env_config += ["--timeout", os.environ["timeout"]] + +# for sql agent +if os.environ.get("db_path") is not None: + env_config += ["--db_path", os.environ["db_path"]] + +if os.environ.get("db_name") is not None: + env_config += ["--db_name", os.environ["db_name"]] + +if os.environ.get("use_hints") is not None: + env_config += ["--use_hints", os.environ["use_hints"]] + +if os.environ.get("hints_file") is not None: + env_config += ["--hints_file", os.environ["hints_file"]] diff --git a/comps/agent/langchain/src/strategy/base_agent.py b/comps/agent/langchain/src/strategy/base_agent.py index beb4fa9f8f..8c0048b879 100644 --- a/comps/agent/langchain/src/strategy/base_agent.py +++ b/comps/agent/langchain/src/strategy/base_agent.py @@ -36,5 +36,37 @@ def compile(self): def execute(self, state: dict): pass - def non_streaming_run(self, query, config): + def prepare_initial_state(self, query): raise NotImplementedError + + async def stream_generator(self, query, config): + initial_state = self.prepare_initial_state(query) + try: + async for event in self.app.astream(initial_state, config=config): + for node_name, node_state in event.items(): + yield f"--- CALL {node_name} ---\n" + for k, v in node_state.items(): + if v is not None: + yield f"{k}: {v}\n" + + yield f"data: {repr(event)}\n\n" + yield "data: [DONE]\n\n" + except Exception as e: + yield str(e) + + async def non_streaming_run(self, query, config): + initial_state = self.prepare_initial_state(query) + print("@@@ Initial State: ", initial_state) + try: + async for s in self.app.astream(initial_state, config=config, stream_mode="values"): + message = s["messages"][-1] + if isinstance(message, tuple): + print(message) + else: + message.pretty_print() + + last_message = s["messages"][-1] + print("******Response: ", last_message.content) + return last_message.content + except Exception as e: + return str(e) diff --git a/comps/agent/langchain/src/strategy/sqlagent/README.md b/comps/agent/langchain/src/strategy/sqlagent/README.md new file mode 100644 index 0000000000..d3e7d7e14f --- /dev/null +++ b/comps/agent/langchain/src/strategy/sqlagent/README.md @@ -0,0 +1,44 @@ +# SQL Agents + +We currently have two types of SQL agents: + +1. `sql_agent_llama`: for using with open-source LLMs, especially `meta-llama/Llama-3.1-70B-Instruct` model. +2. `sql_agent`: for using with OpenAI models, we developed and validated with GPT-4o-mini. + +## Overview of sql_agent_llama + +The architecture of `sql_agent_llama` is shown in the figure below. +The agent node takes user question, hints (optional) and history (when available), and thinks step by step to solve the problem. + +![SQL Agent Llama Architecture](../../../assets/sql_agent_llama.png) + +### Database schema: + +We use langchain's [SQLDatabase](https://python.langchain.com/docs/integrations/tools/sql_database/) API to get table names and schemas from the SQL database. User just need to specify `db_path` and `db_name`. The table schemas are incorporated into the prompts for the agent. + +### Hints module: + +If you want to use the hints module, you need to prepare a csv file that has 3 columns: `table_name`, `column_name`, `description`, and make this file available to the agent microservice. The `description` should include useful information (for example, domain knowledge) about a certain column in a table in the database. The hints module will pick up to five relevant columns together with their descriptions based on the user question using similarity search. The hints module will then pass these column descriptions to the agent node. + +### Output parser: + +Due to the current limitations of open source LLMs and serving frameworks (tgi and vllm) in generating tool call objects, we developed and optimized a custom output parser, together with our specially designed prompt templates. The output parser has 3 functions: + +1. Decide if a valid final answer presents in the raw agent output. This is needed because: a) we found sometimes agent would make guess or hallucinate data, so it is critical to double check, b) sometimes LLM does not strictly follow instructions on output format so simple string parsing can fail. We use one additional LLM call to perform this function. +2. Pick out tool calls from raw agent output. And check if the agent has made same tool calls before. If yes, remove the repeated tool calls. +3. Parse and review SQL query, and fix SQL query if there are errors. This proved to improve SQL agent performance since the initial query may contain errors and having a "second pair of eyes" can often spot the errors while the agent node itself may not be able to identify the errors in subsequent execution steps. + +## Overview of sql_agent + +The architecture of `sql_agent` is shown in the figure below. +The agent node takes user question, hints (optional) and history (when available), and thinks step by step to solve the problem. The basic idea is the same as `sql_agent_llama`. However, since OpenAI APIs produce well-structured tool call objects, we don't need a special output parser. Instead, we only keep the query fixer. + +![SQL Agent Architecture](../../../assets/sql_agent.png) + +## Limitations + +1. Agent is only allowed to issue "SELECT" commands to databases, i.e., agent can only query databases but cannot update databases. +2. We currently does not support "streaming" agent outputs on the fly for `sql_agent_llama`. +3. Users need to pass the SQL database URI to the agent with the `db_path` environment variable. We have only validated SQLite database connected in such way. + +Please submit issues if you want new features to be added. We also welcome community contributions! diff --git a/comps/agent/langchain/src/strategy/sqlagent/__init__.py b/comps/agent/langchain/src/strategy/sqlagent/__init__.py new file mode 100644 index 0000000000..f8bf69ff4f --- /dev/null +++ b/comps/agent/langchain/src/strategy/sqlagent/__init__.py @@ -0,0 +1,5 @@ +# Copyright (C) 2024 Intel Corporation +# SPDX-License-Identifier: Apache-2.0 + +from .planner import SQLAgentLlama +from .planner import SQLAgent diff --git a/comps/agent/langchain/src/strategy/sqlagent/hint.py b/comps/agent/langchain/src/strategy/sqlagent/hint.py new file mode 100644 index 0000000000..7704f55020 --- /dev/null +++ b/comps/agent/langchain/src/strategy/sqlagent/hint.py @@ -0,0 +1,58 @@ +# Copyright (C) 2024 Intel Corporation +# SPDX-License-Identifier: Apache-2.0 + +import glob +import os + +import pandas as pd + + +def read_hints(hints_file): + """ + hints_file: csv with columns: table_name, column_name, description + """ + hints_df = pd.read_csv(hints_file) + cols_descriptions = [] + values_descriptions = [] + for _, row in hints_df.iterrows(): + table_name = row["table_name"] + col_name = row["column_name"] + description = row["description"] + if not pd.isnull(description): + cols_descriptions.append(f"{table_name}.{col_name}: {description}") + values_descriptions.append(f"{col_name}: {description}") + return cols_descriptions, values_descriptions + + +def sort_list(list1, list2): + import numpy as np + + # Use numpy's argsort function to get the indices that would sort the second list + idx = np.argsort(list2) # ascending order + return np.array(list1)[idx].tolist()[::-1], np.array(list2)[idx].tolist()[::-1] # descending order + + +def get_topk_cols(topk, cols_descriptions, similarities): + sorted_cols, similarities = sort_list(cols_descriptions, similarities) + top_k_cols = sorted_cols[:topk] + output = [] + for col, sim in zip(top_k_cols, similarities[:topk]): + # print(f"{col}: {sim}") + if sim > 0.5: + output.append(col) + return output + + +def pick_hints(query, model, column_embeddings, complete_descriptions, topk=5): + if len(complete_descriptions) < topk: + topk_cols_descriptions = complete_descriptions + else: + # use similarity to get the topk columns + query_embedding = model.encode(query, convert_to_tensor=True) + similarities = model.similarity(query_embedding, column_embeddings).flatten() + topk_cols_descriptions = get_topk_cols(topk, complete_descriptions, similarities) + + hint = "" + for col in topk_cols_descriptions: + hint += col + "\n" + return hint diff --git a/comps/agent/langchain/src/strategy/sqlagent/planner.py b/comps/agent/langchain/src/strategy/sqlagent/planner.py new file mode 100644 index 0000000000..a74342a076 --- /dev/null +++ b/comps/agent/langchain/src/strategy/sqlagent/planner.py @@ -0,0 +1,324 @@ +# Copyright (C) 2024 Intel Corporation +# SPDX-License-Identifier: Apache-2.0 + +import json +import os +from typing import Annotated, Sequence, TypedDict + +from langchain_core.messages import AIMessage, BaseMessage, HumanMessage, SystemMessage, ToolMessage +from langchain_core.prompts import PromptTemplate +from langchain_core.runnables import RunnableLambda +from langgraph.graph import END, StateGraph +from langgraph.graph.message import add_messages +from langgraph.managed import IsLastStep +from langgraph.prebuilt import ToolNode + +from ...utils import setup_chat_model, tool_renderer +from ..base_agent import BaseAgent +from .hint import pick_hints, read_hints +from .prompt import AGENT_NODE_TEMPLATE, AGENT_SYSM, QUERYFIXER_PROMPT +from .sql_tools import get_sql_query_tool, get_table_schema +from .utils import ( + LlamaOutputParserAndQueryFixer, + assemble_history, + convert_json_to_tool_call, + remove_repeated_tool_calls, +) + + +class AgentState(TypedDict): + """The state of the agent.""" + + messages: Annotated[Sequence[BaseMessage], add_messages] + is_last_step: IsLastStep + hint: str + + +class AgentNodeLlama: + def __init__(self, args, tools): + self.llm = setup_chat_model(args) + self.args = args + # two types of tools: + # sql_db_query - always available, no need to specify + # other tools - user defined + # here, self.tools is a list of user defined tools + self.tools = tool_renderer(tools) + print("@@@@ Tools: ", self.tools) + + self.chain = self.llm + + self.output_parser = LlamaOutputParserAndQueryFixer(chat_model=self.llm) + + if args.use_hints: + from sentence_transformers import SentenceTransformer + + self.cols_descriptions, self.values_descriptions = read_hints(args.hints_file) + self.embed_model = SentenceTransformer("BAAI/bge-large-en-v1.5") + self.column_embeddings = self.embed_model.encode(self.values_descriptions) + print("Done embedding column descriptions") + + def __call__(self, state): + print("----------Call Agent Node----------") + question = state["messages"][0].content + table_schema, num_tables = get_table_schema(self.args.db_path) + print("@@@@ Table Schema: ", table_schema) + if self.args.use_hints: + if not state["hint"]: + hints = pick_hints(question, self.embed_model, self.column_embeddings, self.cols_descriptions) + else: + hints = state["hint"] + print("@@@ Hints: ", hints) + else: + hints = "" + + history = assemble_history(state["messages"]) + print("@@@ History: ", history) + + prompt = AGENT_NODE_TEMPLATE.format( + domain=self.args.db_name, + tools=self.tools, + num_tables=num_tables, + tables_schema=table_schema, + question=question, + hints=hints, + history=history, + ) + + output = self.chain.invoke(prompt) + output = self.output_parser.parse( + output.content, history, table_schema, hints, question, state["messages"] + ) # text: str, history: str, db_schema: str, hint: str + print("@@@@@ Agent output:\n", output) + + # convert output to tool calls + tool_calls = [] + for res in output: + if "tool" in res: + tool_call = convert_json_to_tool_call(res) + tool_calls.append(tool_call) + + # check if same tool calls have been made before + # if yes, then remove the repeated tool calls + if tool_calls: + new_tool_calls = remove_repeated_tool_calls(tool_calls, state["messages"]) + print("@@@@ New Tool Calls:\n", new_tool_calls) + else: + new_tool_calls = [] + + if new_tool_calls: + ai_message = AIMessage(content="", tool_calls=new_tool_calls) + elif tool_calls: + ai_message = AIMessage(content="Repeated previous steps.", tool_calls=tool_calls) + elif "answer" in output[0]: + ai_message = AIMessage(content=str(output[0]["answer"])) + else: + ai_message = AIMessage(content=str(output)) + + return {"messages": [ai_message], "hint": hints} + + +class SQLAgentLlama(BaseAgent): + # need new args: + # # db_name and db_path + # # use_hints, hints_file + def __init__(self, args, with_memory=False, **kwargs): + super().__init__(args, local_vars=globals(), **kwargs) + # note: here tools only include user defined tools + # we need to add the sql query tool as well + print("@@@@ user defined tools: ", self.tools_descriptions) + agent = AgentNodeLlama(args, self.tools_descriptions) + sql_tool = get_sql_query_tool(args.db_path) + print("@@@@ SQL Tool: ", sql_tool) + tools = self.tools_descriptions + [sql_tool] + print("@@@@ ALL Tools: ", tools) + tool_node = ToolNode(tools) + + workflow = StateGraph(AgentState) + + # Define the nodes we will cycle between + workflow.add_node("agent", agent) + workflow.add_node("tools", tool_node) + + workflow.set_entry_point("agent") + + workflow.add_conditional_edges( + "agent", + self.decide_next_step, + { + # If `tools`, then we call the tool node. + "tools": "tools", + "agent": "agent", + "end": END, + }, + ) + + # We now add a normal edge from `tools` to `agent`. + # This means that after `tools` is called, `agent` node is called next. + workflow.add_edge("tools", "agent") + + self.app = workflow.compile() + + def decide_next_step(self, state: AgentState): + messages = state["messages"] + last_message = messages[-1] + if last_message.tool_calls and last_message.content == "Repeated previous steps.": + print("@@@@ Repeated tool calls from previous steps, go back to agent") + return "agent" + elif last_message.tool_calls and last_message.content != "Repeated previous steps.": + print("@@@@ New Tool calls, go to tools") + return "tools" + else: + return "end" + + def prepare_initial_state(self, query): + return {"messages": [HumanMessage(content=query)], "is_last_step": IsLastStep(False), "hint": ""} + + +################################################ +# Below is SQL agent using OpenAI models +################################################ +class AgentNode: + def __init__(self, args, llm, tools): + self.llm = llm.bind_tools(tools) + self.args = args + if args.use_hints: + from sentence_transformers import SentenceTransformer + + self.cols_descriptions, self.values_descriptions = read_hints(args.hints_file) + self.embed_model = SentenceTransformer("BAAI/bge-large-en-v1.5") + self.column_embeddings = self.embed_model.encode(self.values_descriptions) + + def __call__(self, state): + print("----------Call Agent Node----------") + question = state["messages"][0].content + table_schema, num_tables = get_table_schema(self.args.db_path) + if self.args.use_hints: + if not state["hint"]: + hints = pick_hints(question, self.embed_model, self.column_embeddings, self.cols_descriptions) + else: + hints = state["hint"] + else: + hints = "" + + sysm = AGENT_SYSM.format(num_tables=num_tables, tables_schema=table_schema, question=question, hints=hints) + _system_message = SystemMessage(content=sysm) + state_modifier_runnable = RunnableLambda( + lambda state: [_system_message] + state["messages"], + name="StateModifier", + ) + + chain = state_modifier_runnable | self.llm + response = chain.invoke(state) + + return {"messages": [response], "hint": hints} + + +class QueryFixerNode: + def __init__(self, args, llm): + prompt = PromptTemplate( + template=QUERYFIXER_PROMPT, + input_variables=["DATABASE_SCHEMA", "QUESTION", "HINT", "QUERY", "RESULT"], + ) + self.chain = prompt | llm + self.args = args + + def get_sql_query_and_result(self, state): + messages = state["messages"] + assert isinstance(messages[-1], ToolMessage), "The last message should be a tool message" + result = messages[-1].content + id = messages[-1].tool_call_id + query = "" + for msg in reversed(messages): + if isinstance(msg, AIMessage) and msg.tool_calls: + if msg.tool_calls[0]["id"] == id: + query = msg.tool_calls[0]["args"]["query"] + break + print("@@@@ Executed SQL Query: ", query) + print("@@@@ Execution Result: ", result) + return query, result + + def __call__(self, state): + print("----------Call Query Fixer Node----------") + table_schema, _ = get_table_schema(self.args.db_path) + question = state["messages"][0].content + hint = state["hint"] + query, result = self.get_sql_query_and_result(state) + response = self.chain.invoke( + { + "DATABASE_SCHEMA": table_schema, + "QUESTION": question, + "HINT": hint, + "QUERY": query, + "RESULT": result, + } + ) + # print("@@@@@ Query fixer output:\n", response.content) + return {"messages": [response]} + + +class SQLAgent(BaseAgent): + def __init__(self, args, with_memory=False, **kwargs): + super().__init__(args, local_vars=globals(), **kwargs) + + sql_tool = get_sql_query_tool(args.db_path) + tools = self.tools_descriptions + [sql_tool] + print("@@@@ ALL Tools: ", tools) + + tool_node = ToolNode(tools) + agent = AgentNode(args, self.llm, tools) + query_fixer = QueryFixerNode(args, self.llm) + + workflow = StateGraph(AgentState) + + # Define the nodes we will cycle between + workflow.add_node("agent", agent) + workflow.add_node("query_fixer", query_fixer) + workflow.add_node("tools", tool_node) + + workflow.set_entry_point("agent") + + # We now add a conditional edge + workflow.add_conditional_edges( + "agent", + self.should_continue, + { + # If `tools`, then we call the tool node. + "continue": "tools", + "end": END, + }, + ) + + workflow.add_conditional_edges( + "tools", + self.should_go_to_query_fixer, + {"true": "query_fixer", "false": "agent"}, + ) + workflow.add_edge("query_fixer", "agent") + + self.app = workflow.compile() + + # Define the function that determines whether to continue or not + def should_continue(self, state: AgentState): + messages = state["messages"] + last_message = messages[-1] + # If there is no function call, then we finish + if not last_message.tool_calls: + return "end" + # Otherwise if there is, we continue + else: + return "continue" + + def should_go_to_query_fixer(self, state: AgentState): + messages = state["messages"] + last_message = messages[-1] + assert isinstance(last_message, ToolMessage), "The last message should be a tool message" + print("@@@@ Called Tool: ", last_message.name) + if last_message.name == "sql_db_query": + print("@@@@ Going to Query Fixer") + return "true" + else: + print("@@@@ Going back to Agent") + return "false" + + def prepare_initial_state(self, query): + return {"messages": [HumanMessage(content=query)], "is_last_step": IsLastStep(False), "hint": ""} diff --git a/comps/agent/langchain/src/strategy/sqlagent/prompt.py b/comps/agent/langchain/src/strategy/sqlagent/prompt.py new file mode 100644 index 0000000000..dae63766fc --- /dev/null +++ b/comps/agent/langchain/src/strategy/sqlagent/prompt.py @@ -0,0 +1,225 @@ +# Copyright (C) 2024 Intel Corporation +# SPDX-License-Identifier: Apache-2.0 + +AGENT_NODE_TEMPLATE = """\ +You are an SQL expert tasked with answering questions about {domain}. +In addition to the database, you have the following tools to gather information: +{tools} + +You can access a database that has {num_tables} tables. The schema of the tables is as follows. Read the schema carefully. +**Table Schema:** +{tables_schema} + +**Hints:** +{hints} + +When querying the database, remember the following: +1. Unless the user specifies a specific number of examples they wish to obtain, always limit your query to no more than 20 results. +2. Only query columns that are relevant to the question. Remember to also fetch the ranking or filtering columns to check if they contain nulls. +3. DO NOT make any DML statements (INSERT, UPDATE, DELETE, DROP etc.) to the database. + +**Output format:** +1. Write down your thinking process. +2. When querying the database, write your SQL query in the following format: +```sql +SELECT column1, column2, ... +``` +3. When making tool calls, you must use the following format. Make ONLY one tool call at a time. +TOOL CALL: {{"tool": "tool1", "args": {{"arg1": "value1", "arg2": "value2", ...}}}} + +4. After you have arrived at the answer with data and reasoning, write your final answer after "FINAL ANSWER". + +You have done the following steps so far: +**Your previous steps:** +{history} + +**IMPORTANT:** +* Review your previous steps carefully and utilize them to answer the question. Do not repeat your previous steps. +* The database may not have all the information needed to answer the question. Use the additional tools provided if necessary. +* If you did not get the answer at first, do not give up. Reflect on the steps that you have taken and try a different way. Think out of the box. + +Now take a deep breath and think step by step to answeer the following question. +Question: +{question} +""" + + +ANSWER_PARSER_PROMPT = """\ +Review the output from an SQL agent and determine if a correct answer has been provided and grounded on real data. + +Say "yes" when all the following conditions are met: +1. The answer is complete and does not require additional steps to be taken. +2. The answer does not have placeholders that need to be filled in. +3. The agent has acquired data from database and its execution history is Not empty. +4. If agent made mistakes in its execution history, the agent has corrected them. +5. If agent has tried to get data several times but cannot get all the data needed, the agent has come up with an answer based on available data and reasonable assumptions. + +If the conditions above are not met, say "no". + +Here is the output from the SQL agent: +{output} +====================== +Here is the agent execution history: +{history} +====================== + +Has a final answer been provided based on real data? Analyze the agent output and make your judgement "yes" or "no". +""" + + +SQL_QUERY_FIXER_PROMPT = """\ +You are an SQL database expert tasked with reviewing a SQL query written by an agent. +**Procedure:** +1. Review Database Schema: +- Examine the table creation statements to understand the database structure. +2. Review the Hint provided. +- Use the provided hints to understand the domain knowledge relevant to the query. +3. Check against the following common errors: +- Failure to exclude null values, ranking or filtering columns have nulls, syntax errors, incorrect table references, incorrect column references, logical mistakes. +4. Check if aggregation should be used: +- Read the user question, and determine if user is asking for specific instances or aggregated info. If aggregation is needed, check if the original SQL query has used appropriate functions like COUNT and SUM. +5. Correct the Query only when Necessary: +- If issues were identified, modify the SQL query to address the identified issues, ensuring it correctly fetches the requested data according to the database schema and query requirements. + +======= Your task ======= +************************** +Table creation statements +{DATABASE_SCHEMA} +************************** +Hint: +{HINT} +************************** +The SQL query to review: +{QUERY} +************************** +User question: +{QUESTION} +************************** + +Now analyze the SQL query step by step. Present your reasonings. + +If you identified issues in the original query, write down the corrected SQL query in the format below: +```sql +SELECT column1, column2, ... +``` + +If the original SQL query is correct, just say the query is correct. + +Note: Some user questions can only be answered partially with the database. This is OK. The agent may use other tools in subsequent steps to get additional info. In some cases, the agent may have got additional info with other tools and have incorporated those in its query. Your goal is to review the SQL query and fix it when necessary. +Only use the tables provided in the database schema in your corrected query. Do not join tables that are not present in the schema. Do not create any new tables. +If you cannot do better than the original query, just say the query is correct. +""" + +SQL_QUERY_FIXER_PROMPT_with_result = """\ +You are an SQL database expert tasked with reviewing a SQL query. +**Procedure:** +1. Review Database Schema: +- Examine the table creation statements to understand the database structure. +2. Review the Hint provided. +- Use the provided hints to understand the domain knowledge relevant to the query. +3. Analyze Query Requirements: +- User Question: Consider what information the query is supposed to retrieve. Decide if aggregation like COUNT or SUM is needed. +- Executed SQL Query: Review the SQL query that was previously executed. +- Execution Result: Analyze the outcome of the executed query. Think carefully if the result makes sense. +4. Check against the following common errors: +- Failure to exclude null values, ranking or filtering columns have nulls, syntax errors, incorrect table references, incorrect column references, logical mistakes. +5. Correct the Query only when Necessary: +- If issues were identified, modify the SQL query to address the identified issues, ensuring it correctly fetches the requested data according to the database schema and query requirements. + +======= Your task ======= +************************** +Table creation statements +{DATABASE_SCHEMA} +************************** +Hint: +{HINT} +************************** +User Question: +{QUESTION} +************************** +The SQL query executed was: +{QUERY} +************************** +The execution result: +{RESULT} +************************** + +Now analyze the SQL query step by step. Present your reasonings. + +If you identified issues in the original query, write down the corrected SQL query in the format below: +```sql +SELECT column1, column2, ... +``` + +If the original SQL query is correct, just say the query is correct. + +Note: Some user questions can only be answered partially with the database. This is OK. The agent may use other tools in subsequent steps to get additional info. In some cases, the agent may have got additional info with other tools and have incorporated those in its query. Your goal is to review the SQL query and fix it when necessary. +Only use the tables provided in the database schema in your corrected query. Do not join tables that are not present in the schema. Do not create any new tables. +If you cannot do better than the original query, just say the query is correct. +""" + + +########################################## +## Prompt templates for SQL agent using OpenAI models +########################################## +AGENT_SYSM = """\ +You are an SQL expert tasked with answering questions about schools in California. +You can access a database that has {num_tables} tables. The schema of the tables is as follows. Read the schema carefully. +{tables_schema} +**************** +Question: {question} + +Hints: +{hints} +**************** + +When querying the database, remember the following: +1. You MUST double check your SQL query before executing it. Reflect on the steps you have taken and fix errors if there are any. If you get an error while executing a query, rewrite the query and try again. +2. Unless the user specifies a specific number of examples they wish to obtain, always limit your query to no more than 20 results. +3. Only query columns that are relevant to the question. +4. DO NOT make any DML statements (INSERT, UPDATE, DELETE, DROP etc.) to the database. + +IMPORTANT: +* Divide the question into sub-questions and conquer sub-questions one by one. +* You may need to combine information from multiple tables to answer the question. +* If database does not have all the information needed to answer the question, use the web search tool or your own knowledge. +* If you did not get the answer at first, do not give up. Reflect on the steps that you have taken and try a different way. Think out of the box. You hard work will be rewarded. + +Now take a deep breath and think step by step to solve the problem. +""" + +QUERYFIXER_PROMPT = """\ +You are an SQL database expert tasked with reviewing a SQL query. +**Procedure:** +1. Review Database Schema: +- Examine the table creation statements to understand the database structure. +2. Review the Hint provided. +- Use the provided hints to understand the domain knowledge relevant to the query. +3. Analyze Query Requirements: +- Original Question: Consider what information the query is supposed to retrieve. +- Executed SQL Query: Review the SQL query that was previously executed. +- Execution Result: Analyze the outcome of the executed query. Think carefully if the result makes sense. If the result does not make sense, identify the issues with the executed SQL query (e.g., null values, syntax +errors, incorrect table references, incorrect column references, logical mistakes). +4. Correct the Query if Necessary: +- If issues were identified, modify the SQL query to address the identified issues, ensuring it correctly fetches the requested data +according to the database schema and query requirements. +5. If the query is correct, provide the same query as the final answer. + +======= Your task ======= +************************** +Table creation statements +{DATABASE_SCHEMA} +************************** +Hint: +{HINT} +************************** +The original question is: +Question: +{QUESTION} +The SQL query executed was: +{QUERY} +The execution result: +{RESULT} +************************** +Based on the question, table schema, hint and the previous query, analyze the result. Fix the query if needed and provide your reasoning. If the query is correct, provide the same query as the final answer. +""" diff --git a/comps/agent/langchain/src/strategy/sqlagent/sql_tools.py b/comps/agent/langchain/src/strategy/sqlagent/sql_tools.py new file mode 100644 index 0000000000..1d11eb7a05 --- /dev/null +++ b/comps/agent/langchain/src/strategy/sqlagent/sql_tools.py @@ -0,0 +1,31 @@ +# Copyright (C) 2024 Intel Corporation +# SPDX-License-Identifier: Apache-2.0 + +from langchain_community.tools.sql_database.tool import QuerySQLDataBaseTool +from langchain_community.utilities import SQLDatabase + + +def connect_to_db(db_path): + db = SQLDatabase.from_uri(db_path) + return db + + +def get_table_schema(db_path): + db = connect_to_db(db_path) + table_names = ", ".join(db.get_usable_table_names()) + num_tables = len(table_names.split(",")) + schema = db.get_table_info_no_throw([t.strip() for t in table_names.split(",")]) + return schema, num_tables + + +def get_sql_query_tool(db_path): + db = connect_to_db(db_path) + query_sql_database_tool_description = ( + "Input to this tool is a detailed and correct SQL query, output is a " + "result from the database. If the query is not correct, an error message " + "will be returned. If an error is returned, rewrite the query, check the " + "query, and try again. " + ) + db_query_tool = QuerySQLDataBaseTool(db=db, name="sql_db_query", description=query_sql_database_tool_description) + print("SQL Query Tool Created: ", db_query_tool) + return db_query_tool diff --git a/comps/agent/langchain/src/strategy/sqlagent/utils.py b/comps/agent/langchain/src/strategy/sqlagent/utils.py new file mode 100644 index 0000000000..32bf611a9a --- /dev/null +++ b/comps/agent/langchain/src/strategy/sqlagent/utils.py @@ -0,0 +1,219 @@ +# Copyright (C) 2024 Intel Corporation +# SPDX-License-Identifier: Apache-2.0 + +import json +import uuid + +from langchain_core.messages import AIMessage, ToolMessage +from langchain_core.messages.tool import ToolCall + +from .prompt import ANSWER_PARSER_PROMPT, SQL_QUERY_FIXER_PROMPT, SQL_QUERY_FIXER_PROMPT_with_result + + +def parse_answer_with_llm(text, history, chat_model): + if "FINAL ANSWER:" in text.upper(): + if history == "": + history = "The agent execution history is empty." + + prompt = ANSWER_PARSER_PROMPT.format(output=text, history=history) + response = chat_model.invoke(prompt).content + print("@@@ Answer parser response: ", response) + + temp = response[:5] + if "yes" in temp.lower(): + return text.split("FINAL ANSWER:")[-1] + else: + temp = response.split("\n")[0] + if "yes" in temp.lower(): + return text.split("FINAL ANSWER:")[-1] + return None + else: + return None + + +def get_tool_calls_other_than_sql(text): + """Get the tool calls other than sql_db_query.""" + tool_calls = [] + text = text.replace("assistant", "") + json_lines = text.split("\n") + # only get the unique lines + json_lines = list(set(json_lines)) + for line in json_lines: + if "TOOL CALL:" in line: + if "sql_db_query" not in line: + line = line.replace("TOOL CALL:", "") + if "assistant" in line: + line = line.replace("assistant", "") + if "\\" in line: + line = line.replace("\\", "") + try: + parsed_line = json.loads(line) + if isinstance(parsed_line, dict): + if "tool" in parsed_line: + tool_calls.append(parsed_line) + + except: + pass + return tool_calls + + +def get_all_sql_queries(text): + queries = [] + if "```sql" in text: + temp = text.split("```sql") + for t in temp: + if "```" in t: + query = t.split("```")[0] + if "SELECT" in query.upper() and "TOOL CALL" not in query.upper(): + queries.append(query) + + return queries + + +def get_the_last_sql_query(text): + queries = get_all_sql_queries(text) + if queries: + return queries[-1] + else: + return None + + +def check_query_if_executed_and_result(query, messages): + # get previous sql_db_query tool calls + previous_tool_calls = [] + for m in messages: + if isinstance(m, AIMessage) and m.tool_calls: + for tc in m.tool_calls: + if tc["name"] == "sql_db_query": + previous_tool_calls.append(tc) + for tc in previous_tool_calls: + if query == tc["args"]["query"]: + return get_tool_output(messages, tc["id"]) + + return None + + +def parse_and_fix_sql_query_v2(text, chat_model, db_schema, hint, question, messages): + chosen_query = get_the_last_sql_query(text) + if chosen_query: + # check if the query has been executed before + # if yes, pass execution result to the fixer + # if not, pass only the query to the fixer + result = check_query_if_executed_and_result(chosen_query, messages) + if result: + prompt = SQL_QUERY_FIXER_PROMPT_with_result.format( + DATABASE_SCHEMA=db_schema, HINT=hint, QUERY=chosen_query, QUESTION=question, RESULT=result + ) + else: + prompt = SQL_QUERY_FIXER_PROMPT.format( + DATABASE_SCHEMA=db_schema, HINT=hint, QUERY=chosen_query, QUESTION=question + ) + + response = chat_model.invoke(prompt).content + print("@@@ SQL query fixer response: ", response) + if "query is correct" in response.lower(): + return chosen_query + else: + # parse the fixed query + fixed_query = get_the_last_sql_query(response) + return fixed_query + else: + return None + + +class LlamaOutputParserAndQueryFixer: + def __init__(self, chat_model): + self.chat_model = chat_model + + def parse(self, text: str, history: str, db_schema: str, hint: str, question: str, messages: list): + print("@@@ Raw output from llm:\n", text) + answer = parse_answer_with_llm(text, history, self.chat_model) + if answer: + print("Final answer exists.") + return answer + else: + tool_calls = get_tool_calls_other_than_sql(text) + sql_query = parse_and_fix_sql_query_v2(text, self.chat_model, db_schema, hint, question, messages) + if sql_query: + sql_tool_call = [{"tool": "sql_db_query", "args": {"query": sql_query}}] + tool_calls.extend(sql_tool_call) + if tool_calls: + return tool_calls + else: + return text + + +def convert_json_to_tool_call(json_str): + tool_name = json_str["tool"] + tool_args = json_str["args"] + tcid = str(uuid.uuid4()) + tool_call = ToolCall(name=tool_name, args=tool_args, id=tcid) + return tool_call + + +def get_tool_output(messages, id): + tool_output = "" + for msg in reversed(messages): + if isinstance(msg, ToolMessage): + if msg.tool_call_id == id: + tool_output = msg.content + tool_output = tool_output[:1000] # limit to 1000 characters + break + return tool_output + + +def assemble_history(messages): + """ + messages: AI, TOOL, AI, TOOL, etc. + """ + query_history = "" + breaker = "-" * 10 + n = 1 + for m in messages[1:]: # exclude the first message + if isinstance(m, AIMessage): + # if there is tool call + if hasattr(m, "tool_calls") and len(m.tool_calls) > 0 and m.content != "Repeated previous steps.": + for tool_call in m.tool_calls: + tool = tool_call["name"] + tc_args = tool_call["args"] + id = tool_call["id"] + tool_output = get_tool_output(messages, id) + if tool == "sql_db_query": + sql_query = tc_args["query"] + query_history += ( + f"Step {n}. Executed SQL query: {sql_query}\nQuery Result: {tool_output}\n{breaker}\n" + ) + else: + query_history += ( + f"Step {n}. Called tool: {tool} - {tc_args}\nTool Output: {tool_output}\n{breaker}\n" + ) + n += 1 + elif m.content == "Repeated previous steps.": # repeated steps + query_history += f"Step {n}. Repeated tool calls from previous steps.\n{breaker}\n" + n += 1 + else: + # did not make tool calls + query_history += f"Assistant Output: {m.content}\n" + + return query_history + + +def remove_repeated_tool_calls(tool_calls, messages): + """Remove repeated tool calls in the messages. + + tool_calls: list of tool calls: ToolCall(name=tool_name, args=tool_args, id=tcid) + messages: list of messages: AIMessage, ToolMessage, HumanMessage + """ + # first get all the previous tool calls in messages + previous_tool_calls = [] + for m in messages: + if isinstance(m, AIMessage) and m.tool_calls and m.content != "Repeated previous steps.": + for tc in m.tool_calls: + previous_tool_calls.append({"tool": tc["name"], "args": tc["args"]}) + + unique_tool_calls = [] + for tc in tool_calls: + if {"tool": tc["name"], "args": tc["args"]} not in previous_tool_calls: + unique_tool_calls.append(tc) + + return unique_tool_calls diff --git a/comps/agent/langchain/src/tools.py b/comps/agent/langchain/src/tools.py index ff3918cc9f..5ab9559890 100644 --- a/comps/agent/langchain/src/tools.py +++ b/comps/agent/langchain/src/tools.py @@ -125,18 +125,19 @@ def load_python_tools(file_dir_path: str): def get_tools_descriptions(file_dir_path: str): tools = [] - file_path_list = [] - if os.path.isdir(file_dir_path): - file_path_list += glob.glob(file_dir_path + "/*") - else: - file_path_list.append(file_dir_path) - for file in file_path_list: - if os.path.basename(file).endswith(".yaml"): - tools += load_yaml_tools(file) - elif os.path.basename(file).endswith(".yml"): - tools += load_yaml_tools(file) - elif os.path.basename(file).endswith(".py"): - tools += load_python_tools(file) + if file_dir_path: + file_path_list = [] + if os.path.isdir(file_dir_path): + file_path_list += glob.glob(file_dir_path + "/*") else: - pass + file_path_list.append(file_dir_path) + for file in file_path_list: + if os.path.basename(file).endswith(".yaml"): + tools += load_yaml_tools(file) + elif os.path.basename(file).endswith(".yml"): + tools += load_yaml_tools(file) + elif os.path.basename(file).endswith(".py"): + tools += load_python_tools(file) + else: + pass return tools diff --git a/comps/agent/langchain/src/utils.py b/comps/agent/langchain/src/utils.py index e8a317a5df..f477b08d2b 100644 --- a/comps/agent/langchain/src/utils.py +++ b/comps/agent/langchain/src/utils.py @@ -120,7 +120,7 @@ def get_args(): parser.add_argument("--agent_name", type=str, default="OPEA_Default_Agent") parser.add_argument("--strategy", type=str, default="react_langchain") parser.add_argument("--role_description", type=str, default="LLM enhanced agent") - parser.add_argument("--tools", type=str, default="tools/custom_tools.yaml") + parser.add_argument("--tools", type=str, default=None, help="path to the tools file") parser.add_argument("--recursion_limit", type=int, default=5) parser.add_argument("--require_human_feedback", action="store_true", help="If this agent requires human feedback") parser.add_argument("--debug", action="store_true", help="Test with endpoint mode") @@ -139,8 +139,14 @@ def get_args(): parser.add_argument("--with_store", type=bool, default=False) parser.add_argument("--timeout", type=int, default=60) + # for sql agent + parser.add_argument("--db_path", type=str, help="database path") + parser.add_argument("--db_name", type=str, help="database name") + parser.add_argument("--use_hints", type=str, default="false", help="If this agent uses hints") + parser.add_argument("--hints_file", type=str, help="path to the hints file") + sys_args, unknown_args = parser.parse_known_args() - # print("env_config: ", env_config) + print("env_config: ", env_config) if env_config != []: env_args, env_unknown_args = parser.parse_known_args(env_config) unknown_args += env_unknown_args @@ -151,5 +157,12 @@ def get_args(): sys_args.streaming = True else: sys_args.streaming = False + + if sys_args.use_hints == "true": + print("SQL agent will use hints") + sys_args.use_hints = True + else: + sys_args.use_hints = False + print("==========sys_args==========:\n", sys_args) return sys_args, unknown_args diff --git a/tests/agent/sql_agent_llama.yaml b/tests/agent/sql_agent_llama.yaml new file mode 100644 index 0000000000..92730a72aa --- /dev/null +++ b/tests/agent/sql_agent_llama.yaml @@ -0,0 +1,35 @@ +# Copyright (C) 2024 Intel Corporation +# SPDX-License-Identifier: Apache-2.0 + +services: + agent: + image: ${agent_image} + container_name: test-comps-agent-endpoint + volumes: + - ${TOOLSET_PATH}:/home/user/tools/ # tools + # - ${WORKDIR}/GenAIComps/comps:/home/user/comps # code + - ${WORKDIR}/TAG-Bench/:/home/user/TAG-Bench # SQL database and hints_file + ports: + - "9095:9095" + ipc: host + environment: + ip_address: ${ip_address} + strategy: sql_agent_llama + db_name: ${db_name} + db_path: ${db_path} + use_hints: false + hints_file: /home/user/TAG-Bench/${db_name}_hints.csv + recursion_limit: ${recursion_limit} + llm_engine: vllm + HUGGINGFACEHUB_API_TOKEN: ${HUGGINGFACEHUB_API_TOKEN} + llm_endpoint_url: ${LLM_ENDPOINT_URL} + model: ${LLM_MODEL_ID} + temperature: ${temperature} + max_new_tokens: ${max_new_tokens} + streaming: false + tools: /home/user/tools/custom_tools.yaml + require_human_feedback: false + no_proxy: ${no_proxy} + http_proxy: ${http_proxy} + https_proxy: ${https_proxy} + port: 9095 diff --git a/tests/agent/sql_agent_openai.yaml b/tests/agent/sql_agent_openai.yaml new file mode 100644 index 0000000000..124eccae99 --- /dev/null +++ b/tests/agent/sql_agent_openai.yaml @@ -0,0 +1,36 @@ +# Copyright (C) 2024 Intel Corporation +# SPDX-License-Identifier: Apache-2.0 + +services: + agent: + image: ${agent_image} + container_name: test-comps-agent-endpoint + volumes: + - ${TOOLSET_PATH}:/home/user/tools/ # tools + - ${WORKDIR}/GenAIComps/comps:/home/user/comps # code + - ${WORKDIR}/TAG-Bench/:/home/user/TAG-Bench # SQL database and hints_file + ports: + - "9095:9095" + ipc: host + environment: + ip_address: ${ip_address} + strategy: sql_agent + db_name: ${db_name} + db_path: ${db_path} + use_hints: false #true + hints_file: /home/user/TAG-Bench/${db_name}_hints.csv + recursion_limit: ${recursion_limit} + llm_engine: openai + OPENAI_API_KEY: ${OPENAI_API_KEY} + model: "gpt-4o-mini-2024-07-18" + temperature: 0 + max_new_tokens: ${max_new_tokens} + streaming: false + tools: /home/user/tools/sql_agent_tools.yaml # /home/user/tools/custom_tools.yaml # + require_human_feedback: false + no_proxy: ${no_proxy} + http_proxy: ${http_proxy} + https_proxy: ${https_proxy} + port: 9095 + GOOGLE_CSE_ID: ${GOOGLE_CSE_ID} #delete + GOOGLE_API_KEY: ${GOOGLE_API_KEY} # delete diff --git a/tests/agent/sql_agent_test/generate_hints_file.py b/tests/agent/sql_agent_test/generate_hints_file.py new file mode 100644 index 0000000000..3551b7306f --- /dev/null +++ b/tests/agent/sql_agent_test/generate_hints_file.py @@ -0,0 +1,45 @@ +# Copyright (C) 2024 Intel Corporation +# SPDX-License-Identifier: Apache-2.0 + +import glob +import os + +import pandas as pd + + +def generate_column_descriptions(db_name): + descriptions = [] + working_dir = os.getenv("WORKDIR") + assert working_dir is not None, "WORKDIR environment variable is not set." + DESCRIPTION_FOLDER = os.path.join( + working_dir, f"TAG-Bench/dev_folder/dev_databases/{db_name}/database_description/" + ) + table_files = glob.glob(os.path.join(DESCRIPTION_FOLDER, "*.csv")) + table_name_col = [] + col_name_col = [] + for table_file in table_files: + table_name = os.path.basename(table_file).split(".")[0] + print("Table name: ", table_name) + df = pd.read_csv(table_file) + for _, row in df.iterrows(): + col_name = row["original_column_name"] + if not pd.isnull(row["value_description"]): + description = str(row["value_description"]) + if description.lower() in col_name.lower(): + print("Description {} is same as column name {}".format(description, col_name)) + pass + else: + description = description.replace("\n", " ") + description = " ".join(description.split()) + descriptions.append(description) + table_name_col.append(table_name) + col_name_col.append(col_name) + hints_df = pd.DataFrame({"table_name": table_name_col, "column_name": col_name_col, "description": descriptions}) + tag_bench_dir = os.path.join(working_dir, "TAG-Bench") + output_file = os.path.join(tag_bench_dir, f"{db_name}_hints.csv") + hints_df.to_csv(output_file, index=False) + print(f"Generated hints file: {output_file}") + + +if __name__ == "__main__": + generate_column_descriptions("california_schools") diff --git a/tests/agent/sql_agent_test/run_data_split.sh b/tests/agent/sql_agent_test/run_data_split.sh new file mode 100644 index 0000000000..2fc2dfcb0e --- /dev/null +++ b/tests/agent/sql_agent_test/run_data_split.sh @@ -0,0 +1,6 @@ +# Copyright (C) 2024 Intel Corporation +# SPDX-License-Identifier: Apache-2.0 + +DATAPATH=$WORKDIR/TAG-Bench/tag_queries.csv +OUTFOLDER=$WORKDIR/TAG-Bench/query_by_db +python3 split_data.py --path $DATAPATH --output $OUTFOLDER diff --git a/tests/agent/sql_agent_test/split_data.py b/tests/agent/sql_agent_test/split_data.py new file mode 100644 index 0000000000..1b3f5cfc79 --- /dev/null +++ b/tests/agent/sql_agent_test/split_data.py @@ -0,0 +1,27 @@ +# Copyright (C) 2024 Intel Corporation +# SPDX-License-Identifier: Apache-2.0 + +import argparse +import os + +import pandas as pd + +if __name__ == "__main__": + parser = argparse.ArgumentParser() + parser.add_argument("--path", type=str, required=True) + parser.add_argument("--output", type=str, required=True) + args = parser.parse_args() + + # if output folder does not exist, create it + if not os.path.exists(args.output): + os.makedirs(args.output) + + # Load the data + data = pd.read_csv(args.path) + + # Split the data by domain + domains = data["DB used"].unique() + for domain in domains: + domain_data = data[data["DB used"] == domain] + out = os.path.join(args.output, f"query_{domain}.csv") + domain_data.to_csv(out, index=False) diff --git a/tests/agent/sql_agent_test/sql_agent_tools.py b/tests/agent/sql_agent_test/sql_agent_tools.py new file mode 100644 index 0000000000..fc14efe8ee --- /dev/null +++ b/tests/agent/sql_agent_test/sql_agent_tools.py @@ -0,0 +1,19 @@ +# Copyright (C) 2024 Intel Corporation +# SPDX-License-Identifier: Apache-2.0 + + +def search_web(query: str) -> str: + """Search the web for information not contained in databases.""" + from langchain_core.tools import Tool + from langchain_google_community import GoogleSearchAPIWrapper + + search = GoogleSearchAPIWrapper() + + tool = Tool( + name="google_search", + description="Search Google for recent results.", + func=search.run, + ) + + response = tool.run(query) + return response diff --git a/tests/agent/sql_agent_test/sql_agent_tools.yaml b/tests/agent/sql_agent_test/sql_agent_tools.yaml new file mode 100644 index 0000000000..ccd5c8e718 --- /dev/null +++ b/tests/agent/sql_agent_test/sql_agent_tools.yaml @@ -0,0 +1,11 @@ +# Copyright (C) 2024 Intel Corporation +# SPDX-License-Identifier: Apache-2.0 + +search_web: + description: Search the web for a given query. + callable_api: sql_agent_tools.py:search_web + args_schema: + query: + type: str + description: query + return_output: retrieved_data diff --git a/tests/agent/sql_agent_test/test_sql_agent.sh b/tests/agent/sql_agent_test/test_sql_agent.sh new file mode 100644 index 0000000000..502abaa318 --- /dev/null +++ b/tests/agent/sql_agent_test/test_sql_agent.sh @@ -0,0 +1,179 @@ +#!/bin/bash +# Copyright (C) 2024 Intel Corporation +# SPDX-License-Identifier: Apache-2.0 + +#set -xe + +# this script should be run from tests directory +# bash agent/sql_agent_test/test_sql_agent.sh + +WORKPATH=$(dirname "$PWD") +echo $WORKPATH +LOG_PATH="$WORKPATH/tests" + +# WORKDIR is one level up from GenAIComps +export WORKDIR=$(dirname "$WORKPATH") +echo $WORKDIR + +export agent_image="opea/agent-langchain:comps" +export agent_container_name="test-comps-agent-endpoint" + +export ip_address=$(hostname -I | awk '{print $1}') + +vllm_port=8086 +vllm_volume=${HF_CACHE_DIR} + +export model=meta-llama/Meta-Llama-3.1-70B-Instruct +export HUGGINGFACEHUB_API_TOKEN=${HF_TOKEN} +export LLM_MODEL_ID="meta-llama/Meta-Llama-3.1-70B-Instruct" +export LLM_ENDPOINT_URL="http://${ip_address}:${vllm_port}" +export temperature=0.01 +export max_new_tokens=4096 +export TOOLSET_PATH=$WORKPATH/comps/agent/langchain/tools/ # $WORKPATH/tests/agent/sql_agent_test/ +echo "TOOLSET_PATH=${TOOLSET_PATH}" +export recursion_limit=15 +export db_name=california_schools +export db_path="sqlite:////home/user/TAG-Bench/dev_folder/dev_databases/${db_name}/${db_name}.sqlite" + +# for using Google search API +export GOOGLE_CSE_ID=${GOOGLE_CSE_ID} +export GOOGLE_API_KEY=${GOOGLE_API_KEY} + + +# download the test data +function prepare_data() { + cd $WORKDIR + + echo "Downloading data..." + git clone https://github.com/TAG-Research/TAG-Bench.git + cd TAG-Bench/setup + chmod +x get_dbs.sh + ./get_dbs.sh + + echo "Split data..." + cd $WORKPATH/tests/agent/sql_agent_test + bash run_data_split.sh + + echo "Data preparation done!" +} + +function remove_data() { + echo "Removing data..." + cd $WORKDIR + rm -rf TAG-Bench + echo "Data removed!" +} + + +function generate_hints_for_benchmark() { + echo "Generating hints for benchmark..." + cd $WORKPATH/tests/agent/sql_agent_test + python3 generate_hints_file.py +} + +function build_docker_images() { + echo "Building the docker images" + cd $WORKPATH + echo $WORKPATH + docker build --no-cache -t $agent_image --build-arg http_proxy=$http_proxy --build-arg https_proxy=$https_proxy -f comps/agent/langchain/Dockerfile . + if [ $? -ne 0 ]; then + echo "opea/agent-langchain built fail" + exit 1 + else + echo "opea/agent-langchain built successful" + fi +} + +function build_vllm_docker_images() { + echo "Building the vllm docker images" + cd $WORKPATH + echo $WORKPATH + if [ ! -d "./vllm" ]; then + git clone https://github.com/HabanaAI/vllm-fork.git + fi + cd ./vllm-fork + docker build --no-cache -f Dockerfile.hpu -t opea/vllm-gaudi:comps --shm-size=128g . --build-arg https_proxy=$https_proxy --build-arg http_proxy=$http_proxy + if [ $? -ne 0 ]; then + echo "opea/vllm-gaudi:comps failed" + exit 1 + else + echo "opea/vllm-gaudi:comps successful" + fi +} + +function start_vllm_service() { + # redis endpoint + echo "token is ${HF_TOKEN}" + + #single card + echo "start vllm gaudi service" + echo "**************model is $model**************" + docker run -d --runtime=habana --rm --name "test-comps-vllm-gaudi-service" -e HABANA_VISIBLE_DEVICES=0,1,2,3 -p $vllm_port:80 -v $vllm_volume:/data -e HF_TOKEN=$HF_TOKEN -e HF_HOME=/data -e OMPI_MCA_btl_vader_single_copy_mechanism=none -e PT_HPU_ENABLE_LAZY_COLLECTIVES=true -e http_proxy=$http_proxy -e https_proxy=$https_proxy -e no_proxy=$no_proxy -e VLLM_SKIP_WARMUP=true --cap-add=sys_nice --ipc=host opea/vllm-gaudi:comps --model ${model} --host 0.0.0.0 --port 80 --block-size 128 --max-seq-len-to-capture 16384 --tensor-parallel-size 4 + sleep 5s + echo "Waiting vllm gaudi ready" + n=0 + until [[ "$n" -ge 100 ]] || [[ $ready == true ]]; do + docker logs test-comps-vllm-gaudi-service &> ${LOG_PATH}/vllm-gaudi-service.log + n=$((n+1)) + if grep -q "Uvicorn running on" ${LOG_PATH}/vllm-gaudi-service.log; then + break + fi + if grep -q "No such container" ${LOG_PATH}/vllm-gaudi-service.log; then + echo "container test-comps-vllm-gaudi-service not found" + exit 1 + fi + sleep 5s + done + sleep 5s + echo "Service started successfully" +} +# launch the agent +function start_sql_agent_llama_service() { + echo "Starting sql_agent_llama agent microservice" + docker compose -f $WORKPATH/tests/agent/sql_agent_llama.yaml up -d + sleep 3m + docker logs test-comps-agent-endpoint + echo "Service started successfully" +} + + +function start_sql_agent_openai_service() { + export OPENAI_API_KEY=${OPENAI_API_KEY} + echo "Starting sql_agent_openai agent microservice" + docker compose -f $WORKPATH/tests/agent/sql_agent_openai.yaml up -d + sleep 3m + docker logs test-comps-agent-endpoint + echo "Service started successfully" +} + +# run the test +function run_test() { + echo "Running test..." + cd $WORKPATH/tests/agent/ + python3 test.py --test-sql-agent +} + +function run_benchmark() { + echo "Running benchmark..." + cd $WORKPATH/tests/agent/sql_agent_test + query_file=${WORKDIR}/TAG-Bench/query_by_db/query_california_schools.csv + outdir=$WORKDIR/sql_agent_output + outfile=california_school_agent_test_result.csv + python3 test_tag_bench.py --query_file $query_file --output_dir $outdir --output_file $outfile +} + + +echo "Preparing data...." +prepare_data + +echo "launching sql_agent_llama service...." +start_sql_agent_llama_service + +# echo "launching sql_agent_openai service...." +# start_sql_agent_openai_service + +echo "Running test...." +run_test + +echo "Removing data...." +remove_data diff --git a/tests/agent/sql_agent_test/test_tag_bench.py b/tests/agent/sql_agent_test/test_tag_bench.py new file mode 100644 index 0000000000..6664759f31 --- /dev/null +++ b/tests/agent/sql_agent_test/test_tag_bench.py @@ -0,0 +1,55 @@ +# Copyright (C) 2024 Intel Corporation +# SPDX-License-Identifier: Apache-2.0 + +import argparse +import os + +import pandas as pd +import requests + + +def generate_answer_agent_api(url, prompt): + proxies = {"http": ""} + payload = { + "query": prompt, + } + response = requests.post(url, json=payload, proxies=proxies) + answer = response.json()["text"] + return answer + + +def save_json_lines(json_lines, args): + outfile = "sql_agent_results.json" + output = os.path.join(args.output_dir, outfile) + with open(output, "w") as f: + for line in json_lines: + f.write(str(line) + "\n") + + +if __name__ == "__main__": + parser = argparse.ArgumentParser() + parser.add_argument("--query_file", type=str) + parser.add_argument("--output_dir", type=str) + parser.add_argument("--output_file", type=str) + args = parser.parse_args() + + df = pd.read_csv(args.query_file) + + if not os.path.exists(args.output_dir): + os.makedirs(args.output_dir) + + ip_address = os.getenv("ip_address", "localhost") + url = f"http://{ip_address}:9095/v1/chat/completions" + + json_lines = [] + for _, row in df.iterrows(): + query = row["Query"] + ref_answer = row["Answer"] + print("******Query:\n", query) + res = generate_answer_agent_api(url, query) + print("******Answer:\n", res) + # json_lines.append({"query": query,"answer":ref_answer, "agent_answer": res}) + # save_json_lines(json_lines, args) + print("=" * 20) + + df.to_csv(os.path.join(args.output_dir, args.output_file), index=False) diff --git a/tests/agent/test.py b/tests/agent/test.py index fdbfe1c5b3..e345e89420 100644 --- a/tests/agent/test.py +++ b/tests/agent/test.py @@ -45,11 +45,16 @@ def process_request(url, query, is_stream=False): if __name__ == "__main__": parser = argparse.ArgumentParser() parser.add_argument("--stream", action="store_true", help="Stream the response") + parser.add_argument("--test-sql-agent", action="store_true", help="Test the SQL agent") args = parser.parse_args() ip_address = os.getenv("ip_address", "localhost") url = f"http://{ip_address}:9095/v1/chat/completions" - prompt = "What is OPEA?" + if args.test_sql_agent: + prompt = "How many schools have the average score in Math over 560 in the SAT test?" + else: + prompt = "What is OPEA?" + if args.stream: process_request(url, prompt, is_stream=True) else: diff --git a/tests/agent/test_agent_langchain_on_intel_hpu.sh b/tests/agent/test_agent_langchain_on_intel_hpu.sh index 9ba25228ad..6d5f2f1a4c 100644 --- a/tests/agent/test_agent_langchain_on_intel_hpu.sh +++ b/tests/agent/test_agent_langchain_on_intel_hpu.sh @@ -6,12 +6,17 @@ WORKPATH=$(dirname "$PWD") echo $WORKPATH +ls $WORKPATH +echo "=========================" LOG_PATH="$WORKPATH/tests" ip_address=$(hostname -I | awk '{print $1}') tgi_port=8085 tgi_volume=$WORKPATH/data + vllm_port=8086 -vllm_volume=$WORKPATH/data +export vllm_volume=$WORKPATH/data +echo "vllm_volume:" +ls $vllm_volume export WORKPATH=$WORKPATH @@ -23,7 +28,7 @@ export HUGGINGFACEHUB_API_TOKEN=${HF_TOKEN} export ip_address=$(hostname -I | awk '{print $1}') export HUGGINGFACEHUB_API_TOKEN=${HUGGINGFACEHUB_API_TOKEN} export LLM_MODEL_ID="meta-llama/Meta-Llama-3.1-70B-Instruct" -export LLM_ENDPOINT_URL="http://${ip_address}:${tgi_port}" +export LLM_ENDPOINT_URL="http://${ip_address}:${vllm_port}" export temperature=0.01 export max_new_tokens=4096 export TOOLSET_PATH=$WORKPATH/comps/agent/langchain/tools/ @@ -48,10 +53,11 @@ function build_vllm_docker_images() { cd $WORKPATH echo $WORKPATH if [ ! -d "./vllm" ]; then - git clone https://github.com/HabanaAI/vllm-fork.git + git clone https://github.com/vllm-project/vllm.git fi - cd ./vllm-fork - docker build -f Dockerfile.hpu -t opea/vllm-gaudi:comps --shm-size=128g . --build-arg https_proxy=$https_proxy --build-arg http_proxy=$http_proxy + cd ./vllm + git checkout main + docker build --no-cache -f Dockerfile.hpu -t opea/vllm-gaudi:comps --shm-size=128g . --build-arg https_proxy=$https_proxy --build-arg http_proxy=$http_proxy if [ $? -ne 0 ]; then echo "opea/vllm-gaudi:comps failed" exit 1 @@ -82,13 +88,11 @@ function start_tgi_service() { } function start_vllm_service() { - # redis endpoint echo "token is ${HF_TOKEN}" - #single card echo "start vllm gaudi service" echo "**************model is $model**************" - docker run -d --runtime=habana --rm --name "test-comps-vllm-gaudi-service" -e HABANA_VISIBLE_DEVICES=all -p $vllm_port:80 -v $vllm_volume:/data -e HF_TOKEN=$HF_TOKEN -e HF_HOME=/data -e OMPI_MCA_btl_vader_single_copy_mechanism=none -e PT_HPU_ENABLE_LAZY_COLLECTIVES=true -e http_proxy=$http_proxy -e https_proxy=$https_proxy -e VLLM_SKIP_WARMUP=true --cap-add=sys_nice --ipc=host opea/vllm-gaudi:comps --model ${model} --host 0.0.0.0 --port 80 --block-size 128 --max-num-seqs 4096 --max-seq_len-to-capture 8192 + docker run -d --runtime=habana --rm --name "test-comps-vllm-gaudi-service" -e HABANA_VISIBLE_DEVICES=all -p $vllm_port:8000 -v $vllm_volume:/data -e HF_TOKEN=$HF_TOKEN -e HUGGING_FACE_HUB_TOKEN=$HF_TOKEN -e HF_HOME=/data -e OMPI_MCA_btl_vader_single_copy_mechanism=none -e PT_HPU_ENABLE_LAZY_COLLECTIVES=true -e http_proxy=$http_proxy -e https_proxy=$https_proxy -e no_proxy=$no_proxy -e VLLM_SKIP_WARMUP=true --cap-add=sys_nice --ipc=host opea/vllm-gaudi:comps --model ${model} --max-seq-len-to-capture 8192 --guided-decoding-backend lm-format-enforcer --tensor-parallel-size 4 sleep 5s echo "Waiting vllm gaudi ready" n=0 @@ -115,7 +119,32 @@ function start_vllm_auto_tool_choice_service() { #single card echo "start vllm gaudi service" echo "**************auto_tool model is $model**************" - docker run -d --runtime=habana --rm --name "test-comps-vllm-gaudi-service" -e HABANA_VISIBLE_DEVICES=all -p $vllm_port:80 -v $vllm_volume:/data -e HF_TOKEN=$HF_TOKEN -e HF_HOME=/data -e OMPI_MCA_btl_vader_single_copy_mechanism=none -e PT_HPU_ENABLE_LAZY_COLLECTIVES=true -e http_proxy=$http_proxy -e https_proxy=$https_proxy -e VLLM_SKIP_WARMUP=true --cap-add=sys_nice --ipc=host opea/vllm-gaudi:comps --model ${model} --host 0.0.0.0 --port 80 --block-size 128 --max-num-seqs 4096 --max-seq_len-to-capture 8192 --enable-auto-tool-choice --tool-call-parser ${model_parser} + docker run -d --runtime=habana --rm --name "test-comps-vllm-gaudi-service" -e HABANA_VISIBLE_DEVICES=all -p $vllm_port:8000 -v $vllm_volume:/data -e HF_TOKEN=$HF_TOKEN -e HUGGING_FACE_HUB_TOKEN=$HF_TOKEN -e HF_HOME=/data -e OMPI_MCA_btl_vader_single_copy_mechanism=none -e PT_HPU_ENABLE_LAZY_COLLECTIVES=true -e http_proxy=$http_proxy -e https_proxy=$https_proxy -e no_proxy=$no_proxy -e VLLM_SKIP_WARMUP=true --cap-add=sys_nice --ipc=host opea/vllm-gaudi:comps --model ${model} --max-seq-len-to-capture 8192 --enable-auto-tool-choice --tool-call-parser ${model_parser} --tensor-parallel-size 4 + sleep 5s + echo "Waiting vllm gaudi ready" + n=0 + until [[ "$n" -ge 100 ]] || [[ $ready == true ]]; do + docker logs test-comps-vllm-gaudi-service &> ${LOG_PATH}/vllm-gaudi-service.log + n=$((n+1)) + if grep -q "Uvicorn running on" ${LOG_PATH}/vllm-gaudi-service.log; then + break + fi + if grep -q "No such container" ${LOG_PATH}/vllm-gaudi-service.log; then + echo "container test-comps-vllm-gaudi-service not found" + exit 1 + fi + sleep 5s + done + sleep 5s + echo "Service started successfully" +} + +function start_vllm_service_70B() { + echo "token is ${HF_TOKEN}" + + echo "start vllm gaudi service" + echo "**************model is $model**************" + docker run -d --runtime=habana --rm --name "test-comps-vllm-gaudi-service" -e HABANA_VISIBLE_DEVICES=0,1,2,3 -p $vllm_port:8000 -v $vllm_volume:/data -e HF_TOKEN=$HF_TOKEN -e HUGGING_FACE_HUB_TOKEN=$HF_TOKEN -e HF_HOME=/data -e OMPI_MCA_btl_vader_single_copy_mechanism=none -e PT_HPU_ENABLE_LAZY_COLLECTIVES=true -e http_proxy=$http_proxy -e https_proxy=$https_proxy -e no_proxy=$no_proxy -e VLLM_SKIP_WARMUP=true --cap-add=sys_nice --ipc=host opea/vllm-gaudi:comps --model ${model} --max-seq-len-to-capture 16384 --tensor-parallel-size 4 sleep 5s echo "Waiting vllm gaudi ready" n=0 @@ -138,7 +167,7 @@ function start_vllm_auto_tool_choice_service() { function start_react_langchain_agent_service() { echo "Starting react_langchain agent microservice" docker compose -f $WORKPATH/tests/agent/react_langchain.yaml up -d - sleep 5s + sleep 120s docker logs test-comps-agent-endpoint echo "Service started successfully" } @@ -147,7 +176,7 @@ function start_react_langchain_agent_service() { function start_react_langgraph_agent_service_openai() { echo "Starting react_langchain agent microservice" docker compose -f $WORKPATH/tests/agent/react_langgraph_openai.yaml up -d - sleep 5s + sleep 120s docker logs test-comps-agent-endpoint echo "Service started successfully" } @@ -156,7 +185,7 @@ function start_react_langgraph_agent_service_openai() { function start_react_llama_agent_service() { echo "Starting react_langgraph agent microservice" docker compose -f $WORKPATH/tests/agent/reactllama.yaml up -d - sleep 5s + sleep 120s docker logs test-comps-agent-endpoint echo "Service started successfully" } @@ -164,7 +193,7 @@ function start_react_llama_agent_service() { function start_react_langgraph_agent_service_vllm() { echo "Starting react_langgraph agent microservice" docker compose -f $WORKPATH/tests/agent/react_vllm.yaml up -d - sleep 5s + sleep 120s docker logs test-comps-agent-endpoint echo "Service started successfully" } @@ -172,7 +201,7 @@ function start_react_langgraph_agent_service_vllm() { function start_planexec_agent_service_vllm() { echo "Starting planexec agent microservice" docker compose -f $WORKPATH/tests/agent/planexec_vllm.yaml up -d - sleep 5s + sleep 120s docker logs test-comps-agent-endpoint echo "Service started successfully" } @@ -180,7 +209,7 @@ function start_planexec_agent_service_vllm() { function start_ragagent_agent_service() { echo "Starting rag agent microservice" docker compose -f $WORKPATH/tests/agent/ragagent.yaml up -d - sleep 5s + sleep 120s docker logs test-comps-agent-endpoint echo "Service started successfully" } @@ -188,7 +217,7 @@ function start_ragagent_agent_service() { function start_ragagent_agent_service_openai() { echo "Starting rag agent microservice" docker compose -f $WORKPATH/tests/agent/ragagent_openai.yaml up -d - sleep 5s + sleep 120s docker logs test-comps-agent-endpoint echo "Service started successfully" } @@ -196,7 +225,7 @@ function start_ragagent_agent_service_openai() { function start_planexec_agent_service_openai() { echo "Starting plan execute agent microservice" docker compose -f $WORKPATH/tests/agent/planexec_openai.yaml up -d - sleep 5s + sleep 120s docker logs test-comps-agent-endpoint echo "Service started successfully" } @@ -229,8 +258,8 @@ function validate_microservice() { local EXIT_CODE="${EXIT_CODE:0-1}" echo "return value is $EXIT_CODE" if [ "$EXIT_CODE" == "1" ]; then - echo "==================TGI logs ======================" - docker logs test-comps-tgi-gaudi-service + echo "==================vllm logs ======================" + docker logs test-comps-vllm-gaudi-service echo "==================Agent logs ======================" docker logs test-comps-agent-endpoint exit 1 @@ -275,7 +304,12 @@ function stop_tgi_docker() { cid=$(docker ps -aq --filter "name=test-comps-tgi-gaudi-service") echo "Stopping the docker containers "${cid} if [[ ! -z "$cid" ]]; then docker rm $cid -f && sleep 1s; fi - echo "Docker containers stopped successfully" + echo "TGI Docker containers stopped successfully" + + cid=$(docker ps -aq --filter "name=tgi-server") + echo "Stopping the docker containers "${cid} + if [[ ! -z "$cid" ]]; then docker rm $cid -f && sleep 1s; fi + echo "TGI Docker containers stopped successfully" } function stop_vllm_docker() { @@ -298,22 +332,43 @@ function stop_docker() { stop_agent_docker } + +function validate_sql_agent(){ + cd $WORKPATH/tests/ + local CONTENT=$(bash agent/sql_agent_test/test_sql_agent.sh) + local EXIT_CODE=$(validate "$CONTENT" "173" "test-sql-agent") + echo "$EXIT_CODE" + local EXIT_CODE="${EXIT_CODE:0-1}" + echo "return value is $EXIT_CODE" + if [ "$EXIT_CODE" == "1" ]; then + echo "==================SQL Agent logs ======================" + docker logs test-comps-agent-endpoint + # echo "================== vllm gaudi service logs ======================" + # docker logs test-comps-vllm-gaudi-service + exit 1 + fi +} + + function main() { stop_agent_docker stop_docker build_docker_images + build_vllm_docker_images + + # ==================== Tests with 70B model ==================== + # RAG agent, react_llama, react_langchain, assistant apis - # ==================== TGI tests ==================== - start_tgi_service + start_vllm_service_70B - # test rag agent + # # test rag agent start_ragagent_agent_service echo "=============Testing RAG Agent=============" validate_microservice stop_agent_docker echo "=============================================" - # test react_llama + # # # test react_llama start_react_llama_agent_service echo "===========Testing ReAct Llama =============" validate_microservice @@ -321,7 +376,7 @@ function main() { echo "=============================================" - # test react_langchain + # # # test react_langchain start_react_langchain_agent_service echo "=============Testing ReAct Langchain=============" validate_microservice_streaming @@ -329,57 +384,29 @@ function main() { stop_agent_docker echo "=============================================" - stop_tgi_docker - - # ==================== VLLM tests ==================== - build_vllm_docker_images - - export model=mistralai/Mistral-7B-Instruct-v0.3 - export LLM_MODEL_ID=${model} - export model_parser=mistral - export LLM_ENDPOINT_URL="http://${ip_address}:${vllm_port}" - - # test react with vllm - Mistral - start_vllm_auto_tool_choice_service - start_react_langgraph_agent_service_vllm - echo "===========Testing ReAct Langgraph VLLM Mistral =============" - validate_microservice - # stop_agent_docker - # stop_vllm_docker + # # test sql agent + echo "=============Testing SQL llama=============" + validate_sql_agent + stop_docker echo "=============================================" - # test plan execute with vllm - Mistral + echo "===========Testing Plan Execute VLLM Llama3.1 =============" start_vllm_service start_planexec_agent_service_vllm - echo "===========Testing Plan Execute VLLM Mistral =============" validate_microservice stop_agent_docker stop_vllm_docker echo "=============================================" - export model=meta-llama/Llama-3.1-8B-Instruct - export LLM_MODEL_ID=${model} + echo "===========Testing ReAct Langgraph VLLM llama3.1 =============" export model_parser=llama3_json - - # test react with vllm - llama3 support has not been synced to vllm-gaudi yet - # start_vllm_auto_tool_choice_service - # start_react_langgraph_agent_service_vllm - # echo "===========Testing ReAct VLLM =============" - # validate_microservice - # stop_agent_docker - # stop_vllm_docker - # echo "=============================================" - - # test plan execute with vllm - llama3.1 - start_vllm_service - start_planexec_agent_service_vllm - echo "===========Testing Plan Execute VLLM Llama3.1 =============" + start_vllm_auto_tool_choice_service + start_react_langgraph_agent_service_vllm validate_microservice stop_agent_docker stop_vllm_docker echo "=============================================" - # # ==================== OpenAI tests ==================== # start_ragagent_agent_service_openai # echo "=============Testing RAG Agent OpenAI=============" @@ -399,6 +426,7 @@ function main() { # stop_agent_docker stop_docker + echo y | docker system prune 2>&1 > /dev/null }