Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

refactor(breaking): unify LLM API #283

Merged
merged 9 commits into from
Sep 1, 2023
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Next Next commit
refactor: initial work to _gen
Signed-off-by: aarnphm-ec2-dev <29749331+aarnphm@users.noreply.github.com>
Signed-off-by: Aaron <29749331+aarnphm@users.noreply.github.com>
  • Loading branch information
aarnphm committed Aug 30, 2023
commit 03c402ef839a5fa795e49f1e86ec966ceed65689
2 changes: 1 addition & 1 deletion CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -230,7 +230,7 @@ No significant changes.

```bash
docker run --rm --gpus all -it -v /home/ubuntu/.local/share/bentoml:/tmp/bentoml -e BENTOML_HOME=/tmp/bentoml \
-e OPENLLM_USE_LOCAL_LATEST=True -e OPENLLM_LLAMA_FRAMEWORK=vllm ghcr.io/bentoml/openllm:2b5e96f90ad314f54e07b5b31e386e7d688d9bb2 start llama --model-id meta-llama/Llama-2-7b-chat-hf --workers-per-resource conserved --debug`
-e OPENLLM_USE_LOCAL_LATEST=True -e OPENLLM_BACKEND=vllm ghcr.io/bentoml/openllm:2b5e96f90ad314f54e07b5b31e386e7d688d9bb2 start llama --model-id meta-llama/Llama-2-7b-chat-hf --workers-per-resource conserved --debug`
```

In conjunction with this, OpenLLM now also have a set of small CLI utilities via ``openllm ext`` for ease-of-use
Expand Down
Empty file.
2 changes: 1 addition & 1 deletion openllm-core/src/openllm_core/_schema.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
'''Schema definition for OpenLLM. This can be use for client interaction.'''
'''Schema definition for OpenLLM. This schema is used throughout openllm core components library.'''
from __future__ import annotations
import functools
import typing as t
Expand Down
30 changes: 23 additions & 7 deletions openllm-core/src/openllm_core/_typing_compat.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,8 @@
from bentoml._internal.runner.runnable import RunnableMethod
from bentoml._internal.runner.runner import RunnerMethod
from bentoml._internal.runner.strategy import Strategy
from openllm._llm import LLM
from openllm_core._schema import EmbeddingsOutput

from .utils.lazy import VersionInfo

Expand Down Expand Up @@ -73,10 +75,6 @@ class PeftAdapterOutput(t.TypedDict):
result: t.Dict[str, peft.PeftConfig]
error_msg: str

class LLMEmbeddings(t.TypedDict):
embeddings: t.List[t.List[float]]
num_tokens: int

class AdaptersTuple(TupleAny):
adapter_id: str
name: t.Optional[str]
Expand All @@ -94,7 +92,7 @@ class LLMRunnable(bentoml.Runnable, t.Generic[M, T]):
SUPPORTS_CPU_MULTI_THREADING = True
__call__: RunnableMethod[LLMRunnable[M, T], [str], list[t.Any]]
set_adapter: RunnableMethod[LLMRunnable[M, T], [str], dict[t.Literal['success', 'error_msg'], bool | str]]
embeddings: RunnableMethod[LLMRunnable[M, T], [list[str]], LLMEmbeddings]
embeddings: RunnableMethod[LLMRunnable[M, T], [list[str]], EmbeddingsOutput]
generate: RunnableMethod[LLMRunnable[M, T], [str], list[t.Any]]
generate_one: RunnableMethod[LLMRunnable[M, T], [str, list[str]], t.Sequence[dict[t.Literal['generated_text'], str]]]
generate_iterator: RunnableMethod[LLMRunnable[M, T], [str], t.Generator[str, None, str]]
Expand All @@ -112,7 +110,7 @@ class LLMRunner(bentoml.Runner, t.Generic[M, T]):
supports_embeddings: bool
supports_hf_agent: bool
has_adapters: bool
embeddings: RunnerMethod[LLMRunnable[M, T], [list[str]], t.Sequence[LLMEmbeddings]]
embeddings: RunnerMethod[LLMRunnable[M, T], [list[str]], t.Sequence[EmbeddingsOutput]]
generate: RunnerMethod[LLMRunnable[M, T], [str], list[t.Any]]
generate_one: RunnerMethod[LLMRunnable[M, T], [str, list[str]], t.Sequence[dict[t.Literal['generated_text'], str]]]
generate_iterator: RunnerMethod[LLMRunnable[M, T], [str], t.Generator[str, None, str]]
Expand All @@ -136,7 +134,7 @@ def __call__(self, prompt: str, **attrs: t.Any) -> t.Any:
...

