Skip to content

Commit

Permalink
Merge branch 'main' into upload-directory-uploads-to-scheduler
Browse files Browse the repository at this point in the history
  • Loading branch information
hendrikmakait authored Jan 13, 2025
2 parents 6d05589 + bcdbabe commit a990d13
Show file tree
Hide file tree
Showing 11 changed files with 32 additions and 120 deletions.
5 changes: 1 addition & 4 deletions distributed/client.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,7 +45,6 @@
from dask.core import flatten, validate_key
from dask.highlevelgraph import HighLevelGraph
from dask.layers import Layer
from dask.optimization import SubgraphCallable
from dask.tokenize import tokenize
from dask.typing import Key, NestedKeys, NoDefault, no_default
from dask.utils import (
Expand Down Expand Up @@ -1147,7 +1146,7 @@ def __init__(
if security is None and isinstance(address, str):
security = _maybe_call_security_loader(address)

if security is None:
if security is None or security is False:
security = Security()
elif isinstance(security, dict):
security = Security(**security)
Expand Down Expand Up @@ -6120,8 +6119,6 @@ def futures_of(o, client=None):
stack.extend(x)
elif type(x) is dict:
stack.extend(x.values())
elif type(x) is SubgraphCallable:
stack.extend(x.dsk.values())
elif isinstance(x, TaskRef):
if x not in seen:
seen.add(x)
Expand Down
19 changes: 10 additions & 9 deletions distributed/comm/tcp.py
Original file line number Diff line number Diff line change
Expand Up @@ -129,17 +129,18 @@ def get_stream_address(comm):

def convert_stream_closed_error(obj, exc):
"""
Re-raise StreamClosedError as CommClosedError.
Re-raise StreamClosedError or SSLError as CommClosedError.
"""
if exc.real_error is not None:
if hasattr(exc, "real_error"):
# The stream was closed because of an underlying OS error
if exc.real_error is None:
raise CommClosedError(f"in {obj}: {exc}") from exc
exc = exc.real_error
if isinstance(exc, ssl.SSLError):
if exc.reason and "UNKNOWN_CA" in exc.reason:
raise FatalCommClosedError(f"in {obj}: {exc.__class__.__name__}: {exc}")
raise CommClosedError(f"in {obj}: {exc.__class__.__name__}: {exc}") from exc
else:
raise CommClosedError(f"in {obj}: {exc}") from exc

if isinstance(exc, ssl.SSLError):
if exc.reason and "UNKNOWN_CA" in exc.reason:
raise FatalCommClosedError(f"in {obj}: {exc.__class__.__name__}: {exc}")
raise CommClosedError(f"in {obj}: {exc.__class__.__name__}: {exc}") from exc


def _close_comm(ref):
Expand Down Expand Up @@ -230,7 +231,7 @@ async def read(self, deserializers=None):
buffer = await read_bytes_rw(stream, buffer_nbytes)
frames.append(buffer)

except StreamClosedError as e:
except (StreamClosedError, SSLError) as e:
self.stream = None
self._closed = True
convert_stream_closed_error(self, e)
Expand Down
9 changes: 5 additions & 4 deletions distributed/diagnostics/nvml.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,8 @@ class NVMLState(IntEnum):


class CudaDeviceInfo(NamedTuple):
uuid: bytes | None = None
# Older versions of pynvml returned bytes, newer versions return str.
uuid: str | bytes | None = None
device_index: int | None = None
mig_index: int | None = None

Expand Down Expand Up @@ -278,13 +279,13 @@ def get_device_index_and_uuid(device):
Examples
--------
>>> get_device_index_and_uuid(0) # doctest: +SKIP
{'device-index': 0, 'uuid': b'GPU-e1006a74-5836-264f-5c26-53d19d212dfe'}
{'device-index': 0, 'uuid': 'GPU-e1006a74-5836-264f-5c26-53d19d212dfe'}
>>> get_device_index_and_uuid('GPU-e1006a74-5836-264f-5c26-53d19d212dfe') # doctest: +SKIP
{'device-index': 0, 'uuid': b'GPU-e1006a74-5836-264f-5c26-53d19d212dfe'}
{'device-index': 0, 'uuid': 'GPU-e1006a74-5836-264f-5c26-53d19d212dfe'}
>>> get_device_index_and_uuid('MIG-7feb6df5-eccf-5faa-ab00-9a441867e237') # doctest: +SKIP
{'device-index': 0, 'uuid': b'MIG-7feb6df5-eccf-5faa-ab00-9a441867e237'}
{'device-index': 0, 'uuid': 'MIG-7feb6df5-eccf-5faa-ab00-9a441867e237'}
"""
init_once()
try:
Expand Down
7 changes: 4 additions & 3 deletions distributed/diagnostics/tests/test_nvml.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@
pynvml = pytest.importorskip("pynvml")

import dask
from dask.utils import ensure_unicode

from distributed.diagnostics import nvml
from distributed.utils_test import gen_cluster
Expand Down Expand Up @@ -66,7 +67,7 @@ def run_has_cuda_context(queue):
assert (
ctx.has_context
and ctx.device_info.device_index == 0
and isinstance(ctx.device_info.uuid, bytes)
and isinstance(ctx.device_info.uuid, str)
)

queue.put(None)
Expand Down Expand Up @@ -127,7 +128,7 @@ def test_visible_devices_uuid():
assert info.uuid

with mock.patch.dict(
os.environ, {"CUDA_VISIBLE_DEVICES": info.uuid.decode("utf-8")}
os.environ, {"CUDA_VISIBLE_DEVICES": ensure_unicode(info.uuid)}
):
h = nvml._pynvml_handles()
h_expected = pynvml.nvmlDeviceGetHandleByIndex(0)
Expand All @@ -147,7 +148,7 @@ def test_visible_devices_uuid_2(index):
assert info.uuid

with mock.patch.dict(
os.environ, {"CUDA_VISIBLE_DEVICES": info.uuid.decode("utf-8")}
os.environ, {"CUDA_VISIBLE_DEVICES": ensure_unicode(info.uuid)}
):
h = nvml._pynvml_handles()
h_expected = pynvml.nvmlDeviceGetHandleByIndex(index)
Expand Down
2 changes: 1 addition & 1 deletion distributed/protocol/tests/test_cupy.py
Original file line number Diff line number Diff line change
Expand Up @@ -74,7 +74,7 @@ def test_serialize_cupy_from_rmm(size):
)
@pytest.mark.parametrize(
"dtype",
[numpy.dtype("<f4"), numpy.dtype(">f4"), numpy.dtype("<f8"), numpy.dtype(">f8")],
[numpy.dtype("<f4"), numpy.dtype("<f8")],
)
@pytest.mark.parametrize("serializer", ["cuda", "dask", "pickle"])
def test_serialize_cupy_sparse(sparse_name, dtype, serializer):
Expand Down
12 changes: 0 additions & 12 deletions distributed/protocol/tests/test_scipy.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,19 +30,7 @@
"dtype",
[
numpy.dtype("<f4"),
pytest.param(
numpy.dtype(">f4"),
marks=pytest.mark.skipif(
SCIPY_GE_1_15_0, reason="https://github.com/scipy/scipy/issues/22258"
),
),
numpy.dtype("<f8"),
pytest.param(
numpy.dtype(">f8"),
marks=pytest.mark.skipif(
SCIPY_GE_1_15_0, reason="https://github.com/scipy/scipy/issues/22258"
),
),
],
)
def test_serialize_scipy_sparse(sparse_type, dtype):
Expand Down
45 changes: 0 additions & 45 deletions distributed/tests/test_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,7 +42,6 @@
import dask
import dask.bag as db
from dask import delayed
from dask.optimization import SubgraphCallable
from dask.tokenize import tokenize
from dask.utils import get_default_shuffle_method, parse_timedelta, tmpfile

Expand Down Expand Up @@ -2627,13 +2626,6 @@ async def test_futures_of_get(c, s, a, b):
b = db.Bag({("b", i): f for i, f in enumerate([x, y, z])}, "b", 3)
assert set(futures_of(b)) == {x, y, z}

sg = SubgraphCallable(
{"x": x, "y": y, "z": z, "out": (add, (add, (add, x, y), z), "in")},
"out",
("in",),
)
assert set(futures_of(sg)) == {x, y, z}


def test_futures_of_class():
pytest.importorskip("numpy")
Expand Down Expand Up @@ -6192,43 +6184,6 @@ async def test_profile_bokeh(c, s, a, b):
assert os.path.exists(fn)


@gen_cluster(client=True, nthreads=[("", 1)])
async def test_get_mix_futures_and_SubgraphCallable(c, s, a):
future = c.submit(add, 1, 2)

subgraph = SubgraphCallable(
{"_2": (add, "_0", "_1"), "_3": (add, future, "_2")},
"_3",
("_0", "_1"),
)
dsk = {
"a": 1,
"b": 2,
"c": (subgraph, "a", "b"),
"d": (subgraph, "c", "b"),
}

future2 = c.get(dsk, "d", sync=False)
result = await future2
assert result == 11

# Nested subgraphs
subgraph2 = SubgraphCallable(
{
"_2": (subgraph, "_0", "_1"),
"_3": (subgraph, "_2", "_1"),
"_4": (add, "_3", future2),
},
"_4",
("_0", "_1"),
)

dsk2 = {"e": 1, "f": 2, "g": (subgraph2, "e", "f")}

result = await c.get(dsk2, "g", sync=False)
assert result == 22


@gen_cluster(client=True)
async def test_get_mix_futures_and_SubgraphCallable_dask_dataframe(c, s, a, b):
pd = pytest.importorskip("pandas")
Expand Down
10 changes: 10 additions & 0 deletions distributed/tests/test_tls_functional.py
Original file line number Diff line number Diff line change
Expand Up @@ -209,6 +209,16 @@ async def test_security_dict_input_no_security():
assert result == 2


@gen_test()
async def test_security_bool_input_disabled_security():
async with Scheduler(dashboard_address=":0", security=False) as s:
async with Worker(s.address, security=False):
async with Client(s.address, security=False, asynchronous=True) as c:
result = await c.submit(inc, 1)
assert c.security.require_encryption is False
assert result == 2


@gen_test()
async def test_security_dict_input():
conf = tls_config()
Expand Down
16 changes: 0 additions & 16 deletions distributed/tests/test_utils_comm.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,6 @@
import pytest

from dask._task_spec import TaskRef
from dask.optimization import SubgraphCallable

from distributed import wait
from distributed.compatibility import asyncio_run
Expand Down Expand Up @@ -246,18 +245,3 @@ def assert_eq(keys1: set[TaskRef], keys2: set[TaskRef]) -> None:
res, keys = unpack_remotedata(TaskRef("mykey"))
assert res == "mykey"
assert_eq(keys, {TaskRef("mykey")})

# Check unpack of SC that contains a wrapped key
sc = SubgraphCallable({"key": (TaskRef("data"),)}, outkey="key", inkeys=["arg1"])
dsk = (sc, "arg1")
res, keys = unpack_remotedata(dsk)
assert res[0] != sc # Notice, the first item (the SC) has been changed
assert res[1:] == ("arg1", "data")
assert_eq(keys, {TaskRef("data")})

# Check unpack of SC when it takes a wrapped key as argument
sc = SubgraphCallable({"key": ("arg1",)}, outkey="key", inkeys=[TaskRef("arg1")])
dsk = (sc, "arg1")
res, keys = unpack_remotedata(dsk)
assert res == (sc, "arg1") # Notice, the first item (the SC) has NOT been changed
assert_eq(keys, set())
2 changes: 1 addition & 1 deletion distributed/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -1283,7 +1283,7 @@ def command_has_keyword(cmd, k):

@toolz.memoize
def color_of(x, palette=palette):
h = md5(str(x).encode())
h = md5(str(x).encode(), usedforsecurity=False)
n = int(h.hexdigest()[:8], 16)
return palette[n % len(palette)]

Expand Down
25 changes: 0 additions & 25 deletions distributed/utils_comm.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,6 @@

import dask.config
from dask._task_spec import TaskRef
from dask.optimization import SubgraphCallable
from dask.typing import Key
from dask.utils import is_namedtuple_instance, parse_timedelta

Expand Down Expand Up @@ -197,30 +196,6 @@ def _unpack_remotedata_inner(
if typ is tuple:
if not o:
return o
if type(o[0]) is SubgraphCallable:
# Unpack futures within the arguments of the subgraph callable
futures: set[TaskRef] = set()
args = tuple(_unpack_remotedata_inner(i, byte_keys, futures) for i in o[1:])
found_futures.update(futures)

# Unpack futures within the subgraph callable itself
sc: SubgraphCallable = o[0]
futures = set()
dsk = {
k: _unpack_remotedata_inner(v, byte_keys, futures)
for k, v in sc.dsk.items()
}
future_keys: tuple = ()
if futures: # If no futures is in the subgraph, we just use `sc` as-is
found_futures.update(futures)
future_keys = (
tuple(f.key for f in futures)
if byte_keys
else tuple(f.key for f in futures)
)
inkeys = tuple(sc.inkeys) + future_keys
sc = SubgraphCallable(dsk, sc.outkey, inkeys, sc.name)
return (sc,) + args + future_keys
else:
return tuple(
_unpack_remotedata_inner(item, byte_keys, found_futures) for item in o
Expand Down

0 comments on commit a990d13

Please sign in to comment.