Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

refactor(breaking): unify LLM API #283

Merged
merged 9 commits into from
Sep 1, 2023
Prev Previous commit
Next Next commit
fix: run format
Signed-off-by: Aaron <29749331+aarnphm@users.noreply.github.com>
  • Loading branch information
aarnphm committed Aug 30, 2023
commit 9d4ec1567a2faa75e4da8ebb487509d5d49b1db7
4 changes: 3 additions & 1 deletion openllm-core/src/openllm_core/_configuration.py
Original file line number Diff line number Diff line change
Expand Up @@ -707,7 +707,9 @@ def structure_settings(cl_: type[LLMConfig], cls: type[_ModelSettingsAttr]) -> _
if not BACKENDS_MAPPING[library_stub][0](): default_implementation[rs] = 'pt'
_final_value_dct['default_implementation'] = default_implementation

env = openllm_core.utils.EnvVarMixin(model_name, get_default_implementation(default_implementation), model_id=_settings_attr.default_id)
env = openllm_core.utils.EnvVarMixin(model_name,
get_default_implementation(default_implementation),
model_id=_settings_attr.default_id)
_final_value_dct['env'] = env

_final_value_dct['service_name'] = f'generated_{model_name}_service.py'
Expand Down
4 changes: 4 additions & 0 deletions openllm-core/src/openllm_core/_typing_compat.py
Original file line number Diff line number Diff line change
Expand Up @@ -161,19 +161,23 @@ def __repr_keys__(self) -> set[str]:
...

class load_model_protocol(t.Generic[M, T], t.Protocol):

def __call__(self, llm: LLM[M, T], *decls: t.Any, **attrs: t.Any) -> M:
...

class load_tokenizer_protocol(t.Generic[M, T], t.Protocol):

def __call__(self, llm: LLM[M, T], **attrs: t.Any) -> T:
...

_R = t.TypeVar('_R', covariant=True)

class import_model_protocol(t.Generic[_R, M, T], t.Protocol):

def __call__(self, llm: LLM[M, T], *decls: t.Any, trust_remote_code: bool, **attrs: t.Any) -> _R:
...

class llm_post_init_protocol(t.Generic[M, T], t.Protocol):

def __call__(self, llm: LLM[M, T]) -> T:
...
4 changes: 3 additions & 1 deletion openllm-core/src/openllm_core/utils/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@
from circus.exc import ConflictError

import openllm_core

from bentoml._internal.configuration import DEBUG_ENV_VAR as DEBUG_ENV_VAR
from bentoml._internal.configuration import GRPC_DEBUG_ENV_VAR as _GRPC_DEBUG_ENV_VAR
from bentoml._internal.configuration import QUIET_ENV_VAR as QUIET_ENV_VAR
Expand Down Expand Up @@ -107,7 +108,8 @@ def field_env_key(model_name: str, key: str, suffix: str | None = None) -> str:
return '_'.join(filter(None, map(str.upper, ['OPENLLM', model_name, suffix.strip('_') if suffix else '', key])))

# Special debug flag controled via OPENLLMDEVDEBUG
DEBUG: bool = sys.flags.dev_mode or (not sys.flags.ignore_environment) or (str(os.environ.get(DEV_DEBUG_VAR, None)).upper() in ENV_VARS_TRUE_VALUES)
DEBUG: bool = sys.flags.dev_mode or (not sys.flags.ignore_environment) or (str(os.environ.get(
DEV_DEBUG_VAR, None)).upper() in ENV_VARS_TRUE_VALUES)
# Whether to show the codenge for debug purposes
SHOW_CODEGEN: bool = DEBUG and int(os.environ.get('OPENLLMDEVDEBUG', str(0))) > 3
# MYPY is like t.TYPE_CHECKING, but reserved for Mypy plugins
Expand Down
14 changes: 6 additions & 8 deletions openllm-core/src/openllm_core/utils/import_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -380,14 +380,12 @@ def __getitem__(self, item: str | t.Any) -> t.Any:
elif hasattr(self, item): return getattr(self, item)
raise KeyError(f'Key {item} not found in {self}')

