Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[core][compiled-graphs] Support wait-and-get to round-robin the acquisition of mutable objects allowing for fast failure #49444

Closed
wants to merge 30 commits into from
Closed
Show file tree
Hide file tree
Changes from 11 commits
Commits
Show all changes
30 commits
Select commit Hold shift + click to select a range
e631160
pass tests
kevin85421 Dec 26, 2024
bf585fa
pass tests, refactor
kevin85421 Dec 26, 2024
bf7a26e
pass tests, waitables
kevin85421 Dec 26, 2024
93bff72
update
kevin85421 Dec 26, 2024
d68f8db
pass tests
kevin85421 Dec 27, 2024
aa7b831
pass tests, retrieve obj one by one in sync reader
kevin85421 Dec 27, 2024
a22bbaa
update
kevin85421 Dec 30, 2024
d1aac6c
update comment and move import to top-level
kevin85421 Dec 30, 2024
a84e561
remove logs and update comments for WaitAndGetExperimentalMutableObjects
kevin85421 Dec 30, 2024
9e44591
update comments
kevin85421 Dec 30, 2024
3ddad3a
add some utils
kevin85421 Dec 30, 2024
4ca28e9
fix test_channel tests
kevin85421 Dec 30, 2024
d8dadb4
update
kevin85421 Dec 31, 2024
5d18be7
Merge remote-tracking branch 'upstream/master' into 20241224-2
kevin85421 Dec 31, 2024
cddaef7
fix test_channel
kevin85421 Dec 31, 2024
abd4d28
update type hint
kevin85421 Dec 31, 2024
a684329
remove c++ log
kevin85421 Dec 31, 2024
2a1b3fc
update _read_list
kevin85421 Dec 31, 2024
9a3ae27
refactor
kevin85421 Dec 31, 2024
9f2f1d1
remove comment for visualize tests
kevin85421 Dec 31, 2024
ba489c3
add tests
kevin85421 Dec 31, 2024
ebbee6f
address comments
kevin85421 Dec 31, 2024
8ebe78c
move retrieve_obj_refs to util
kevin85421 Dec 31, 2024
d958b24
fix lint error
kevin85421 Dec 31, 2024
551bacb
fix nccl channel tests
kevin85421 Jan 1, 2025
1ae976b
fix test
kevin85421 Jan 1, 2025
1cc8210
fix typo
kevin85421 Jan 1, 2025
8c4e61d
refactor
kevin85421 Jan 1, 2025
6d22187
Merge remote-tracking branch 'upstream/master' into 20241224-2
kevin85421 Jan 6, 2025
38c8652
fix lint
kevin85421 Jan 6, 2025
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
62 changes: 62 additions & 0 deletions python/ray/_private/worker.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,7 @@
TypeVar,
Union,
overload,
Set,
)
from urllib.parse import urlparse

