Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix test_restarting_does_not_deadlock #8849

Merged
merged 5 commits into from
Oct 29, 2024
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
111 changes: 69 additions & 42 deletions distributed/shuffle/tests/test_shuffle.py
Original file line number Diff line number Diff line change
Expand Up @@ -325,23 +325,42 @@ async def test_bad_disk(c, s, a, b):
await assert_scheduler_cleanup(s)


async def wait_until_worker_has_tasks(
prefix: str, worker: str, count: int, scheduler: Scheduler, interval: float = 0.01
) -> None:
ws = scheduler.workers[worker]
while (
len(
[
key
for key, ts in scheduler.tasks.items()
if prefix in key_split(key)
and ts.state == "memory"
and {ws} == ts.who_has
]
)
< count
):
await asyncio.sleep(interval)
from distributed.diagnostics.plugin import SchedulerPlugin


class ObserveTasksPlugin(SchedulerPlugin):
def __init__(self, prefixes, count, worker):
self.prefixes = prefixes
self.count = count
self.worker = worker
self.counter = defaultdict(int)
self.event = asyncio.Event()

async def start(self, scheduler):
self.scheduler = scheduler

def transition(self, key, start, finish, *args, **kwargs):
if (
finish == "processing"
and key_split(key) in self.prefixes
and self.scheduler.tasks[key].processing_on
and self.scheduler.tasks[key].processing_on.address == self.worker
):
self.counter[key_split(key)] += 1
if self.counter[key_split(key)] == self.count:
self.event.set()
return key, start, finish


@contextlib.asynccontextmanager
async def wait_until_worker_has_tasks(prefix, worker, count, scheduler):
plugin = ObserveTasksPlugin([prefix], count, worker)
scheduler.add_plugin(plugin, name="observe-tasks")
await plugin.start(scheduler)
try:
yield plugin.event
finally:
scheduler.remove_plugin("observe-tasks")


async def wait_for_tasks_in_state(
Expand Down Expand Up @@ -554,8 +573,12 @@ async def test_get_or_create_from_dangling_transfer(c, s, a, b):
@pytest.mark.slow
@gen_cluster(client=True, nthreads=[("", 1)])
async def test_crashed_worker_during_transfer(c, s, a):
async with Nanny(s.address, nthreads=1) as n:
killed_worker_address = n.worker_address
async with (
Nanny(s.address, nthreads=1) as n,
wait_until_worker_has_tasks(
"shuffle-transfer", n.worker_address, 1, s
) as event,
):
df = dask.datasets.timeseries(
start="2000-01-01",
end="2000-03-01",
Expand All @@ -565,9 +588,7 @@ async def test_crashed_worker_during_transfer(c, s, a):
with dask.config.set({"dataframe.shuffle.method": "p2p"}):
shuffled = df.shuffle("x")
fut = c.compute([shuffled, df], sync=True)
await wait_until_worker_has_tasks(
"shuffle-transfer", killed_worker_address, 1, s
)
await event.wait()
await n.process.process.kill()

result, expected = await fut
Expand Down Expand Up @@ -597,20 +618,16 @@ async def test_restarting_does_not_deadlock(c, s):
)
df = await c.persist(df)
expected = await c.compute(df)

async with Nanny(s.address) as b:
async with Worker(s.address) as b:
with dask.config.set({"dataframe.shuffle.method": "p2p"}):
out = df.shuffle("x")
assert not s.workers[b.worker_address].has_what
result = c.compute(out)
await wait_until_worker_has_tasks(
"shuffle-transfer", b.worker_address, 1, s
)
while not s.extensions["shuffle"].active_shuffles:
await asyncio.sleep(0)
a.status = Status.paused
await async_poll_for(lambda: len(s.running) == 1, timeout=5)
b.close_gracefully()
await b.process.process.kill()

b.batched_stream.close()
await async_poll_for(lambda: not s.running, timeout=5)

a.status = Status.running
Expand Down Expand Up @@ -663,8 +680,12 @@ def mock_mock_get_worker_for_range_sharding(
"distributed.shuffle._shuffle._get_worker_for_range_sharding",
mock_mock_get_worker_for_range_sharding,
):
async with Nanny(s.address, nthreads=1) as n:
killed_worker_address = n.worker_address
async with (
Nanny(s.address, nthreads=1) as n,
wait_until_worker_has_tasks(
"shuffle-transfer", n.worker_address, 1, s
) as event,
):
df = dask.datasets.timeseries(
start="2000-01-01",
end="2000-03-01",
Expand All @@ -674,9 +695,7 @@ def mock_mock_get_worker_for_range_sharding(
with dask.config.set({"dataframe.shuffle.method": "p2p"}):
shuffled = df.shuffle("x")
fut = c.compute([shuffled, df], sync=True)
await wait_until_worker_has_tasks(
"shuffle-transfer", n.worker_address, 1, s
)
await event.wait()
await n.process.process.kill()

result, expected = await fut
Expand Down Expand Up @@ -1024,8 +1043,10 @@ async def test_restarting_during_unpack_raises_killed_worker(c, s, a, b):
@pytest.mark.slow
@gen_cluster(client=True, nthreads=[("", 1)])
async def test_crashed_worker_during_unpack(c, s, a):
async with Nanny(s.address, nthreads=2) as n:
killed_worker_address = n.worker_address
async with (
Nanny(s.address, nthreads=2) as n,
wait_until_worker_has_tasks(UNPACK_PREFIX, n.worker_address, 1, s) as event,
):
df = dask.datasets.timeseries(
start="2000-01-01",
end="2000-03-01",
Expand All @@ -1037,7 +1058,7 @@ async def test_crashed_worker_during_unpack(c, s, a):
shuffled = df.shuffle("x")
result = c.compute(shuffled)

await wait_until_worker_has_tasks(UNPACK_PREFIX, killed_worker_address, 1, s)
await event.wait()
await n.process.process.kill()

result = await result
Expand Down Expand Up @@ -1477,7 +1498,10 @@ def block(df, in_event, block_event):
block_event.wait()
return df

async with Nanny(s.address, nthreads=1) as n:
async with (
Nanny(s.address, nthreads=1) as n,
wait_until_worker_has_tasks(UNPACK_PREFIX, n.worker_address, 1, s) as event,
):
df = dask.datasets.timeseries(
start="2000-01-01",
end="2000-03-01",
Expand All @@ -1498,7 +1522,7 @@ def block(df, in_event, block_event):
allow_other_workers=True,
)

await wait_until_worker_has_tasks(UNPACK_PREFIX, n.worker_address, 1, s)
await event.wait()
await in_event.wait()
await n.process.process.kill()
await block_event.set()
Expand All @@ -1515,7 +1539,10 @@ def block(df, in_event, block_event):

@gen_cluster(client=True, nthreads=[("", 1)])
async def test_crashed_worker_after_shuffle_persisted(c, s, a):
async with Nanny(s.address, nthreads=1) as n:
async with (
Nanny(s.address, nthreads=1) as n,
wait_until_worker_has_tasks(UNPACK_PREFIX, n.worker_address, 1, s) as event,
):
df = df = dask.datasets.timeseries(
start="2000-01-01",
end="2000-01-10",
Expand All @@ -1527,7 +1554,7 @@ async def test_crashed_worker_after_shuffle_persisted(c, s, a):
out = df.shuffle("x")
out = out.persist()

await wait_until_worker_has_tasks(UNPACK_PREFIX, n.worker_address, 1, s)
await event.wait()
await out

await n.process.process.kill()
Expand Down
Loading