Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[core][compiled-graphs] Support wait-and-get to round-robin the acquisition of mutable objects allowing for fast failure #49444

Closed
wants to merge 30 commits into from
Closed
Show file tree
Hide file tree
Changes from 21 commits
Commits
Show all changes
30 commits
Select commit Hold shift + click to select a range
e631160
pass tests
kevin85421 Dec 26, 2024
bf585fa
pass tests, refactor
kevin85421 Dec 26, 2024
bf7a26e
pass tests, waitables
kevin85421 Dec 26, 2024
93bff72
update
kevin85421 Dec 26, 2024
d68f8db
pass tests
kevin85421 Dec 27, 2024
aa7b831
pass tests, retrieve obj one by one in sync reader
kevin85421 Dec 27, 2024
a22bbaa
update
kevin85421 Dec 30, 2024
d1aac6c
update comment and move import to top-level
kevin85421 Dec 30, 2024
a84e561
remove logs and update comments for WaitAndGetExperimentalMutableObjects
kevin85421 Dec 30, 2024
9e44591
update comments
kevin85421 Dec 30, 2024
3ddad3a
add some utils
kevin85421 Dec 30, 2024
4ca28e9
fix test_channel tests
kevin85421 Dec 30, 2024
d8dadb4
update
kevin85421 Dec 31, 2024
5d18be7
Merge remote-tracking branch 'upstream/master' into 20241224-2
kevin85421 Dec 31, 2024
cddaef7
fix test_channel
kevin85421 Dec 31, 2024
abd4d28
update type hint
kevin85421 Dec 31, 2024
a684329
remove c++ log
kevin85421 Dec 31, 2024
2a1b3fc
update _read_list
kevin85421 Dec 31, 2024
9a3ae27
refactor
kevin85421 Dec 31, 2024
9f2f1d1
remove comment for visualize tests
kevin85421 Dec 31, 2024
ba489c3
add tests
kevin85421 Dec 31, 2024
ebbee6f
address comments
kevin85421 Dec 31, 2024
8ebe78c
move retrieve_obj_refs to util
kevin85421 Dec 31, 2024
d958b24
fix lint error
kevin85421 Dec 31, 2024
551bacb
fix nccl channel tests
kevin85421 Jan 1, 2025
1ae976b
fix test
kevin85421 Jan 1, 2025
1cc8210
fix typo
kevin85421 Jan 1, 2025
8c4e61d
refactor
kevin85421 Jan 1, 2025
6d22187
Merge remote-tracking branch 'upstream/master' into 20241224-2
kevin85421 Jan 6, 2025
38c8652
fix lint
kevin85421 Jan 6, 2025
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
62 changes: 62 additions & 0 deletions python/ray/_private/worker.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,7 @@
TypeVar,
Union,
overload,
Set,
)
from urllib.parse import urlparse

Expand Down Expand Up @@ -922,6 +923,67 @@ def get_objects(

return values, debugger_breakpoint

def experimental_wait_and_get_mutable_objects(
kevin85421 marked this conversation as resolved.
Show resolved Hide resolved
self,
object_refs: List[ObjectRef],
num_returns: int,
timeout_ms: int = -1,
return_exceptions: bool = False,
skip_deserialization: bool = False,
suppress_timeout_errors: bool = False,
) -> Tuple[List[Any], Set[ObjectRef]]:
"""
Wait for `num_returns` experimental mutable objects in `object_refs` to
be ready and read them.

Args:
object_refs: List of object refs to read.
num_returns: Number of objects to read in this round.
timeout_ms: Timeout in milliseconds.
return_exceptions: If any of the objects deserialize to an
Exception object, whether to return them as values in the
returned list. If False, then the first found exception will be
raised.
skip_deserialization: If True, only the buffer will be released and
the object associated with the buffer will not be deserailized.
kevin85421 marked this conversation as resolved.
Show resolved Hide resolved
suppress_timeout_errors: If True, suppress timeout errors.
Returns:
A tuple containing the list of objects read and the set of
object refs that were not read in this round.
"""
for object_ref in object_refs:
if not isinstance(object_ref, ObjectRef):
raise TypeError(
"Attempting to call `experimental_wait_and_get_mutable_objects` "
f"on the value {object_ref}, which is not an ray.ObjectRef."
)
data_metadata_pairs: List[
Tuple[ray._raylet.Buffer, bytes]
] = self.core_worker.experimental_wait_and_get_mutable_objects(
object_refs, num_returns, timeout_ms, suppress_timeout_errors
)

if skip_deserialization:
return data_metadata_pairs, set()

non_complete_object_refs_set = set()
for i, (data, _) in enumerate(data_metadata_pairs):
if data is None:
non_complete_object_refs_set.add(object_refs[i])

values = self.deserialize_objects(data_metadata_pairs, object_refs)
if not return_exceptions:
# Raise exceptions instead of returning them to the user.
for value in values:
if isinstance(value, RayError):
if isinstance(value, ray.exceptions.ObjectLostError):
global_worker.core_worker.dump_object_store_memory_usage()
if isinstance(value, RayTaskError):
raise value.as_instanceof_cause()
else:
raise value
return values, non_complete_object_refs_set

def main_loop(self):
"""The main loop a worker runs to receive and execute tasks."""

Expand Down
36 changes: 36 additions & 0 deletions python/ray/_raylet.pyx
Original file line number Diff line number Diff line change
Expand Up @@ -3405,6 +3405,42 @@ cdef class CoreWorker:
CCoreWorkerProcess.GetCoreWorker()
.ExperimentalRegisterMutableObjectReader(c_object_id))

