Skip to content

Commit

Permalink
Huggingface Inference backend internal models (onyx-dot-app#265)
Browse files Browse the repository at this point in the history
  • Loading branch information
pkabra authored Aug 5, 2023
1 parent df62648 commit 0e667d3
Show file tree
Hide file tree
Showing 4 changed files with 279 additions and 2 deletions.
17 changes: 16 additions & 1 deletion backend/danswer/configs/model_configs.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,12 +38,27 @@
# - gpt4all-completion -> Due to M1 Macs not having compatible gpt4all version, please install dependency yourself
# - gpt4all-chat-completion-> Due to M1 Macs not having compatible gpt4all version, please install dependency yourself
# To use gpt4all, run: pip install --upgrade gpt4all==1.0.5
# These support HuggingFace Inference API, Inference Endpoints and servers running the text-generation-inference backend
# - huggingface-inference-completion
# - huggingface-inference-chat-completion

INTERNAL_MODEL_VERSION = os.environ.get(
"INTERNAL_MODEL_VERSION", "openai-chat-completion"
)
# For GPT4ALL, use "ggml-model-gpt4all-falcon-q4_0.bin" for the below for a tested model
GEN_AI_MODEL_VERSION = os.environ.get("GEN_AI_MODEL_VERSION", "gpt-3.5-turbo")
GEN_AI_MAX_OUTPUT_TOKENS = 512
GEN_AI_MAX_OUTPUT_TOKENS = int(os.environ.get("GEN_AI_MAX_OUTPUT_TOKENS", "512"))
# Use HuggingFace API Token for Huggingface inference client
GEN_AI_HUGGINGFACE_API_TOKEN = os.environ.get("GEN_AI_HUGGINGFACE_API_TOKEN", None)
# Use the conversational API with the huggingface-inference-chat-completion internal model
# Note - this only works with models that support conversational interfaces
GEN_AI_HUGGINGFACE_USE_CONVERSATIONAL = (
os.environ.get("GEN_AI_HUGGINGFACE_USE_CONVERSATIONAL", "").lower() == "true"
)
# Disable streaming responses. Set this to true to "polyfill" streaming for models that don't support streaming
GEN_AI_HUGGINGFACE_DISABLE_STREAM = (
os.environ.get("GEN_AI_HUGGINGFACE_DISABLE_STREAM", "").lower() == "true"
)

# Danswer custom Deep Learning Models
INTENT_MODEL_VERSION = "danswer/intent-model"
15 changes: 14 additions & 1 deletion backend/danswer/direct_qa/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,9 +4,16 @@
from openai.error import Timeout

from danswer.configs.app_configs import QA_TIMEOUT
from danswer.configs.model_configs import INTERNAL_MODEL_VERSION
from danswer.configs.model_configs import (
GEN_AI_HUGGINGFACE_API_TOKEN,
INTERNAL_MODEL_VERSION,
)
from danswer.direct_qa.exceptions import UnknownModelError
from danswer.direct_qa.interfaces import QAModel
from danswer.direct_qa.huggingface_inference import (
HuggingFaceInferenceChatCompletionQA,
HuggingFaceInferenceCompletionQA,
)
from danswer.direct_qa.open_ai import OpenAIChatCompletionQA
from danswer.direct_qa.open_ai import OpenAICompletionQA

Expand Down Expand Up @@ -44,6 +51,12 @@ def get_default_backend_qa_model(
return OpenAICompletionQA(timeout=timeout, api_key=api_key, **kwargs)
elif internal_model == "openai-chat-completion":
return OpenAIChatCompletionQA(timeout=timeout, api_key=api_key, **kwargs)
elif internal_model == "huggingface-inference-completion":
api_key = api_key if api_key is not None else GEN_AI_HUGGINGFACE_API_TOKEN
return HuggingFaceInferenceCompletionQA(api_key=api_key, **kwargs)
elif internal_model == "huggingface-inference-chat-completion":
api_key = api_key if api_key is not None else GEN_AI_HUGGINGFACE_API_TOKEN
return HuggingFaceInferenceChatCompletionQA(api_key=api_key, **kwargs)
# Note GPT4All is not supported for M1 Mac machines currently, removing until support is added
# elif internal_model == "gpt4all-completion":
# return GPT4AllCompletionQA(**kwargs)
Expand Down
248 changes: 248 additions & 0 deletions backend/danswer/direct_qa/huggingface_inference.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,248 @@
from collections.abc import Generator
from typing import Any, Iterator

from danswer.chunking.models import InferenceChunk
from danswer.configs.model_configs import (
GEN_AI_HUGGINGFACE_DISABLE_STREAM,
GEN_AI_HUGGINGFACE_USE_CONVERSATIONAL,
GEN_AI_MAX_OUTPUT_TOKENS,
)
from danswer.configs.model_configs import GEN_AI_MODEL_VERSION
from danswer.direct_qa.interfaces import DanswerAnswer
from danswer.direct_qa.interfaces import DanswerQuote
from danswer.direct_qa.interfaces import QAModel
from danswer.direct_qa.qa_prompts import (
ChatPromptProcessor,
JsonChatProcessor,
JsonProcessor,
)
from danswer.direct_qa.qa_prompts import NonChatPromptProcessor
from danswer.direct_qa.qa_utils import process_answer
from danswer.direct_qa.qa_utils import process_model_tokens
from danswer.utils.logger import setup_logger
from danswer.utils.timing import log_function_time
from huggingface_hub import InferenceClient

logger = setup_logger()


def _build_hf_inference_settings(**kwargs: Any) -> dict[str, Any]:
"""
Utility to add in some common default values so they don't have to be set every time.
"""
return {
"do_sample": False,
"seed": 69, # For reproducibility
**kwargs,
}


def _generic_chat_dialog_to_prompt_formatter(dialog: list[dict[str, str]]) -> str:
"""
Utility to convert chat dialog to a text-generation prompt for models tuned for chat.
Note - This is a "best guess" attempt at a generic completions prompt for chat
completion models. It isn't optimized for all chat trained models, but tries
to serialize to a format that most models understand.
Models like Llama2-chat have been optimized for certain formatting of chat
completions, and this function doesn't take that into account, so you won't
always get the best possible outcome.
TODO - Add the ability to pass custom formatters for chat dialogue
"""
DEFAULT_SYSTEM_PROMPT = """\
You are a helpful, respectful and honest assistant. Always answer as helpfully as possible.
If a question does not make any sense or is not factually coherent, explain why instead of answering incorrectly.
If you don't know the answer to a question, don't share false information."""
prompt = ""
if dialog[0]["role"] != "system":
dialog = [
{
"role": "system",
"content": DEFAULT_SYSTEM_PROMPT,
}
] + dialog
for message in dialog:
prompt += f"{message['role'].upper()}: {message['content']}\n"
prompt += "ASSISTANT:"
return prompt


def _mock_streaming_response(tokens: str) -> Generator[str, None, None]:
"""Utility to mock a streaming response"""
for token in tokens:
yield token


class HuggingFaceInferenceCompletionQA(QAModel):
def __init__(
self,
prompt_processor: NonChatPromptProcessor = JsonProcessor(),
model_version: str = GEN_AI_MODEL_VERSION,
max_output_tokens: int = GEN_AI_MAX_OUTPUT_TOKENS,
include_metadata: bool = False,
api_key: str | None = None,
) -> None:
self.prompt_processor = prompt_processor
self.max_output_tokens = max_output_tokens
self.include_metadata = include_metadata
self.client = InferenceClient(model=model_version, token=api_key)

@log_function_time()
def answer_question(
self, query: str, context_docs: list[InferenceChunk]
) -> tuple[DanswerAnswer, list[DanswerQuote]]:
filled_prompt = self.prompt_processor.fill_prompt(
query, context_docs, self.include_metadata
)
logger.debug(filled_prompt)
model_output = self.client.text_generation(
filled_prompt,
**_build_hf_inference_settings(max_new_tokens=self.max_output_tokens),
)
logger.debug(model_output)
answer, quotes_dict = process_answer(model_output, context_docs)
return answer, quotes_dict

def answer_question_stream(
self, query: str, context_docs: list[InferenceChunk]
) -> Generator[dict[str, Any] | None, None, None]:
filled_prompt = self.prompt_processor.fill_prompt(
query, context_docs, self.include_metadata
)
logger.debug(filled_prompt)
if not GEN_AI_HUGGINGFACE_DISABLE_STREAM:
model_stream = self.client.text_generation(
filled_prompt,
**_build_hf_inference_settings(
max_new_tokens=self.max_output_tokens, stream=True
),
)
else:
model_output = self.client.text_generation(
filled_prompt,
**_build_hf_inference_settings(max_new_tokens=self.max_output_tokens),
)
logger.debug(model_output)
model_stream = _mock_streaming_response(model_output)
yield from process_model_tokens(
tokens=model_stream,
context_docs=context_docs,
is_json_prompt=self.prompt_processor.specifies_json_output,
)


class HuggingFaceInferenceChatCompletionQA(QAModel):
def __init__(
self,
prompt_processor: ChatPromptProcessor = JsonChatProcessor(),
model_version: str = GEN_AI_MODEL_VERSION,
max_output_tokens: int = GEN_AI_MAX_OUTPUT_TOKENS,
include_metadata: bool = False,
api_key: str | None = None,
) -> None:
self.prompt_processor = prompt_processor
self.max_output_tokens = max_output_tokens
self.include_metadata = include_metadata
self.client = InferenceClient(model=model_version, token=api_key)

@staticmethod
def convert_dialog_to_conversational_format(
dialog: list[dict[str, str]]
) -> tuple[str, list[str], list[str]]:
if dialog[-1]["role"] != "user":
raise Exception(
"Last message in conversational dialog must be User message"
)
user_message = dialog[-1]["content"]
dialog = dialog[0:-1]
generated_responses = []
past_user_inputs = []
for message in dialog:
# HuggingFace inference client doesn't support system messages today
# so lumping them in with user messages
if message["role"] in ["user", "system"]:
past_user_inputs += [message["content"]]
else:
generated_responses += [message["content"]]
return user_message, generated_responses, past_user_inputs

@log_function_time()
def answer_question(
self, query: str, context_docs: list[InferenceChunk]
) -> tuple[DanswerAnswer, list[DanswerQuote]]:
filled_prompt = self.prompt_processor.fill_prompt(
query, context_docs, self.include_metadata
)
logger.debug(filled_prompt)
if GEN_AI_HUGGINGFACE_USE_CONVERSATIONAL:
(
message,
generated_responses,
past_user_inputs,
) = self.convert_dialog_to_conversational_format(filled_prompt)
model_output = self.client.conversational(
message,
generated_responses=generated_responses,
past_user_inputs=past_user_inputs,
parameters={"max_length": self.max_output_tokens},
)
else:
chat_prompt = _generic_chat_dialog_to_prompt_formatter(filled_prompt)
logger.debug(chat_prompt)
model_output = self.client.text_generation(
chat_prompt,
**_build_hf_inference_settings(max_new_tokens=self.max_output_tokens),
)
logger.debug(model_output)
answer, quotes_dict = process_answer(model_output, context_docs)
return answer, quotes_dict

def answer_question_stream(
self, query: str, context_docs: list[InferenceChunk]
) -> Generator[dict[str, Any] | None, None, None]:
filled_prompt = self.prompt_processor.fill_prompt(
query, context_docs, self.include_metadata
)
logger.debug(filled_prompt)
if not GEN_AI_HUGGINGFACE_DISABLE_STREAM:
if GEN_AI_HUGGINGFACE_USE_CONVERSATIONAL:
raise Exception(
"Conversational API is not available with streaming enabled. Please either "
+ "disable streaming, or disable using conversational API."
)
chat_prompt = _generic_chat_dialog_to_prompt_formatter(filled_prompt)
logger.debug(chat_prompt)
model_stream = self.client.text_generation(
chat_prompt,
**_build_hf_inference_settings(
max_new_tokens=self.max_output_tokens, stream=True
),
)
else:
if GEN_AI_HUGGINGFACE_USE_CONVERSATIONAL:
(
message,
generated_responses,
past_user_inputs,
) = self.convert_dialog_to_conversational_format(filled_prompt)
model_output = self.client.conversational(
message,
generated_responses=generated_responses,
past_user_inputs=past_user_inputs,
parameters={"max_length": self.max_output_tokens},
)
else:
chat_prompt = _generic_chat_dialog_to_prompt_formatter(filled_prompt)
logger.debug(chat_prompt)
model_output = self.client.text_generation(
chat_prompt,
**_build_hf_inference_settings(
max_new_tokens=self.max_output_tokens
),
)
logger.debug(model_output)
model_stream = _mock_streaming_response(model_output)
yield from process_model_tokens(
tokens=model_stream,
context_docs=context_docs,
is_json_prompt=self.prompt_processor.specifies_json_output,
)
1 change: 1 addition & 0 deletions backend/requirements/default.txt
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@ google-auth-oauthlib==1.0.0
httpcore==0.16.3
httpx==0.23.3
httpx-oauth==0.11.2
huggingface-hub==0.16.4
jira==3.5.1
Mako==1.2.4
nltk==3.8.1
Expand Down

0 comments on commit 0e667d3

Please sign in to comment.