Skip to content

Commit

Permalink
add tool choices for agent. (#1126)
Browse files Browse the repository at this point in the history
  • Loading branch information
lkk12014402 authored Jan 13, 2025
1 parent fe24dec commit 3a7ccb0
Show file tree
Hide file tree
Showing 5 changed files with 80 additions and 33 deletions.
41 changes: 21 additions & 20 deletions comps/agent/src/agent.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
import pathlib
import sys
from datetime import datetime
from typing import Union
from typing import List, Optional, Union

from fastapi.responses import StreamingResponse

Expand Down Expand Up @@ -40,7 +40,10 @@
agent_inst = instantiate_agent(args, args.strategy, with_memory=args.with_memory)


class AgentCompletionRequest(LLMParamsDoc):
class AgentCompletionRequest(ChatCompletionRequest):
# rewrite, specify tools in this turn of conversation
tool_choice: Optional[List[str]] = None
# for short/long term in-memory
thread_id: str = "0"
user_id: str = "0"

Expand All @@ -52,42 +55,40 @@ class AgentCompletionRequest(LLMParamsDoc):
host="0.0.0.0",
port=args.port,
)
async def llm_generate(input: Union[LLMParamsDoc, ChatCompletionRequest, AgentCompletionRequest]):
async def llm_generate(input: AgentCompletionRequest):
if logflag:
logger.info(input)

input.stream = args.stream
config = {"recursion_limit": args.recursion_limit}
# don't use global stream setting
# input.stream = args.stream
config = {"recursion_limit": args.recursion_limit, "tool_choice": input.tool_choice}

if args.with_memory:
if isinstance(input, AgentCompletionRequest):
config["configurable"] = {"thread_id": input.thread_id}
else:
config["configurable"] = {"thread_id": "0"}
config["configurable"] = {"thread_id": input.thread_id}

if logflag:
logger.info(type(agent_inst))

if isinstance(input, LLMParamsDoc):
# use query as input
input_query = input.query
# openai compatible input
if isinstance(input.messages, str):
messages = input.messages
else:
# openai compatible input
if isinstance(input.messages, str):
input_query = input.messages
else:
input_query = input.messages[-1]["content"]
# TODO: need handle multi-turn messages
messages = input.messages[-1]["content"]

# 2. prepare the input for the agent
if input.stream:
logger.info("-----------STREAMING-------------")
return StreamingResponse(agent_inst.stream_generator(input_query, config), media_type="text/event-stream")
return StreamingResponse(
agent_inst.stream_generator(messages, config),
media_type="text/event-stream",
)

else:
logger.info("-----------NOT STREAMING-------------")
response = await agent_inst.non_streaming_run(input_query, config)
response = await agent_inst.non_streaming_run(messages, config)
logger.info("-----------Response-------------")
return GeneratedDoc(text=response, prompt=input_query)
return GeneratedDoc(text=response, prompt=messages)


