Skip to content

Commit

Permalink
task done if stalled; agent friendly warn on misnamed tool (langroid#327
Browse files Browse the repository at this point in the history
)

* task done if stalled; agent friendly warn on misnamed tool

* improve wrong tool name warning

* DocChatAgent: make relevance extractor optional
  • Loading branch information
pchalasani authored Dec 17, 2023
1 parent 36c192c commit 41bf092
Show file tree
Hide file tree
Showing 4 changed files with 39 additions and 6 deletions.
7 changes: 7 additions & 0 deletions examples/docqa/chat-local.py
Original file line number Diff line number Diff line change
Expand Up @@ -92,13 +92,20 @@ def main(
# "local/localhost:8000"
chat_context_length=2048, # adjust based on model
)

relevance_extractor_config = lr.agent.special.RelevanceExtractorAgentConfig(
llm=llm_config, # or this could be a different llm_config
# system_message="...override default RelevanceExtractorAgent system msg here",
)

config = DocChatAgentConfig(
n_query_rephrases=0,
cross_encoder_reranking_model="cross-encoder/ms-marco-MiniLM-L-6-v2",
hypothetical_answer=False,
# set it to > 0 to retrieve a window of k chunks on either side of a match
n_neighbor_chunks=0,
llm=llm_config,
relevance_extractor_config=relevance_extractor_config, # or None to turn off
# system_message="...override default DocChatAgent system msg here",
# user_message="...override default DocChatAgent user msg here",
# summarize_prompt="...override default DocChatAgent summarize prompt here",
Expand Down
10 changes: 10 additions & 0 deletions langroid/agent/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -554,6 +554,16 @@ def get_function_call_class(self, msg: ChatDocument) -> Optional[ToolMessage]:
tool_name = msg.function_call.name
tool_msg = msg.function_call.arguments or {}
if tool_name not in self.llm_tools_handled:
logger.warning(
f"""
The function_call '{tool_name}' is not handled
by this agent named '{self.config.name}'!
If you intended this agent to handle this function_call,
either the fn-call name is incorrectly generated by the LLM,
(in which case you may need to adjust your LLM instructions),
or you need to enable this agent to handle this fn-call.
"""
)
raise ValueError(f"{tool_name} is not a valid function_call!")
tool_class = self.llm_tools_map[tool_name]
tool_msg.update(dict(request=tool_name))
Expand Down
6 changes: 5 additions & 1 deletion langroid/agent/special/doc_chat_agent.py
Original file line number Diff line number Diff line change
Expand Up @@ -107,7 +107,7 @@ class DocChatAgentConfig(ChatAgentConfig):
cache: bool = True # cache results
debug: bool = False
stream: bool = True # allow streaming where needed
relevance_extractor_config: RelevanceExtractorAgentConfig = (
relevance_extractor_config: None | RelevanceExtractorAgentConfig = (
RelevanceExtractorAgentConfig()
)
doc_paths: List[str] = []
Expand Down Expand Up @@ -834,6 +834,10 @@ def get_verbatim_extracts(
List[Document]: list of Documents containing extracts and metadata.
"""
agent_cfg = self.config.relevance_extractor_config
if agent_cfg is None:
# no relevance extraction: simply return passages
return passages

agent_cfg.query = query
agent_cfg.segment_length = 1
agent_cfg.llm.stream = False # disable streaming for concurrent calls
Expand Down
22 changes: 17 additions & 5 deletions langroid/agent/task.py
Original file line number Diff line number Diff line change
Expand Up @@ -67,6 +67,7 @@ def __init__(
only_user_quits_root: bool = True,
erase_substeps: bool = False,
allow_null_result: bool = True,
max_stalled_steps: int = 3,
):
"""
A task to be performed by an agent.
Expand Down Expand Up @@ -100,6 +101,8 @@ def __init__(
allow_null_result (bool): if true, allow null (empty or NO_ANSWER)
as the result of a step or overall task result.
Optional, default is True.
max_stalled_steps (int): task considered done after this many consecutive
steps with no progress. Default is 3.
"""
if agent is None:
agent = ChatAgent()
Expand All @@ -119,8 +122,10 @@ def __init__(
self.tsv_logger: None | logging.Logger = None
self.color_log: bool = False if settings.notebook else True
self.agent = agent
self.step_progress = False
self.task_progress = False
self.step_progress = False # progress in current step?
self.n_stalled_steps = 0 # how many consecutive steps with no progress?
self.max_stalled_steps = max_stalled_steps
self.task_progress = False # progress in current task (since run or run_async)?
self.name = name or agent.config.name
self.default_human_response = default_human_response
self.interactive = interactive
Expand Down Expand Up @@ -626,6 +631,7 @@ def _process_responder_result(
self.log_message(self.pending_sender, result, mark=True)
self.step_progress = True
self.task_progress = True
self.n_stalled_steps = 0 # reset stuck counter since we made progress
return True
else:
self.log_message(r, result)
Expand All @@ -639,6 +645,7 @@ def _process_invalid_step_result(self, parent: ChatDocument | None) -> None:
Args:
parent (ChatDocument|None): parent message of the current message
"""
self.n_stalled_steps += 1
if not self.task_progress or self.allow_null_result:
# There has been no progress at all in this task, so we
# update the pending_message to a dummy NO_ANSWER msg
Expand Down Expand Up @@ -770,13 +777,18 @@ def done(self) -> bool:
# for top-level task, only user can quit out
return user_quit

if self.n_stalled_steps >= self.max_stalled_steps:
# we are stuck, so bail to avoid infinite loop
return True
if (
not self.step_progress
and self.pending_sender == Entity.LLM
and not self.llm_delegate
and (not self.llm_delegate or not self._can_respond(Entity.LLM))
):
# LLM is NOT driving the task, and no progress in latest step,
# and it is NOT the LLM's turn to respond, so we are done.
# no progress in latest step, and pending msg is from LLM, and
# EITHER LLM is not driving the task,
# OR LLM IS driving the task, but CANNOT respond
# (e.g. b/c the pending message is a function call)
return True

return (
Expand Down

0 comments on commit 41bf092

Please sign in to comment.