Skip to content

Commit

Permalink
fix(build): only load model when eager is True
Browse files Browse the repository at this point in the history
Signed-off-by: Aaron <29749331+aarnphm@users.noreply.github.com>
  • Loading branch information
aarnphm committed Nov 20, 2023
1 parent 5b92e84 commit f753662
Show file tree
Hide file tree
Showing 2 changed files with 17 additions and 6 deletions.
16 changes: 10 additions & 6 deletions openllm-python/src/openllm/_llm.py
Original file line number Diff line number Diff line change
Expand Up @@ -163,6 +163,7 @@ def __init__(
embedded=False,
dtype='auto',
low_cpu_mem_usage=True,
_eager=True,
**attrs,
):
# fmt: off
Expand Down Expand Up @@ -201,12 +202,15 @@ def __init__(
llm_trust_remote_code__=trust_remote_code,
)

try:
model = bentoml.models.get(self.tag)
except bentoml.exceptions.NotFound:
model = openllm.serialisation.import_model(self, trust_remote_code=self.trust_remote_code)
# resolve the tag
self._tag = model.tag
if _eager:
try:
model = bentoml.models.get(self.tag)
except bentoml.exceptions.NotFound:
model = openllm.serialisation.import_model(self, trust_remote_code=self.trust_remote_code)
# resolve the tag
self._tag = model.tag
if not _eager and embedded:
raise RuntimeError("Embedded mode is not supported when '_eager' is False.")
if embedded and not get_disable_warnings() and not get_quiet_mode():
logger.warning(
'You are using embedded mode, which means the models will be loaded into memory. This is often not recommended in production and should only be used for local development only.'
Expand Down
7 changes: 7 additions & 0 deletions openllm-python/src/openllm_cli/entrypoint.py
Original file line number Diff line number Diff line change
Expand Up @@ -1047,10 +1047,17 @@ def build_command(
serialisation=first_not_none(
serialisation, default='safetensors' if has_safetensors_weights(model_id, model_version) else 'legacy'
),
_eager=False,
)
if llm.__llm_backend__ not in llm.config['backend']:
raise click.ClickException(f"'{backend}' is not supported with {model_id}")
backend_warning(llm.__llm_backend__, build=True)
try:
model = bentoml.models.get(llm.tag)
except bentoml.exceptions.NotFound:
model = openllm.serialisation.import_model(llm, trust_remote_code=llm.trust_remote_code)
llm._tag = model.tag

os.environ.update(
**process_environ(
llm.config,
Expand Down

0 comments on commit f753662

Please sign in to comment.