Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: continuous batching with vLLM #349

Merged
merged 5 commits into from
Sep 14, 2023
Merged
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Prev Previous commit
Next Next commit
chore: wip
Signed-off-by: Aaron <29749331+aarnphm@users.noreply.github.com>
  • Loading branch information
aarnphm committed Sep 13, 2023
commit 2e31d16cb4764e3797705b0a5846529ee29cd927
25 changes: 12 additions & 13 deletions openllm-python/src/openllm/_assign.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@
from openllm_core.utils import codegen
from openllm_core.utils import device_count
from openllm_core.utils import first_not_none
from openllm_core.utils import get_debug_mode
from openllm_core.utils import is_torch_available

if t.TYPE_CHECKING:
Expand All @@ -38,11 +39,8 @@
def import_model(fn: import_model_protocol[bentoml.Model, M, T]) -> t.Callable[[LLM[M, T]], bentoml.Model]:
@functools.wraps(fn)
def inner(self: LLM[M, T], *decls: t.Any, trust_remote_code: bool | None = None, **attrs: t.Any) -> bentoml.Model:
trust_remote_code = first_not_none(trust_remote_code, default=self.trust_remote_code)
(model_decls, model_attrs), _ = self.llm_parameters
decls = (*model_decls, *decls)
attrs = {**model_attrs, **attrs}
return fn(self, *decls, trust_remote_code=trust_remote_code, **attrs)
return fn(self, *model_decls, *decls, trust_remote_code=first_not_none(trust_remote_code, default=self.trust_remote_code), **model_attrs, **attrs)

return inner

Expand All @@ -52,21 +50,22 @@ def inner(self: LLM[M, T], *decls: t.Any, **attrs: t.Any) -> M | vllm.LLMEngine:
if self.__llm_backend__ == 'vllm':
num_gpus, dev = 1, device_count()
if dev >= 2: num_gpus = min(dev // 2 * 2, dev)
# TODO: Do some more processing with token_id once we support token streaming
try:
return vllm.LLMEngine.from_engine_args(
vllm.EngineArgs(model=self._bentomodel.path,
tokenizer=self._bentomodel.path if self.tokenizer_id == 'local' else self.tokenizer_id,
tokenizer_mode='auto',
tensor_parallel_size=num_gpus,
dtype='auto',
worker_use_ray=False))
return vllm.AsyncLLMEngine.from_engine_args(
vllm.AsyncEngineArgs(model=self._bentomodel.path,
tokenizer=self._bentomodel.path if self.tokenizer_id == 'local' else self.tokenizer_id,
tokenizer_mode='auto',
tensor_parallel_size=num_gpus,
dtype='auto',
disable_log_requests=not get_debug_mode(),
worker_use_ray=False,
engine_use_ray=False))
except Exception as err:
traceback.print_exc()
raise OpenLLMException(f'Failed to initialise vLLMEngine due to the following error:\n{err}') from None
else:
(model_decls, model_attrs), _ = self.llm_parameters
return fn(self, *(*model_decls, *decls), **{**model_attrs, **attrs})
return fn(self, *model_decls, *decls, **model_attrs, **attrs)

return inner

Expand Down