@abc.abstractmethod
def embed(self, prompt: str | list[str]) -> LLMEmbeddings:
def embed(self, prompt: str | list[str]) -> EmbeddingsOutput:
...

def run(self, prompt: str, **attrs: t.Any) -> t.Any:
Expand All @@ -158,3 +156,21 @@ def peft_adapters(self) -> PeftAdapterOutput:
@abc.abstractmethod
def __repr_keys__(self) -> set[str]:
...

class load_model_protocol(t.Generic[M, T], t.Protocol):
def __call__(self, llm: LLM[M, T], *decls: t.Any, **attrs: t.Any) -> M:
...

class load_tokenizer_protocol(t.Generic[M, T], t.Protocol):
def __call__(self, llm: LLM[M, T], **attrs: t.Any) -> T:
...

_R = t.TypeVar('_R', covariant=True)

class import_model_protocol(t.Generic[_R, M, T], t.Protocol):
def __call__(self, llm: LLM[M, T], *decls: t.Any, trust_remote_code: bool, **attrs: t.Any) -> _R:
...

class llm_post_init_protocol(t.Generic[M, T], t.Protocol):
def __call__(self, llm: LLM[M, T]) -> T:
...
Empty file.
6 changes: 3 additions & 3 deletions openllm-core/src/openllm_core/utils/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,6 @@
from circus.exc import ConflictError

import openllm_core

from bentoml._internal.configuration import DEBUG_ENV_VAR as DEBUG_ENV_VAR
from bentoml._internal.configuration import GRPC_DEBUG_ENV_VAR as _GRPC_DEBUG_ENV_VAR
from bentoml._internal.configuration import QUIET_ENV_VAR as QUIET_ENV_VAR
Expand Down Expand Up @@ -105,10 +104,11 @@ def field_env_key(model_name: str, key: str, suffix: str | None = None) -> str:
return '_'.join(filter(None, map(str.upper, ['OPENLLM', model_name, suffix.strip('_') if suffix else '', key])))

# Special debug flag controled via OPENLLMDEVDEBUG
DEBUG: bool = sys.flags.dev_mode or (not sys.flags.ignore_environment and bool(os.environ.get(DEV_DEBUG_VAR)))
DEBUG: bool = sys.flags.dev_mode or (not sys.flags.ignore_environment) or (str(os.environ.get(DEV_DEBUG_VAR, None)).upper() in ENV_VARS_TRUE_VALUES)
# Whether to show the codenge for debug purposes
SHOW_CODEGEN: bool = DEBUG and int(os.environ.get('OPENLLMDEVDEBUG', str(0))) > 3
# MYPY is like t.TYPE_CHECKING, but reserved for Mypy plugins
MYPY = False
SHOW_CODEGEN: bool = DEBUG and int(os.environ.get('OPENLLMDEVDEBUG', str(0))) > 3

def get_debug_mode() -> bool:
return DEBUG or _get_debug_mode()
Expand Down
4 changes: 2 additions & 2 deletions openllm-core/src/openllm_core/utils/codegen.py
Original file line number Diff line number Diff line change
Expand Up @@ -96,7 +96,7 @@ class MyClassAttributes(tuple):
else:
attr_class_template.append(' pass')
globs: DictStrAny = {'_attrs_itemgetter': itemgetter, '_attrs_property': property}
if SHOW_CODEGEN: logger.info('Generated class for %s:\n\n%s', attr_class_name, '\n'.join(attr_class_template))
if SHOW_CODEGEN: print(f'Generated class for {attr_class_name}:\n\n', '\n'.join(attr_class_template))
_compile_and_eval('\n'.join(attr_class_template), globs)
return globs[attr_class_name]

