Skip to content

Commit

Permalink
Use IO task marker in scheduling (#8950)
Browse files Browse the repository at this point in the history
  • Loading branch information
jrbourbeau authored Jan 16, 2025
1 parent f28498e commit bbdd2ee
Show file tree
Hide file tree
Showing 2 changed files with 35 additions and 0 deletions.
3 changes: 3 additions & 0 deletions distributed/scheduler.py
Original file line number Diff line number Diff line change
Expand Up @@ -3085,6 +3085,9 @@ def is_rootish(self, ts: TaskState) -> bool:
"""
if ts.resource_restrictions or ts.worker_restrictions or ts.host_restrictions:
return False
# Check explicitly marked data producer tasks
if ts.run_spec and ts.run_spec.data_producer:
return True
tg = ts.group
# TODO short-circuit to True if `not ts.dependencies`?
return (
Expand Down
32 changes: 32 additions & 0 deletions distributed/tests/test_scheduler.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@

import dask
from dask import bag, delayed
from dask.base import DaskMethodsMixin
from dask.core import flatten
from dask.highlevelgraph import HighLevelGraph, MaterializedLayer
from dask.utils import parse_timedelta, tmpfile, typename
Expand Down Expand Up @@ -5315,3 +5316,34 @@ async def test_alias_resolving_break_queuing(c, s, a):
while not s.tasks:
await asyncio.sleep(0.01)
assert sum([s.is_rootish(v) for v in s.tasks.values()]) == 18


@gen_cluster(client=True, nthreads=[("", 1)])
async def test_data_producers(c, s, a):
from dask._task_spec import DataNode, Task, TaskRef

def func(*args):
return 100

class MyArray(DaskMethodsMixin):
def __dask_graph__(self):
return {
"a": DataNode("a", 10),
"b": Task("b", func, TaskRef("a"), _data_producer=True),
"c": Task("c", func, TaskRef("b")),
"d": Task("d", func, TaskRef("c")),
}

def __dask_keys__(self):
return ["d"]

def __dask_postcompute__(self):
return func, ()

arr = MyArray()
x = c.compute(arr)
await async_poll_for(lambda: s.tasks, 5)
assert (
sum([s.is_rootish(v) and v.run_spec.data_producer for v in s.tasks.values()])
== 2
)

0 comments on commit bbdd2ee

Please sign in to comment.