Skip to content

Commit

Permalink
Merge branch 'main' into shuffle_failures
Browse files Browse the repository at this point in the history
  • Loading branch information
fjetter authored Sep 4, 2024
2 parents a009e71 + 7c2134e commit 5fc3d74
Show file tree
Hide file tree
Showing 4 changed files with 90 additions and 31 deletions.
2 changes: 1 addition & 1 deletion .github/workflows/test-report.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -54,7 +54,7 @@ jobs:
mv test_report.html test_short_report.html deploy/
- name: Deploy 🚀
uses: JamesIves/github-pages-deploy-action@v4.5.0
uses: JamesIves/github-pages-deploy-action@v4.6.4
with:
branch: gh-pages
folder: deploy
4 changes: 3 additions & 1 deletion distributed/client.py
Original file line number Diff line number Diff line change
Expand Up @@ -877,8 +877,9 @@ def _keys(self) -> Iterable[Key]:

else:
if self.pure:
tok = tokenize(self.func, self.kwargs)
keys = [
self.key + "-" + tokenize(self.func, self.kwargs, args) # type: ignore
self.key + "-" + tokenize(tok, args) # type: ignore
for args in zip(*self.iterables)
]
else:
Expand Down Expand Up @@ -5083,6 +5084,7 @@ def register_plugin(
"future version. Please mark your plugin as idempotent by setting its "
"`.idempotent` attribute to `True`.",
FutureWarning,
stacklevel=2,
)
else:
idempotent = getattr(plugin, "idempotent", False)
Expand Down
47 changes: 27 additions & 20 deletions distributed/nanny.py
Original file line number Diff line number Diff line change
Expand Up @@ -516,7 +516,7 @@ async def _():
await self.instantiate()

try:
await wait_for(_(), timeout)
await wait_for(asyncio.shield(_()), timeout)
except asyncio.TimeoutError:
logger.error(
f"Restart timed out after {timeout}s; returning before finished"
Expand Down Expand Up @@ -745,26 +745,30 @@ async def start(self) -> Status:
os.environ.update(self.pre_spawn_env)

try:
await self.process.start()
except OSError:
logger.exception("Nanny failed to start process", exc_info=True)
# NOTE: doesn't wait for process to terminate, just for terminate signal to be sent
await self.process.terminate()
self.status = Status.failed
try:
msg = await self._wait_until_connected(uid)
except Exception:
# NOTE: doesn't wait for process to terminate, just for terminate signal to be sent
await self.process.terminate()
self.status = Status.failed
raise
try:
await self.process.start()
except OSError:
# This can only happen if the actual process creation failed, e.g.
# multiprocessing.Process.start failed. This is not tested!
logger.exception("Nanny failed to start process", exc_info=True)
# NOTE: doesn't wait for process to terminate, just for terminate signal to be sent
await self.process.terminate()
self.status = Status.failed
try:
msg = await self._wait_until_connected(uid)
except Exception:
# NOTE: doesn't wait for process to terminate, just for terminate signal to be sent
await self.process.terminate()
self.status = Status.failed
raise
finally:
self.running.set()
if not msg:
return self.status
self.worker_address = msg["address"]
self.worker_dir = msg["dir"]
assert self.worker_address
self.status = Status.running
self.running.set()

return self.status

Expand Down Expand Up @@ -799,6 +803,7 @@ def mark_stopped(self):
msg = self._death_message(self.process.pid, r)
logger.info(msg)
self.status = Status.stopped
self.running.clear()
self.stopped.set()
# Release resources
self.process.close()
Expand Down Expand Up @@ -830,22 +835,24 @@ async def kill(
"""
deadline = time() + timeout

if self.status == Status.stopped:
return
if self.status == Status.stopping:
await self.stopped.wait()
return
# If the process is not properly up it will not watch the closing queue
# and we may end up leaking this process
# Therefore wait for it to be properly started before killing it
if self.status == Status.starting:
await self.running.wait()

assert self.status in (
Status.stopping,
Status.stopped,
Status.running,
Status.failed, # process failed to start, but hasn't been joined yet
Status.closing_gracefully,
), self.status
if self.status == Status.stopped:
return
if self.status == Status.stopping:
await self.stopped.wait()
return
self.status = Status.stopping
logger.info("Nanny asking worker to close. Reason: %s", reason)

Expand Down
68 changes: 59 additions & 9 deletions distributed/tests/test_nanny.py
Original file line number Diff line number Diff line change
Expand Up @@ -208,25 +208,47 @@ async def test_scheduler_file():
s.stop()


@pytest.mark.xfail(
os.environ.get("MINDEPS") == "true",
reason="Timeout errors with mindeps environment",
)
@gen_cluster(client=True, Worker=Nanny, nthreads=[("127.0.0.1", 2)])
async def test_nanny_timeout(c, s, a):
@gen_cluster(client=True, Worker=Nanny, nthreads=[("", 1)])
async def test_nanny_restart(c, s, a):
x = await c.scatter(123)
assert await c.submit(lambda: 1) == 1

await a.restart()

while x.status != "cancelled":
await asyncio.sleep(0.1)

assert await c.submit(lambda: 1) == 1


@gen_cluster(client=True, Worker=Nanny, nthreads=[("", 1)])
async def test_nanny_restart_timeout(c, s, a):
x = await c.scatter(123)
with captured_logger(
logging.getLogger("distributed.nanny"), level=logging.ERROR
) as logger:
await a.restart(timeout=0.1)
await a.restart(timeout=0)

out = logger.getvalue()
assert "timed out" in out.lower()

start = time()
while x.status != "cancelled":
await asyncio.sleep(0.1)
assert time() < start + 7

assert await c.submit(lambda: 1) == 1


@gen_cluster(client=True, Worker=Nanny, nthreads=[("", 1)])
async def test_nanny_restart_timeout_stress(c, s, a):
x = await c.scatter(123)
restarts = [a.restart(timeout=random.random()) for _ in range(100)]
await asyncio.gather(*restarts)

while x.status != "cancelled":
await asyncio.sleep(0.1)

assert await c.submit(lambda: 1) == 1
assert len(s.workers) == 1


@gen_cluster(
Expand Down Expand Up @@ -582,6 +604,34 @@ async def test_worker_start_exception(s):
assert logs.getvalue().count("ValueError: broken") == 1, logs.getvalue()


@gen_cluster(nthreads=[])
async def test_worker_start_exception_while_killing(s):
nanny = Nanny(s.address, worker_class=BrokenWorker)

async def try_to_kill_nanny():
while not nanny.process or nanny.process.status != Status.starting:
await asyncio.sleep(0)
await nanny.kill()

kill_task = asyncio.create_task(try_to_kill_nanny())
with captured_logger(logger="distributed.nanny", level=logging.WARNING) as logs:
with raises_with_cause(
RuntimeError,
"Nanny failed to start",
RuntimeError,
"BrokenWorker failed to start",
):
async with nanny:
pass
await kill_task
assert nanny.status == Status.failed
# ^ NOTE: `Nanny.close` sets it to `closed`, then `Server.start._close_on_failure` sets it to `failed`
assert nanny.process is None
assert "Restarting worker" not in logs.getvalue()
# Avoid excessive spewing. (It's also printed once extra within the subprocess, which is okay.)
assert logs.getvalue().count("ValueError: broken") == 1, logs.getvalue()


@gen_cluster(nthreads=[])
async def test_failure_during_worker_initialization(s):
with captured_logger(logger="distributed.nanny", level=logging.WARNING) as logs:
Expand Down

0 comments on commit 5fc3d74

Please sign in to comment.