def __init__(
self,
model_name: str,
implementation: LiteralRuntime = 'pt',
model_id: str | None = None,
quantize: LiteralString | None = None,
runtime: t.Literal['ggml', 'transformers'] = 'transformers'
) -> None:
def __init__(self,
model_name: str,
implementation: LiteralRuntime = 'pt',
model_id: str | None = None,
quantize: LiteralString | None = None,
runtime: t.Literal['ggml', 'transformers'] = 'transformers') -> None:
'''EnvVarMixin is a mixin class that returns the value extracted from environment variables.'''
from openllm_core.utils import field_env_key
self.model_name = inflection.underscore(model_name)
Expand Down
19 changes: 16 additions & 3 deletions openllm-python/src/openllm/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,8 +48,13 @@
"_quantisation": ["infer_quantisation_config"],
"_embeddings": ["GenericEmbeddingRunnable"],
"_llm": ["LLM", "Runner", "LLMRunner", "LLMRunnable", "EmbeddingsOutput"],
"_generation": ["StopSequenceCriteria", "StopOnTokens", "LogitsProcessorList", "StoppingCriteriaList", "prepare_logits_processor"],
"models.auto": ["MODEL_MAPPING_NAMES", "MODEL_FLAX_MAPPING_NAMES", "MODEL_TF_MAPPING_NAMES", "MODEL_VLLM_MAPPING_NAMES"],
"_generation": [
"StopSequenceCriteria", "StopOnTokens", "LogitsProcessorList", "StoppingCriteriaList",
"prepare_logits_processor"
],
"models.auto": [
"MODEL_MAPPING_NAMES", "MODEL_FLAX_MAPPING_NAMES", "MODEL_TF_MAPPING_NAMES", "MODEL_VLLM_MAPPING_NAMES"
],
"models.chatglm": [],
"models.baichuan": [],
"models.dolly_v2": [],
Expand Down Expand Up @@ -188,7 +193,15 @@
from .models.opt import TFOPT as TFOPT

# NOTE: update this to sys.modules[__name__] once mypy_extensions can recognize __spec__
__lazy = openllm_core.utils.LazyModule(__name__, globals()["__file__"], _import_structure, extra_objects={"COMPILED": COMPILED, "__openllm_migration__": {"LLMEmbeddings": "EmbeddingsOutput"}})
__lazy = openllm_core.utils.LazyModule(__name__,
globals()["__file__"],
_import_structure,
extra_objects={
"COMPILED": COMPILED,
"__openllm_migration__": {
"LLMEmbeddings": "EmbeddingsOutput"
}
})
__all__ = __lazy.__all__
__dir__ = __lazy.__dir__
__getattr__ = __lazy.__getattr__
5 changes: 4 additions & 1 deletion openllm-python/src/openllm/_embeddings.py
Original file line number Diff line number Diff line change
Expand Up @@ -68,7 +68,10 @@ def encode(self, sentences: list[str]) -> t.Sequence[openllm.EmbeddingsOutput]:
model_output = self.model(**encoded_input)
# Perform pooling and normalize
sentence_embeddings = F.normalize(self.mean_pooling(model_output, attention_mask), p=2, dim=1)
return [openllm.EmbeddingsOutput(embeddings=sentence_embeddings.cpu().numpy(), num_tokens=int(torch.sum(attention_mask).item()))]
return [
openllm.EmbeddingsOutput(embeddings=sentence_embeddings.cpu().numpy(),
num_tokens=int(torch.sum(attention_mask).item()))
]

@staticmethod
def mean_pooling(model_output: torch.Tensor, attention_mask: torch.Tensor) -> torch.Tensor:
Expand Down
64 changes: 44 additions & 20 deletions openllm-python/src/openllm/_gen.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@
import typing as t

import openllm

from openllm_core._configuration import _object_getattribute
from openllm_core._configuration import _setattr_class
from openllm_core._schema import unmarshal_vllm_outputs
Expand All @@ -32,19 +33,22 @@
from openllm_core.utils import is_torch_available

from .exceptions import OpenLLMException

if t.TYPE_CHECKING:
import torch
import transformers
import vllm

import bentoml