@register_microservice(
Expand Down
50 changes: 40 additions & 10 deletions comps/agent/src/integrations/strategy/react/planner.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@
from langgraph.prebuilt import create_react_agent

from ...global_var import threads_global_kv
from ...utils import has_multi_tool_inputs, tool_renderer
from ...utils import filter_tools, has_multi_tool_inputs, tool_renderer
from ..base_agent import BaseAgent
from .prompt import REACT_SYS_MESSAGE, hwchase17_react_prompt

Expand Down Expand Up @@ -136,7 +136,8 @@ async def non_streaming_run(self, query, config):
# does not rely on langchain bind_tools API
# since tgi and vllm still do not have very good support for tool calling like OpenAI

from typing import Annotated, Sequence, TypedDict
import json
from typing import Annotated, List, Optional, Sequence, TypedDict

from langchain_core.messages import AIMessage, BaseMessage
from langchain_core.prompts import PromptTemplate
Expand All @@ -154,6 +155,7 @@ class AgentState(TypedDict):
"""The state of the agent."""

messages: Annotated[Sequence[BaseMessage], add_messages]
tool_choice: Optional[List[str]] = None
is_last_step: IsLastStep


Expand Down Expand Up @@ -191,7 +193,11 @@ def __call__(self, state):
history = assemble_history(messages)
print("@@@ History: ", history)

tools_descriptions = tool_renderer(self.tools)
tools_used = self.tools
if state["tool_choice"] is not None:
tools_used = filter_tools(self.tools, state["tool_choice"])

tools_descriptions = tool_renderer(tools_used)
print("@@@ Tools description: ", tools_descriptions)

# invoke chain
Expand Down Expand Up @@ -279,21 +285,45 @@ def prepare_initial_state(self, query):

async def stream_generator(self, query, config):
initial_state = self.prepare_initial_state(query)
if "tool_choice" in config:
initial_state["tool_choice"] = config.pop("tool_choice")

try:
async for event in self.app.astream(initial_state, config=config):
for node_name, node_state in event.items():
yield f"--- CALL {node_name} ---\n"
for k, v in node_state.items():
if v is not None:
yield f"{k}: {v}\n"
async for event in self.app.astream(initial_state, config=config, stream_mode=["updates"]):
event_type = event[0]
data = event[1]
if event_type == "updates":
for node_name, node_state in data.items():
print(f"--- CALL {node_name} node ---\n")
for k, v in node_state.items():
if v is not None:
print(f"------- {k}, {v} -------\n\n")
if node_name == "agent":
if v[0].content == "":
tool_names = []
for tool_call in v[0].tool_calls:
tool_names.append(tool_call["name"])
result = {"tool": tool_names}
else:
result = {"content": [v[0].content.replace("\n\n", "\n")]}
# ui needs this format
yield f"data: {json.dumps(result)}\n\n"
elif node_name == "tools":
full_content = v[0].content
tool_name = v[0].name
result = {"tool": tool_name, "content": [full_content]}
yield f"data: {json.dumps(result)}\n\n"
if not full_content:
continue

yield f"data: {repr(event)}\n\n"
yield "data: [DONE]\n\n"
except Exception as e:
yield str(e)

async def non_streaming_run(self, query, config):
initial_state = self.prepare_initial_state(query)
if "tool_choice" in config:
initial_state["tool_choice"] = config.pop("tool_choice")
try:
async for s in self.app.astream(initial_state, config=config, stream_mode="values"):
message = s["messages"][-1]
Expand Down
8 changes: 8 additions & 0 deletions comps/agent/src/integrations/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -86,6 +86,14 @@ def tool_renderer(tools):
return "\n".join(tool_strings)


def filter_tools(tools, tools_choices):
tool_used = []
for tool in tools:
if tool.name in tools_choices:
tool_used.append(tool)
return tool_used


def has_multi_tool_inputs(tools):
ret = False
for tool in tools:
Expand Down
10 changes: 9 additions & 1 deletion comps/agent/src/tools/custom_tools.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,9 +4,17 @@

# tool for unit test
def search_web(query: str) -> str:
"""Search the web for a given query."""
"""Search the web knowledge for a given query."""
ret_text = """
The Linux Foundation AI & Data announced the Open Platform for Enterprise AI (OPEA) as its latest Sandbox Project.
OPEA aims to accelerate secure, cost-effective generative AI (GenAI) deployments for businesses by driving interoperability across a diverse and heterogeneous ecosystem, starting with retrieval-augmented generation (RAG).
"""
return ret_text


def search_weather(query: str) -> str:
"""Search the weather for a given query."""
ret_text = """
It's clear.
"""
return ret_text
4 changes: 2 additions & 2 deletions tests/agent/test.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@
def generate_answer_agent_api(url, prompt):
proxies = {"http": ""}
payload = {
"query": prompt,
"messages": prompt,
}
response = requests.post(url, json=payload, proxies=proxies)
answer = response.json()["text"]
Expand All @@ -21,7 +21,7 @@ def process_request(url, query, is_stream=False):
proxies = {"http": ""}

payload = {
"query": query,
"messages": query,
}

try:
Expand Down

0 comments on commit 3a7ccb0

Please sign in to comment.