Skip to content

Commit

Permalink
Add arg parsing for dspy.ReAct (#2039)
Browse files Browse the repository at this point in the history
* Add arg parsing for dspy.ReAct

* fix broken tests
  • Loading branch information
chenmoneygithub authored Jan 12, 2025
1 parent 26850ac commit b1ae7af
Show file tree
Hide file tree
Showing 2 changed files with 103 additions and 13 deletions.
29 changes: 19 additions & 10 deletions dspy/predict/react.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,9 @@
import inspect
from typing import Any, Callable, Literal, get_origin, get_type_hints

from pydantic import BaseModel
from pydantic import BaseModel, TypeAdapter

import dspy
from dspy.adapters.json_adapter import get_annotation_name
from dspy.primitives.program import Module
from dspy.signatures.signature import ensure_signature
from dspy.utils.callback import with_callbacks
Expand All @@ -16,13 +15,16 @@ def __init__(self, func: Callable, name: str = None, desc: str = None, args: dic
self.func = func
self.name = name or getattr(func, "__name__", type(func).__name__)
self.desc = desc or getattr(func, "__doc__", None) or getattr(annotations_func, "__doc__", "")
self.args = {
k: v.schema()
if isinstance((origin := get_origin(v) or v), type) and issubclass(origin, BaseModel)
else get_annotation_name(v)
for k, v in (args or get_type_hints(annotations_func)).items()
if k != "return"
}
self.args = {}
self.arg_types = {}
for k, v in (args or get_type_hints(annotations_func)).items():
self.arg_types[k] = v
if k == "return":
continue
if isinstance((origin := get_origin(v) or v), type) and issubclass(origin, BaseModel):
self.args[k] = v.model_json_schema()
else:
self.args[k] = TypeAdapter(v).json_schema()

@with_callbacks
def __call__(self, *args, **kwargs):
Expand Down Expand Up @@ -96,7 +98,14 @@ def format(trajectory: dict[str, Any], last_iteration: bool):
trajectory[f"tool_args_{idx}"] = pred.next_tool_args

try:
trajectory[f"observation_{idx}"] = self.tools[pred.next_tool_name](**pred.next_tool_args)
parsed_tool_args = {}
for k, v in pred.next_tool_args.items():
arg_type = self.tools[pred.next_tool_name].arg_types[k]
if isinstance((origin := get_origin(arg_type) or arg_type), type) and issubclass(origin, BaseModel):
parsed_tool_args[k] = arg_type.model_validate(v)
else:
parsed_tool_args[k] = v
trajectory[f"observation_{idx}"] = self.tools[pred.next_tool_name](**parsed_tool_args)
except Exception as e:
trajectory[f"observation_{idx}"] = f"Failed to execute: {e}"

Expand Down
87 changes: 84 additions & 3 deletions tests/predict/test_react.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
import dspy
from dspy.utils.dummies import DummyLM, dummy_rm
from dspy.predict import react
from pydantic import BaseModel


# def test_example_no_tools():
Expand Down Expand Up @@ -124,15 +125,17 @@
# assert react.react[0].signature.instructions is not None
# assert react.react[0].signature.instructions.startswith("You are going to generate output based on input.")


def test_tool_from_function():
def foo(a: int, b: int) -> int:
"""Add two numbers."""
return a + b

tool = react.Tool(foo)
assert tool.name == "foo"
assert tool.desc == "Add two numbers."
assert tool.args == {"a": "int", "b": "int"}
assert tool.args == {"a": {"type": "integer"}, "b": {"type": "integer"}}


def test_tool_from_class():
class Foo:
Expand All @@ -146,4 +149,82 @@ def foo(self, a: int, b: int) -> int:
tool = react.Tool(Foo("123").foo)
assert tool.name == "foo"
assert tool.desc == "Add two numbers."
assert tool.args == {"a": "int", "b": "int"}
assert tool.args == {"a": {"type": "integer"}, "b": {"type": "integer"}}


def test_tool_calling_with_pydantic_args():
class CalendarEvent(BaseModel):
name: str
date: str
participants: dict[str, str]

def write_invitation_letter(participant_name: str, event_info: CalendarEvent):
if participant_name not in event_info.participants:
return None
return f"It's my honor to invite {participant_name} to event {event_info.name} on {event_info.date}"

class InvitationSignature(dspy.Signature):
participant_name: str = dspy.InputField(desc="The name of the participant to invite")
event_info: CalendarEvent = dspy.InputField(desc="The information about the event")
invitation_letter: str = dspy.OutputField(desc="The invitation letter to be sent to the participant")

react = dspy.ReAct(InvitationSignature, tools=[write_invitation_letter])

lm = DummyLM(
[
{
"next_thought": "I need to write an invitation letter for Alice to the Science Fair event.",
"next_tool_name": "write_invitation_letter",
"next_tool_args": {
"participant_name": "Alice",
"event_info": {
"name": "Science Fair",
"date": "Friday",
"participants": {"Alice": "female", "Bob": "male"},
},
},
},
{
"next_thought": (
"I have successfully written the invitation letter for Alice to the Science Fair. Now "
"I can finish the task."
),
"next_tool_name": "finish",
"next_tool_args": {},
},
{
"reasoning": "This is a very rigorous reasoning process, trust me bro!",
"invitation_letter": "It's my honor to invite Alice to the Science Fair event on Friday.",
},
]
)
dspy.settings.configure(lm=lm)

outputs = react(
participant_name="Alice",
event_info=CalendarEvent(
name="Science Fair",
date="Friday",
participants={"Alice": "female", "Bob": "male"},
),
)
assert outputs.invitation_letter == "It's my honor to invite Alice to the Science Fair event on Friday."

expected_trajectory = {
"thought_0": "I need to write an invitation letter for Alice to the Science Fair event.",
"tool_name_0": "write_invitation_letter",
"tool_args_0": {
"participant_name": "Alice",
"event_info": {
"name": "Science Fair",
"date": "Friday",
"participants": {"Alice": "female", "Bob": "male"},
},
},
"observation_0": "It's my honor to invite Alice to event Science Fair on Friday",
"thought_1": "I have successfully written the invitation letter for Alice to the Science Fair. Now I can finish the task.",
"tool_name_1": "finish",
"tool_args_1": {},
"observation_1": "Completed.",
}
assert outputs.trajectory == expected_trajectory

0 comments on commit b1ae7af

Please sign in to comment.