Skip to content

Commit

Permalink
Standardize Chat Message Stream (onyx-dot-app#1098)
Browse files Browse the repository at this point in the history
  • Loading branch information
yuhongsun96 authored Feb 19, 2024
1 parent 31278fc commit 15335dc
Show file tree
Hide file tree
Showing 2 changed files with 37 additions and 23 deletions.
56 changes: 35 additions & 21 deletions backend/danswer/chat/process_message.py
Original file line number Diff line number Diff line change
Expand Up @@ -59,6 +59,7 @@
from danswer.search.search_runner import inference_documents_from_ids
from danswer.secondary_llm_flows.choose_search import check_if_need_search
from danswer.secondary_llm_flows.query_expansion import history_based_query_rephrase
from danswer.server.query_and_chat.models import ChatMessageDetail
from danswer.server.query_and_chat.models import CreateChatMessageRequest
from danswer.server.utils import get_json_line
from danswer.utils.logger import setup_logger
Expand Down Expand Up @@ -153,8 +154,7 @@ def translate_citations(
return citation_to_saved_doc_id_map


@log_generator_function_time()
def stream_chat_message(
def stream_chat_message_objects(
new_msg_req: CreateChatMessageRequest,
user: User | None,
db_session: Session,
Expand All @@ -164,7 +164,14 @@ def stream_chat_message(
# For flow with search, don't include as many chunks as possible since we need to leave space
# for the chat history, for smaller models, we likely won't get MAX_CHUNKS_FED_TO_CHAT chunks
max_document_percentage: float = CHAT_TARGET_CHUNK_PERCENTAGE,
) -> Iterator[str]:
) -> Iterator[
StreamingError
| QADocsResponse
| LLMRelevanceFilterResponse
| ChatMessageDetail
| DanswerAnswerPiece
| CitationInfo
]:
"""Streams in order:
1. [conditional] Retrieved documents if a search needs to be run
2. [conditional] LLM selected chunk indices if LLM chunk filtering is turned on
Expand Down Expand Up @@ -313,10 +320,8 @@ def stream_chat_message(
# only allow the final document to get truncated
# if more than that, then the user message is too long
if final_doc_ind != len(tokens_per_doc) - 1:
yield get_json_line(
StreamingError(
error="LLM context window exceeded. Please de-select some documents or shorten your query."
).dict()
yield StreamingError(
error="LLM context window exceeded. Please de-select some documents or shorten your query."
)
return

Expand Down Expand Up @@ -417,8 +422,8 @@ def stream_chat_message(
applied_source_filters=retrieval_request.filters.source_type,
applied_time_cutoff=time_cutoff,
recency_bias_multiplier=recency_bias_multiplier,
).dict()
yield get_json_line(initial_response)
)
yield initial_response

# Get the final ordering of chunks for the LLM call
llm_chunk_selection = cast(list[bool], next(documents_generator))
Expand All @@ -430,8 +435,8 @@ def stream_chat_message(
]
if run_llm_chunk_filter
else []
).dict()
yield get_json_line(llm_relevance_filtering_response)
)
yield llm_relevance_filtering_response

# Prep chunks to pass to LLM
num_llm_chunks = (
Expand Down Expand Up @@ -497,7 +502,7 @@ def stream_chat_message(
gen_ai_response_message
)

yield get_json_line(msg_detail_response.dict())
yield msg_detail_response

# Stop here after saving message details, the above still needs to be sent for the
# message id to send the next follow-up message
Expand Down Expand Up @@ -530,17 +535,13 @@ def stream_chat_message(
citations.append(packet)
continue

yield get_json_line(packet.dict())
yield packet
except Exception as e:
logger.exception(e)

# Frontend will erase whatever answer and show this instead
# This will be the issue 99% of the time
error_packet = StreamingError(
error="LLM failed to respond, have you set your API key?"
)

yield get_json_line(error_packet.dict())
yield StreamingError(error="LLM failed to respond, have you set your API key?")
return

# Post-LLM answer processing
Expand All @@ -564,11 +565,24 @@ def stream_chat_message(
gen_ai_response_message
)

yield get_json_line(msg_detail_response.dict())
yield msg_detail_response
except Exception as e:
logger.exception(e)

# Frontend will erase whatever answer and show this instead
error_packet = StreamingError(error="Failed to parse LLM output")
yield StreamingError(error="Failed to parse LLM output")

yield get_json_line(error_packet.dict())

@log_generator_function_time()
def stream_chat_message(
new_msg_req: CreateChatMessageRequest,
user: User | None,
db_session: Session,
) -> Iterator[str]:
objects = stream_chat_message_objects(
new_msg_req=new_msg_req,
user=user,
db_session=db_session,
)
for obj in objects:
yield get_json_line(obj.dict())
4 changes: 2 additions & 2 deletions backend/danswer/search/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -86,10 +86,10 @@ class RetrievalDetails(BaseModel):
# Use LLM to determine whether to do a retrieval or only rely on existing history
# If the Persona is configured to not run search (0 chunks), this is bypassed
# If no Prompt is configured, the only search results are shown, this is bypassed
run_search: OptionalSearchSetting
run_search: OptionalSearchSetting = OptionalSearchSetting.ALWAYS
# Is this a real-time/streaming call or a question where Danswer can take more time?
# Used to determine reranking flow
real_time: bool
real_time: bool = True
# The following have defaults in the Persona settings which can be overriden via
# the query, if None, then use Persona settings
filters: BaseFilters | None = None
Expand Down

0 comments on commit 15335dc

Please sign in to comment.