Skip to content

Commit

Permalink
Auto-Detect if Better Default LLM available for OpenAI (onyx-dot-app#…
Browse files Browse the repository at this point in the history
  • Loading branch information
yuhongsun96 authored Feb 22, 2024
1 parent 918bc38 commit 7748f4d
Show file tree
Hide file tree
Showing 14 changed files with 134 additions and 73 deletions.
6 changes: 3 additions & 3 deletions backend/danswer/chat/chat_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,14 +17,14 @@
from danswer.configs.chat_configs import STOP_STREAM_PAT
from danswer.configs.constants import DocumentSource
from danswer.configs.constants import IGNORE_FOR_QA
from danswer.configs.model_configs import GEN_AI_MODEL_VERSION
from danswer.configs.model_configs import GEN_AI_SINGLE_USER_MESSAGE_EXPECTED_MAX_TOKENS
from danswer.db.chat import get_chat_messages_by_session
from danswer.db.models import ChatMessage
from danswer.db.models import Persona
from danswer.db.models import Prompt
from danswer.indexing.models import InferenceChunk
from danswer.llm.utils import check_number_of_tokens
from danswer.llm.utils import get_default_llm_version
from danswer.llm.utils import get_max_input_tokens
from danswer.prompts.chat_prompts import CHAT_USER_CONTEXT_FREE_PROMPT
from danswer.prompts.chat_prompts import CHAT_USER_PROMPT
Expand Down Expand Up @@ -577,7 +577,7 @@ def compute_max_document_tokens(
if we're trying to determine if the user should be able to select another document) then we just set an
arbitrary "upper bound".
"""
llm_name = GEN_AI_MODEL_VERSION
llm_name = get_default_llm_version()[0]
if persona.llm_model_version_override:
llm_name = persona.llm_model_version_override

Expand All @@ -603,7 +603,7 @@ def compute_max_document_tokens(

def compute_max_llm_input_tokens(persona: Persona) -> int:
"""Maximum tokens allows in the input to the LLM (of any type)."""
llm_name = GEN_AI_MODEL_VERSION
llm_name = get_default_llm_version()[0]
if persona.llm_model_version_override:
llm_name = persona.llm_model_version_override

Expand Down
4 changes: 2 additions & 2 deletions backend/danswer/chat/process_message.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,6 @@
from danswer.configs.constants import DISABLED_GEN_AI_MSG
from danswer.configs.constants import MessageType
from danswer.configs.model_configs import CHUNK_SIZE
from danswer.configs.model_configs import GEN_AI_MODEL_VERSION
from danswer.db.chat import create_db_search_doc
from danswer.db.chat import create_new_chat_message
from danswer.db.chat import get_chat_message
Expand All @@ -48,6 +47,7 @@
from danswer.llm.factory import get_default_llm
from danswer.llm.interfaces import LLM
from danswer.llm.utils import get_default_llm_tokenizer
from danswer.llm.utils import get_default_llm_version
from danswer.llm.utils import get_max_input_tokens
from danswer.llm.utils import tokenizer_trim_content
from danswer.llm.utils import translate_history_to_basemessages
Expand Down Expand Up @@ -445,7 +445,7 @@ def stream_chat_message_objects(
else default_num_chunks
)

llm_name = GEN_AI_MODEL_VERSION
llm_name = get_default_llm_version()[0]
if persona.llm_model_version_override:
llm_name = persona.llm_model_version_override

Expand Down
4 changes: 4 additions & 0 deletions backend/danswer/configs/constants.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,6 +48,10 @@
INDEX_SEPARATOR = "==="


# Key-Value store constants
GEN_AI_DETECTED_MODEL = "gen_ai_detected_model"


# Messages
DISABLED_GEN_AI_MSG = (
"Your System Admin has disabled the Generative AI functionalities of Danswer.\n"
Expand Down
9 changes: 3 additions & 6 deletions backend/danswer/configs/model_configs.py
Original file line number Diff line number Diff line change
Expand Up @@ -83,14 +83,11 @@
# Set GEN_AI_MODEL_PROVIDER to "gpt4all" to use gpt4all models running locally
GEN_AI_MODEL_PROVIDER = os.environ.get("GEN_AI_MODEL_PROVIDER") or "openai"
# If using Azure, it's the engine name, for example: Danswer
GEN_AI_MODEL_VERSION = (
os.environ.get("GEN_AI_MODEL_VERSION") or "gpt-3.5-turbo-16k-0613"
)
GEN_AI_MODEL_VERSION = os.environ.get("GEN_AI_MODEL_VERSION")

# For secondary flows like extracting filters or deciding if a chunk is useful, we don't need
# as powerful of a model as say GPT-4 so we can use an alternative that is faster and cheaper
FAST_GEN_AI_MODEL_VERSION = (
os.environ.get("FAST_GEN_AI_MODEL_VERSION") or GEN_AI_MODEL_VERSION
)
FAST_GEN_AI_MODEL_VERSION = os.environ.get("FAST_GEN_AI_MODEL_VERSION")

# If the Generative AI model requires an API key for access, otherwise can leave blank
GEN_AI_API_KEY = (
Expand Down
4 changes: 2 additions & 2 deletions backend/danswer/danswerbot/slack/handlers/handle_message.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,6 @@
from danswer.configs.danswerbot_configs import DANSWER_REACT_EMOJI
from danswer.configs.danswerbot_configs import DISABLE_DANSWER_BOT_FILTER_DETECT
from danswer.configs.danswerbot_configs import ENABLE_DANSWERBOT_REFLEXION
from danswer.configs.model_configs import GEN_AI_MODEL_VERSION
from danswer.danswerbot.slack.blocks import build_documents_blocks
from danswer.danswerbot.slack.blocks import build_follow_up_block
from danswer.danswerbot.slack.blocks import build_qa_response_blocks
Expand All @@ -37,6 +36,7 @@
from danswer.db.engine import get_sqlalchemy_engine
from danswer.db.models import SlackBotConfig
from danswer.llm.utils import check_number_of_tokens
from danswer.llm.utils import get_default_llm_version
from danswer.llm.utils import get_max_input_tokens
from danswer.one_shot_answer.answer_question import get_search_answer
from danswer.one_shot_answer.models import DirectQARequest
Expand Down Expand Up @@ -224,7 +224,7 @@ def _get_answer(new_message_request: DirectQARequest) -> OneShotQAResponse:
max_document_tokens: int | None = None
max_history_tokens: int | None = None
if len(new_message_request.messages) > 1:
llm_name = GEN_AI_MODEL_VERSION
llm_name = get_default_llm_version()[0]
if persona and persona.llm_model_version_override:
llm_name = persona.llm_model_version_override

Expand Down
8 changes: 6 additions & 2 deletions backend/danswer/llm/chat_llm.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
from danswer.configs.model_configs import GEN_AI_MODEL_VERSION
from danswer.configs.model_configs import GEN_AI_TEMPERATURE
from danswer.llm.interfaces import LLM
from danswer.llm.utils import get_default_llm_version
from danswer.llm.utils import message_generator_to_string_generator
from danswer.llm.utils import should_be_verbose
from danswer.utils.logger import setup_logger
Expand Down Expand Up @@ -92,7 +93,8 @@ def _get_model_str(
return model_version

# User specified something wrong, just use Danswer default
return GEN_AI_MODEL_VERSION
base, _ = get_default_llm_version()
return base


class DefaultMultiLLM(LangChainChatLLM):
Expand All @@ -109,7 +111,7 @@ def __init__(
api_key: str | None,
timeout: int,
model_provider: str = GEN_AI_MODEL_PROVIDER,
model_version: str = GEN_AI_MODEL_VERSION,
model_version: str | None = GEN_AI_MODEL_VERSION,
api_base: str | None = GEN_AI_API_ENDPOINT,
api_version: str | None = GEN_AI_API_VERSION,
custom_llm_provider: str | None = GEN_AI_LLM_PROVIDER_TYPE,
Expand All @@ -121,6 +123,8 @@ def __init__(
litellm.api_key = api_key or "dummy-key"
litellm.api_version = api_version

model_version = model_version or get_default_llm_version()[0]

self._llm = ChatLiteLLM( # type: ignore
model=model_version
if custom_llm_provider
Expand Down
8 changes: 3 additions & 5 deletions backend/danswer/llm/factory.py
Original file line number Diff line number Diff line change
@@ -1,13 +1,12 @@
from danswer.configs.app_configs import DISABLE_GENERATIVE_AI
from danswer.configs.chat_configs import QA_TIMEOUT
from danswer.configs.model_configs import FAST_GEN_AI_MODEL_VERSION
from danswer.configs.model_configs import GEN_AI_MODEL_PROVIDER
from danswer.configs.model_configs import GEN_AI_MODEL_VERSION
from danswer.llm.chat_llm import DefaultMultiLLM
from danswer.llm.custom_llm import CustomModelServer
from danswer.llm.exceptions import GenAIDisabledException
from danswer.llm.gpt_4_all import DanswerGPT4All
from danswer.llm.interfaces import LLM
from danswer.llm.utils import get_default_llm_version
from danswer.llm.utils import get_gen_ai_api_key


Expand All @@ -26,9 +25,8 @@ def get_default_llm(
if gen_ai_model_version_override:
model_version = gen_ai_model_version_override
else:
model_version = (
FAST_GEN_AI_MODEL_VERSION if use_fast_llm else GEN_AI_MODEL_VERSION
)
base, fast = get_default_llm_version()
model_version = fast if use_fast_llm else base
if api_key is None:
api_key = get_gen_ai_api_key()

Expand Down
5 changes: 3 additions & 2 deletions backend/danswer/llm/gpt_4_all.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
from danswer.configs.model_configs import GEN_AI_TEMPERATURE
from danswer.llm.interfaces import LLM
from danswer.llm.utils import convert_lm_input_to_basic_string
from danswer.llm.utils import get_default_llm_version
from danswer.utils.logger import setup_logger


Expand Down Expand Up @@ -52,14 +53,14 @@ def requires_api_key(self) -> bool:
def __init__(
self,
timeout: int,
model_version: str = GEN_AI_MODEL_VERSION,
model_version: str | None = GEN_AI_MODEL_VERSION,
max_output_tokens: int = GEN_AI_MAX_OUTPUT_TOKENS,
temperature: float = GEN_AI_TEMPERATURE,
):
self.timeout = timeout
self.max_output_tokens = max_output_tokens
self.temperature = temperature
self.gpt4all_model = GPT4All(model_version)
self.gpt4all_model = GPT4All(model_version or get_default_llm_version()[0])

def log_model_configs(self) -> None:
logger.debug(
Expand Down
35 changes: 33 additions & 2 deletions backend/danswer/llm/utils.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
from collections.abc import Callable
from collections.abc import Iterator
from copy import copy
from functools import lru_cache
from typing import Any
from typing import cast

Expand All @@ -19,8 +20,10 @@

from danswer.configs.app_configs import LOG_LEVEL
from danswer.configs.constants import GEN_AI_API_KEY_STORAGE_KEY
from danswer.configs.constants import GEN_AI_DETECTED_MODEL
from danswer.configs.constants import MessageType
from danswer.configs.model_configs import DOC_EMBEDDING_CONTEXT_SIZE
from danswer.configs.model_configs import FAST_GEN_AI_MODEL_VERSION
from danswer.configs.model_configs import GEN_AI_API_KEY
from danswer.configs.model_configs import GEN_AI_MAX_OUTPUT_TOKENS
from danswer.configs.model_configs import GEN_AI_MAX_TOKENS
Expand All @@ -39,6 +42,34 @@
_LLM_TOKENIZER_ENCODE: Callable[[str], Any] | None = None


@lru_cache()
def get_default_llm_version() -> tuple[str, str]:
default_openai_model = "gpt-3.5-turbo-16k-0613"
if GEN_AI_MODEL_VERSION:
llm_version = GEN_AI_MODEL_VERSION
else:
if GEN_AI_MODEL_PROVIDER != "openai":
logger.warning("No LLM Model Version set")
# Either this value is unused or it will throw an error
llm_version = default_openai_model
else:
kv_store = get_dynamic_config_store()
try:
llm_version = cast(str, kv_store.load(GEN_AI_DETECTED_MODEL))
except ConfigNotFoundError:
llm_version = default_openai_model

if FAST_GEN_AI_MODEL_VERSION:
fast_llm_version = FAST_GEN_AI_MODEL_VERSION
else:
if GEN_AI_MODEL_PROVIDER == "openai":
fast_llm_version = default_openai_model
else:
fast_llm_version = llm_version

return llm_version, fast_llm_version


def get_default_llm_tokenizer() -> Encoding:
"""Currently only supports the OpenAI default tokenizer: tiktoken"""
global _LLM_TOKENIZER
Expand Down Expand Up @@ -215,8 +246,7 @@ def get_llm_max_tokens(
# This is an override, so always return this
return GEN_AI_MAX_TOKENS

if not model_name:
return 4096
model_name = model_name or get_default_llm_version()[0]

try:
if model_provider == "openai":
Expand All @@ -231,6 +261,7 @@ def get_max_input_tokens(
model_provider: str = GEN_AI_MODEL_PROVIDER,
output_tokens: int = GEN_AI_MAX_OUTPUT_TOKENS,
) -> int:
model_name = model_name or get_default_llm_version()[0]
input_toks = (
get_llm_max_tokens(model_name=model_name, model_provider=model_provider)
- output_tokens
Expand Down
12 changes: 5 additions & 7 deletions backend/danswer/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,10 +35,8 @@
from danswer.configs.chat_configs import MULTILINGUAL_QUERY_EXPANSION
from danswer.configs.constants import AuthType
from danswer.configs.model_configs import ENABLE_RERANKING_REAL_TIME_FLOW
from danswer.configs.model_configs import FAST_GEN_AI_MODEL_VERSION
from danswer.configs.model_configs import GEN_AI_API_ENDPOINT
from danswer.configs.model_configs import GEN_AI_MODEL_PROVIDER
from danswer.configs.model_configs import GEN_AI_MODEL_VERSION
from danswer.db.chat import delete_old_default_personas
from danswer.db.connector import create_initial_default_connector
from danswer.db.connector_credential_pair import associate_default_cc_pair
Expand All @@ -52,6 +50,7 @@
from danswer.db.index_attempt import expire_index_attempts
from danswer.document_index.factory import get_default_document_index
from danswer.llm.factory import get_default_llm
from danswer.llm.utils import get_default_llm_version
from danswer.search.search_nlp_models import warm_up_models
from danswer.server.danswer_api.ingestion import get_danswer_api_key
from danswer.server.danswer_api.ingestion import router as danswer_api_router
Expand Down Expand Up @@ -234,11 +233,10 @@ def startup_event() -> None:
logger.info("Generative AI Q&A disabled")
else:
logger.info(f"Using LLM Provider: {GEN_AI_MODEL_PROVIDER}")
logger.info(f"Using LLM Model Version: {GEN_AI_MODEL_VERSION}")
if GEN_AI_MODEL_VERSION != FAST_GEN_AI_MODEL_VERSION:
logger.info(
f"Using Fast LLM Model Version: {FAST_GEN_AI_MODEL_VERSION}"
)
base, fast = get_default_llm_version()
logger.info(f"Using LLM Model Version: {base}")
if base != fast:
logger.info(f"Using Fast LLM Model Version: {fast}")
if GEN_AI_API_ENDPOINT:
logger.info(f"Using LLM Endpoint: {GEN_AI_API_ENDPOINT}")

Expand Down
4 changes: 2 additions & 2 deletions backend/danswer/server/features/persona/api.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,6 @@
from danswer.auth.users import current_admin_user
from danswer.auth.users import current_user
from danswer.configs.model_configs import GEN_AI_MODEL_PROVIDER
from danswer.configs.model_configs import GEN_AI_MODEL_VERSION
from danswer.db.chat import get_persona_by_id
from danswer.db.chat import get_personas
from danswer.db.chat import get_prompts_by_ids
Expand All @@ -18,6 +17,7 @@
from danswer.db.document_set import get_document_sets_by_ids
from danswer.db.engine import get_session
from danswer.db.models import User
from danswer.llm.utils import get_default_llm_version
from danswer.one_shot_answer.qa_block import build_dummy_prompt
from danswer.server.features.persona.models import CreatePersonaRequest
from danswer.server.features.persona.models import PersonaSnapshot
Expand Down Expand Up @@ -239,4 +239,4 @@ def get_default_model(
if GEN_AI_MODEL_PROVIDER != "openai":
return ""

return GEN_AI_MODEL_VERSION
return get_default_llm_version()[0]
Loading

0 comments on commit 7748f4d

Please sign in to comment.