def experimental_wait_and_get_mutable_objects(
self,
kevin85421 marked this conversation as resolved.
Show resolved Hide resolved
object_refs,
int num_returns,
int64_t timeout_ms=-1,
c_bool suppress_timeout_errors=False):
"""
Wait for `num_returns` experimental mutable objects in `object_refs` to
be ready and read them.

Args:
object_refs: List of object refs to read.
num_returns: Number of objects to read in this round.
timeout_ms: Timeout in milliseconds.
suppress_timeout_errors: If True, suppress timeout errors.
kevin85421 marked this conversation as resolved.
Show resolved Hide resolved
"""
cdef:
c_vector[shared_ptr[CRayObject]] results
c_vector[CObjectID] c_object_ids = ObjectRefsToVector(object_refs)
with nogil:
op_status = (CCoreWorkerProcess.GetCoreWorker()
.WaitAndGetExperimentalMutableObjects(
c_object_ids, timeout_ms, num_returns, results))

# The caller can determine whether `timeout` was raised by checking the value
# of `results`. At the same time, users still want to get the values of some
# objects even if some objects are not ready. Hence, we don't raise
# the exception if `suppress_timeout_errors` is set to True and instead return
# `results`.
try:
check_status(op_status)
except RayChannelTimeoutError:
if not suppress_timeout_errors:
raise
return RayObjectsToDataMetadataPairs(results)