Expand Down Expand Up @@ -923,6 +924,67 @@ def get_objects(

return values, debugger_breakpoint

def experimental_wait_and_get_mutable_objects(
kevin85421 marked this conversation as resolved.
Show resolved Hide resolved
self,
object_refs: List[ObjectRef],
num_returns: int,
timeout_ms: int = -1,
return_exceptions: bool = False,
skip_deserialization: bool = False,
suppress_timeout_errors: bool = False,
) -> Tuple[List[Any], Set[ObjectRef]]:
"""
Wait for `num_returns` experimental mutable objects in `object_refs` to
be ready and read them.

Args:
object_refs: List of object refs to read.
num_returns: Number of objects to read in this round.
timeout_ms: Timeout in milliseconds.
return_exceptions: If any of the objects deserialize to an
Exception object, whether to return them as values in the
returned list. If False, then the first found exception will be
raised.
skip_deserialization: If True, only the buffer will be released and
the object associated with the buffer will not be deserailized.
kevin85421 marked this conversation as resolved.
Show resolved Hide resolved
suppress_timeout_errors: If True, suppress timeout errors.
Returns:
A tuple containing the list of objects read and the set of
object refs that were not read in this round.
"""
for object_ref in object_refs:
if not isinstance(object_ref, ObjectRef):
raise TypeError(
"Attempting to call `experimental_wait_and_get_mutable_objects` "
f"on the value {object_ref}, which is not an ray.ObjectRef."
)
data_metadata_pairs: List[
Tuple[ray._raylet.Buffer, bytes]
] = self.core_worker.experimental_wait_and_get_mutable_objects(
object_refs, num_returns, timeout_ms, suppress_timeout_errors
)

if skip_deserialization:
return None, set()

non_complete_object_refs_set = set()
for i, (data, _) in enumerate(data_metadata_pairs):
if data is None:
non_complete_object_refs_set.add(object_refs[i])

values = self.deserialize_objects(data_metadata_pairs, object_refs)
if not return_exceptions:
# Raise exceptions instead of returning them to the user.
for value in values:
if isinstance(value, RayError):
if isinstance(value, ray.exceptions.ObjectLostError):
global_worker.core_worker.dump_object_store_memory_usage()
if isinstance(value, RayTaskError):
raise value.as_instanceof_cause()
else:
raise value
return values, non_complete_object_refs_set

def main_loop(self):
"""The main loop a worker runs to receive and execute tasks."""

Expand Down
36 changes: 36 additions & 0 deletions python/ray/_raylet.pyx
Original file line number Diff line number Diff line change
Expand Up @@ -3423,6 +3423,42 @@ cdef class CoreWorker:
.ExperimentalChannelReadRelease(c_object_ids))
check_status(op_status)

def experimental_wait_and_get_mutable_objects(
self,
kevin85421 marked this conversation as resolved.
Show resolved Hide resolved
object_refs,
int num_returns,
int64_t timeout_ms=-1,
c_bool suppress_timeout_errors=False):
"""
Wait for `num_returns` experimental mutable objects in `object_refs` to
be ready and read them.

Args:
object_refs: List of object refs to read.
num_returns: Number of objects to read in this round.
timeout_ms: Timeout in milliseconds.
suppress_timeout_errors: If True, suppress timeout errors.
kevin85421 marked this conversation as resolved.
Show resolved Hide resolved
"""
cdef:
c_vector[shared_ptr[CRayObject]] results
c_vector[CObjectID] c_object_ids = ObjectRefsToVector(object_refs)
with nogil:
op_status = (CCoreWorkerProcess.GetCoreWorker()
.WaitAndGetExperimentalMutableObjects(
c_object_ids, timeout_ms, num_returns, results))

# The caller can determine whether `timeout` was raised by checking the value
# of `results`. At the same time, users still want to get the values of some
# objects even if some objects are not ready. Hence, we don't raise
# the exception if `suppress_timeout_errors` is set to True and instead return
# `results`.
try:
check_status(op_status)
except RayChannelTimeoutError:
if not suppress_timeout_errors:
raise
return RayObjectsToDataMetadataPairs(results)

