forked from mlflow/mlflow
-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathdatabricks_utils.py
159 lines (119 loc) · 5.21 KB
/
databricks_utils.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
import os
import logging
import subprocess
from mlflow.exceptions import MlflowException
from mlflow.utils.rest_utils import MlflowHostCreds
from databricks_cli.configure import provider
_logger = logging.getLogger(__name__)
def _get_dbutils():
try:
import IPython
ip_shell = IPython.get_ipython()
if ip_shell is None:
raise _NoDbutilsError
return ip_shell.ns_table["user_global"]["dbutils"]
except ImportError:
raise _NoDbutilsError
except KeyError:
raise _NoDbutilsError
class _NoDbutilsError(Exception):
pass
def _get_java_dbutils():
dbutils = _get_dbutils()
return dbutils.notebook.entry_point.getDbutils()
def _get_extra_context(context_key):
return _get_java_dbutils().notebook().getContext().extraContext().get(context_key).get()
def _get_context_tag(context_tag_key):
tag_opt = _get_java_dbutils().notebook().getContext().tags().get(context_tag_key)
if tag_opt.isDefined():
return tag_opt.get()
else:
return None
def _get_property_from_spark_context(key):
try:
from pyspark import TaskContext # pylint: disable=import-error
task_context = TaskContext.get()
if task_context:
return task_context.getLocalProperty(key)
except Exception: # pylint: disable=broad-except
return None
def is_in_databricks_notebook():
if _get_property_from_spark_context("spark.databricks.notebook.id") is not None:
return True
try:
return _get_extra_context("aclPathOfAclRoot").startswith('/workspace')
except Exception: # pylint: disable=broad-except
return False
def is_in_databricks_job():
try:
return _get_context_tag("jobId") is not None and _get_context_tag("idInJob") is not None
except Exception: # pylint: disable=broad-except
return False
def is_dbfs_fuse_available():
with open(os.devnull, 'w') as devnull_stderr, open(os.devnull, 'w') as devnull_stdout:
try:
return subprocess.call(
["mountpoint", "/dbfs"], stderr=devnull_stderr, stdout=devnull_stdout) == 0
except Exception: # pylint: disable=broad-except
return False
def get_notebook_id():
"""Should only be called if is_in_databricks_notebook is true"""
notebook_id = _get_property_from_spark_context("spark.databricks.notebook.id")
if notebook_id is not None:
return notebook_id
acl_path = _get_extra_context("aclPathOfAclRoot")
if acl_path.startswith('/workspace'):
return acl_path.split('/')[-1]
return None
def get_notebook_path():
"""Should only be called if is_in_databricks_notebook is true"""
path = _get_property_from_spark_context("spark.databricks.notebook.path")
if path is not None:
return path
return _get_extra_context("notebook_path")
def get_job_id():
"""Should only be called if is_in_databricks_job is true"""
return _get_context_tag("jobId")
def get_job_run_id():
"""Should only be called if is_in_databricks_job is true"""
return _get_context_tag("idInJob")
def get_job_type():
"""Should only be called if is_in_databricks_job is true"""
return _get_context_tag("jobTaskType")
def get_webapp_url():
"""Should only be called if is_in_databricks_notebook is true"""
url = _get_property_from_spark_context("spark.databricks.api.url")
if url is not None:
return url
return _get_extra_context("api_url")
def _fail_malformed_databricks_auth(profile):
raise MlflowException("Got malformed Databricks CLI profile '%s'. Please make sure the "
"Databricks CLI is properly configured as described at "
"https://github.com/databricks/databricks-cli." % profile)
def get_databricks_host_creds(profile=None):
"""
Reads in configuration necessary to make HTTP requests to a Databricks server. This
uses the Databricks CLI's ConfigProvider interface to load the DatabricksConfig object.
This method will throw an exception if sufficient auth cannot be found.
:param profile: Databricks CLI profile. If not provided, we will read the default profile.
:return: :py:class:`mlflow.rest_utils.MlflowHostCreds` which includes the hostname and
authentication information necessary to talk to the Databricks server.
"""
if not hasattr(provider, 'get_config'):
_logger.warning(
"Support for databricks-cli<0.8.0 is deprecated and will be removed"
" in a future version.")
config = provider.get_config_for_profile(profile)
elif profile:
config = provider.ProfileConfigProvider(profile).get_config()
else:
config = provider.get_config()
if not config or not config.host:
_fail_malformed_databricks_auth(profile)
insecure = hasattr(config, 'insecure') and config.insecure
if config.username is not None and config.password is not None:
return MlflowHostCreds(config.host, username=config.username, password=config.password,
ignore_tls_verification=insecure)
elif config.token:
return MlflowHostCreds(config.host, token=config.token, ignore_tls_verification=insecure)
_fail_malformed_databricks_auth(profile)