def put_serialized_object_and_increment_local_ref(
self, serialized_object,
ObjectRef object_ref=None,
Expand Down
188 changes: 188 additions & 0 deletions python/ray/dag/tests/experimental/test_accelerated_dag.py
Original file line number Diff line number Diff line change
Expand Up @@ -1841,6 +1841,194 @@ def test_intra_process_channel_with_multi_readers(
assert ray.get(refs) == [3, 3]


@ray.remote
class FastFailActor:
def sleep_and_echo(self, x):
time.sleep(x)
return x

def fail_if_x_is_even(self, x):
if x % 2 == 0:
raise ValueError("x is even")
return x

def sleep_and_fail(self, x):
time.sleep(x)
raise ValueError("fail")


class TestFastFail:
@pytest.mark.parametrize("is_async", [True, False])
def test_first_input_fail(self, ray_start_regular, is_async):
"""
Tests the case where the failing input is at the beginning of the input list.
"""
a = FastFailActor.remote()
with InputNode() as inp:
dag = MultiOutputNode(
[a.fail_if_x_is_even.bind(inp), a.sleep_and_echo.bind(inp)]
)
compiled_dag = dag.experimental_compile(enable_asyncio=is_async)

if is_async:

async def main():
futs = await compiled_dag.execute_async(6)
start_time = time.time()
with pytest.raises(ValueError, match="x is even"):
await asyncio.gather(*futs)
end_time = time.time()
assert end_time - start_time < 6

loop = get_or_create_event_loop()
loop.run_until_complete(main())
else:
start_time = time.time()
with pytest.raises(ValueError, match="x is even"):
ray.get(compiled_dag.execute(6))
end_time = time.time()
assert end_time - start_time < 6

@pytest.mark.parametrize("is_async", [True, False])
def test_last_input_fail(self, ray_start_regular, is_async):
"""
Tests the case where the failing input is at the end of the input list.
The test cannot use the same actor for both `sleep_and_echo` and
`fail_if_x_is_even` tasks because the control dependency would make the
`fail_if_x_is_even` task execute after the `sleep_and_echo` task finishes.
"""
a = FastFailActor.remote()
b = FastFailActor.remote()
with InputNode() as inp:
dag = MultiOutputNode(
[a.sleep_and_echo.bind(inp), b.fail_if_x_is_even.bind(inp)]
)
compiled_dag = dag.experimental_compile(enable_asyncio=is_async)

if is_async:

async def main():
futs = await compiled_dag.execute_async(6)
start_time = time.time()
with pytest.raises(ValueError, match="x is even"):
await asyncio.gather(*futs)
end_time = time.time()
assert end_time - start_time < 6

loop = get_or_create_event_loop()
loop.run_until_complete(main())
else:
start_time = time.time()
with pytest.raises(ValueError, match="x is even"):
print(ray.get(compiled_dag.execute(6)))
end_time = time.time()
assert end_time - start_time < 6

@pytest.mark.parametrize("is_async", [True, False])
def test_middle_input_fail(self, ray_start_regular, is_async):
"""
Tests the case where the failing input is in the middle of the input list.
"""
a = FastFailActor.remote()
b = FastFailActor.remote()
c = FastFailActor.remote()
with InputNode() as inp:
dag = MultiOutputNode(
[
a.sleep_and_echo.bind(inp),
b.fail_if_x_is_even.bind(inp),
c.sleep_and_echo.bind(inp),
]
)
compiled_dag = dag.experimental_compile(enable_asyncio=is_async)

if is_async:

async def main():
futs = await compiled_dag.execute_async(6)
start_time = time.time()
with pytest.raises(ValueError, match="x is even"):
await asyncio.gather(*futs)
end_time = time.time()
assert end_time - start_time < 6

loop = get_or_create_event_loop()
loop.run_until_complete(main())
else:
start_time = time.time()
with pytest.raises(ValueError, match="x is even"):
print(ray.get(compiled_dag.execute(6)))
end_time = time.time()
assert end_time - start_time < 6

@pytest.mark.parametrize("is_async", [True, False])
def test_all_inputs_fail(self, ray_start_regular, is_async):
"""
Tests the case where all inputs fail with different sleep times.
"""
a = FastFailActor.remote()
b = FastFailActor.remote()
c = FastFailActor.remote()
with InputNode() as inp:
dag = MultiOutputNode(
[
a.sleep_and_fail.bind(inp[0]),
b.sleep_and_fail.bind(inp[1]),
c.sleep_and_fail.bind(inp[2]),
]
)
compiled_dag = dag.experimental_compile(enable_asyncio=is_async)

if is_async:

async def main():
futs = await compiled_dag.execute_async(6, 0, 6)
start_time = time.time()
with pytest.raises(ValueError, match="fail"):
await asyncio.gather(*futs)
end_time = time.time()
assert end_time - start_time < 6

loop = get_or_create_event_loop()
loop.run_until_complete(main())
else:
start_time = time.time()
with pytest.raises(ValueError, match="fail"):
ray.get(compiled_dag.execute(6, 0, 6))
end_time = time.time()
assert end_time - start_time < 6

@pytest.mark.parametrize("is_async", [True, False])
def test_fail_and_retry(self, ray_start_regular, is_async):
"""
Tests the case where the first input fails, but the second input succeeds.
"""
a = FastFailActor.remote()
with InputNode() as inp:
dag = MultiOutputNode(
[a.fail_if_x_is_even.bind(inp), a.sleep_and_echo.bind(inp)]
)
compiled_dag = dag.experimental_compile(enable_asyncio=is_async)

if is_async:

async def main():
futs = await compiled_dag.execute_async(2)
with pytest.raises(ValueError, match="x is even"):
await asyncio.gather(*futs)
for _ in range(3):
futs = await compiled_dag.execute_async(1)
assert await asyncio.gather(*futs) == [1, 1]

loop = get_or_create_event_loop()
loop.run_until_complete(main())
else:
with pytest.raises(ValueError, match="x is even"):
ray.get(compiled_dag.execute(2))
for _ in range(3):
assert ray.get(compiled_dag.execute(1)) == [1, 1]


class TestLeafNode:
"""
Leaf nodes are not allowed right now because the exception thrown by the leaf
Expand Down
14 changes: 13 additions & 1 deletion python/ray/experimental/channel/cached_channel.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
import uuid
from typing import Any, Optional
from typing import Any, List, Optional, Tuple

from ray import ObjectRef
from ray.experimental.channel.common import ChannelInterface


Expand Down Expand Up @@ -100,6 +101,17 @@ def read(self, timeout: Optional[float] = None) -> Any:
# https://github.com/ray-project/ray/issues/47409
return ctx.get_data(self._channel_id)

def get_ray_waitables(self) -> List[Tuple[ObjectRef, bool]]:
self.ensure_registered_as_reader()
from ray.experimental.channel import ChannelContext

ctx = ChannelContext.get_current().serialization_context
if ctx.has_data(self._channel_id):
return []
if self._inner_channel is not None:
return self._inner_channel.get_ray_waitables()
return []

def close(self) -> None:
from ray.experimental.channel import ChannelContext

Expand Down
Loading
Loading