Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[core][compiled graphs] Throw unrecoverable actor exceptions at ray.get() #49461

Open
wants to merge 8 commits into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Next Next commit
[core][compiled graphs] Throw unrecoverable actor exceptions at ray.g…
…et()

Signed-off-by: Rui Qiao <ruisearch42@gmail.com>
  • Loading branch information
ruisearch42 committed Dec 27, 2024
commit 3f4867df8bebea48a65b5ad69480c121a494ecad
13 changes: 11 additions & 2 deletions python/ray/dag/compiled_dag_node.py
Original file line number Diff line number Diff line change
Expand Up @@ -2156,11 +2156,20 @@ def execute(

if self._returns_list:
ref = [
CompiledDAGRef(self, self._execution_index, channel_index)
CompiledDAGRef(
self,
self._execution_index,
list(self.worker_task_refs.values()),
channel_index,
)
for channel_index in range(len(self.dag_output_channels))
]
else:
ref = CompiledDAGRef(self, self._execution_index)
ref = CompiledDAGRef(
self,
self._execution_index,
list(self.worker_task_refs.values()),
)

self._execution_index += 1
return ref
Expand Down
53 changes: 49 additions & 4 deletions python/ray/dag/tests/experimental/test_torch_tensor_dag.py
Original file line number Diff line number Diff line change
Expand Up @@ -64,6 +64,9 @@ def send_or_raise(self, shape, dtype, value: int, raise_exception=False):
raise RuntimeError()
return torch.ones(shape, dtype=dtype, device=self.device) * value

def send_int(self, value: int):
return value

def recv(self, tensor):
# Check that tensor got loaded to the correct device.
assert tensor.device == self.device
Expand Down Expand Up @@ -828,7 +831,7 @@ def test_torch_tensor_exceptions(
ray_start_regular, static_shape, direct_return, overlap_gpu_communication
):
"""
Test exceptions being thrown by a NCCL sending task.
Test exceptions being thrown by a NCCL sending task's execution.
"""
if not USE_GPU:
pytest.skip("NCCL tests require GPUs")
Expand Down Expand Up @@ -881,10 +884,9 @@ def test_torch_tensor_exceptions(
value=i,
raise_exception=True,
)

if static_shape or direct_return:
with pytest.raises(RayChannelError):
# TODO(swang): Ideally return the RuntimeError thrown by the
# application instead of a generic RayChannelError.
with pytest.raises(RuntimeError):
ray.get(ref)

with pytest.raises(RayChannelError):
Expand All @@ -911,6 +913,49 @@ def test_torch_tensor_exceptions(
assert result == (i, shape, dtype)


@pytest.mark.parametrize("ray_start_regular", [{"num_cpus": 4}], indirect=True)
def test_torch_tensor_exceptions2(
ray_start_regular,
):
"""
Test exceptions being thrown by a NCCL sending task's write operation.
"""
if not USE_GPU:
pytest.skip("NCCL tests require GPUs")

assert (
sum(node["Resources"].get("GPU", 0) for node in ray.nodes()) > 1
), "This test requires at least 2 GPUs"

actor_cls = TorchTensorWorker.options(num_gpus=1)
sender = actor_cls.remote()
receiver = actor_cls.remote()

with InputNode() as inp:
dag = sender.send_int.bind(inp)
dag = dag.with_type_hint(
TorchTensorType(
transport="nccl",
_direct_return=True,
_static_shape=True,
)
)
dag = receiver.recv.bind(dag)

compiled_dag = dag.experimental_compile()

ref = compiled_dag.execute(1)
with pytest.raises(
ValueError,
match="Task annotated with _direct_return=True must return a CUDA torch.Tensor, instead found value `1`. DAG will shut down.",
):
ray.get(ref)

with pytest.raises(RayChannelError):
# The DAG is not usable after the exception.
ref = compiled_dag.execute(2)


@pytest.mark.parametrize("ray_start_regular", [{"num_cpus": 4}], indirect=True)
def test_torch_tensor_nccl_all_reduce(ray_start_regular):
"""
Expand Down
14 changes: 7 additions & 7 deletions python/ray/experimental/channel/torch_tensor_nccl_channel.py
Original file line number Diff line number Diff line change
Expand Up @@ -166,8 +166,8 @@ def write(self, value: Any, timeout: Optional[float] = None) -> None:
if isinstance(value, ray.exceptions.RayTaskError):
if self._typ.static_shape or self._typ.direct_return:
# Raise a fatal error to teardown the DAG.
# TODO(swang): Write exceptions to the tensor metadata or
# non-tensor data channel if it is available.
# This error will also be caught from `CompiledDAGRef.get()`
# and raised to the user
raise value

if self._cpu_data_channel is None:
Expand All @@ -176,12 +176,12 @@ def write(self, value: Any, timeout: Optional[float] = None) -> None:
# directly without trying to serialize it first.
import torch

# These ValueErrors will also be caught from `CompiledDAGRef.get()`
# and raised to the user
if not isinstance(value, torch.Tensor):
# TODO(swang): These errors are currently fatal for the DAG
# because there is no way for the receiver to receive the
# exception. This could be improved by sending the exception
# through the gpu_data_channel's CPU-based metadata channel,
# if one exists.
# TODO(swang): These errors are currently fatal for the DAG.
# This could be improved by sending the exception through the
# gpu_data_channel's CPU-based metadata channel, if one exists.
raise ValueError(
"Task annotated with _direct_return=True must "
"return a CUDA torch.Tensor, instead found value "
Expand Down
38 changes: 33 additions & 5 deletions python/ray/experimental/compiled_dag_ref.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
from typing import Any, List, Optional

import ray
from ray.exceptions import RayTaskError
from ray.exceptions import RayChannelError, RayTaskError
from ray.util.annotations import PublicAPI


Expand Down Expand Up @@ -46,6 +46,7 @@ def __init__(
self,
dag: "ray.experimental.CompiledDAG",
execution_index: int,
actor_task_refs: List[ray.ObjectRef],
channel_index: Optional[int] = None,
):
"""
Expand All @@ -54,14 +55,19 @@ def __init__(
execution_index: The index of the execution for the DAG.
A DAG can be executed multiple times, and execution index
indicates which execution this CompiledDAGRef corresponds to.
actor_task_refs: The actor task refs that are used to execute
ruisearch42 marked this conversation as resolved.
Show resolved Hide resolved
the DAG. This can be used internally to check the task
execution errors in case of exceptions.
channel_index: The index of the DAG's output channel to fetch
the result from. A DAG can have multiple output channels, and
channel index indicates which channel this CompiledDAGRef
corresponds to. If channel index is not provided, this CompiledDAGRef
wraps the results from all output channels.

"""
self._dag = dag
self._execution_index = execution_index
self._actor_task_refs = actor_task_refs
self._channel_index = channel_index
# Whether ray.get() was called on this CompiledDAGRef.
self._ray_get_called = False
Expand Down Expand Up @@ -100,10 +106,32 @@ def get(self, timeout: Optional[float] = None):
)

self._ray_get_called = True
return_vals = self._dag._execute_until(
self._execution_index, self._channel_index, timeout
)
return _process_return_vals(return_vals, True)
try:
return_vals = self._dag._execute_until(
self._execution_index, self._channel_index, timeout
)
return _process_return_vals(return_vals, True)
except RayChannelError as channel_error:
# If we get a channel error, we'd like to call ray.get()
# on the actor task refs to check if this is a result of
# task execution error which could not be passed down
# (e.g., when a pure NCCL channel is used, it is only
# able to send tensors, but not the wrapped exceptions).
# In this case, we'd like to raise the task execution error
# (which is the actual cause of the channel error) instead
# of the channel error itself.
# TODO(rui): determine which error to raise if multiple
# actor task refs have errors.
try:
ray.get(self._actor_task_refs)
ruisearch42 marked this conversation as resolved.
Show resolved Hide resolved
except Exception as task_error:
# Use 'from None' to suppress the context of the original
# channel error, which is not useful to the user.
raise task_error from None
else:
raise channel_error
except Exception:
raise


@PublicAPI(stability="alpha")
Expand Down