diff --git a/backend/danswer/configs/model_configs.py b/backend/danswer/configs/model_configs.py index 44f8a122ace..e5b92911139 100644 --- a/backend/danswer/configs/model_configs.py +++ b/backend/danswer/configs/model_configs.py @@ -38,12 +38,27 @@ # - gpt4all-completion -> Due to M1 Macs not having compatible gpt4all version, please install dependency yourself # - gpt4all-chat-completion-> Due to M1 Macs not having compatible gpt4all version, please install dependency yourself # To use gpt4all, run: pip install --upgrade gpt4all==1.0.5 +# These support HuggingFace Inference API, Inference Endpoints and servers running the text-generation-inference backend +# - huggingface-inference-completion +# - huggingface-inference-chat-completion + INTERNAL_MODEL_VERSION = os.environ.get( "INTERNAL_MODEL_VERSION", "openai-chat-completion" ) # For GPT4ALL, use "ggml-model-gpt4all-falcon-q4_0.bin" for the below for a tested model GEN_AI_MODEL_VERSION = os.environ.get("GEN_AI_MODEL_VERSION", "gpt-3.5-turbo") -GEN_AI_MAX_OUTPUT_TOKENS = 512 +GEN_AI_MAX_OUTPUT_TOKENS = int(os.environ.get("GEN_AI_MAX_OUTPUT_TOKENS", "512")) +# Use HuggingFace API Token for Huggingface inference client +GEN_AI_HUGGINGFACE_API_TOKEN = os.environ.get("GEN_AI_HUGGINGFACE_API_TOKEN", None) +# Use the conversational API with the huggingface-inference-chat-completion internal model +# Note - this only works with models that support conversational interfaces +GEN_AI_HUGGINGFACE_USE_CONVERSATIONAL = ( + os.environ.get("GEN_AI_HUGGINGFACE_USE_CONVERSATIONAL", "").lower() == "true" +) +# Disable streaming responses. Set this to true to "polyfill" streaming for models that don't support streaming +GEN_AI_HUGGINGFACE_DISABLE_STREAM = ( + os.environ.get("GEN_AI_HUGGINGFACE_DISABLE_STREAM", "").lower() == "true" +) # Danswer custom Deep Learning Models INTENT_MODEL_VERSION = "danswer/intent-model" diff --git a/backend/danswer/direct_qa/__init__.py b/backend/danswer/direct_qa/__init__.py index 2ffe84cebdf..0f1c9ddefdb 100644 --- a/backend/danswer/direct_qa/__init__.py +++ b/backend/danswer/direct_qa/__init__.py @@ -4,9 +4,16 @@ from openai.error import Timeout from danswer.configs.app_configs import QA_TIMEOUT -from danswer.configs.model_configs import INTERNAL_MODEL_VERSION +from danswer.configs.model_configs import ( + GEN_AI_HUGGINGFACE_API_TOKEN, + INTERNAL_MODEL_VERSION, +) from danswer.direct_qa.exceptions import UnknownModelError from danswer.direct_qa.interfaces import QAModel +from danswer.direct_qa.huggingface_inference import ( + HuggingFaceInferenceChatCompletionQA, + HuggingFaceInferenceCompletionQA, +) from danswer.direct_qa.open_ai import OpenAIChatCompletionQA from danswer.direct_qa.open_ai import OpenAICompletionQA @@ -44,6 +51,12 @@ def get_default_backend_qa_model( return OpenAICompletionQA(timeout=timeout, api_key=api_key, **kwargs) elif internal_model == "openai-chat-completion": return OpenAIChatCompletionQA(timeout=timeout, api_key=api_key, **kwargs) + elif internal_model == "huggingface-inference-completion": + api_key = api_key if api_key is not None else GEN_AI_HUGGINGFACE_API_TOKEN + return HuggingFaceInferenceCompletionQA(api_key=api_key, **kwargs) + elif internal_model == "huggingface-inference-chat-completion": + api_key = api_key if api_key is not None else GEN_AI_HUGGINGFACE_API_TOKEN + return HuggingFaceInferenceChatCompletionQA(api_key=api_key, **kwargs) # Note GPT4All is not supported for M1 Mac machines currently, removing until support is added # elif internal_model == "gpt4all-completion": # return GPT4AllCompletionQA(**kwargs) diff --git a/backend/danswer/direct_qa/huggingface_inference.py b/backend/danswer/direct_qa/huggingface_inference.py new file mode 100644 index 00000000000..7b9993586e9 --- /dev/null +++ b/backend/danswer/direct_qa/huggingface_inference.py @@ -0,0 +1,248 @@ +from collections.abc import Generator +from typing import Any, Iterator + +from danswer.chunking.models import InferenceChunk +from danswer.configs.model_configs import ( + GEN_AI_HUGGINGFACE_DISABLE_STREAM, + GEN_AI_HUGGINGFACE_USE_CONVERSATIONAL, + GEN_AI_MAX_OUTPUT_TOKENS, +) +from danswer.configs.model_configs import GEN_AI_MODEL_VERSION +from danswer.direct_qa.interfaces import DanswerAnswer +from danswer.direct_qa.interfaces import DanswerQuote +from danswer.direct_qa.interfaces import QAModel +from danswer.direct_qa.qa_prompts import ( + ChatPromptProcessor, + JsonChatProcessor, + JsonProcessor, +) +from danswer.direct_qa.qa_prompts import NonChatPromptProcessor +from danswer.direct_qa.qa_utils import process_answer +from danswer.direct_qa.qa_utils import process_model_tokens +from danswer.utils.logger import setup_logger +from danswer.utils.timing import log_function_time +from huggingface_hub import InferenceClient + +logger = setup_logger() + + +def _build_hf_inference_settings(**kwargs: Any) -> dict[str, Any]: + """ + Utility to add in some common default values so they don't have to be set every time. + """ + return { + "do_sample": False, + "seed": 69, # For reproducibility + **kwargs, + } + + +def _generic_chat_dialog_to_prompt_formatter(dialog: list[dict[str, str]]) -> str: + """ + Utility to convert chat dialog to a text-generation prompt for models tuned for chat. + Note - This is a "best guess" attempt at a generic completions prompt for chat + completion models. It isn't optimized for all chat trained models, but tries + to serialize to a format that most models understand. + Models like Llama2-chat have been optimized for certain formatting of chat + completions, and this function doesn't take that into account, so you won't + always get the best possible outcome. + TODO - Add the ability to pass custom formatters for chat dialogue + """ + DEFAULT_SYSTEM_PROMPT = """\ + You are a helpful, respectful and honest assistant. Always answer as helpfully as possible. + If a question does not make any sense or is not factually coherent, explain why instead of answering incorrectly. + If you don't know the answer to a question, don't share false information.""" + prompt = "" + if dialog[0]["role"] != "system": + dialog = [ + { + "role": "system", + "content": DEFAULT_SYSTEM_PROMPT, + } + ] + dialog + for message in dialog: + prompt += f"{message['role'].upper()}: {message['content']}\n" + prompt += "ASSISTANT:" + return prompt + + +def _mock_streaming_response(tokens: str) -> Generator[str, None, None]: + """Utility to mock a streaming response""" + for token in tokens: + yield token + + +class HuggingFaceInferenceCompletionQA(QAModel): + def __init__( + self, + prompt_processor: NonChatPromptProcessor = JsonProcessor(), + model_version: str = GEN_AI_MODEL_VERSION, + max_output_tokens: int = GEN_AI_MAX_OUTPUT_TOKENS, + include_metadata: bool = False, + api_key: str | None = None, + ) -> None: + self.prompt_processor = prompt_processor + self.max_output_tokens = max_output_tokens + self.include_metadata = include_metadata + self.client = InferenceClient(model=model_version, token=api_key) + + @log_function_time() + def answer_question( + self, query: str, context_docs: list[InferenceChunk] + ) -> tuple[DanswerAnswer, list[DanswerQuote]]: + filled_prompt = self.prompt_processor.fill_prompt( + query, context_docs, self.include_metadata + ) + logger.debug(filled_prompt) + model_output = self.client.text_generation( + filled_prompt, + **_build_hf_inference_settings(max_new_tokens=self.max_output_tokens), + ) + logger.debug(model_output) + answer, quotes_dict = process_answer(model_output, context_docs) + return answer, quotes_dict + + def answer_question_stream( + self, query: str, context_docs: list[InferenceChunk] + ) -> Generator[dict[str, Any] | None, None, None]: + filled_prompt = self.prompt_processor.fill_prompt( + query, context_docs, self.include_metadata + ) + logger.debug(filled_prompt) + if not GEN_AI_HUGGINGFACE_DISABLE_STREAM: + model_stream = self.client.text_generation( + filled_prompt, + **_build_hf_inference_settings( + max_new_tokens=self.max_output_tokens, stream=True + ), + ) + else: + model_output = self.client.text_generation( + filled_prompt, + **_build_hf_inference_settings(max_new_tokens=self.max_output_tokens), + ) + logger.debug(model_output) + model_stream = _mock_streaming_response(model_output) + yield from process_model_tokens( + tokens=model_stream, + context_docs=context_docs, + is_json_prompt=self.prompt_processor.specifies_json_output, + ) + + +class HuggingFaceInferenceChatCompletionQA(QAModel): + def __init__( + self, + prompt_processor: ChatPromptProcessor = JsonChatProcessor(), + model_version: str = GEN_AI_MODEL_VERSION, + max_output_tokens: int = GEN_AI_MAX_OUTPUT_TOKENS, + include_metadata: bool = False, + api_key: str | None = None, + ) -> None: + self.prompt_processor = prompt_processor + self.max_output_tokens = max_output_tokens + self.include_metadata = include_metadata + self.client = InferenceClient(model=model_version, token=api_key) + + @staticmethod + def convert_dialog_to_conversational_format( + dialog: list[dict[str, str]] + ) -> tuple[str, list[str], list[str]]: + if dialog[-1]["role"] != "user": + raise Exception( + "Last message in conversational dialog must be User message" + ) + user_message = dialog[-1]["content"] + dialog = dialog[0:-1] + generated_responses = [] + past_user_inputs = [] + for message in dialog: + # HuggingFace inference client doesn't support system messages today + # so lumping them in with user messages + if message["role"] in ["user", "system"]: + past_user_inputs += [message["content"]] + else: + generated_responses += [message["content"]] + return user_message, generated_responses, past_user_inputs + + @log_function_time() + def answer_question( + self, query: str, context_docs: list[InferenceChunk] + ) -> tuple[DanswerAnswer, list[DanswerQuote]]: + filled_prompt = self.prompt_processor.fill_prompt( + query, context_docs, self.include_metadata + ) + logger.debug(filled_prompt) + if GEN_AI_HUGGINGFACE_USE_CONVERSATIONAL: + ( + message, + generated_responses, + past_user_inputs, + ) = self.convert_dialog_to_conversational_format(filled_prompt) + model_output = self.client.conversational( + message, + generated_responses=generated_responses, + past_user_inputs=past_user_inputs, + parameters={"max_length": self.max_output_tokens}, + ) + else: + chat_prompt = _generic_chat_dialog_to_prompt_formatter(filled_prompt) + logger.debug(chat_prompt) + model_output = self.client.text_generation( + chat_prompt, + **_build_hf_inference_settings(max_new_tokens=self.max_output_tokens), + ) + logger.debug(model_output) + answer, quotes_dict = process_answer(model_output, context_docs) + return answer, quotes_dict + + def answer_question_stream( + self, query: str, context_docs: list[InferenceChunk] + ) -> Generator[dict[str, Any] | None, None, None]: + filled_prompt = self.prompt_processor.fill_prompt( + query, context_docs, self.include_metadata + ) + logger.debug(filled_prompt) + if not GEN_AI_HUGGINGFACE_DISABLE_STREAM: + if GEN_AI_HUGGINGFACE_USE_CONVERSATIONAL: + raise Exception( + "Conversational API is not available with streaming enabled. Please either " + + "disable streaming, or disable using conversational API." + ) + chat_prompt = _generic_chat_dialog_to_prompt_formatter(filled_prompt) + logger.debug(chat_prompt) + model_stream = self.client.text_generation( + chat_prompt, + **_build_hf_inference_settings( + max_new_tokens=self.max_output_tokens, stream=True + ), + ) + else: + if GEN_AI_HUGGINGFACE_USE_CONVERSATIONAL: + ( + message, + generated_responses, + past_user_inputs, + ) = self.convert_dialog_to_conversational_format(filled_prompt) + model_output = self.client.conversational( + message, + generated_responses=generated_responses, + past_user_inputs=past_user_inputs, + parameters={"max_length": self.max_output_tokens}, + ) + else: + chat_prompt = _generic_chat_dialog_to_prompt_formatter(filled_prompt) + logger.debug(chat_prompt) + model_output = self.client.text_generation( + chat_prompt, + **_build_hf_inference_settings( + max_new_tokens=self.max_output_tokens + ), + ) + logger.debug(model_output) + model_stream = _mock_streaming_response(model_output) + yield from process_model_tokens( + tokens=model_stream, + context_docs=context_docs, + is_json_prompt=self.prompt_processor.specifies_json_output, + ) diff --git a/backend/requirements/default.txt b/backend/requirements/default.txt index 80cf7f52cae..461613be3ef 100644 --- a/backend/requirements/default.txt +++ b/backend/requirements/default.txt @@ -16,6 +16,7 @@ google-auth-oauthlib==1.0.0 httpcore==0.16.3 httpx==0.23.3 httpx-oauth==0.11.2 +huggingface-hub==0.16.4 jira==3.5.1 Mako==1.2.4 nltk==3.8.1