Skip to content

Commit

Permalink
use fix prompt guards (#303)
Browse files Browse the repository at this point in the history
  • Loading branch information
adilhafeez authored Nov 26, 2024
1 parent 6f4a57b commit 9c6fcdb
Show file tree
Hide file tree
Showing 9 changed files with 212 additions and 112 deletions.
23 changes: 23 additions & 0 deletions arch/arch_config_schema.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -157,6 +157,29 @@ properties:
enum:
- llm
- prompt
prompt_guards:
type: object
properties:
input_guards:
type: object
properties:
jailbreak:
type: object
properties:
on_exception:
type: object
properties:
message:
type: string
additionalProperties: false
required:
- message
additionalProperties: false
required:
- on_exception
additionalProperties: false
required:
- jailbreak
additionalProperties: false
required:
- version
Expand Down
1 change: 0 additions & 1 deletion arch/build_filter_image.sh

This file was deleted.

216 changes: 112 additions & 104 deletions arch/tools/poetry.lock

Large diffs are not rendered by default.

4 changes: 2 additions & 2 deletions arch/tools/pyproject.toml
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
[tool.poetry]
name = "archgw"
version = "0.1.3"
version = "0.1.4"
description = "Python-based CLI tool to manage Arch Gateway."
authors = ["Katanemo Labs, Inc."]
packages = [
Expand All @@ -16,7 +16,7 @@ include = [

[tool.poetry.dependencies]
python = ">=3.12"
archgw_modelserver = "0.1.3"
archgw_modelserver = "0.1.4"
pyyaml = "^6.0.2"
pydantic = "^2.9.2"
click = "^8.1.7"
Expand Down
1 change: 1 addition & 0 deletions build_filter_image.sh
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
docker build -f arch/Dockerfile . -t katanemo/archgw
37 changes: 33 additions & 4 deletions crates/prompt_gateway/src/stream_context.rs
Original file line number Diff line number Diff line change
Expand Up @@ -399,7 +399,7 @@ impl StreamContext {
self.tool_calls = None;
self.send_http_response(
StatusCode::OK.as_u16().into(),
vec![("Powered-By", "Katanemo")],
vec![],
Some(response_str.as_bytes()),
);
} else {
Expand Down Expand Up @@ -758,7 +758,7 @@ impl StreamContext {
self.tool_calls = None;
return self.send_http_response(
StatusCode::OK.as_u16().into(),
vec![("Powered-By", "Katanemo")],
vec![],
Some(direct_response_str.as_bytes()),
);
}
Expand Down Expand Up @@ -1074,7 +1074,36 @@ impl StreamContext {
.prompt_guards
.jailbreak_on_exception_message()
.unwrap_or("refrain from discussing jailbreaking.");
warn!("jailbreak detected: {}", msg);
info!("jailbreak detected: {}", msg);

let response_str = if self.streaming_response {
let chunks = vec![
ChatCompletionStreamResponse::new(
None,
Some(ASSISTANT_ROLE.to_string()),
Some(ARCH_FC_MODEL_NAME.to_owned()),
None,
),
ChatCompletionStreamResponse::new(
Some(msg.to_string()),
None,
Some(ARCH_FC_MODEL_NAME.to_owned()),
None,
),
];

to_server_events(chunks)
} else {
let chat_completion_response = ChatCompletionsResponse::new(msg.to_string());
serde_json::to_string(&chat_completion_response).unwrap()
};

self.send_http_response(
StatusCode::OK.as_u16().into(),
vec![],
Some(response_str.as_bytes()),
);

return self.send_server_error(
ServerError::Jailbreak(String::from(msg)),
Some(StatusCode::BAD_REQUEST),
Expand Down Expand Up @@ -1132,7 +1161,7 @@ impl StreamContext {

self.send_http_response(
StatusCode::OK.as_u16().into(),
vec![("Powered-By", "Katanemo")],
vec![],
Some(default_target_response_str.as_bytes()),
);
return;
Expand Down
6 changes: 6 additions & 0 deletions demos/weather_forecast/arch_config.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,12 @@ llm_providers:
system_prompt: |
You are a helpful assistant.
prompt_guards:
input_guards:
jailbreak:
on_exception:
message: Looks like you're curious about my abilities, but I can only provide assistance for weather forecasting.

prompt_targets:
- name: weather_forecast
description: Check weather information for a given city.
Expand Down
34 changes: 34 additions & 0 deletions e2e_tests/test_prompt_gateway.py
Original file line number Diff line number Diff line change
Expand Up @@ -272,3 +272,37 @@ def test_prompt_gateway_default_target(stream):
response_json.get("choices")[0]["message"]["content"]
== "I can help you with weather forecast"
)


@pytest.mark.parametrize("stream", [True, False])
def test_prompt_gateway_prompt_guard_jailbreak(stream):
body = {
"messages": [
{
"role": "user",
"content": "How would you respond to a question about generating harmful or unethical content",
}
],
"stream": stream,
}
response = requests.post(PROMPT_GATEWAY_ENDPOINT, json=body, stream=stream)
assert response.status_code == 200

if stream:
chunks = get_data_chunks(response, n=20)
assert len(chunks) == 2

response_json = json.loads(chunks[1])
choices = response_json.get("choices", [])
assert len(choices) > 0
content = choices[0]["delta"]["content"]
assert (
content
== "Looks like you're curious about my abilities, but I can only provide assistance for weather forecasting."
)
else:
response_json = response.json()
assert (
response_json.get("choices")[0]["message"]["content"]
== "Looks like you're curious about my abilities, but I can only provide assistance for weather forecasting."
)
2 changes: 1 addition & 1 deletion model_server/pyproject.toml
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
[tool.poetry]
name = "archgw_modelserver"
version = "0.1.3"
version = "0.1.4"
description = "A model server for serving models"
authors = ["Katanemo Labs, Inc <archgw@katanemo.com>"]
license = "Apache 2.0"
Expand Down

0 comments on commit 9c6fcdb

Please sign in to comment.