Skip to content

Commit

Permalink
🔨 fix tool error
Browse files Browse the repository at this point in the history
  • Loading branch information
shroominic committed Sep 29, 2023
1 parent 4f630cd commit c290aa2
Showing 1 changed file with 83 additions and 116 deletions.
199 changes: 83 additions & 116 deletions codeinterpreterapi/agents/functions_agent.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,13 @@
"""Module implements an agent that uses OpenAI's APIs function enabled API."""
import json
from dataclasses import dataclass
from json import JSONDecodeError
from typing import Any, List, Optional, Sequence, Tuple, Union

from langchain.agents import BaseSingleActionAgent
from langchain.agents.agent import AgentOutputParser
from langchain.agents.format_scratchpad.openai_functions import (
format_to_openai_functions,
)
from langchain.callbacks.base import BaseCallbackManager
from langchain.callbacks.manager import Callbacks # type: ignore
from langchain.chat_models.openai import ChatOpenAI
Expand All @@ -18,131 +21,91 @@
from langchain.schema import (
AgentAction,
AgentFinish,
BasePromptTemplate,
OutputParserException,
)
from langchain.schema.language_model import BaseLanguageModel
from langchain.schema.messages import (
AIMessage,
BaseMessage,
FunctionMessage,
BasePromptTemplate,
OutputParserException,
SystemMessage,
)
from langchain.tools import BaseTool
from langchain.tools.convert_to_openai import format_tool_to_openai_function


@dataclass
class _FunctionsAgentAction(AgentAction):
message_log: List[BaseMessage]
from langchain.schema.agent import AgentActionMessageLog
from langchain.schema.language_model import BaseLanguageModel
from langchain.schema.output import ChatGeneration, Generation
from langchain.tools.base import BaseTool
from langchain.tools.render import format_tool_to_openai_function


