Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

add capacity for self-hosted llama model #57

Merged
merged 2 commits into from
Oct 10, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 4 additions & 0 deletions src/.env.config.example
Original file line number Diff line number Diff line change
@@ -1,7 +1,11 @@
MODEL_PROVIDER=vertex
#MODEL_PROVIDER=openai
#MODEL_PROVIDER=hosted
MODEL_TEMPERATURE=0.0

### Hosted Model ###
HOSTED_MODEL_URI="http://somehosted-model-uri"

### Vertex AI ###
VERTEX_PROJECT_ID=shadowbot-YOURNAME
VERTEX_REGION=us-central1
Expand Down
53 changes: 53 additions & 0 deletions src/hosted_llm.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,53 @@
"""Module for handling self-hosted LLama2 models"""

from typing import Any, List, Mapping, Optional
import requests
from langchain.callbacks.manager import CallbackManagerForLLMRun
from langchain.llms.base import LLM
from langchain.schema.output_parser import BaseOutputParser


class HostedLLM(LLM):
"""
Class to define interaction with the hosted LLM at a specified URI
"""
uri: str

@property
def _llm_type(self) -> str:
return "custom"

def _call(self,
prompt: str,
stop: Optional[List[str]] = None,
run_manager: Optional[CallbackManagerForLLMRun] = None,
**kwargs: Any,
) -> str:
if stop is not None:
raise ValueError("stop kwargs are not permitted.")
response = requests.get(self.uri,
params={"text" : prompt},timeout=600)
if response.status_code == 200:
return str(response.content)
return f"Model Server is not Working due to error {response.status_code}"


@property
def _identifying_params(self) -> Mapping[str, Any]:
"""Get the identifying parameters."""
return {"uri": self.uri}

class CustomLlamaParser(BaseOutputParser[str]): # pylint: disable=R0903
"""Class to correctly parse model outputs"""

def parse(self, text:str) -> str:
"""Parse the output of our LLM"""
if text.startswith("Model Server is not Working due"):
return text
cleaned = str(text).split("[/INST]")
return cleaned[1]

@property
def _type(self) -> str:
return "custom_output_parser"

27 changes: 26 additions & 1 deletion src/v1/endpoints.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,13 +3,14 @@
from fastapi import APIRouter, Body, HTTPException
from fastapi.responses import JSONResponse
from pydantic import BaseModel # pylint: disable=E0611

from langchain.llms import OpenAI
from langchain.llms import VertexAI
from langchain.callbacks import get_openai_callback
from langchain.chains import LLMChain
from langchain.prompts import PromptTemplate

from src.hosted_llm import HostedLLM
from src.hosted_llm import CustomLlamaParser
from src.config import Config
from src.embeddings import EmbeddingSource
from src.logging_setup import setup_logger
Expand Down Expand Up @@ -42,10 +43,31 @@ def call_language_model(input_val):
result = call_openai(input_val, prompt)
elif model_provider == 'vertex':
result = call_vertexai(input_val, prompt)
elif model_provider == 'hosted':
result = call_hosted_llm(input_val, prompt)
else:
raise ValueError(f"Invalid model name: {model_provider}")
return result


def call_hosted_llm(input_val, prompt):
"""Call the hosted language model and return the result.

Args:
input_val: The input value to pass to the language model.

Returns:
The result from the language model.
"""
hosted_model_name = config.get("HOSTED_MODEL_NAME", "Llama2-Hosted")
logger.debug("Using self-hosted model: %s", hosted_model_name)
hosted_model_uri = config.get("HOSTED_MODEL_URI", None)
llm = HostedLLM(uri=hosted_model_uri)
chain = LLMChain(llm=llm, prompt=prompt, output_parser=CustomLlamaParser())
result = chain.run(input_val)
return result


def call_vertexai(input_val, prompt):
"""Call the Vertex AI language model and return the result.

Expand Down Expand Up @@ -256,6 +278,7 @@ def synthesize_response(

if prompt is None:
prompt = (
"<s>[INST] <<SYS>> \n"
"Below is the only information you know.\n"
"It was obtained by doing a vector search for the user's query:\n\n"
"---START INFO---\n\n{embedding_results}\n\n"
Expand All @@ -266,6 +289,8 @@ def synthesize_response(
"Use no other knowledge to respond. Do not make anything up. "
"You can let the reader know if do not think you have enough information "
"to respond to their query...\n\n"
"<</SYS>>"
f"{query} [/INST]"
)

prompt = prompt.format(embedding_results=embedding_results_text)
Expand Down