diff --git a/examples/docqa/chat-local.py b/examples/docqa/chat-local.py index c63aaea51..f6839fe05 100644 --- a/examples/docqa/chat-local.py +++ b/examples/docqa/chat-local.py @@ -92,6 +92,12 @@ def main( # "local/localhost:8000" chat_context_length=2048, # adjust based on model ) + + relevance_extractor_config = lr.agent.special.RelevanceExtractorAgentConfig( + llm=llm_config, # or this could be a different llm_config + # system_message="...override default RelevanceExtractorAgent system msg here", + ) + config = DocChatAgentConfig( n_query_rephrases=0, cross_encoder_reranking_model="cross-encoder/ms-marco-MiniLM-L-6-v2", @@ -99,6 +105,7 @@ def main( # set it to > 0 to retrieve a window of k chunks on either side of a match n_neighbor_chunks=0, llm=llm_config, + relevance_extractor_config=relevance_extractor_config, # or None to turn off # system_message="...override default DocChatAgent system msg here", # user_message="...override default DocChatAgent user msg here", # summarize_prompt="...override default DocChatAgent summarize prompt here", diff --git a/langroid/agent/base.py b/langroid/agent/base.py index bca0eeaf1..b87f00bb4 100644 --- a/langroid/agent/base.py +++ b/langroid/agent/base.py @@ -554,6 +554,16 @@ def get_function_call_class(self, msg: ChatDocument) -> Optional[ToolMessage]: tool_name = msg.function_call.name tool_msg = msg.function_call.arguments or {} if tool_name not in self.llm_tools_handled: + logger.warning( + f""" + The function_call '{tool_name}' is not handled + by this agent named '{self.config.name}'! + If you intended this agent to handle this function_call, + either the fn-call name is incorrectly generated by the LLM, + (in which case you may need to adjust your LLM instructions), + or you need to enable this agent to handle this fn-call. + """ + ) raise ValueError(f"{tool_name} is not a valid function_call!") tool_class = self.llm_tools_map[tool_name] tool_msg.update(dict(request=tool_name)) diff --git a/langroid/agent/special/doc_chat_agent.py b/langroid/agent/special/doc_chat_agent.py index 36656c11e..dbfdbca15 100644 --- a/langroid/agent/special/doc_chat_agent.py +++ b/langroid/agent/special/doc_chat_agent.py @@ -107,7 +107,7 @@ class DocChatAgentConfig(ChatAgentConfig): cache: bool = True # cache results debug: bool = False stream: bool = True # allow streaming where needed - relevance_extractor_config: RelevanceExtractorAgentConfig = ( + relevance_extractor_config: None | RelevanceExtractorAgentConfig = ( RelevanceExtractorAgentConfig() ) doc_paths: List[str] = [] @@ -834,6 +834,10 @@ def get_verbatim_extracts( List[Document]: list of Documents containing extracts and metadata. """ agent_cfg = self.config.relevance_extractor_config + if agent_cfg is None: + # no relevance extraction: simply return passages + return passages + agent_cfg.query = query agent_cfg.segment_length = 1 agent_cfg.llm.stream = False # disable streaming for concurrent calls diff --git a/langroid/agent/task.py b/langroid/agent/task.py index 9cfd7f3f7..64185f5c0 100644 --- a/langroid/agent/task.py +++ b/langroid/agent/task.py @@ -67,6 +67,7 @@ def __init__( only_user_quits_root: bool = True, erase_substeps: bool = False, allow_null_result: bool = True, + max_stalled_steps: int = 3, ): """ A task to be performed by an agent. @@ -100,6 +101,8 @@ def __init__( allow_null_result (bool): if true, allow null (empty or NO_ANSWER) as the result of a step or overall task result. Optional, default is True. + max_stalled_steps (int): task considered done after this many consecutive + steps with no progress. Default is 3. """ if agent is None: agent = ChatAgent() @@ -119,8 +122,10 @@ def __init__( self.tsv_logger: None | logging.Logger = None self.color_log: bool = False if settings.notebook else True self.agent = agent - self.step_progress = False - self.task_progress = False + self.step_progress = False # progress in current step? + self.n_stalled_steps = 0 # how many consecutive steps with no progress? + self.max_stalled_steps = max_stalled_steps + self.task_progress = False # progress in current task (since run or run_async)? self.name = name or agent.config.name self.default_human_response = default_human_response self.interactive = interactive @@ -626,6 +631,7 @@ def _process_responder_result( self.log_message(self.pending_sender, result, mark=True) self.step_progress = True self.task_progress = True + self.n_stalled_steps = 0 # reset stuck counter since we made progress return True else: self.log_message(r, result) @@ -639,6 +645,7 @@ def _process_invalid_step_result(self, parent: ChatDocument | None) -> None: Args: parent (ChatDocument|None): parent message of the current message """ + self.n_stalled_steps += 1 if not self.task_progress or self.allow_null_result: # There has been no progress at all in this task, so we # update the pending_message to a dummy NO_ANSWER msg @@ -770,13 +777,18 @@ def done(self) -> bool: # for top-level task, only user can quit out return user_quit + if self.n_stalled_steps >= self.max_stalled_steps: + # we are stuck, so bail to avoid infinite loop + return True if ( not self.step_progress and self.pending_sender == Entity.LLM - and not self.llm_delegate + and (not self.llm_delegate or not self._can_respond(Entity.LLM)) ): - # LLM is NOT driving the task, and no progress in latest step, - # and it is NOT the LLM's turn to respond, so we are done. + # no progress in latest step, and pending msg is from LLM, and + # EITHER LLM is not driving the task, + # OR LLM IS driving the task, but CANNOT respond + # (e.g. b/c the pending message is a function call) return True return (