from openllm._llm import LLM
else:
transformers = LazyLoader('transformers', globals(), 'transformers')
torch = LazyLoader('torch', globals(), 'torch')
vllm = LazyLoader('vllm', globals(), 'vllm')

def import_model(fn: import_model_protocol[bentoml.Model, M, T]) -> t.Callable[[LLM[M, T]], bentoml.Model]:

@functools.wraps(fn)
def inner(self: LLM[M, T], *decls: t.Any, trust_remote_code: bool | None = None, **attrs: t.Any) -> bentoml.Model:
trust_remote_code = first_not_none(trust_remote_code, default=self.__llm_trust_remote_code__)
Expand All @@ -56,21 +60,19 @@ def inner(self: LLM[M, T], *decls: t.Any, trust_remote_code: bool | None = None,
return inner

def load_model(fn: load_model_protocol[M, T]) -> t.Callable[[LLM[M, T]], M | vllm.LLMEngine]:

@functools.wraps(fn)
def inner(self: LLM[M, T], *decls: t.Any, **attrs: t.Any) -> M | vllm.LLMEngine:
if self.__llm_implementation__ == 'vllm':
# TODO: Do some more processing with token_id once we support token streaming
try:
return vllm.LLMEngine.from_engine_args(
vllm.EngineArgs(
model=self._bentomodel.path,
tokenizer=self._bentomodel.path if self.tokenizer_id == 'local' else self.tokenizer_id,
tokenizer_mode='auto',
tensor_parallel_size=1 if device_count() < 2 else device_count(),
dtype='auto',
worker_use_ray=False
)
)
vllm.EngineArgs(model=self._bentomodel.path,
tokenizer=self._bentomodel.path if self.tokenizer_id == 'local' else self.tokenizer_id,
tokenizer_mode='auto',
tensor_parallel_size=1 if device_count() < 2 else device_count(),
dtype='auto',
worker_use_ray=False))
except Exception as err:
traceback.print_exc()
raise OpenLLMException(f'Failed to initialise vLLMEngine due to the following error:\n{err}') from None
Expand All @@ -81,16 +83,19 @@ def inner(self: LLM[M, T], *decls: t.Any, **attrs: t.Any) -> M | vllm.LLMEngine:
return inner

def load_tokenizer(fn: load_tokenizer_protocol[M, T]) -> t.Callable[[LLM[M, T]], T]:

@functools.wraps(fn)
def inner(self: LLM[M, T], **tokenizer_attrs: t.Any) -> T:
return fn(self, **{**self.llm_parameters[-1], **tokenizer_attrs})

return inner

def llm_post_init(fn: llm_post_init_protocol[M, T]) -> t.Callable[[LLM[M, T]], None]:

@functools.wraps(fn)
def inner(self: LLM[M, T]) -> None:
if self.__llm_implementation__ == 'pt' and is_torch_available(): self.device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
if self.__llm_implementation__ == 'pt' and is_torch_available():
self.device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
fn(self)

