Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: openai.Model.list() #499

Merged
merged 9 commits into from
Oct 14, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions changelog.d/499.feature.md
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
Add `/v1/models` endpoint for OpenAI compatible API
6 changes: 4 additions & 2 deletions examples/openai_client.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,10 @@
import openai
import openai, os

openai.api_base = "http://localhost:3000/v1"
openai.api_base = os.getenv('OPENLLM_ENDPOINT', "http://localhost:3000") + '/v1'
openai.api_key = "na"

print("Model:", openai.Model.list())

response = openai.Completion.create(model="gpt-3.5-turbo-instruct", prompt="Write a tagline for an ice cream shop.", max_tokens=256)

print(response)
Expand Down
20 changes: 16 additions & 4 deletions openllm-python/src/openllm/_service.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
# mypy: disable-error-code="call-arg,misc,attr-defined,type-abstract,type-arg,valid-type,arg-type"
from __future__ import annotations
import logging
import typing as t
import warnings

Expand All @@ -23,6 +24,8 @@
warnings.filterwarnings('ignore', message='MatMul8bitLt: inputs will be cast from torch.bfloat16 to float16 during quantization')
warnings.filterwarnings('ignore', message='The installed version of bitsandbytes was compiled without GPU support.')

logger = logging.getLogger(__name__)

model = svars.model
model_id = svars.model_id
adapter_map = svars.adapter_map
Expand Down Expand Up @@ -62,6 +65,8 @@ async def generate_stream_v1(input_dict: dict[str, t.Any]) -> t.AsyncGenerator[s
input=bentoml.io.JSON.from_sample(openllm.utils.bentoml_cattr.unstructure(openllm.openai.CompletionRequest(prompt='What is 1+1?', model=runner.llm_type))),
output=bentoml.io.Text())
async def completion_v1(input_dict: dict[str, t.Any], ctx: bentoml.Context) -> str | t.AsyncGenerator[str, None]:
_model = input_dict.get('model', None)
if _model is not runner.llm_type: logger.warning("Model '%s' is not supported. Run openai.Model.list() to see all supported models.", _model)
prompt = input_dict.pop('prompt', None)
if prompt is None: raise ValueError("'prompt' should not be None.")
stream = input_dict.pop('stream', False)
Expand Down Expand Up @@ -118,6 +123,8 @@ async def stream_response_generator(responses: t.AsyncGenerator[str, None]) -> t
}], model=runner.llm_type))),
output=bentoml.io.Text())
async def chat_completion_v1(input_dict: dict[str, t.Any], ctx: bentoml.Context) -> str | t.AsyncGenerator[str, None]:
_model = input_dict.get('model', None)
if _model is not runner.llm_type: logger.warning("Model '%s' is not supported. Run openai.Model.list() to see all supported models.", _model)
prompt = openllm.openai.messages_to_prompt(input_dict['messages'])
stream = input_dict.pop('stream', False)
config = {
Expand Down Expand Up @@ -158,12 +165,17 @@ async def stream_response_generator(responses: t.AsyncGenerator[str, None]) -> t
responses = await runner.generate.async_run(prompt, **config)
return orjson.dumps(
openllm.utils.bentoml_cattr.unstructure(
openllm.openai.ChatCompletionResponse(choices=[
openllm.openai.ChatCompletionChoice(index=i, message=openllm.openai.Message(role='assistant', content=response)) for i, response in enumerate(responses)
],
model=model) # TODO: logprobs, finish_reason and usage
openllm.openai.ChatCompletionResponse(
choices=[openllm.openai.ChatCompletionChoice(index=i, message=openllm.openai.Message(role='assistant', content=response)) for i, response in enumerate(responses)],
model=runner.llm_type) # TODO: logprobs, finish_reason and usage
)).decode('utf-8')

def models_v1(_: Request) -> Response:
return JSONResponse(openllm.utils.bentoml_cattr.unstructure(openllm.openai.ModelList(data=[openllm.openai.ModelCard(id=runner.llm_type)])), status_code=200)

openai_app = Starlette(debug=True, routes=[Route('/models', models_v1, methods=['GET'])])
svc.mount_asgi_app(openai_app, path='/v1')

@svc.api(route='/v1/metadata',
input=bentoml.io.Text(),
output=bentoml.io.JSON.from_sample({
Expand Down
12 changes: 12 additions & 0 deletions openllm-python/src/openllm/protocol/openai.py
Original file line number Diff line number Diff line change
Expand Up @@ -119,6 +119,18 @@ class ChatCompletionResponseStream:
id: str = attr.field(default=attr.Factory(lambda: openllm_core.utils.gen_random_uuid('chatcmpl')))
created: int = attr.field(default=attr.Factory(lambda: int(time.time())))

@attr.define
class ModelCard:
id: str
object: str = 'model'
created: int = attr.field(default=attr.Factory(lambda: int(time.time())))
owned_by: str = 'na'

@attr.define
class ModelList:
object: str = 'list'
data: t.List[ModelCard] = attr.field(factory=list)

def messages_to_prompt(messages: list[Message]) -> str:
formatted = '\n'.join([f"{message['role']}: {message['content']}" for message in messages])
return f'{formatted}\nassistant:'
Loading