Expand All @@ -110,7 +110,7 @@ def generate_function(
script = 'def %s(%s):\n %s\n' % (func_name, ', '.join(args) if args is not None else '', '\n '.join(lines) if lines else 'pass')
meth = _make_method(func_name, script, generate_unique_filename(typ, func_name), globs)
if annotations: meth.__annotations__ = annotations
if SHOW_CODEGEN: logger.info('Generated script for %s:\n\n%s', typ, script)
if SHOW_CODEGEN: print('Generated script for {typ}:\n\n', script)
return meth

def make_env_transformer(
Expand Down
6 changes: 3 additions & 3 deletions openllm-python/src/openllm/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,7 +44,7 @@
"cli._sdk": ["start", "start_grpc", "build", "import_model", "list_models"],
"_quantisation": ["infer_quantisation_config"],
"_embeddings": ["GenericEmbeddingRunnable"],
"_llm": ["LLM", "Runner", "LLMRunner", "LLMRunnable", "LLMEmbeddings"],
"_llm": ["LLM", "Runner", "LLMRunner", "LLMRunnable", "EmbeddingsOutput"],
"_generation": ["StopSequenceCriteria", "StopOnTokens", "LogitsProcessorList", "StoppingCriteriaList", "prepare_logits_processor"],
"models.auto": ["MODEL_MAPPING_NAMES", "MODEL_FLAX_MAPPING_NAMES", "MODEL_TF_MAPPING_NAMES", "MODEL_VLLM_MAPPING_NAMES"],
"models.chatglm": [],
Expand All @@ -64,7 +64,7 @@
if _t.TYPE_CHECKING:
from . import bundle as bundle, cli as cli, client as client, models as models, playground as playground, serialisation as serialisation, testing as testing
from ._generation import LogitsProcessorList as LogitsProcessorList, StopOnTokens as StopOnTokens, StoppingCriteriaList as StoppingCriteriaList, StopSequenceCriteria as StopSequenceCriteria, prepare_logits_processor as prepare_logits_processor
from ._llm import LLM as LLM, LLMEmbeddings as LLMEmbeddings, LLMRunnable as LLMRunnable, LLMRunner as LLMRunner, Runner as Runner
from ._llm import LLM as LLM, EmbeddingsOutput as EmbeddingsOutput, LLMRunnable as LLMRunnable, LLMRunner as LLMRunner, Runner as Runner
from ._quantisation import infer_quantisation_config as infer_quantisation_config
from ._embeddings import GenericEmbeddingRunnable as GenericEmbeddingRunnable
from .cli._sdk import build as build, import_model as import_model, list_models as list_models, start as start, start_grpc as start_grpc
Expand Down Expand Up @@ -175,7 +175,7 @@
from .models.opt import TFOPT as TFOPT

# NOTE: update this to sys.modules[__name__] once mypy_extensions can recognize __spec__
__lazy = openllm_core.utils.LazyModule(__name__, globals()["__file__"], _import_structure, extra_objects={"COMPILED": COMPILED})
__lazy = openllm_core.utils.LazyModule(__name__, globals()["__file__"], _import_structure, extra_objects={"COMPILED": COMPILED, "__openllm_migration__": {"LLMEmbeddings": "EmbeddingsOutput"}})
__all__ = __lazy.__all__
__dir__ = __lazy.__dir__
__getattr__ = __lazy.__getattr__
4 changes: 2 additions & 2 deletions openllm-python/src/openllm/_embeddings.py
Original file line number Diff line number Diff line change
Expand Up @@ -56,7 +56,7 @@ def __init__(self) -> None:
self.model.to(self.device)

@bentoml.Runnable.method(batchable=True, batch_dim=0)
def encode(self, sentences: list[str]) -> t.Sequence[openllm.LLMEmbeddings]:
def encode(self, sentences: list[str]) -> t.Sequence[openllm.EmbeddingsOutput]:
import torch
import torch.nn.functional as F
encoded_input = self.tokenizer(sentences, padding=True, truncation=True, return_tensors='pt').to(self.device)
Expand All @@ -66,7 +66,7 @@ def encode(self, sentences: list[str]) -> t.Sequence[openllm.LLMEmbeddings]:
model_output = self.model(**encoded_input)
# Perform pooling and normalize
sentence_embeddings = F.normalize(self.mean_pooling(model_output, attention_mask), p=2, dim=1)
return [openllm.LLMEmbeddings(embeddings=sentence_embeddings.cpu().numpy(), num_tokens=int(torch.sum(attention_mask).item()))]
return [openllm.EmbeddingsOutput(embeddings=sentence_embeddings.cpu().numpy(), num_tokens=int(torch.sum(attention_mask).item()))]

@staticmethod
def mean_pooling(model_output: torch.Tensor, attention_mask: torch.Tensor) -> torch.Tensor:
Expand Down
180 changes: 180 additions & 0 deletions openllm-python/src/openllm/_gen.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,180 @@
'''codegen magik for some openllm core representation of LLM.

This has to do with Python doesn't have correct overload and overriding.

Some of the current solution is in terms of subclass implementation, users often have to
prepend an underscore to some of this function. Imho this is not nice, and subclass should
just be able to implement the correct function the internal code.

To achieve this, since Python allows you to pretty much do anything, this module will allow
to generate dynamic class and override generated class with correct setattr attribute for `openllm.LLM`.'''
from __future__ import annotations
import functools
import traceback
import typing as t

import openllm
from openllm_core._configuration import _object_getattribute
from openllm_core._configuration import _setattr_class
from openllm_core._schema import unmarshal_vllm_outputs
from openllm_core._typing_compat import DictStrAny
from openllm_core._typing_compat import ListStr
from openllm_core._typing_compat import M
from openllm_core._typing_compat import T
from openllm_core._typing_compat import import_model_protocol
from openllm_core._typing_compat import llm_post_init_protocol
from openllm_core._typing_compat import load_model_protocol
from openllm_core._typing_compat import load_tokenizer_protocol
from openllm_core.utils import LazyLoader
from openllm_core.utils import codegen
from openllm_core.utils import device_count
from openllm_core.utils import first_not_none
from openllm_core.utils import is_torch_available

from .exceptions import OpenLLMException
if t.TYPE_CHECKING:
import torch
import transformers
import vllm

import bentoml
from openllm._llm import LLM
else:
transformers = LazyLoader('transformers', globals(), 'transformers')
torch = LazyLoader('torch', globals(), 'torch')
vllm = LazyLoader('vllm', globals(), 'vllm')

def import_model(fn: import_model_protocol[bentoml.Model, M, T]) -> t.Callable[[LLM[M, T]], bentoml.Model]:
@functools.wraps(fn)
def inner(self: LLM[M, T], *decls: t.Any, trust_remote_code: bool | None = None, **attrs: t.Any) -> bentoml.Model:
trust_remote_code = first_not_none(trust_remote_code, default=self.__llm_trust_remote_code__)
(model_decls, model_attrs), _ = self.llm_parameters
decls = (*model_decls, *decls)
attrs = {**model_attrs, **attrs}
return fn(self, *decls, trust_remote_code=trust_remote_code, **attrs)

return inner

def load_model(fn: load_model_protocol[M, T]) -> t.Callable[[LLM[M, T]], M | vllm.LLMEngine]:
@functools.wraps(fn)
def inner(self: LLM[M, T], *decls: t.Any, **attrs: t.Any) -> M | vllm.LLMEngine:
if self.__llm_implementation__ == 'vllm':
# TODO: Do some more processing with token_id once we support token streaming
try:
return vllm.LLMEngine.from_engine_args(
vllm.EngineArgs(
model=self._bentomodel.path,
tokenizer=self._bentomodel.path if self.tokenizer_id == 'local' else self.tokenizer_id,
tokenizer_mode='auto',
tensor_parallel_size=1 if device_count() < 2 else device_count(),
dtype='auto',
worker_use_ray=False
)
)
except Exception as err:
traceback.print_exc()
raise OpenLLMException(f'Failed to initialise vLLMEngine due to the following error:\n{err}') from None
else:
(model_decls, model_attrs), _ = self.llm_parameters
return fn(self, *(*model_decls, *decls), **{**model_attrs, **attrs})

return inner

def load_tokenizer(fn: load_tokenizer_protocol[M, T]) -> t.Callable[[LLM[M, T]], T]:
@functools.wraps(fn)
def inner(self: LLM[M, T], **tokenizer_attrs: t.Any) -> T:
return fn(self, **{**self.llm_parameters[-1], **tokenizer_attrs})

return inner

def llm_post_init(fn: llm_post_init_protocol[M, T]) -> t.Callable[[LLM[M, T]], None]:
@functools.wraps(fn)
def inner(self: LLM[M, T]) -> None:
if self.__llm_implementation__ == 'pt' and is_torch_available(): self.device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
fn(self)

return inner

def make_assignment_script(cls: type[LLM[M, T]]) -> t.Callable[[type[LLM[M, T]]], None]:
from ._llm import LLMFunction
from ._llm import LLMInterface
from ._llm import LLMSerialisation

args: ListStr = []
anns: DictStrAny = {}
globs: DictStrAny = {'cls': cls, '__wrapped_llm_post_init': llm_post_init}
# _cached_LLMFunction_get and _ccached_LLMSerialisation_get
globs.update({f'_cached_{cl_.__name__}_get': _object_getattribute.__get__(cl_) for cl_ in {LLMSerialisation, LLMFunction}})
# llm_post_init implementation
lines: ListStr = [f'_impl_{cls.__name__}_func=cls.llm_post_init', _setattr_class('llm_post_init', f'__wrapped_llm_post_init(_impl_{cls.__name__}_func)')]

serialisation_attr = {'import_model': import_model, 'load_model': load_model, 'load_tokenizer': load_tokenizer,}
for func, impl in serialisation_attr.items():
impl_name = f'__wrapped_{func}'
globs.update({f'__serialisation_{func}': getattr(openllm.serialisation, func, None), impl_name: impl})
cached_func_name = f'_cached_{cls.__name__}_func'
func_call = f"_impl_{cls.__name__}_{func}={cached_func_name} if {cached_func_name} is not _cached_LLMSerialisation_get('{func}') else __serialisation_{func}"
lines.extend([f'{cached_func_name}=cls.{func}', func_call, _setattr_class(func, f'{impl_name}(_impl_{cls.__name__}_{func})')])

# assign vLLM implementation
if cls.__llm_implementation__ == 'vllm':
globs.update({f'_vllm_{it}': fn for it, fn in zip(LLMFunction.__abstractmethods__, (vllm_generate, vllm_generate_iterator, vllm_postprocess_generate))})
lines.extend([_setattr_class(it, f'_vllm_{it}') for it in LLMFunction.__abstractmethods__])

# cached attribute initialisation
st_attr = {'bentomodel', 'model', 'tokenizer', 'adapter_map'}
interface_anns = codegen.get_annotations(LLMInterface)

def dunder_cached(key: str) -> str:
return f'__llm_{key}__'

lines.extend([_setattr_class(dunder_cached(v), None) for v in st_attr])
anns.update({dunder_cached(v): interface_anns.get(dunder_cached(v)) for v in st_attr})

# boolean for better LLM implementation resolver
def dunder_support(key: str) -> str:
return f'__llm_supports_{key}__'

bool_attr = {it for it in LLMFunction.__dict__ if not it.startswith('_')}
lines.extend([_setattr_class(dunder_support(fn), f"cls.{fn} is not _cached_LLMFunction_get('{fn}')") for fn in bool_attr])
anns.update({dunder_support(fn): interface_anns.get(dunder_support(fn)) for fn in bool_attr})

return codegen.generate_function(cls, '__assign_llm_attr', lines, args=('cls', *args), globs=globs, annotations=anns)

def vllm_postprocess_generate(self: LLM['vllm.LLMEngine', T], prompt: str, generation_result: list[dict[str, t.Any]], **_: t.Any) -> str:
return generation_result[0]['outputs'][0]['text']

def vllm_generate_iterator(
self: LLM['vllm.LLMEngine', T], prompt: str, /, *, echo: bool = False, stop: str | t.Iterable[str] | None = None, stop_token_ids: list[int] | None = None, **attrs: t.Any
) -> t.Iterator[dict[str, t.Any]]:
request_id: str | None = attrs.pop('request_id', None)
if request_id is None: raise ValueError('request_id must not be None.')
if stop_token_ids is None: stop_token_ids = []
stop_token_ids.append(self.tokenizer.eos_token_id)
stop_: set[str] = set()
if isinstance(stop, str) and stop != '': stop_.add(stop)
elif isinstance(stop, list) and stop != []: stop_.update(stop)
for tid in stop_token_ids:
if tid: stop_.add(self.tokenizer.decode(tid))

if self.config['temperature'] <= 1e-5: top_p = 1.0
else: top_p = self.config['top_p']
config = self.config.model_construct_env(stop=list(stop_), top_p=top_p, **attrs)
self.model.add_request(request_id=request_id, prompt=prompt, sampling_params=config.to_sampling_config())
while self.model.has_unfinished_requests():
for request_output in self.model.step():
prompt = request_output.prompt
if echo: text_outputs = [prompt + output.text for output in request_output.outputs]
else: text_outputs = [output.text for output in request_output.outputs]
yield {'text': text_outputs, 'error_code': 0}
if request_output.finished: break

def vllm_generate(self: LLM['vllm.LLMEngine', T], prompt: str, **attrs: t.Any) -> list[dict[str, t.Any]]:
request_id: str | None = attrs.pop('request_id', None)
if request_id is None: raise ValueError('request_id must not be None.')
outputs: list[vllm.RequestOutput] = []
# TODO: support prompt_token_ids
self.model.add_request(request_id=request_id, prompt=prompt, sampling_params=self.config.model_construct_env(**attrs).to_sampling_config())
while self.model.has_unfinished_requests():
outputs.extend([r for r in self.model.step() if r.finished])
return [unmarshal_vllm_outputs(i) for i in outputs]
Loading