Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Enable UploadDirectory plugin to upload to scheduler #8986

Merged
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
93 changes: 72 additions & 21 deletions distributed/diagnostics/plugin.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@
from typing import TYPE_CHECKING, Any, Callable, ClassVar

from dask.typing import Key
from dask.utils import funcname, tmpfile
from dask.utils import _deprecated_kwarg, funcname, tmpfile

from distributed.protocol.pickle import dumps

Expand Down Expand Up @@ -896,36 +896,46 @@ async def setup(self, nanny):
nanny.env.update(self.environ)


class UploadDirectory(NannyPlugin):
"""A NannyPlugin to upload a local file to workers.
UPLOAD_DIRECTORY_MODES = ["all", "scheduler", "workers"]


class UploadDirectory(SchedulerPlugin):
"""Scheduler to upload a local directory to the cluster.

Parameters
----------
path: str
A path to the directory to upload
path:
Path to the directory to upload
scheduler:
Whether to upload the directory to the scheduler

Examples
--------
>>> from distributed.diagnostics.plugin import UploadDirectory
>>> client.register_plugin(UploadDirectory("/path/to/directory"), nanny=True) # doctest: +SKIP
>>> client.register_plugin(UploadDirectory("/path/to/directory")) # doctest: +SKIP
"""

@_deprecated_kwarg("restart", "restart_workers")
def __init__(
self,
path,
restart=False,
restart_workers=False,
update_path=False,
skip_words=(".git", ".github", ".pytest_cache", "tests", "docs"),
skip=(lambda fn: os.path.splitext(fn)[1] == ".pyc",),
mode="workers",
):
"""
Initialize the plugin by reading in the data from the given file.
"""
path = os.path.expanduser(path)
self.path = os.path.split(path)[-1]
self.restart = restart
self.restart_workers = restart_workers
self.update_path = update_path

if mode not in UPLOAD_DIRECTORY_MODES:
raise ValueError(
f"{mode=} not supported, expected one of {UPLOAD_DIRECTORY_MODES}"
)
self.mode = mode

self.name = "upload-directory-" + os.path.split(path)[-1]

with tmpfile(extension="zip") as fn:
Expand All @@ -944,26 +954,67 @@ def __init__(
)
z.write(filename, archive_name)

with open(fn, "rb") as f:
with open(fn, mode="rb") as f:
self.data = f.read()

async def setup(self, nanny):
fn = os.path.join(nanny.local_directory, f"tmp-{uuid.uuid4()}.zip")
with open(fn, "wb") as f:
f.write(self.data)
async def start(self, scheduler):
from distributed.core import clean_exception
from distributed.protocol.serialize import Serialized, deserialize

if self.mode in ("all", "scheduler"):
_extract_data(
scheduler.local_directory, self.path, self.data, self.update_graph
)

if self.mode in ("all", "workers"):
nanny_plugin = _UploadDirectoryNannyPlugin(
self.path, self.data, self.restart_workers, self.update_path, self.name
)
responses = await scheduler.register_nanny_plugin(
comm=None,
plugin=dumps(nanny_plugin),
name=self.name,
idempotent=False,
)

for response in responses.values():
if response["status"] == "error":
response = {
k: deserialize(v.header, v.frames)
for k, v in response.items()
if isinstance(v, Serialized)
}
_, exc, tb = clean_exception(**response)
raise exc.with_traceback(tb)


class _UploadDirectoryNannyPlugin(NannyPlugin):
def __init__(self, path, data, restart, update_path, name):
self.path = path
self.data = data
self.name = name
self.restart = restart
self.update_path = update_path

def setup(self, nanny):
_extract_data(nanny.local_directory, self.path, self.data, self.update_path)


def _extract_data(base_path, path, data, update_path):
with tmpfile(extension="zip") as fn:
with open(fn, mode="wb") as f:
f.write(data)

import zipfile

with zipfile.ZipFile(fn) as z:
z.extractall(path=nanny.local_directory)
z.extractall(path=base_path)

if self.update_path:
path = os.path.join(nanny.local_directory, self.path)
if update_path:
path = os.path.join(base_path, path)
if path not in sys.path:
sys.path.insert(0, path)

os.remove(fn)


class forward_stream:
def __init__(self, stream, worker):
Expand Down
47 changes: 44 additions & 3 deletions distributed/tests/test_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -79,7 +79,8 @@
from distributed.comm import CommClosedError
from distributed.compatibility import LINUX, MACOS, WINDOWS
from distributed.core import Status
from distributed.diagnostics.plugin import WorkerPlugin
from distributed.deploy.subprocess import SubprocessCluster
from distributed.diagnostics.plugin import UploadDirectory, WorkerPlugin
from distributed.metrics import time
from distributed.scheduler import CollectTaskMetaDataPlugin, KilledWorker, Scheduler
from distributed.shuffle import check_minimal_arrow_version
Expand Down Expand Up @@ -7363,7 +7364,6 @@ async def test_computation_object_code_client_compute(c, s, a, b):
assert comp.code[0][-1].code == test_function_code


@pytest.mark.slow
@gen_cluster(client=True, Worker=Nanny)
async def test_upload_directory(c, s, a, b, tmp_path):
from dask.distributed import UploadDirectory
Expand All @@ -7376,7 +7376,7 @@ async def test_upload_directory(c, s, a, b, tmp_path):
with open(tmp_path / "bar.py", "w") as f:
f.write("from foo import x")

plugin = UploadDirectory(tmp_path, restart=True, update_path=True)
plugin = UploadDirectory(tmp_path, restart_workers=True, update_path=True)
await c.register_plugin(plugin)

[name] = a.plugins
Expand All @@ -7399,6 +7399,47 @@ def f():
assert files_start == files_end # no change


def test_upload_directory_invalid_mode():
with pytest.raises(ValueError, match="mode"):
UploadDirectory(".", mode="invalid")


@pytest.mark.skipif(WINDOWS, reason="distributed#7434")
@pytest.mark.parametrize("mode", ["all", "scheduler"])
@gen_test()
async def test_upload_directory_to_scheduler(mode, tmp_path):
from dask.distributed import UploadDirectory

# Be sure to exclude code coverage reports
files_start = {f for f in os.listdir() if not f.startswith(".coverage")}

with open(tmp_path / "foo.py", "w") as f:
f.write("x = 123")
with open(tmp_path / "bar.py", "w") as f:
f.write("from foo import x")

def f():
import bar

return bar.x

async with SubprocessCluster(
asynchronous=True,
dashboard_address=":0",
scheduler_kwargs={"idle_timeout": "5s"},
worker_kwargs={"death_timeout": "5s"},
) as cluster:
async with Client(cluster, asynchronous=True) as client:
with pytest.raises(ModuleNotFoundError, match="'bar'"):
res = await client.run_on_scheduler(f)

plugin = UploadDirectory(
tmp_path, mode=mode, restart_workers=True, update_path=True
)
await client.register_plugin(plugin)
assert await client.run_on_scheduler(f) == 123


@gen_cluster(client=True, nthreads=[("", 1)])
async def test_duck_typed_register_plugin_raises(c, s, a):
class DuckPlugin:
Expand Down
Loading