return inner
Expand All @@ -104,21 +109,31 @@ def make_assignment_script(cls: type[LLM[M, T]]) -> t.Callable[[type[LLM[M, T]]]
anns: DictStrAny = {}
globs: DictStrAny = {'cls': cls, '__wrapped_llm_post_init': llm_post_init}
# _cached_LLMFunction_get and _ccached_LLMSerialisation_get
globs.update({f'_cached_{cl_.__name__}_get': _object_getattribute.__get__(cl_) for cl_ in {LLMSerialisation, LLMFunction}})
globs.update(
{f'_cached_{cl_.__name__}_get': _object_getattribute.__get__(cl_) for cl_ in {LLMSerialisation, LLMFunction}})
# llm_post_init implementation
lines: ListStr = [f'_impl_{cls.__name__}_func=cls.llm_post_init', _setattr_class('llm_post_init', f'__wrapped_llm_post_init(_impl_{cls.__name__}_func)')]
lines: ListStr = [
f'_impl_{cls.__name__}_func=cls.llm_post_init',
_setattr_class('llm_post_init', f'__wrapped_llm_post_init(_impl_{cls.__name__}_func)')
]

serialisation_attr = {'import_model': import_model, 'load_model': load_model, 'load_tokenizer': load_tokenizer,}
for func, impl in serialisation_attr.items():
impl_name = f'__wrapped_{func}'
globs.update({f'__serialisation_{func}': getattr(openllm.serialisation, func, None), impl_name: impl})
cached_func_name = f'_cached_{cls.__name__}_func'
func_call = f"_impl_{cls.__name__}_{func}={cached_func_name} if {cached_func_name} is not _cached_LLMSerialisation_get('{func}') else __serialisation_{func}"
lines.extend([f'{cached_func_name}=cls.{func}', func_call, _setattr_class(func, f'{impl_name}(_impl_{cls.__name__}_{func})')])
lines.extend([
f'{cached_func_name}=cls.{func}', func_call,
_setattr_class(func, f'{impl_name}(_impl_{cls.__name__}_{func})')
])

# assign vLLM implementation
if cls.__llm_implementation__ == 'vllm':
globs.update({f'_vllm_{it}': fn for it, fn in zip(LLMFunction.__abstractmethods__, (vllm_generate, vllm_generate_iterator, vllm_postprocess_generate))})
globs.update({
f'_vllm_{it}': fn for it, fn in zip(LLMFunction.__abstractmethods__, (vllm_generate, vllm_generate_iterator,
vllm_postprocess_generate))
})
lines.extend([_setattr_class(it, f'_vllm_{it}') for it in LLMFunction.__abstractmethods__])

# cached attribute initialisation
Expand All @@ -136,17 +151,24 @@ def dunder_support(key: str) -> str:
return f'__llm_supports_{key}__'

bool_attr = {it for it in LLMFunction.__dict__ if not it.startswith('_')}
lines.extend([_setattr_class(dunder_support(fn), f"cls.{fn} is not _cached_LLMFunction_get('{fn}')") for fn in bool_attr])
lines.extend(
[_setattr_class(dunder_support(fn), f"cls.{fn} is not _cached_LLMFunction_get('{fn}')") for fn in bool_attr])
anns.update({dunder_support(fn): interface_anns.get(dunder_support(fn)) for fn in bool_attr})

return codegen.generate_function(cls, '__assign_llm_attr', lines, args=('cls', *args), globs=globs, annotations=anns)

def vllm_postprocess_generate(self: LLM['vllm.LLMEngine', T], prompt: str, generation_result: list[dict[str, t.Any]], **_: t.Any) -> str:
def vllm_postprocess_generate(self: LLM['vllm.LLMEngine', T], prompt: str, generation_result: list[dict[str, t.Any]],
**_: t.Any) -> str:
return generation_result[0]['outputs'][0]['text']

def vllm_generate_iterator(
self: LLM['vllm.LLMEngine', T], prompt: str, /, *, echo: bool = False, stop: str | t.Iterable[str] | None = None, stop_token_ids: list[int] | None = None, **attrs: t.Any
) -> t.Iterator[dict[str, t.Any]]:
def vllm_generate_iterator(self: LLM['vllm.LLMEngine', T],
prompt: str,
/,
*,
echo: bool = False,
stop: str | t.Iterable[str] | None = None,
stop_token_ids: list[int] | None = None,
**attrs: t.Any) -> t.Iterator[dict[str, t.Any]]:
request_id: str | None = attrs.pop('request_id', None)
if request_id is None: raise ValueError('request_id must not be None.')
if stop_token_ids is None: stop_token_ids = []
Expand Down Expand Up @@ -174,7 +196,9 @@ def vllm_generate(self: LLM['vllm.LLMEngine', T], prompt: str, **attrs: t.Any) -
if request_id is None: raise ValueError('request_id must not be None.')
outputs: list[vllm.RequestOutput] = []
# TODO: support prompt_token_ids
self.model.add_request(request_id=request_id, prompt=prompt, sampling_params=self.config.model_construct_env(**attrs).to_sampling_config())
self.model.add_request(request_id=request_id,
prompt=prompt,
sampling_params=self.config.model_construct_env(**attrs).to_sampling_config())
while self.model.has_unfinished_requests():
outputs.extend([r for r in self.model.step() if r.finished])
return [unmarshal_vllm_outputs(i) for i in outputs]
Loading