diff --git a/backend/danswer/chat/process_message.py b/backend/danswer/chat/process_message.py index 1f8ecc55a0e..7295682dd4f 100644 --- a/backend/danswer/chat/process_message.py +++ b/backend/danswer/chat/process_message.py @@ -59,6 +59,7 @@ from danswer.search.search_runner import inference_documents_from_ids from danswer.secondary_llm_flows.choose_search import check_if_need_search from danswer.secondary_llm_flows.query_expansion import history_based_query_rephrase +from danswer.server.query_and_chat.models import ChatMessageDetail from danswer.server.query_and_chat.models import CreateChatMessageRequest from danswer.server.utils import get_json_line from danswer.utils.logger import setup_logger @@ -153,8 +154,7 @@ def translate_citations( return citation_to_saved_doc_id_map -@log_generator_function_time() -def stream_chat_message( +def stream_chat_message_objects( new_msg_req: CreateChatMessageRequest, user: User | None, db_session: Session, @@ -164,7 +164,14 @@ def stream_chat_message( # For flow with search, don't include as many chunks as possible since we need to leave space # for the chat history, for smaller models, we likely won't get MAX_CHUNKS_FED_TO_CHAT chunks max_document_percentage: float = CHAT_TARGET_CHUNK_PERCENTAGE, -) -> Iterator[str]: +) -> Iterator[ + StreamingError + | QADocsResponse + | LLMRelevanceFilterResponse + | ChatMessageDetail + | DanswerAnswerPiece + | CitationInfo +]: """Streams in order: 1. [conditional] Retrieved documents if a search needs to be run 2. [conditional] LLM selected chunk indices if LLM chunk filtering is turned on @@ -313,10 +320,8 @@ def stream_chat_message( # only allow the final document to get truncated # if more than that, then the user message is too long if final_doc_ind != len(tokens_per_doc) - 1: - yield get_json_line( - StreamingError( - error="LLM context window exceeded. Please de-select some documents or shorten your query." - ).dict() + yield StreamingError( + error="LLM context window exceeded. Please de-select some documents or shorten your query." ) return @@ -417,8 +422,8 @@ def stream_chat_message( applied_source_filters=retrieval_request.filters.source_type, applied_time_cutoff=time_cutoff, recency_bias_multiplier=recency_bias_multiplier, - ).dict() - yield get_json_line(initial_response) + ) + yield initial_response # Get the final ordering of chunks for the LLM call llm_chunk_selection = cast(list[bool], next(documents_generator)) @@ -430,8 +435,8 @@ def stream_chat_message( ] if run_llm_chunk_filter else [] - ).dict() - yield get_json_line(llm_relevance_filtering_response) + ) + yield llm_relevance_filtering_response # Prep chunks to pass to LLM num_llm_chunks = ( @@ -497,7 +502,7 @@ def stream_chat_message( gen_ai_response_message ) - yield get_json_line(msg_detail_response.dict()) + yield msg_detail_response # Stop here after saving message details, the above still needs to be sent for the # message id to send the next follow-up message @@ -530,17 +535,13 @@ def stream_chat_message( citations.append(packet) continue - yield get_json_line(packet.dict()) + yield packet except Exception as e: logger.exception(e) # Frontend will erase whatever answer and show this instead # This will be the issue 99% of the time - error_packet = StreamingError( - error="LLM failed to respond, have you set your API key?" - ) - - yield get_json_line(error_packet.dict()) + yield StreamingError(error="LLM failed to respond, have you set your API key?") return # Post-LLM answer processing @@ -564,11 +565,24 @@ def stream_chat_message( gen_ai_response_message ) - yield get_json_line(msg_detail_response.dict()) + yield msg_detail_response except Exception as e: logger.exception(e) # Frontend will erase whatever answer and show this instead - error_packet = StreamingError(error="Failed to parse LLM output") + yield StreamingError(error="Failed to parse LLM output") - yield get_json_line(error_packet.dict()) + +@log_generator_function_time() +def stream_chat_message( + new_msg_req: CreateChatMessageRequest, + user: User | None, + db_session: Session, +) -> Iterator[str]: + objects = stream_chat_message_objects( + new_msg_req=new_msg_req, + user=user, + db_session=db_session, + ) + for obj in objects: + yield get_json_line(obj.dict()) diff --git a/backend/danswer/search/models.py b/backend/danswer/search/models.py index 93c5a18688b..db3dc31f83b 100644 --- a/backend/danswer/search/models.py +++ b/backend/danswer/search/models.py @@ -86,10 +86,10 @@ class RetrievalDetails(BaseModel): # Use LLM to determine whether to do a retrieval or only rely on existing history # If the Persona is configured to not run search (0 chunks), this is bypassed # If no Prompt is configured, the only search results are shown, this is bypassed - run_search: OptionalSearchSetting + run_search: OptionalSearchSetting = OptionalSearchSetting.ALWAYS # Is this a real-time/streaming call or a question where Danswer can take more time? # Used to determine reranking flow - real_time: bool + real_time: bool = True # The following have defaults in the Persona settings which can be overriden via # the query, if None, then use Persona settings filters: BaseFilters | None = None