def _convert_agent_action_to_messages(
agent_action: AgentAction, observation: str
) -> List[BaseMessage]:
"""Convert an agent action to a message.
class OpenAIFunctionsAgentOutputParser(AgentOutputParser):
"""Parses a message into agent action/finish.
This code is used to reconstruct the original AI message from the agent action.
Is meant to be used with OpenAI models, as it relies on the specific
function_call parameter from OpenAI to convey what tools to use.
Args:
agent_action: Agent action to convert.
If a function_call parameter is passed, then that is used to get
the tool and tool input.
Returns:
AIMessage that corresponds to the original tool invocation.
"""
if isinstance(agent_action, _FunctionsAgentAction):
return agent_action.message_log + [
_create_function_message(agent_action, observation)
]
else:
return [AIMessage(content=agent_action.log)]


def _create_function_message(
agent_action: AgentAction, observation: str
) -> FunctionMessage:
"""Convert agent action and observation into a function message.
Args:
agent_action: the tool invocation request from the agent
observation: the result of the tool invocation
Returns:
FunctionMessage that corresponds to the original tool invocation
"""
if not isinstance(observation, str):
try:
content = json.dumps(observation, ensure_ascii=False)
except Exception:
content = str(observation)
else:
content = observation
return FunctionMessage(
name=agent_action.tool,
content=content,
)


def _format_intermediate_steps(
intermediate_steps: List[Tuple[AgentAction, str]],
) -> List[BaseMessage]:
"""Format intermediate steps.
Args:
intermediate_steps: Steps the LLM has taken to date, along with observations
Returns:
list of messages to send to the LLM for the next prediction
If one is not passed, then the AIMessage is assumed to be the final output.
"""
messages = []

for intermediate_step in intermediate_steps:
agent_action, observation = intermediate_step
messages.extend(_convert_agent_action_to_messages(agent_action, observation))

return messages


def _parse_ai_message(message: BaseMessage) -> Union[AgentAction, AgentFinish]:
"""Parse an AI message."""
if not isinstance(message, AIMessage):
raise TypeError(f"Expected an AI message got {type(message)}")

function_call = message.additional_kwargs.get("function_call", {})

if function_call:
function_name = function_call["name"]
try:
_tool_input = json.loads(function_call["arguments"])
except JSONDecodeError:
if function_name == "python":
code = function_call["arguments"]
_tool_input = {
"code": code,
}
@property
def _type(self) -> str:
return "openai-functions-agent"

@staticmethod
def _parse_ai_message(message: BaseMessage) -> Union[AgentAction, AgentFinish]:
"""Parse an AI message."""
if not isinstance(message, AIMessage):
raise TypeError(f"Expected an AI message got {type(message)}")

function_call = message.additional_kwargs.get("function_call", {})

if function_call:
function_name = function_call["name"]
try:
_tool_input = json.loads(function_call["arguments"])
except JSONDecodeError:
if function_name == "python":
code = function_call["arguments"]
_tool_input = {
"code": code,
}
else:
raise OutputParserException(
f"Could not parse tool input: {function_call} because "
f"the `arguments` is not valid JSON."
)

# HACK HACK HACK:
# The code that encodes tool input into Open AI uses a special variable
# name called `__arg1` to handle old style tools that do not expose a
# schema and expect a single string argument as an input.
# We unpack the argument here if it exists.
# Open AI does not support passing in a JSON array as an argument.
if "__arg1" in _tool_input:
tool_input = _tool_input["__arg1"]
else:
raise OutputParserException(
f"Could not parse tool input: {function_call} because "
f"the `arguments` is not valid JSON."
)

# HACK HACK HACK:
# The code that encodes tool input into Open AI uses a special variable
# name called `__arg1` to handle old style tools that do not expose a
# schema and expect a single string argument as an input.
# We unpack the argument here if it exists.
# Open AI does not support passing in a JSON array as an argument.
if "__arg1" in _tool_input:
tool_input = _tool_input["__arg1"]
else:
tool_input = _tool_input

content_msg = "responded: {content}\n" if message.content else "\n"
tool_input = _tool_input

content_msg = f"responded: {message.content}\n" if message.content else "\n"
log = f"\nInvoking: `{function_name}` with `{tool_input}`\n{content_msg}\n"
return AgentActionMessageLog(
tool=function_name,
tool_input=tool_input,
log=log,
message_log=[message],
)

return _FunctionsAgentAction(
tool=function_name,
tool_input=tool_input,
log=f"\nInvoking: `{function_name}` with `{tool_input}`\n{content_msg}\n",
message_log=[message],
return AgentFinish(
return_values={"output": message.content}, log=message.content
)

return AgentFinish(return_values={"output": message.content}, log=message.content)
def parse_result(self, result: List[Generation]) -> Union[AgentAction, AgentFinish]:
if not isinstance(result[0], ChatGeneration):
raise ValueError("This output parser only works on ChatGeneration output")
message = result[0].message
return self._parse_ai_message(message)

def parse(self, text: str) -> Union[AgentAction, AgentFinish]:
raise ValueError("Can only parse messages")


class OpenAIFunctionsAgent(BaseSingleActionAgent):
Expand Down Expand Up @@ -206,7 +169,7 @@ def plan(
Returns:
Action specifying what tool to use.
"""
agent_scratchpad = _format_intermediate_steps(intermediate_steps)
agent_scratchpad = format_to_openai_functions(intermediate_steps)
selected_inputs = {
k: kwargs[k] for k in self.prompt.input_variables if k != "agent_scratchpad"
}
Expand All @@ -224,7 +187,9 @@ def plan(
messages,
callbacks=callbacks,
)
agent_decision = _parse_ai_message(predicted_message)
agent_decision = OpenAIFunctionsAgentOutputParser._parse_ai_message(
predicted_message
)
return agent_decision

async def aplan(
Expand All @@ -243,7 +208,7 @@ async def aplan(
Returns:
Action specifying what tool to use.
"""
agent_scratchpad = _format_intermediate_steps(intermediate_steps)
agent_scratchpad = format_to_openai_functions(intermediate_steps)
selected_inputs = {
k: kwargs[k] for k in self.prompt.input_variables if k != "agent_scratchpad"
}
Expand All @@ -253,7 +218,9 @@ async def aplan(
predicted_message = await self.llm.apredict_messages(
messages, functions=self.functions, callbacks=callbacks
)
agent_decision = _parse_ai_message(predicted_message)
agent_decision = OpenAIFunctionsAgentOutputParser._parse_ai_message(
predicted_message
)
return agent_decision

def return_stopped_response(
Expand Down Expand Up @@ -339,7 +306,7 @@ def from_llm_and_tools(
extra_prompt_messages=extra_prompt_messages,
system_message=system_message,
)
return cls( # type: ignore
return cls(
llm=llm,
prompt=prompt,
tools=tools,
Expand Down

0 comments on commit c290aa2

Please sign in to comment.