-
Notifications
You must be signed in to change notification settings - Fork 654
/
model.py
145 lines (119 loc) · 4.39 KB
/
model.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
import re
import typing
from typing import Optional
import tabulate
import typer
from openllm.accelerator_spec import DeploymentTarget, can_run
from openllm.analytic import OpenLLMTyper
from openllm.common import VERBOSE_LEVEL, BentoInfo, output
from openllm.repo import ensure_repo_updated, list_repo
app = OpenLLMTyper(help='manage models')
@app.command(help='get model')
def get(tag: str, repo: Optional[str] = None, verbose: bool = False):
if verbose:
VERBOSE_LEVEL.set(20)
bento_info = ensure_bento(tag, repo_name=repo)
if bento_info:
output(bento_info)
@app.command(name='list', help='list available models')
def list_model(tag: Optional[str] = None, repo: Optional[str] = None, verbose: bool = False):
if verbose:
VERBOSE_LEVEL.set(20)
bentos = list_bento(tag=tag, repo_name=repo)
bentos.sort(key=lambda x: x.name)
seen = set()
def is_seen(value):
if value in seen:
return True
seen.add(value)
return False
table = tabulate.tabulate(
[
[
'' if is_seen(bento.name) else bento.name,
bento.tag,
bento.repo.name,
bento.pretty_gpu,
','.join(bento.platforms),
]
for bento in bentos
],
headers=['model', 'version', 'repo', 'required GPU RAM', 'platforms'],
)
output(table)
def ensure_bento(model: str, target: Optional[DeploymentTarget] = None, repo_name: Optional[str] = None) -> BentoInfo:
bentos = list_bento(model, repo_name=repo_name)
if len(bentos) == 0:
output(f'No model found for {model}', style='red')
raise typer.Exit(1)
if len(bentos) == 1:
output(f'Found model {bentos[0]}', style='green')
if target is not None and can_run(bentos[0], target) <= 0:
output(
f'The machine({target.name}) with {target.accelerators_repr} does not appear to have sufficient '
f'resources to run model {bentos[0]}\n',
style='yellow',
)
return bentos[0]
# multiple models, pick one according to target
output(f'Multiple models match {model}, did you mean one of these?', style='red')
list_model(model, repo=repo_name)
raise typer.Exit(1)
NUMBER_RE = re.compile(r'\d+')
def _extract_first_number(s: str):
match = NUMBER_RE.search(s)
if match:
return int(match.group())
else:
return 100
def list_bento(
tag: typing.Optional[str] = None, repo_name: typing.Optional[str] = None, include_alias: bool = False
) -> typing.List[BentoInfo]:
ensure_repo_updated()
if repo_name is None and tag and '/' in tag:
repo_name, tag = tag.split('/', 1)
repo_list = list_repo(repo_name)
if repo_name is not None:
repo_map = {repo.name: repo for repo in repo_list}
if repo_name not in repo_map:
output(f'Repo `{repo_name}` not found, did you mean one of these?')
for repo_name in repo_map:
output(f' {repo_name}')
raise typer.Exit(1)
if not tag:
glob_pattern = 'bentoml/bentos/*/*'
elif ':' in tag:
bento_name, version = tag.split(':')
glob_pattern = f'bentoml/bentos/{bento_name}/{version}'
else:
glob_pattern = f'bentoml/bentos/{tag}/*'
model_list = []
repo_list = list_repo(repo_name)
for repo in repo_list:
paths = sorted(
repo.path.glob(glob_pattern),
key=lambda x: (x.parent.name, _extract_first_number(x.name), len(x.name), x.name),
)
for path in paths:
if path.is_dir() and (path / 'bento.yaml').exists():
model = BentoInfo(repo=repo, path=path)
elif path.is_file():
with open(path) as f:
origin_name = f.read().strip()
origin_path = path.parent / origin_name
model = BentoInfo(alias=path.name, repo=repo, path=origin_path)
else:
model = None
if model:
model_list.append(model)
if not include_alias:
seen = set()
model_list = [
x
for x in model_list
if not (
f'{x.bento_yaml["name"]}:{x.bento_yaml["version"]}' in seen
or seen.add(f'{x.bento_yaml["name"]}:{x.bento_yaml["version"]}')
)
]
return model_list