Skip to content

Commit

Permalink
add support for default target (#111)
Browse files Browse the repository at this point in the history
* add support for default target

* add more fixes
  • Loading branch information
adilhafeez authored Oct 3, 2024
1 parent c8d0dbe commit 1b57a49
Show file tree
Hide file tree
Showing 8 changed files with 213 additions and 86 deletions.
2 changes: 2 additions & 0 deletions arch/arch_config_schema.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -67,6 +67,8 @@ properties:
type: boolean
description:
type: string
auto_llm_dispatch_on_response:
type: boolean
parameters:
type: array
items:
Expand Down
167 changes: 151 additions & 16 deletions arch/src/stream_context.rs
Original file line number Diff line number Diff line change
Expand Up @@ -39,6 +39,7 @@ enum ResponseHandlerType {
FunctionCall,
ZeroShotIntent,
ArchGuard,
DefaultTarget,
}

pub struct CallContext {
Expand Down Expand Up @@ -179,12 +180,16 @@ impl StreamContext {

let prompt_target_names = prompt_targets
.iter()
// exclude default target
.filter(|(_, prompt_target)| !prompt_target.default.unwrap_or(false))
.map(|(name, _)| name.clone())
.collect();

let similarity_scores: Vec<(String, f64)> = prompt_targets
.iter()
.map(|(prompt_name, _prompt_target)| {
// exclude default prompt target
.filter(|(_, prompt_target)| !prompt_target.default.unwrap_or(false))
.map(|(prompt_name, _)| {
let default_embeddings = HashMap::new();
let pte = prompt_target_embeddings
.get(prompt_name)
Expand Down Expand Up @@ -331,34 +336,84 @@ impl StreamContext {

// check to ensure that the prompt target similarity score is above the threshold
if prompt_target_similarity_score < prompt_target_intent_matching_threshold
&& !arch_assistant
|| arch_assistant
{
debug!("intent score is low or arch assistant is handling the conversation");
// if arch fc responded to the user message, then we don't need to check the similarity score
// it may be that arch fc is handling the conversation for parameter collection
if arch_assistant {
info!("arch assistant is handling the conversation");
} else {
info!(
"prompt target below limit: {:.3}, threshold: {:.3}, continue conversation with user",
prompt_target_similarity_score,
prompt_target_intent_matching_threshold
);
debug!("checking for default prompt target");
if let Some(default_prompt_target) = self
.prompt_targets
.read()
.unwrap()
.values()
.find(|pt| pt.default.unwrap_or(false))
{
debug!("default prompt target found");
let endpoint = default_prompt_target.endpoint.clone().unwrap();
let upstream_path: String = endpoint.path.unwrap_or(String::from("/"));

let upstream_endpoint = endpoint.name;
let mut params = HashMap::new();
params.insert(
ARCH_MESSAGES_KEY.to_string(),
callout_context.request_body.messages.clone(),
);
let arch_messages_json = serde_json::to_string(&params).unwrap();
debug!("no prompt target found with similarity score above threshold, using default prompt target");
let token_id = match self.dispatch_http_call(
&upstream_endpoint,
vec![
(":method", "POST"),
(":path", &upstream_path),
(":authority", &upstream_endpoint),
("content-type", "application/json"),
("x-envoy-max-retries", "3"),
(
"x-envoy-upstream-rq-timeout-ms",
ARCH_FC_REQUEST_TIMEOUT_MS.to_string().as_str(),
),
],
Some(arch_messages_json.as_bytes()),
vec![],
Duration::from_secs(5),
) {
Ok(token_id) => token_id,
Err(e) => {
let error_msg =
format!("Error dispatching HTTP call for default-target: {:?}", e);
return self
.send_server_error(error_msg, Some(StatusCode::BAD_REQUEST));
}
};

self.metrics.active_http_calls.increment(1);
callout_context.response_handler_type = ResponseHandlerType::DefaultTarget;
callout_context.prompt_target_name = Some(default_prompt_target.name.clone());
if self.callouts.insert(token_id, callout_context).is_some() {
panic!("duplicate token_id")
}
return;
}
self.resume_http_request();
return;
}
}

let prompt_target = self
.prompt_targets
.read()
.unwrap()
.get(&prompt_target_name)
.unwrap()
.clone();
let prompt_target = match self.prompt_targets.read().unwrap().get(&prompt_target_name) {
Some(prompt_target) => prompt_target.clone(),
None => {
return self.send_server_error(
format!("Prompt target not found: {}", prompt_target_name),
None,
);
}
};

info!("prompt_target name: {:?}", prompt_target_name);

//TODO: handle default function resolver type
let mut chat_completion_tools: Vec<ChatCompletionTool> = Vec::new();
for pt in self.prompt_targets.read().unwrap().values() {
// only extract entity names
Expand Down Expand Up @@ -761,6 +816,83 @@ impl StreamContext {
)
}
}

fn default_target_handler(&self, body: Vec<u8>, callout_context: CallContext) {
let prompt_target = self
.prompt_targets
.read()
.unwrap()
.get(callout_context.prompt_target_name.as_ref().unwrap())
.unwrap()
.clone();
debug!(
"response received for default target: {}",
prompt_target.name
);
// check if the default target should be dispatched to the LLM provider
if !prompt_target.auto_llm_dispatch_on_response.unwrap_or(false) {
let default_target_response_str = String::from_utf8(body).unwrap();
debug!(
"sending response back to developer: {}",
default_target_response_str
);
self.send_http_response(
StatusCode::OK.as_u16().into(),
vec![("Powered-By", "Katanemo")],
Some(default_target_response_str.as_bytes()),
);
// self.resume_http_request();
return;
}
debug!("default_target: sending api response to default llm");
let chat_completions_resp: ChatCompletionsResponse = match serde_json::from_slice(&body) {
Ok(chat_completions_resp) => chat_completions_resp,
Err(e) => {
return self.send_server_error(
format!("Error deserializing default target response: {:?}", e),
None,
);
}
};
let api_resp = chat_completions_resp.choices[0]
.message
.content
.as_ref()
.unwrap();
let mut messages = callout_context.request_body.messages;

// add system prompt
match prompt_target.system_prompt.as_ref() {
None => {}
Some(system_prompt) => {
let system_prompt_message = Message {
role: SYSTEM_ROLE.to_string(),
content: Some(system_prompt.clone()),
model: None,
tool_calls: None,
};
messages.push(system_prompt_message);
}
}

messages.push(Message {
role: USER_ROLE.to_string(),
content: Some(api_resp.clone()),
model: None,
tool_calls: None,
});
let chat_completion_request = ChatCompletionsRequest {
model: GPT_35_TURBO.to_string(),
messages,
tools: None,
stream: callout_context.request_body.stream,
stream_options: callout_context.request_body.stream_options,
};
let json_resp = serde_json::to_string(&chat_completion_request).unwrap();
debug!("sending response back to default llm: {}", json_resp);
self.set_http_request_body(0, json_resp.len(), json_resp.as_bytes());
self.resume_http_request();
}
}

// HttpContext is the trait that allows the Rust code to interact with HTTP objects.
Expand Down Expand Up @@ -1067,6 +1199,9 @@ impl Context for StreamContext {
self.function_call_response_handler(body, callout_context)
}
ResponseHandlerType::ArchGuard => self.arch_guard_handler(body, callout_context),
ResponseHandlerType::DefaultTarget => {
self.default_target_handler(body, callout_context)
}
}
} else {
self.send_server_error(
Expand Down
2 changes: 1 addition & 1 deletion demos/function_calling/api_server/Dockerfile
Original file line number Diff line number Diff line change
Expand Up @@ -16,4 +16,4 @@ COPY --from=builder /runtime /usr/local
COPY /app /app
WORKDIR /app

CMD ["uvicorn", "main:app", "--host", "0.0.0.0", "--port", "80"]
CMD ["uvicorn", "main:app", "--host", "0.0.0.0", "--port", "80", "--log-level", "info"]
47 changes: 29 additions & 18 deletions demos/function_calling/api_server/app/main.py
Original file line number Diff line number Diff line change
@@ -1,12 +1,13 @@
import json
import random
from fastapi import FastAPI, Response
from datetime import datetime, date, timedelta, timezone
import logging
from pydantic import BaseModel
import pytz

logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s')
logger = logging.getLogger(__name__)

logger = logging.getLogger('uvicorn.error')
logger.setLevel(logging.INFO)

app = FastAPI()

Expand Down Expand Up @@ -58,18 +59,28 @@ async def insurance_claim_details(req: InsuranceClaimDetailsRequest, res: Respon

return claim_details

@app.get("/current_time")
async def current_time(timezone: str):
tz = None
try:
timezone.strip('"')
tz = pytz.timezone(timezone)
except pytz.exceptions.UnknownTimeZoneError:
return {
"error": "Invalid timezone: {}".format(timezone)
}
current_time = datetime.now(tz)
return {
"timezone": timezone,
"current_time": current_time.strftime("%Y-%m-%d %H:%M:%S %Z")
}

class DefaultTargetRequest(BaseModel):
arch_messages: list

@app.post("/default_target")
async def default_target(req: DefaultTargetRequest, res: Response):
logger.info(f"Received arch_messages: {req.arch_messages}")
resp = {
"choices": [
{
"message": {
"role": "assistant",
"content": "hello world from api server"
},
"finish_reason": "completed",
"index": 0
}
],
"model": "api_server",
"usage": {
"completion_tokens": 0
}
}
logger.info(f"sending response: {json.dumps(resp)}")
return resp
2 changes: 1 addition & 1 deletion demos/function_calling/api_server/requirements.txt
Original file line number Diff line number Diff line change
@@ -1,3 +1,3 @@
fastapi
uvicorn
pytz
pyyaml
27 changes: 12 additions & 15 deletions demos/function_calling/arch_config.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -47,21 +47,6 @@ prompt_targets:
- Use farenheight for temperature
- Use miles per hour for wind speed
- name: system_time
description: This function provides the current system time.
parameters:
- name: timezone
description: The city for which the weather forecast is requested.
default: US/Pacific
type: string
endpoint:
name: api_server
path: /current_time
system_prompt: |
You are a helpful system time provider. Use system time data that is provided to you. Please following following guidelines when responding to user queries:
- Use 12 hour time format
- Use AM/PM for time
- name: insurance_claim_details
description: This function resolver provides insurance claim details for a given policy number.
parameters:
Expand All @@ -80,6 +65,18 @@ prompt_targets:
You are a helpful insurance claim details provider. Use insurance claim data that is provided to you. Please following following guidelines when responding to user queries:
- Use policy number to retrieve insurance claim details
- name: default_target
default: true
description: This is the default target for all unmatched prompts.
endpoint:
name: api_server
path: /default_target
system_prompt: |
You are a helpful assistant. Use the information that is provided to you.
# if it is set to false arch will send response that it received from this prompt target to the user
# if true arch will forward the response to the default LLM
auto_llm_dispatch_on_response: true

ratelimits:
- provider: gpt-3.5-turbo
selector:
Expand Down
38 changes: 11 additions & 27 deletions model_server/Dockerfile
Original file line number Diff line number Diff line change
@@ -1,41 +1,25 @@
FROM python:3.10 AS base
FROM python:3.10 AS builder

#
# builder
#
FROM base AS builder
COPY requirements.txt .
RUN pip install --prefix=/runtime -r requirements.txt

WORKDIR /src
RUN pip install --upgrade pip

# Install git (needed for cloning the repository)
RUN apt-get update && apt-get install -y git && apt-get clean

COPY requirements.txt /src/

RUN pip install --prefix=/runtime --force-reinstall -r requirements.txt
FROM python:3.10-slim AS output

COPY . /src
# curl is needed for health check in docker-compose
RUN apt-get update && apt-get install -y curl && apt-get clean && rm -rf /var/lib/apt/lists/*

#
# output
#
COPY --from=builder /runtime /usr/local

FROM python:3.10-slim AS output
WORKDIR /src

# specify list of models that will go into the image as a comma separated list
# following models have been tested to work with this image
# "sentence-transformers/all-MiniLM-L6-v2,sentence-transformers/all-mpnet-base-v2,thenlper/gte-base,thenlper/gte-large,thenlper/gte-small"
ENV MODELS="BAAI/bge-large-en-v1.5"

COPY --from=builder /runtime /usr/local

COPY ./ /app
WORKDIR /app

RUN apt-get update && apt-get install -y \
curl \
&& rm -rf /var/lib/apt/lists/*
COPY ./app ./app
COPY ./guard_model_config.yaml .
COPY ./openai_params.yaml .

# comment it out for now as we don't want to download the model every time we build the image
# we will mount host cache to docker image to avoid downloading the model every time
Expand Down
Loading

0 comments on commit 1b57a49

Please sign in to comment.