-
-
Notifications
You must be signed in to change notification settings - Fork 407
/
_patch_parser.py
103 lines (85 loc) · 4 KB
/
_patch_parser.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
import asyncio
import json
from json import JSONDecodeError
from typing import List, Union
from langchain.agents.agent import AgentOutputParser
from langchain.agents.openai_functions_agent import base
from langchain_core.agents import AgentAction, AgentActionMessageLog, AgentFinish
from langchain_core.exceptions import OutputParserException
from langchain_core.messages import AIMessage, BaseMessage
from langchain_core.outputs import ChatGeneration, Generation
class OpenAIFunctionsAgentOutputParser(AgentOutputParser):
"""Parses a message into agent action/finish.
Is meant to be used with OpenAI models, as it relies on the specific
function_call parameter from OpenAI to convey what tools to use.
If a function_call parameter is passed, then that is used to get
the tool and tool input.
If one is not passed, then the AIMessage is assumed to be the final output.
"""
@property
def _type(self) -> str:
return "openai-functions-agent"
@staticmethod
def _parse_ai_message(message: BaseMessage) -> Union[AgentAction, AgentFinish]:
"""Parse an AI message."""
if not isinstance(message, AIMessage):
raise TypeError(f"Expected an AI message got {type(message)}")
function_call = message.additional_kwargs.get("function_call", {})
if function_call:
function_name = function_call["name"]
try:
if len(function_call["arguments"].strip()) == 0:
# OpenAI returns an empty string for functions containing no args
_tool_input = {}
else:
# otherwise it returns a json object
_tool_input = json.loads(function_call["arguments"])
except JSONDecodeError:
if function_name == "python":
code = function_call["arguments"]
_tool_input = {
"code": code,
}
else:
raise OutputParserException(
f"Could not parse tool input: {function_call} because "
f"the `arguments` is not valid JSON."
)
# HACK HACK HACK:
# The code that encodes tool input into Open AI uses a special variable
# name called `__arg1` to handle old style tools that do not expose a
# schema and expect a single string argument as an input.
# We unpack the argument here if it exists.
# Open AI does not support passing in a JSON array as an argument.
if "__arg1" in _tool_input:
tool_input = _tool_input["__arg1"]
else:
tool_input = _tool_input
content_msg = f"responded: {message.content}\n" if message.content else "\n"
log = f"\nInvoking: `{function_name}` with `{tool_input}`\n{content_msg}\n"
return AgentActionMessageLog(
tool=function_name,
tool_input=tool_input,
log=log,
message_log=[message],
)
return AgentFinish(
return_values={"output": message.content}, log=str(message.content)
)
def parse_result(
self, result: List[Generation], *, partial: bool = False
) -> Union[AgentAction, AgentFinish]:
if not isinstance(result[0], ChatGeneration):
raise ValueError("This output parser only works on ChatGeneration output")
message = result[0].message
return self._parse_ai_message(message)
async def aparse_result(
self, result: List[Generation], *, partial: bool = False
) -> Union[AgentAction, AgentFinish]:
return await asyncio.get_running_loop().run_in_executor(
None, self.parse_result, result
)
def parse(self, text: str) -> Union[AgentAction, AgentFinish]:
raise ValueError("Can only parse messages")
# patch
base.OpenAIFunctionsAgentOutputParser = OpenAIFunctionsAgentOutputParser # type: ignore