Skip to content

Commit

Permalink
feat(openai): dynamic model_type registration (#704)
Browse files Browse the repository at this point in the history
Signed-off-by: Aaron Pham <29749331+aarnphm@users.noreply.github.com>
  • Loading branch information
aarnphm authored Nov 20, 2023
1 parent 6505abd commit 513c08c
Show file tree
Hide file tree
Showing 3 changed files with 29 additions and 9 deletions.
16 changes: 11 additions & 5 deletions openllm-python/src/openllm/entrypoints/_openapi.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,7 @@
example:
object: 'list'
data:
- id: meta-llama--Llama-2-13b-chat-hf
- id: __model_id__
object: model
created: 1686935002
owned_by: 'na'
Expand Down Expand Up @@ -69,7 +69,7 @@
content: You are a helpful assistant.
- role: user
content: Hello, I'm looking for a chatbot that can help me with my work.
model: meta-llama--Llama-2-13b-chat-hf
model: __model_id__
max_tokens: 256
temperature: 0.7
top_p: 0.43
Expand All @@ -83,7 +83,7 @@
content: You are a helpful assistant.
- role: user
content: Hello, I'm looking for a chatbot that can help me with my work.
model: meta-llama--Llama-2-13b-chat-hf
model: __model_id__
max_tokens: 256
temperature: 0.7
top_p: 0.43
Expand Down Expand Up @@ -206,7 +206,7 @@
summary: One-shot input example
value:
prompt: This is a test
model: meta-llama--Llama-2-13b-chat-hf
model: __model_id__
max_tokens: 256
temperature: 0.7
logprobs: 1
Expand All @@ -217,7 +217,7 @@
summary: Streaming input example
value:
prompt: This is a test
model: meta-llama--Llama-2-13b-chat-hf
model: __model_id__
max_tokens: 256
temperature: 0.7
top_p: 0.43
Expand Down Expand Up @@ -472,6 +472,12 @@
_SCHEMAS = {k[:-7].lower(): v for k, v in locals().items() if k.endswith('_SCHEMA')}


def apply_schema(func, **attrs):
for k, v in attrs.items():
func.__doc__ = func.__doc__.replace(k, v)
return func


def add_schema_definitions(func):
append_str = _SCHEMAS.get(func.__name__.lower(), '')
if not append_str:
Expand Down
1 change: 1 addition & 0 deletions openllm-python/src/openllm/entrypoints/_openapi.pyi
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@ class OpenLLMSchemaGenerator:
def get_schema(self, routes: list[BaseRoute], mount_path: Optional[str] = ...) -> Dict[str, Any]: ...
def parse_docstring(self, func_or_method: Callable[P, Any]) -> Dict[str, Any]: ...

def apply_schema(func: Callable[P, Any], **attrs: Any) -> Callable[P, Any]: ...
def add_schema_definitions(func: Callable[P, Any]) -> Callable[P, Any]: ...
def append_schemas(
svc: Service, generated_schema: Dict[str, Any], tags_order: Literal['prepend', 'append'] = ..., inject: bool = ...
Expand Down
21 changes: 17 additions & 4 deletions openllm-python/src/openllm/entrypoints/openai.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@
from openllm_core._schemas import SampleLogprobs
from openllm_core.utils import converter, gen_random_uuid

from ._openapi import add_schema_definitions, append_schemas, get_generator
from ._openapi import add_schema_definitions, append_schemas, apply_schema, get_generator
from ..protocol.openai import (
ChatCompletionRequest,
ChatCompletionResponse,
Expand Down Expand Up @@ -100,12 +100,25 @@ def create_logprobs(token_ids, id_logprobs, initial_text_offset=0, *, llm):


def mount_to_svc(svc, llm):
list_models.__doc__ = list_models.__doc__.replace('__model_id__', llm.llm_type)
completions.__doc__ = completions.__doc__.replace('__model_id__', llm.llm_type)
chat_completions.__doc__ = chat_completions.__doc__.replace('__model_id__', llm.llm_type)
app = Starlette(
debug=True,
routes=[
Route('/models', functools.partial(list_models, llm=llm), methods=['GET']),
Route('/completions', functools.partial(completions, llm=llm), methods=['POST']),
Route('/chat/completions', functools.partial(chat_completions, llm=llm), methods=['POST']),
Route(
'/models', functools.partial(apply_schema(list_models, __model_id__=llm.llm_type), llm=llm), methods=['GET']
),
Route(
'/completions',
functools.partial(apply_schema(completions, __model_id__=llm.llm_type), llm=llm),
methods=['POST'],
),
Route(
'/chat/completions',
functools.partial(apply_schema(chat_completions, __model_id__=llm.llm_type), llm=llm),
methods=['POST'],
),
Route('/schema', endpoint=lambda req: schemas.OpenAPIResponse(req), include_in_schema=False),
],
)
Expand Down

0 comments on commit 513c08c

Please sign in to comment.