diff --git a/arch/arch_config_schema.yaml b/arch/arch_config_schema.yaml index c8f7259f..5055f682 100644 --- a/arch/arch_config_schema.yaml +++ b/arch/arch_config_schema.yaml @@ -67,6 +67,8 @@ properties: type: boolean description: type: string + auto_llm_dispatch_on_response: + type: boolean parameters: type: array items: diff --git a/arch/src/stream_context.rs b/arch/src/stream_context.rs index c5f2d959..97f495a7 100644 --- a/arch/src/stream_context.rs +++ b/arch/src/stream_context.rs @@ -39,6 +39,7 @@ enum ResponseHandlerType { FunctionCall, ZeroShotIntent, ArchGuard, + DefaultTarget, } pub struct CallContext { @@ -179,12 +180,16 @@ impl StreamContext { let prompt_target_names = prompt_targets .iter() + // exclude default target + .filter(|(_, prompt_target)| !prompt_target.default.unwrap_or(false)) .map(|(name, _)| name.clone()) .collect(); let similarity_scores: Vec<(String, f64)> = prompt_targets .iter() - .map(|(prompt_name, _prompt_target)| { + // exclude default prompt target + .filter(|(_, prompt_target)| !prompt_target.default.unwrap_or(false)) + .map(|(prompt_name, _)| { let default_embeddings = HashMap::new(); let pte = prompt_target_embeddings .get(prompt_name) @@ -331,34 +336,84 @@ impl StreamContext { // check to ensure that the prompt target similarity score is above the threshold if prompt_target_similarity_score < prompt_target_intent_matching_threshold - && !arch_assistant + || arch_assistant { + debug!("intent score is low or arch assistant is handling the conversation"); // if arch fc responded to the user message, then we don't need to check the similarity score // it may be that arch fc is handling the conversation for parameter collection if arch_assistant { info!("arch assistant is handling the conversation"); } else { - info!( - "prompt target below limit: {:.3}, threshold: {:.3}, continue conversation with user", - prompt_target_similarity_score, - prompt_target_intent_matching_threshold - ); + debug!("checking for default prompt target"); + if let Some(default_prompt_target) = self + .prompt_targets + .read() + .unwrap() + .values() + .find(|pt| pt.default.unwrap_or(false)) + { + debug!("default prompt target found"); + let endpoint = default_prompt_target.endpoint.clone().unwrap(); + let upstream_path: String = endpoint.path.unwrap_or(String::from("/")); + + let upstream_endpoint = endpoint.name; + let mut params = HashMap::new(); + params.insert( + ARCH_MESSAGES_KEY.to_string(), + callout_context.request_body.messages.clone(), + ); + let arch_messages_json = serde_json::to_string(¶ms).unwrap(); + debug!("no prompt target found with similarity score above threshold, using default prompt target"); + let token_id = match self.dispatch_http_call( + &upstream_endpoint, + vec![ + (":method", "POST"), + (":path", &upstream_path), + (":authority", &upstream_endpoint), + ("content-type", "application/json"), + ("x-envoy-max-retries", "3"), + ( + "x-envoy-upstream-rq-timeout-ms", + ARCH_FC_REQUEST_TIMEOUT_MS.to_string().as_str(), + ), + ], + Some(arch_messages_json.as_bytes()), + vec![], + Duration::from_secs(5), + ) { + Ok(token_id) => token_id, + Err(e) => { + let error_msg = + format!("Error dispatching HTTP call for default-target: {:?}", e); + return self + .send_server_error(error_msg, Some(StatusCode::BAD_REQUEST)); + } + }; + + self.metrics.active_http_calls.increment(1); + callout_context.response_handler_type = ResponseHandlerType::DefaultTarget; + callout_context.prompt_target_name = Some(default_prompt_target.name.clone()); + if self.callouts.insert(token_id, callout_context).is_some() { + panic!("duplicate token_id") + } + return; + } self.resume_http_request(); return; } } - let prompt_target = self - .prompt_targets - .read() - .unwrap() - .get(&prompt_target_name) - .unwrap() - .clone(); + let prompt_target = match self.prompt_targets.read().unwrap().get(&prompt_target_name) { + Some(prompt_target) => prompt_target.clone(), + None => { + return self.send_server_error( + format!("Prompt target not found: {}", prompt_target_name), + None, + ); + } + }; info!("prompt_target name: {:?}", prompt_target_name); - - //TODO: handle default function resolver type let mut chat_completion_tools: Vec = Vec::new(); for pt in self.prompt_targets.read().unwrap().values() { // only extract entity names @@ -761,6 +816,83 @@ impl StreamContext { ) } } + + fn default_target_handler(&self, body: Vec, callout_context: CallContext) { + let prompt_target = self + .prompt_targets + .read() + .unwrap() + .get(callout_context.prompt_target_name.as_ref().unwrap()) + .unwrap() + .clone(); + debug!( + "response received for default target: {}", + prompt_target.name + ); + // check if the default target should be dispatched to the LLM provider + if !prompt_target.auto_llm_dispatch_on_response.unwrap_or(false) { + let default_target_response_str = String::from_utf8(body).unwrap(); + debug!( + "sending response back to developer: {}", + default_target_response_str + ); + self.send_http_response( + StatusCode::OK.as_u16().into(), + vec![("Powered-By", "Katanemo")], + Some(default_target_response_str.as_bytes()), + ); + // self.resume_http_request(); + return; + } + debug!("default_target: sending api response to default llm"); + let chat_completions_resp: ChatCompletionsResponse = match serde_json::from_slice(&body) { + Ok(chat_completions_resp) => chat_completions_resp, + Err(e) => { + return self.send_server_error( + format!("Error deserializing default target response: {:?}", e), + None, + ); + } + }; + let api_resp = chat_completions_resp.choices[0] + .message + .content + .as_ref() + .unwrap(); + let mut messages = callout_context.request_body.messages; + + // add system prompt + match prompt_target.system_prompt.as_ref() { + None => {} + Some(system_prompt) => { + let system_prompt_message = Message { + role: SYSTEM_ROLE.to_string(), + content: Some(system_prompt.clone()), + model: None, + tool_calls: None, + }; + messages.push(system_prompt_message); + } + } + + messages.push(Message { + role: USER_ROLE.to_string(), + content: Some(api_resp.clone()), + model: None, + tool_calls: None, + }); + let chat_completion_request = ChatCompletionsRequest { + model: GPT_35_TURBO.to_string(), + messages, + tools: None, + stream: callout_context.request_body.stream, + stream_options: callout_context.request_body.stream_options, + }; + let json_resp = serde_json::to_string(&chat_completion_request).unwrap(); + debug!("sending response back to default llm: {}", json_resp); + self.set_http_request_body(0, json_resp.len(), json_resp.as_bytes()); + self.resume_http_request(); + } } // HttpContext is the trait that allows the Rust code to interact with HTTP objects. @@ -1067,6 +1199,9 @@ impl Context for StreamContext { self.function_call_response_handler(body, callout_context) } ResponseHandlerType::ArchGuard => self.arch_guard_handler(body, callout_context), + ResponseHandlerType::DefaultTarget => { + self.default_target_handler(body, callout_context) + } } } else { self.send_server_error( diff --git a/demos/function_calling/api_server/Dockerfile b/demos/function_calling/api_server/Dockerfile index abd21357..ae377412 100644 --- a/demos/function_calling/api_server/Dockerfile +++ b/demos/function_calling/api_server/Dockerfile @@ -16,4 +16,4 @@ COPY --from=builder /runtime /usr/local COPY /app /app WORKDIR /app -CMD ["uvicorn", "main:app", "--host", "0.0.0.0", "--port", "80"] +CMD ["uvicorn", "main:app", "--host", "0.0.0.0", "--port", "80", "--log-level", "info"] diff --git a/demos/function_calling/api_server/app/main.py b/demos/function_calling/api_server/app/main.py index a2d03853..34dcac4f 100644 --- a/demos/function_calling/api_server/app/main.py +++ b/demos/function_calling/api_server/app/main.py @@ -1,12 +1,13 @@ +import json import random from fastapi import FastAPI, Response from datetime import datetime, date, timedelta, timezone import logging from pydantic import BaseModel -import pytz -logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s') -logger = logging.getLogger(__name__) + +logger = logging.getLogger('uvicorn.error') +logger.setLevel(logging.INFO) app = FastAPI() @@ -58,18 +59,28 @@ async def insurance_claim_details(req: InsuranceClaimDetailsRequest, res: Respon return claim_details -@app.get("/current_time") -async def current_time(timezone: str): - tz = None - try: - timezone.strip('"') - tz = pytz.timezone(timezone) - except pytz.exceptions.UnknownTimeZoneError: - return { - "error": "Invalid timezone: {}".format(timezone) - } - current_time = datetime.now(tz) - return { - "timezone": timezone, - "current_time": current_time.strftime("%Y-%m-%d %H:%M:%S %Z") - } + +class DefaultTargetRequest(BaseModel): + arch_messages: list + +@app.post("/default_target") +async def default_target(req: DefaultTargetRequest, res: Response): + logger.info(f"Received arch_messages: {req.arch_messages}") + resp = { + "choices": [ + { + "message": { + "role": "assistant", + "content": "hello world from api server" + }, + "finish_reason": "completed", + "index": 0 + } + ], + "model": "api_server", + "usage": { + "completion_tokens": 0 + } + } + logger.info(f"sending response: {json.dumps(resp)}") + return resp diff --git a/demos/function_calling/api_server/requirements.txt b/demos/function_calling/api_server/requirements.txt index 566bcbcb..fc1a57d7 100644 --- a/demos/function_calling/api_server/requirements.txt +++ b/demos/function_calling/api_server/requirements.txt @@ -1,3 +1,3 @@ fastapi uvicorn -pytz +pyyaml diff --git a/demos/function_calling/arch_config.yaml b/demos/function_calling/arch_config.yaml index 7a6010f1..b13721f5 100644 --- a/demos/function_calling/arch_config.yaml +++ b/demos/function_calling/arch_config.yaml @@ -47,21 +47,6 @@ prompt_targets: - Use farenheight for temperature - Use miles per hour for wind speed - - name: system_time - description: This function provides the current system time. - parameters: - - name: timezone - description: The city for which the weather forecast is requested. - default: US/Pacific - type: string - endpoint: - name: api_server - path: /current_time - system_prompt: | - You are a helpful system time provider. Use system time data that is provided to you. Please following following guidelines when responding to user queries: - - Use 12 hour time format - - Use AM/PM for time - - name: insurance_claim_details description: This function resolver provides insurance claim details for a given policy number. parameters: @@ -80,6 +65,18 @@ prompt_targets: You are a helpful insurance claim details provider. Use insurance claim data that is provided to you. Please following following guidelines when responding to user queries: - Use policy number to retrieve insurance claim details + - name: default_target + default: true + description: This is the default target for all unmatched prompts. + endpoint: + name: api_server + path: /default_target + system_prompt: | + You are a helpful assistant. Use the information that is provided to you. + # if it is set to false arch will send response that it received from this prompt target to the user + # if true arch will forward the response to the default LLM + auto_llm_dispatch_on_response: true + ratelimits: - provider: gpt-3.5-turbo selector: diff --git a/model_server/Dockerfile b/model_server/Dockerfile index e90f332f..242f7893 100644 --- a/model_server/Dockerfile +++ b/model_server/Dockerfile @@ -1,41 +1,25 @@ -FROM python:3.10 AS base +FROM python:3.10 AS builder -# -# builder -# -FROM base AS builder +COPY requirements.txt . +RUN pip install --prefix=/runtime -r requirements.txt -WORKDIR /src -RUN pip install --upgrade pip - -# Install git (needed for cloning the repository) -RUN apt-get update && apt-get install -y git && apt-get clean - -COPY requirements.txt /src/ - -RUN pip install --prefix=/runtime --force-reinstall -r requirements.txt +FROM python:3.10-slim AS output -COPY . /src +# curl is needed for health check in docker-compose +RUN apt-get update && apt-get install -y curl && apt-get clean && rm -rf /var/lib/apt/lists/* -# -# output -# +COPY --from=builder /runtime /usr/local -FROM python:3.10-slim AS output +WORKDIR /src # specify list of models that will go into the image as a comma separated list # following models have been tested to work with this image # "sentence-transformers/all-MiniLM-L6-v2,sentence-transformers/all-mpnet-base-v2,thenlper/gte-base,thenlper/gte-large,thenlper/gte-small" ENV MODELS="BAAI/bge-large-en-v1.5" -COPY --from=builder /runtime /usr/local - -COPY ./ /app -WORKDIR /app - -RUN apt-get update && apt-get install -y \ - curl \ - && rm -rf /var/lib/apt/lists/* +COPY ./app ./app +COPY ./guard_model_config.yaml . +COPY ./openai_params.yaml . # comment it out for now as we don't want to download the model every time we build the image # we will mount host cache to docker image to avoid downloading the model every time diff --git a/model_server/app/arch_fc/arch_fc.py b/model_server/app/arch_fc/arch_fc.py index b60fdb5c..ec505eeb 100644 --- a/model_server/app/arch_fc/arch_fc.py +++ b/model_server/app/arch_fc/arch_fc.py @@ -9,6 +9,10 @@ from openai import OpenAI import os +logging.basicConfig( + level=logging.INFO, format="%(asctime)s - %(levelname)s - %(message)s" +) +logger = logging.getLogger(__name__) with open("openai_params.yaml") as f: params = yaml.safe_load(f) @@ -20,7 +24,6 @@ if mode not in ["cloud", "local-gpu", "local-cpu"]: raise ValueError(f"Invalid mode: {mode}") arch_api_key = os.getenv("ARCH_API_KEY", "vllm") -logger = logging.getLogger("uvicorn.error") handler = None if ollama_model.startswith("Arch"): @@ -28,17 +31,12 @@ else: handler = BoltHandler() - -# app = FastAPI() - if mode == "cloud": client = OpenAI( base_url=fc_url, api_key="EMPTY", ) - models = client.models.list() - model = models.data[0].id - chosen_model = model + chosen_model = "fc-cloud" endpoint = fc_url else: client = OpenAI( @@ -47,12 +45,12 @@ ) chosen_model = ollama_model endpoint = ollama_endpoint + logger.info(f"serving mode: {mode}") logger.info(f"using model: {chosen_model}") logger.info(f"using endpoint: {endpoint}") - async def chat_completion(req: ChatMessage, res: Response): logger.info("starting request") tools_encoded = handler._format_system(req.tools)