Skip to content

Commit

Permalink
Ability to pass through headers to LLM call
Browse files Browse the repository at this point in the history
  • Loading branch information
Weves committed Jun 10, 2024
1 parent 180b592 commit b723627
Show file tree
Hide file tree
Showing 13 changed files with 99 additions and 35 deletions.
13 changes: 10 additions & 3 deletions backend/danswer/chat/process_message.py
Original file line number Diff line number Diff line change
Expand Up @@ -193,6 +193,7 @@ def stream_chat_message_objects(
# on the `new_msg_req.message`. Currently, requires a state where the last message is a
# user message (e.g. this can only be used for the chat-seeding flow).
use_existing_user_message: bool = False,
litellm_additional_headers: dict[str, str] | None = None,
) -> ChatPacketStream:
"""Streams in order:
1. [conditional] Retrieved documents if a search needs to be run
Expand Down Expand Up @@ -228,7 +229,9 @@ def stream_chat_message_objects(

try:
llm = get_llm_for_persona(
persona, new_msg_req.llm_override or chat_session.llm_override
persona=persona,
llm_override=new_msg_req.llm_override or chat_session.llm_override,
additional_headers=litellm_additional_headers,
)
except GenAIDisabledException:
raise RuntimeError("LLM is disabled. Can't use chat flow without LLM.")
Expand Down Expand Up @@ -410,7 +413,7 @@ def stream_chat_message_objects(
persona=persona,
retrieval_options=retrieval_options,
prompt_config=prompt_config,
llm_config=llm.config,
llm=llm,
pruning_config=document_pruning_config,
selected_docs=selected_llm_docs,
chunks_above=new_msg_req.chunks_above,
Expand Down Expand Up @@ -455,7 +458,9 @@ def stream_chat_message_objects(
llm=(
llm
or get_llm_for_persona(
persona, new_msg_req.llm_override or chat_session.llm_override
persona=persona,
llm_override=new_msg_req.llm_override or chat_session.llm_override,
additional_headers=litellm_additional_headers,
)
),
message_history=[
Expand Down Expand Up @@ -576,13 +581,15 @@ def stream_chat_message(
new_msg_req: CreateChatMessageRequest,
user: User | None,
use_existing_user_message: bool = False,
litellm_additional_headers: dict[str, str] | None = None,
) -> Iterator[str]:
with get_session_context_manager() as db_session:
objects = stream_chat_message_objects(
new_msg_req=new_msg_req,
user=user,
db_session=db_session,
use_existing_user_message=use_existing_user_message,
litellm_additional_headers=litellm_additional_headers,
)
for obj in objects:
yield get_json_line(obj.dict())
17 changes: 16 additions & 1 deletion backend/danswer/configs/model_configs.py
Original file line number Diff line number Diff line change
Expand Up @@ -100,7 +100,7 @@
).lower() == "true"

# extra headers to pass to LiteLLM
LITELLM_EXTRA_HEADERS = None
LITELLM_EXTRA_HEADERS: dict[str, str] | None = None
_LITELLM_EXTRA_HEADERS_RAW = os.environ.get("LITELLM_EXTRA_HEADERS")
if _LITELLM_EXTRA_HEADERS_RAW:
try:
Expand All @@ -113,3 +113,18 @@
logger.error(
"Failed to parse LITELLM_EXTRA_HEADERS, must be a valid JSON object"
)

# if specified, will pass through request headers to the call to the LLM
LITELLM_PASS_THROUGH_HEADERS: list[str] | None = None
_LITELLM_PASS_THROUGH_HEADERS_RAW = os.environ.get("LITELLM_PASS_THROUGH_HEADERS")
if _LITELLM_PASS_THROUGH_HEADERS_RAW:
try:
LITELLM_PASS_THROUGH_HEADERS = json.loads(_LITELLM_PASS_THROUGH_HEADERS_RAW)
except Exception:
# need to import here to avoid circular imports
from danswer.utils.logger import setup_logger

logger = setup_logger()
logger.error(
"Failed to parse LITELLM_PASS_THROUGH_HEADERS, must be a valid JSON object"
)
16 changes: 14 additions & 2 deletions backend/danswer/llm/factory.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,9 @@


def get_llm_for_persona(
persona: Persona, llm_override: LLMOverride | None = None
persona: Persona,
llm_override: LLMOverride | None = None,
additional_headers: dict[str, str] | None = None,
) -> LLM:
model_provider_override = llm_override.model_provider if llm_override else None
model_version_override = llm_override.model_version if llm_override else None
Expand All @@ -25,6 +27,7 @@ def get_llm_for_persona(
),
model_version=(model_version_override or persona.llm_model_version_override),
temperature=temperature_override or GEN_AI_TEMPERATURE,
additional_headers=additional_headers,
)


Expand All @@ -34,6 +37,7 @@ def get_default_llm(
use_fast_llm: bool = False,
model_provider_name: str | None = None,
model_version: str | None = None,
additional_headers: dict[str, str] | None = None,
) -> LLM:
if DISABLE_GENERATIVE_AI:
raise GenAIDisabledException()
Expand Down Expand Up @@ -65,6 +69,7 @@ def get_default_llm(
custom_config=llm_provider.custom_config,
timeout=timeout,
temperature=temperature,
additional_headers=additional_headers,
)


Expand All @@ -77,7 +82,14 @@ def get_llm(
custom_config: dict[str, str] | None = None,
temperature: float = GEN_AI_TEMPERATURE,
timeout: int = QA_TIMEOUT,
additional_headers: dict[str, str] | None = None,
) -> LLM:
extra_headers = {}
if additional_headers:
extra_headers.update(additional_headers)
if LITELLM_EXTRA_HEADERS:
extra_headers.update(LITELLM_EXTRA_HEADERS)

return DefaultMultiLLM(
model_provider=provider,
model_name=model,
Expand All @@ -87,5 +99,5 @@ def get_llm(
timeout=timeout,
temperature=temperature,
custom_config=custom_config,
extra_headers=LITELLM_EXTRA_HEADERS,
extra_headers=extra_headers,
)
22 changes: 22 additions & 0 deletions backend/danswer/llm/headers.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,22 @@
from fastapi.datastructures import Headers

from danswer.configs.model_configs import LITELLM_PASS_THROUGH_HEADERS


def get_litellm_additional_request_headers(
headers: dict[str, str] | Headers
) -> dict[str, str]:
if not LITELLM_PASS_THROUGH_HEADERS:
return {}

pass_through_headers: dict[str, str] = {}
for key in LITELLM_PASS_THROUGH_HEADERS:
if key in headers:
pass_through_headers[key] = headers[key]
else:
# fastapi makes all header keys lowercase, handling that here
lowercase_key = key.lower()
if lowercase_key in headers:
pass_through_headers[lowercase_key] = headers[lowercase_key]

return pass_through_headers
2 changes: 1 addition & 1 deletion backend/danswer/one_shot_answer/answer_question.py
Original file line number Diff line number Diff line change
Expand Up @@ -172,7 +172,7 @@ def stream_answer_objects(
persona=chat_session.persona,
retrieval_options=query_req.retrieval_options,
prompt_config=prompt_config,
llm_config=llm.config,
llm=llm,
pruning_config=document_pruning_config,
)

Expand Down
4 changes: 4 additions & 0 deletions backend/danswer/search/pipeline.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@
from danswer.db.embedding_model import get_current_db_embedding_model
from danswer.db.models import User
from danswer.document_index.factory import get_default_document_index
from danswer.llm.interfaces import LLM
from danswer.search.enums import QueryFlow
from danswer.search.enums import SearchType
from danswer.search.models import IndexFilters
Expand Down Expand Up @@ -54,6 +55,7 @@ def __init__(
self,
search_request: SearchRequest,
user: User | None,
llm: LLM,
db_session: Session,
bypass_acl: bool = False, # NOTE: VERY DANGEROUS, USE WITH CAUTION
retrieval_metrics_callback: Callable[[RetrievalMetricsContainer], None]
Expand All @@ -62,6 +64,7 @@ def __init__(
):
self.search_request = search_request
self.user = user
self.llm = llm
self.db_session = db_session
self.bypass_acl = bypass_acl
self.retrieval_metrics_callback = retrieval_metrics_callback
Expand Down Expand Up @@ -229,6 +232,7 @@ def _run_preprocessing(self) -> None:
) = retrieval_preprocessing(
search_request=self.search_request,
user=self.user,
llm=self.llm,
db_session=self.db_session,
bypass_acl=self.bypass_acl,
)
Expand Down
6 changes: 4 additions & 2 deletions backend/danswer/search/preprocessing/preprocessing.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
from danswer.configs.chat_configs import FAVOR_RECENT_DECAY_MULTIPLIER
from danswer.configs.chat_configs import NUM_RETURNED_HITS
from danswer.db.models import User
from danswer.llm.interfaces import LLM
from danswer.search.enums import QueryFlow
from danswer.search.enums import RecencyBiasSetting
from danswer.search.models import BaseFilters
Expand All @@ -31,6 +32,7 @@
def retrieval_preprocessing(
search_request: SearchRequest,
user: User | None,
llm: LLM,
db_session: Session,
bypass_acl: bool = False,
include_query_intent: bool = True,
Expand Down Expand Up @@ -87,14 +89,14 @@ def retrieval_preprocessing(
# Based on the query figure out if we should apply any hard time filters /
# if we should bias more recent docs even more strongly
run_time_filters = (
FunctionCall(extract_time_filter, (query,), {})
FunctionCall(extract_time_filter, (query, llm), {})
if auto_detect_time_filter
else None
)

# Based on the query, figure out if we should apply any source filters
run_source_filters = (
FunctionCall(extract_source_filter, (query, db_session), {})
FunctionCall(extract_source_filter, (query, llm, db_session), {})
if auto_detect_source_filter
else None
)
Expand Down
14 changes: 5 additions & 9 deletions backend/danswer/secondary_llm_flows/source_filter.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,8 +6,7 @@
from danswer.configs.constants import DocumentSource
from danswer.db.connector import fetch_unique_document_sources
from danswer.db.engine import get_sqlalchemy_engine
from danswer.llm.exceptions import GenAIDisabledException
from danswer.llm.factory import get_default_llm
from danswer.llm.interfaces import LLM
from danswer.llm.utils import dict_based_prompt_to_langchain_prompt
from danswer.llm.utils import message_to_string
from danswer.prompts.constants import SOURCES_KEY
Expand Down Expand Up @@ -44,7 +43,7 @@ def _sample_document_sources(


def extract_source_filter(
query: str, db_session: Session
query: str, llm: LLM, db_session: Session
) -> list[DocumentSource] | None:
"""Returns a list of valid sources for search or None if no specific sources were detected"""

Expand Down Expand Up @@ -147,11 +146,6 @@ def _extract_source_filters_from_llm_out(
logger.warning("LLM failed to provide a valid Source Filter output")
return None

try:
llm = get_default_llm()
except GenAIDisabledException:
return None

valid_sources = fetch_unique_document_sources(db_session)
if not valid_sources:
return None
Expand All @@ -165,9 +159,11 @@ def _extract_source_filters_from_llm_out(


if __name__ == "__main__":
from danswer.llm.factory import get_default_llm

# Just for testing purposes
with Session(get_sqlalchemy_engine()) as db_session:
while True:
user_input = input("Query to Extract Sources: ")
sources = extract_source_filter(user_input, db_session)
sources = extract_source_filter(user_input, get_default_llm(), db_session)
print(sources)
14 changes: 5 additions & 9 deletions backend/danswer/secondary_llm_flows/time_filter.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,8 +5,7 @@

from dateutil.parser import parse

from danswer.llm.exceptions import GenAIDisabledException
from danswer.llm.factory import get_default_llm
from danswer.llm.interfaces import LLM
from danswer.llm.utils import dict_based_prompt_to_langchain_prompt
from danswer.llm.utils import message_to_string
from danswer.prompts.filter_extration import TIME_FILTER_PROMPT
Expand Down Expand Up @@ -41,7 +40,7 @@ def best_match_time(time_str: str) -> datetime | None:
return None


def extract_time_filter(query: str) -> tuple[datetime | None, bool]:
def extract_time_filter(query: str, llm: LLM) -> tuple[datetime | None, bool]:
"""Returns a datetime if a hard time filter should be applied for the given query
Additionally returns a bool, True if more recently updated Documents should be
heavily favored"""
Expand Down Expand Up @@ -147,11 +146,6 @@ def _extract_time_filter_from_llm_out(

return None, False

try:
llm = get_default_llm()
except GenAIDisabledException:
return None, False

messages = _get_time_filter_messages(query)
filled_llm_prompt = dict_based_prompt_to_langchain_prompt(messages)
model_output = message_to_string(llm.invoke(filled_llm_prompt))
Expand All @@ -162,8 +156,10 @@ def _extract_time_filter_from_llm_out(

if __name__ == "__main__":
# Just for testing purposes, too tedious to unit test as it relies on an LLM
from danswer.llm.factory import get_default_llm

while True:
user_input = input("Query to Extract Time: ")
cutoff, recency_bias = extract_time_filter(user_input)
cutoff, recency_bias = extract_time_filter(user_input, get_default_llm())
print(f"Time Cutoff: {cutoff}")
print(f"Favor Recent: {recency_bias}")
2 changes: 2 additions & 0 deletions backend/danswer/server/gpts/api.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
from sqlalchemy.orm import Session

from danswer.db.engine import get_session
from danswer.llm.factory import get_default_llm
from danswer.search.models import SearchRequest
from danswer.search.pipeline import SearchPipeline
from danswer.server.danswer_api.ingestion import api_key_dep
Expand Down Expand Up @@ -71,6 +72,7 @@ def gpt_search(
query=search_request.query,
),
user=None,
llm=get_default_llm(),
db_session=db_session,
).reranked_chunks

Expand Down
6 changes: 6 additions & 0 deletions backend/danswer/server/query_and_chat/chat_backend.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
from fastapi import APIRouter
from fastapi import Depends
from fastapi import HTTPException
from fastapi import Request
from fastapi import Response
from fastapi import UploadFile
from fastapi.responses import StreamingResponse
Expand Down Expand Up @@ -41,6 +42,7 @@
from danswer.llm.answering.prompts.citations_prompt import (
compute_max_document_tokens_for_persona,
)
from danswer.llm.headers import get_litellm_additional_request_headers
from danswer.llm.utils import get_default_llm_tokenizer
from danswer.secondary_llm_flows.chat_session_naming import (
get_renamed_conversation_name,
Expand Down Expand Up @@ -233,6 +235,7 @@ def delete_chat_session_by_id(
@router.post("/send-message")
def handle_new_chat_message(
chat_message_req: CreateChatMessageRequest,
request: Request,
user: User | None = Depends(current_user),
) -> StreamingResponse:
"""This endpoint is both used for all the following purposes:
Expand All @@ -256,6 +259,9 @@ def handle_new_chat_message(
new_msg_req=chat_message_req,
user=user,
use_existing_user_message=chat_message_req.use_existing_user_message,
litellm_additional_headers=get_litellm_additional_request_headers(
request.headers
),
)

return StreamingResponse(packets, media_type="application/json")
Expand Down
Loading

0 comments on commit b723627

Please sign in to comment.