forked from mlflow/mlflow
-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathmodel_utils.py
61 lines (53 loc) · 2.59 KB
/
model_utils.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
import os
from mlflow.exceptions import MlflowException
from mlflow.models import Model
from mlflow.protos.databricks_pb2 import RESOURCE_DOES_NOT_EXIST
from mlflow.tracking.artifact_utils import _download_artifact_from_uri
from mlflow.utils.uri import append_to_uri_path
def _get_flavor_configuration(model_path, flavor_name):
"""
Obtains the configuration for the specified flavor from the specified
MLflow model path. If the model does not contain the specified flavor,
an exception will be thrown.
:param model_path: The path to the root directory of the MLflow model for which to load
the specified flavor configuration.
:param flavor_name: The name of the flavor configuration to load.
:return: The flavor configuration as a dictionary.
"""
model_configuration_path = os.path.join(model_path, "MLmodel")
if not os.path.exists(model_configuration_path):
raise MlflowException(
"Could not find an \"MLmodel\" configuration file at \"{model_path}\"".format(
model_path=model_path),
RESOURCE_DOES_NOT_EXIST)
model_conf = Model.load(model_configuration_path)
if flavor_name not in model_conf.flavors:
raise MlflowException(
"Model does not have the \"{flavor_name}\" flavor".format(flavor_name=flavor_name),
RESOURCE_DOES_NOT_EXIST)
conf = model_conf.flavors[flavor_name]
return conf
def _get_flavor_configuration_from_uri(model_uri, flavor_name):
"""
Obtains the configuration for the specified flavor from the specified
MLflow model uri. If the model does not contain the specified flavor,
an exception will be thrown.
:param model_uri: The path to the root directory of the MLflow model for which to load
the specified flavor configuration.
:param flavor_name: The name of the flavor configuration to load.
:return: The flavor configuration as a dictionary.
"""
try:
ml_model_file = _download_artifact_from_uri(
artifact_uri=append_to_uri_path(model_uri, "MLmodel"))
except Exception as ex:
raise MlflowException(
"Failed to download an \"MLmodel\" model file from \"{model_uri}\": {ex}".format(
model_uri=model_uri, ex=ex),
RESOURCE_DOES_NOT_EXIST)
model_conf = Model.load(ml_model_file)
if flavor_name not in model_conf.flavors:
raise MlflowException(
"Model does not have the \"{flavor_name}\" flavor".format(flavor_name=flavor_name),
RESOURCE_DOES_NOT_EXIST)
return model_conf.flavors[flavor_name]