Skip to content

Commit

Permalink
task orchestration edge cases (langroid#324)
Browse files Browse the repository at this point in the history
* task orchestration edge cases

* Task: add `allow_null_result` arg

* meilisearch version

* update table_chat_agent sys msg; add extract example

* example fix
  • Loading branch information
pchalasani authored Dec 15, 2023
1 parent 8fe3ea8 commit 0397b37
Show file tree
Hide file tree
Showing 12 changed files with 159 additions and 20 deletions.
5 changes: 3 additions & 2 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -48,8 +48,9 @@ into simplifying the developer experience; it does not use `Langchain`.
We welcome contributions -- See the [contributions](./CONTRIBUTING.md) document
for ideas on what to contribute.


Building LLM Applications? [Prasad Chalasani](https://www.linkedin.com/in/pchalasani/) is available for consulting
Are you building LLM Applications, or want help with Langroid for your company,
or want to prioritize Langroid features for your company use-cases?
[Prasad Chalasani](https://www.linkedin.com/in/pchalasani/) is available for consulting
(advisory/development): pchalasani at gmail dot com.

Sponsorship is also accepted via [GitHub Sponsors](https://github.com/sponsors/langroid)
Expand Down
68 changes: 68 additions & 0 deletions examples/extract/capitals.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,68 @@
"""
Extract structured information from a passage using a tool/function.
python3 examples/extract/capitals.py
"""

from rich import print
from pydantic import BaseModel
from typing import List
import langroid as lr


class City(BaseModel):
name: str
country: str
population: int


class CitiesData(BaseModel):
cities: List[City]


PASSAGE = """
Berlin is the capital of Germany. It has a population of 3,850,809.
Paris, France's capital, has 2.161 million residents.
Lisbon is the capital and the largest city of Portugal with the population of 504,718.
"""


class CitiesMessage(lr.agent.ToolMessage):
"""Tool/function to use to extract/present structured capitals info"""

request: str = "capital_info"
purpose: str = "Collect information about city <capitals> from a passage"
capitals: List[CitiesData]

def handle(self) -> str:
"""Tool handler: Print the info about the capitals.
Any format errors are intercepted by Langroid and passed to the LLM to fix."""
print(f"Correctly extracted Capitals Info: {self.capitals}")
return "DONE" # terminates task


agent = lr.ChatAgent(
lr.ChatAgentConfig(
name="CitiesExtractor",
use_functions_api=True,
use_tools=False,
system_message=f"""
From the passage below, extract info about city capitals, and present it
using the `capital_info` tool/function.
PASSAGE: {PASSAGE}
""",
)
)
# connect the Tool to the Agent, so it can use it to present extracted info
agent.enable_message(CitiesMessage)

# wrap the agent in a task and run it
task = lr.Task(
agent,
interactive=False,
llm_delegate=True,
single_round=False,
)

task.run()
6 changes: 0 additions & 6 deletions langroid/agent/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -632,12 +632,6 @@ def handle_message(self, msg: str | ChatDocument) -> None | str | ChatDocument:

str_doc_results = [r for r in results_list if isinstance(r, str)]
final = "\n".join(str_doc_results)
if final == "":
logger.warning(
"""final result from a tool handler should not be empty str, since
it would be considered an invalid result and other responders
will be tried, and we may not necessarily want that"""
)
return final

def handle_message_fallback(
Expand Down
2 changes: 2 additions & 0 deletions langroid/agent/chat_agent.py
Original file line number Diff line number Diff line change
Expand Up @@ -465,6 +465,8 @@ def llm_response(
if self.llm is None:
return None
hist, output_len = self._prep_llm_messages(message)
if len(hist) == 0:
return None
with StreamingIfAllowed(self.llm, self.llm.get_stream()):
response = self.llm_response_messages(hist, output_len)
# TODO - when response contains function_call we should include
Expand Down
7 changes: 7 additions & 0 deletions langroid/agent/special/recipient_validator_agent.py
Original file line number Diff line number Diff line change
Expand Up @@ -63,6 +63,13 @@ def __init__(self, config: RecipientValidatorConfig):
self.config: RecipientValidatorConfig = config
self.llm = None
self.vecdb = None
logger.warning(
"""
RecipientValidator is deprecated. Use RecipientTool instead:
See code at langroid/agent/tools/recipient_tool.py, and usage examples in
tests/main/test_multi_agent_complex.py and
"""
)

def user_response(
self,
Expand Down
2 changes: 1 addition & 1 deletion langroid/agent/special/relevance_extractor_agent.py
Original file line number Diff line number Diff line change
Expand Up @@ -102,7 +102,7 @@ def extract_segments(self, msg: SegmentExtractTool) -> str:
spec = msg.segment_list
if len(self.message_history) == 0:
return NO_ANSWER
if spec is None or spec.strip() == "":
if spec is None or spec.strip() in ["", NO_ANSWER]:
return NO_ANSWER
assert self.numbered_passage is not None, "No numbered passage"
# assume this has numbered segments
Expand Down
2 changes: 2 additions & 0 deletions langroid/agent/special/table_chat_agent.py
Original file line number Diff line number Diff line change
Expand Up @@ -49,6 +49,8 @@
If you receive an error message, try using the `run_code` tool/function
again with the corrected code.
VERY IMPORTANT: When using the `run_code` tool/function, DO NOT EXPLAIN ANYTHING,
SIMPLY USE THE TOOL, with the CODE.
Start by asking me what I want to know about the data.
"""

Expand Down
50 changes: 42 additions & 8 deletions langroid/agent/task.py
Original file line number Diff line number Diff line change
Expand Up @@ -66,6 +66,7 @@ def __init__(
interactive: bool = True,
only_user_quits_root: bool = True,
erase_substeps: bool = False,
allow_null_result: bool = True,
):
"""
A task to be performed by an agent.
Expand Down Expand Up @@ -96,6 +97,9 @@ def __init__(
erase all subtask agents' `message_history`.
Note: erasing can reduce prompt sizes, but results in repetitive
sub-task delegation.
allow_null_result (bool): if true, allow null (empty or NO_ANSWER)
as the result of a step or overall task result.
Optional, default is True.
"""
if agent is None:
agent = ChatAgent()
Expand All @@ -115,6 +119,8 @@ def __init__(
self.tsv_logger: None | logging.Logger = None
self.color_log: bool = False if settings.notebook else True
self.agent = agent
self.step_progress = False
self.task_progress = False
self.name = name or agent.config.name
self.default_human_response = default_human_response
self.interactive = interactive
Expand All @@ -129,6 +135,7 @@ def __init__(
# just the first outgoing message and last incoming message.
# Note this also completely erases sub-task agents' message_history.
self.erase_substeps = erase_substeps
self.allow_null_result = allow_null_result

agent_entity_responders = agent.entity_responders()
agent_entity_responders_async = agent.entity_responders_async()
Expand Down Expand Up @@ -300,7 +307,7 @@ def run(
) -> Optional[ChatDocument]:
"""Synchronous version of `run_async()`.
See `run_async()` for details."""

self.task_progress = False
assert (
msg is None or isinstance(msg, str) or isinstance(msg, ChatDocument)
), f"msg arg in Task.run() must be None, str, or ChatDocument, not {type(msg)}"
Expand Down Expand Up @@ -364,7 +371,7 @@ async def run_async(
# have come from another LLM), as far as this agent is concerned, the initial
# message can be considered to be from the USER
# (from the POV of this agent's LLM).

self.task_progress = False
if (
isinstance(msg, ChatDocument)
and msg.metadata.recipient != ""
Expand Down Expand Up @@ -463,6 +470,7 @@ def step(self, turns: int = -1) -> ChatDocument | None:
Synchronous version of `step_async()`. See `step_async()` for details.
"""
result = None
self.step_progress = False
parent = self.pending_message
recipient = (
""
Expand Down Expand Up @@ -533,6 +541,7 @@ async def step_async(self, turns: int = -1) -> ChatDocument | None:
different context.
"""
result = None
self.step_progress = False
parent = self.pending_message
recipient = (
""
Expand Down Expand Up @@ -615,18 +624,34 @@ def _process_responder_result(
if result.attachment is None:
self.pending_message.attachment = old_attachment
self.log_message(self.pending_sender, result, mark=True)
self.step_progress = True
self.task_progress = True
return True
else:
self.log_message(r, result)
return False

def _process_invalid_step_result(self, parent: ChatDocument | None) -> None:
responder = Entity.LLM if self.pending_sender == Entity.USER else Entity.USER
self.pending_message = ChatDocument(
content=NO_ANSWER,
metadata=ChatDocMetaData(sender=responder, parent=parent),
)
self.pending_sender = responder
"""
Since step had no valid result, decide whether to update the
self.pending_message to a NO_ANSWER message from the opposite entity,
or leave it as is.
Args:
parent (ChatDocument|None): parent message of the current message
"""
if not self.task_progress or self.allow_null_result:
# There has been no progress at all in this task, so we
# update the pending_message to a dummy NO_ANSWER msg
# from the entity 'opposite' to the current pending_sender,
# so we show "progress" and avoid getting stuck in an infinite loop.
responder = (
Entity.LLM if self.pending_sender == Entity.USER else Entity.USER
)
self.pending_message = ChatDocument(
content=NO_ANSWER,
metadata=ChatDocMetaData(sender=responder, parent=parent),
)
self.pending_sender = responder
self.log_message(self.pending_sender, self.pending_message, mark=True)

def _show_pending_message_if_debug(self) -> None:
Expand Down Expand Up @@ -745,6 +770,15 @@ def done(self) -> bool:
# for top-level task, only user can quit out
return user_quit

if (
not self.step_progress
and self.pending_sender == Entity.LLM
and not self.llm_delegate
):
# LLM is NOT driving the task, and no progress in latest step,
# and it is NOT the LLM's turn to respond, so we are done.
return True

return (
# no valid response from any entity/agent in current turn
self.pending_message is None
Expand Down
2 changes: 1 addition & 1 deletion langroid/vector_store/meilisearch.py
Original file line number Diff line number Diff line change
Expand Up @@ -165,7 +165,7 @@ async def _async_add_documents(
async with self.client() as client:
index = client.index(collection_name)
await index.add_documents_in_batches(
documents=documents, # type: ignore
documents=documents,
batch_size=self.config.batch_size,
primary_key=self.config.primary_key,
)
Expand Down
2 changes: 1 addition & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -75,7 +75,7 @@ pytest-asyncio = "^0.21.1"
docstring-parser = "^0.15"
farm-haystack = {extras = ["ocr", "preprocessing", "file-conversion", "pdf"], version = "^1.21.1"}
meilisearch = "^0.28.3"
meilisearch-python-sdk = "^2.0.1"
meilisearch-python-sdk = "^2.2.3"
litellm = {version = "^1.0.0", optional = true}
scrapy = "^2.11.0"
async-generator = "^1.10"
Expand Down
2 changes: 1 addition & 1 deletion tests/main/test_table_chat_agent.py
Original file line number Diff line number Diff line change
Expand Up @@ -155,7 +155,7 @@ def test_table_chat_agent_file_blanks(
@pytest.mark.parametrize("fn_api", [True, False])
def test_table_chat_agent_url(test_settings: Settings, fn_api: bool) -> None:
"""
Test the TableChatAgent with a dataframe as data source
Test the TableChatAgent with a URL of a csv file as data source
"""
set_global(test_settings)
URL = "https://raw.githubusercontent.com/plotly/datasets/master/2011_us_ag_exports.csv"
Expand Down
31 changes: 31 additions & 0 deletions tests/main/test_task.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,31 @@
"""
Other tests for Task are in test_chat_agent.py
"""
import pytest

from langroid.agent.chat_agent import ChatAgent, ChatAgentConfig
from langroid.agent.task import Task
from langroid.utils.configuration import Settings, set_global
from langroid.utils.constants import NO_ANSWER


@pytest.mark.parametrize("allow_null_result", [True, False])
def test_task_empty_response(test_settings: Settings, allow_null_result: bool):
set_global(test_settings)
agent = ChatAgent(ChatAgentConfig(name="Test"))
task = Task(
agent,
interactive=False,
single_round=True,
allow_null_result=allow_null_result,
system_message="""
User will send you a number.
If it is EVEN, repeat the number, else return empty string.
ONLY return these responses, say NOTHING ELSE
""",
)

response = task.run("4")
assert response.content == "4"
response = task.run("3")
assert response.content == NO_ANSWER

0 comments on commit 0397b37

Please sign in to comment.