Skip to content

Commit

Permalink
feat: openai.Model.list() (#499)
Browse files Browse the repository at this point in the history
Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
Co-authored-by: Aaron Pham <29749331+aarnphm@users.noreply.github.com>
  • Loading branch information
3 people authored Oct 14, 2023
1 parent c1ca7cc commit d918326
Show file tree
Hide file tree
Showing 4 changed files with 33 additions and 6 deletions.
1 change: 1 addition & 0 deletions changelog.d/499.feature.md
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
Add `/v1/models` endpoint for OpenAI compatible API
6 changes: 4 additions & 2 deletions examples/openai_client.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,10 @@
import openai
import openai, os

openai.api_base = "http://localhost:3000/v1"
openai.api_base = os.getenv('OPENLLM_ENDPOINT', "http://localhost:3000") + '/v1'
openai.api_key = "na"

print("Model:", openai.Model.list())

response = openai.Completion.create(model="gpt-3.5-turbo-instruct", prompt="Write a tagline for an ice cream shop.", max_tokens=256)

print(response)
Expand Down
20 changes: 16 additions & 4 deletions openllm-python/src/openllm/_service.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
# mypy: disable-error-code="call-arg,misc,attr-defined,type-abstract,type-arg,valid-type,arg-type"
from __future__ import annotations
import logging
import typing as t
import warnings

Expand All @@ -23,6 +24,8 @@
warnings.filterwarnings('ignore', message='MatMul8bitLt: inputs will be cast from torch.bfloat16 to float16 during quantization')
warnings.filterwarnings('ignore', message='The installed version of bitsandbytes was compiled without GPU support.')

logger = logging.getLogger(__name__)

model = svars.model
model_id = svars.model_id
adapter_map = svars.adapter_map
Expand Down Expand Up @@ -62,6 +65,8 @@ async def generate_stream_v1(input_dict: dict[str, t.Any]) -> t.AsyncGenerator[s
input=bentoml.io.JSON.from_sample(openllm.utils.bentoml_cattr.unstructure(openllm.openai.CompletionRequest(prompt='What is 1+1?', model=runner.llm_type))),
output=bentoml.io.Text())
async def completion_v1(input_dict: dict[str, t.Any], ctx: bentoml.Context) -> str | t.AsyncGenerator[str, None]:
_model = input_dict.get('model', None)
if _model is not runner.llm_type: logger.warning("Model '%s' is not supported. Run openai.Model.list() to see all supported models.", _model)
prompt = input_dict.pop('prompt', None)
if prompt is None: raise ValueError("'prompt' should not be None.")
stream = input_dict.pop('stream', False)
Expand Down Expand Up @@ -118,6 +123,8 @@ async def stream_response_generator(responses: t.AsyncGenerator[str, None]) -> t
}], model=runner.llm_type))),
output=bentoml.io.Text())
async def chat_completion_v1(input_dict: dict[str, t.Any], ctx: bentoml.Context) -> str | t.AsyncGenerator[str, None]:
_model = input_dict.get('model', None)
if _model is not runner.llm_type: logger.warning("Model '%s' is not supported. Run openai.Model.list() to see all supported models.", _model)
prompt = openllm.openai.messages_to_prompt(input_dict['messages'])
stream = input_dict.pop('stream', False)
config = {
Expand Down Expand Up @@ -158,12 +165,17 @@ async def stream_response_generator(responses: t.AsyncGenerator[str, None]) -> t
responses = await runner.generate.async_run(prompt, **config)
return orjson.dumps(
openllm.utils.bentoml_cattr.unstructure(
openllm.openai.ChatCompletionResponse(choices=[
openllm.openai.ChatCompletionChoice(index=i, message=openllm.openai.Message(role='assistant', content=response)) for i, response in enumerate(responses)
],
model=model) # TODO: logprobs, finish_reason and usage
openllm.openai.ChatCompletionResponse(
choices=[openllm.openai.ChatCompletionChoice(index=i, message=openllm.openai.Message(role='assistant', content=response)) for i, response in enumerate(responses)],
model=runner.llm_type) # TODO: logprobs, finish_reason and usage
)).decode('utf-8')

def models_v1(_: Request) -> Response:
return JSONResponse(openllm.utils.bentoml_cattr.unstructure(openllm.openai.ModelList(data=[openllm.openai.ModelCard(id=runner.llm_type)])), status_code=200)

openai_app = Starlette(debug=True, routes=[Route('/models', models_v1, methods=['GET'])])
svc.mount_asgi_app(openai_app, path='/v1')

@svc.api(route='/v1/metadata',
input=bentoml.io.Text(),
output=bentoml.io.JSON.from_sample({
Expand Down
12 changes: 12 additions & 0 deletions openllm-python/src/openllm/protocol/openai.py
Original file line number Diff line number Diff line change
Expand Up @@ -119,6 +119,18 @@ class ChatCompletionResponseStream:
id: str = attr.field(default=attr.Factory(lambda: openllm_core.utils.gen_random_uuid('chatcmpl')))
created: int = attr.field(default=attr.Factory(lambda: int(time.time())))

@attr.define
class ModelCard:
id: str
object: str = 'model'
created: int = attr.field(default=attr.Factory(lambda: int(time.time())))
owned_by: str = 'na'

@attr.define
class ModelList:
object: str = 'list'
data: t.List[ModelCard] = attr.field(factory=list)

def messages_to_prompt(messages: list[Message]) -> str:
formatted = '\n'.join([f"{message['role']}: {message['content']}" for message in messages])
return f'{formatted}\nassistant:'

0 comments on commit d918326

Please sign in to comment.