def put_serialized_object_and_increment_local_ref(
self, serialized_object,
ObjectRef object_ref=None,
Expand Down
34 changes: 17 additions & 17 deletions python/ray/dag/tests/experimental/test_accelerated_dag.py
Original file line number Diff line number Diff line change
Expand Up @@ -246,24 +246,24 @@ def test_two_returns_two_readers(ray_start_regular, single_fetch):
assert res == [1, 2]


@pytest.mark.parametrize("single_fetch", [True, False])
def test_inc_two_returns(ray_start_regular, single_fetch):
a = Actor.remote(0)
with InputNode() as i:
o1, o2 = a.inc_and_return_two.bind(i)
dag = MultiOutputNode([o1, o2])
# @pytest.mark.parametrize("single_fetch", [True, False])
# def test_inc_two_returns(ray_start_regular, single_fetch):
# a = Actor.remote(0)
# with InputNode() as i:
# o1, o2 = a.inc_and_return_two.bind(i)
# dag = MultiOutputNode([o1, o2])

compiled_dag = dag.experimental_compile()
compiled_dag.visualize(channel_details=True)
for i in range(3):
refs = compiled_dag.execute(1)
if single_fetch:
for j, ref in enumerate(refs):
res = ray.get(ref)
assert res == i + j + 1
else:
res = ray.get(refs)
assert res == [i + 1, i + 2]
# compiled_dag = dag.experimental_compile()
# compiled_dag.visualize(channel_details=True)
# for i in range(3):
# refs = compiled_dag.execute(1)
# if single_fetch:
# for j, ref in enumerate(refs):
# res = ray.get(ref)
# assert res == i + j + 1
# else:
# res = ray.get(refs)
# assert res == [i + 1, i + 2]


def test_two_as_one_return(ray_start_regular):
Expand Down
14 changes: 13 additions & 1 deletion python/ray/experimental/channel/cached_channel.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
import uuid
from typing import Any, Optional
from typing import Any, List, Optional

from ray import ObjectRef
from ray.experimental.channel.common import ChannelInterface


Expand Down Expand Up @@ -100,6 +101,17 @@ def read(self, timeout: Optional[float] = None) -> Any:
# https://github.com/ray-project/ray/issues/47409
return ctx.get_data(self._channel_id)

def get_ray_waitables(self) -> List[ObjectRef]:
self.ensure_registered_as_reader()
from ray.experimental.channel import ChannelContext

ctx = ChannelContext.get_current().serialization_context
if ctx.has_data(self._channel_id):
return []
if self._inner_channel is not None:
return self._inner_channel.get_ray_waitables()
return []

def close(self) -> None:
from ray.experimental.channel import ChannelContext

Expand Down
118 changes: 118 additions & 0 deletions python/ray/experimental/channel/common.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@

import ray
import ray.exceptions
from ray import ObjectRef
from ray.experimental.channel.communicator import Communicator
from ray.experimental.channel.serialization_context import _SerializationContext
from ray.util.annotations import DeveloperAPI, PublicAPI
Expand Down Expand Up @@ -255,6 +256,12 @@ def read(self, timeout: Optional[float] = None) -> Any:
"""
raise NotImplementedError

def get_ray_waitables(self) -> List[ObjectRef]:
"""
Get the ObjectRefs that will be read in the next read() call.
"""
raise NotImplementedError

def close(self) -> None:
"""
Close this channel. This method must not block and it must be made
Expand All @@ -278,6 +285,7 @@ def __init__(
self._input_channels = input_channels
self._closed = False
self._num_reads = 0
self._non_complete_object_refs: List[ObjectRef] = []
kevin85421 marked this conversation as resolved.
Show resolved Hide resolved

def get_num_reads(self) -> int:
return self._num_reads
Expand All @@ -298,6 +306,14 @@ def _read_list(self, timeout: Optional[float] = None) -> List[Any]:
"""
raise NotImplementedError

def _get_all_waitables_to_num_consumers(self) -> Dict[ObjectRef, int]:
waitable_to_num_consumers = {}
for c in self._input_channels:
waitables = c.get_ray_waitables()
for w in waitables:
waitable_to_num_consumers[w] = waitable_to_num_consumers.get(w, 0) + 1
return waitable_to_num_consumers

def read(self, timeout: Optional[float] = None) -> List[Any]:
"""
Read from this reader.
Expand Down Expand Up @@ -332,7 +348,81 @@ def __init__(
def start(self):
pass

def _consume_non_complete_object_refs_if_needed(
self, timeout: Optional[float] = None
) -> None:
timeout_point = time.monotonic() + timeout
worker = ray._private.worker.global_worker
if len(self._non_complete_object_refs) > 0:
kevin85421 marked this conversation as resolved.
Show resolved Hide resolved
# If the last read failed early, we need to consume the data from
# the non-complete object refs before the next read. If we don't do
# this, the read operation will read different versions of the
# object refs.
(
_,
non_complete_object_refs_set,
) = worker.experimental_wait_and_get_mutable_objects(
self._non_complete_object_refs,
num_returns=len(self._non_complete_object_refs),
timeout_ms=max(0, (timeout_point - time.monotonic()) * 1000),
return_exceptions=True,
# Skip deserialization to speed up this step.
skip_deserialization=True,
suppress_timeout_errors=False,
)
assert len(non_complete_object_refs_set) == 0
self._non_complete_object_refs = []

def _read_list(self, timeout: Optional[float] = None) -> List[Any]:
timeout = 1e6 if timeout is None or timeout == -1 else timeout
self._consume_non_complete_object_refs_if_needed(timeout)

waitable_to_num_consumers = self._get_all_waitables_to_num_consumers()
all_waitables = list(waitable_to_num_consumers.keys())

timeout_point = time.monotonic() + timeout
worker = ray._private.worker.global_worker
while len(all_waitables) > 0:
# Retrieve at most one object each time.
(
values,
non_complete_object_refs_set,
) = worker.experimental_wait_and_get_mutable_objects(
all_waitables,
num_returns=1,
timeout_ms=max(0, (timeout_point - time.monotonic()) * 1000),
return_exceptions=True,
suppress_timeout_errors=True,
)
ctx = ChannelContext.get_current().serialization_context
for i, value in enumerate(values):
if all_waitables[i] in non_complete_object_refs_set:
continue
if isinstance(value, ray.exceptions.RayTaskError):
self._non_complete_object_refs = list(non_complete_object_refs_set)
for w in all_waitables:
ctx.reset_data(w)
# If we raise an exception immediately, it will be considered
# as a system error which will cause the execution loop to
# exit. Hence, return immediately and let `_process_return_vals`
# handle the exception.
#
# Return a list of RayTaskError so that the caller will not
# get an undefined partial result.
kevin85421 marked this conversation as resolved.
Show resolved Hide resolved
return [value for _ in range(len(self._input_channels))]
ctx.set_data(
all_waitables[i],
value,
waitable_to_num_consumers[all_waitables[i]],
)
all_waitables = list(non_complete_object_refs_set)
if time.monotonic() > timeout_point and len(all_waitables) != 0:
# This ensures that the reader attempts to retrieve
# data once even when the `timeout` is 0.
raise ray.exceptions.RayChannelTimeoutError(
"Timed out waiting for channel data."
)

results = []
for c in self._input_channels:
start_time = time.monotonic()
Expand Down Expand Up @@ -378,6 +468,34 @@ def start(self):

def _run(self):
results = []
waitable_to_num_consumers = self._get_all_waitables_to_num_consumers()
kevin85421 marked this conversation as resolved.
Show resolved Hide resolved
all_waitables = list(waitable_to_num_consumers.keys())

worker = ray._private.worker.global_worker
while len(all_waitables) > 0:
(
values,
non_complete_object_refs_set,
) = worker.experimental_wait_and_get_mutable_objects(
all_waitables,
len(all_waitables),
timeout_ms=1000,
return_exceptions=True,
suppress_timeout_errors=True,
)
ctx = ChannelContext.get_current().serialization_context
for i, value in enumerate(values):
if all_waitables[i] in non_complete_object_refs_set:
continue
ctx.set_data(
all_waitables[i],
value,
waitable_to_num_consumers[all_waitables[i]],
)
all_waitables = list(non_complete_object_refs_set)
if sys.is_finalizing():
return results

for c in self._input_channels:
exiting = retry_and_check_interpreter_exit(
lambda: results.append(c.read(timeout=1))
Expand Down
7 changes: 6 additions & 1 deletion python/ray/experimental/channel/intra_process_channel.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
import uuid
from typing import Any, Optional
from typing import Any, List, Optional

from ray import ObjectRef
from ray.experimental.channel import ChannelContext
from ray.experimental.channel.common import ChannelInterface
from ray.util.annotations import PublicAPI
Expand Down Expand Up @@ -63,6 +64,10 @@ def read(self, timeout: Optional[float] = None, deserialize: bool = True) -> Any
ctx = ChannelContext.get_current().serialization_context
return ctx.get_data(self._channel_id)

def get_ray_waitables(self) -> List[ObjectRef]:
self.ensure_registered_as_reader()
return []

def close(self) -> None:
ctx = ChannelContext.get_current().serialization_context
ctx.reset_data(self._channel_id)
Loading
Loading