Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[core][compiled-graphs] Support wait-and-get to round-robin the acquisition of mutable objects allowing for fast failure #49444

Closed
wants to merge 30 commits into from
Closed
Show file tree
Hide file tree
Changes from 1 commit
Commits
Show all changes
30 commits
Select commit Hold shift + click to select a range
e631160
pass tests
kevin85421 Dec 26, 2024
bf585fa
pass tests, refactor
kevin85421 Dec 26, 2024
bf7a26e
pass tests, waitables
kevin85421 Dec 26, 2024
93bff72
update
kevin85421 Dec 26, 2024
d68f8db
pass tests
kevin85421 Dec 27, 2024
aa7b831
pass tests, retrieve obj one by one in sync reader
kevin85421 Dec 27, 2024
a22bbaa
update
kevin85421 Dec 30, 2024
d1aac6c
update comment and move import to top-level
kevin85421 Dec 30, 2024
a84e561
remove logs and update comments for WaitAndGetExperimentalMutableObjects
kevin85421 Dec 30, 2024
9e44591
update comments
kevin85421 Dec 30, 2024
3ddad3a
add some utils
kevin85421 Dec 30, 2024
4ca28e9
fix test_channel tests
kevin85421 Dec 30, 2024
d8dadb4
update
kevin85421 Dec 31, 2024
5d18be7
Merge remote-tracking branch 'upstream/master' into 20241224-2
kevin85421 Dec 31, 2024
cddaef7
fix test_channel
kevin85421 Dec 31, 2024
abd4d28
update type hint
kevin85421 Dec 31, 2024
a684329
remove c++ log
kevin85421 Dec 31, 2024
2a1b3fc
update _read_list
kevin85421 Dec 31, 2024
9a3ae27
refactor
kevin85421 Dec 31, 2024
9f2f1d1
remove comment for visualize tests
kevin85421 Dec 31, 2024
ba489c3
add tests
kevin85421 Dec 31, 2024
ebbee6f
address comments
kevin85421 Dec 31, 2024
8ebe78c
move retrieve_obj_refs to util
kevin85421 Dec 31, 2024
d958b24
fix lint error
kevin85421 Dec 31, 2024
551bacb
fix nccl channel tests
kevin85421 Jan 1, 2025
1ae976b
fix test
kevin85421 Jan 1, 2025
1cc8210
fix typo
kevin85421 Jan 1, 2025
8c4e61d
refactor
kevin85421 Jan 1, 2025
6d22187
Merge remote-tracking branch 'upstream/master' into 20241224-2
kevin85421 Jan 6, 2025
38c8652
fix lint
kevin85421 Jan 6, 2025
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Prev Previous commit
Next Next commit
update
Signed-off-by: Kai-Hsun Chen <kaihsun@anyscale.com>
  • Loading branch information
kevin85421 committed Dec 31, 2024
commit d8dadb47e95d518e2b62fce3327b611d1303c275
2 changes: 1 addition & 1 deletion python/ray/_private/worker.py
Original file line number Diff line number Diff line change
Expand Up @@ -965,7 +965,7 @@ def experimental_wait_and_get_mutable_objects(
)

if skip_deserialization:
return None, set()
return data_metadata_pairs, set()

non_complete_object_refs_set = set()
for i, (data, _) in enumerate(data_metadata_pairs):
Expand Down
110 changes: 86 additions & 24 deletions python/ray/experimental/channel/common.py
Original file line number Diff line number Diff line change
Expand Up @@ -256,9 +256,12 @@ def read(self, timeout: Optional[float] = None) -> Any:
"""
raise NotImplementedError

def get_ray_waitables(self) -> List[ObjectRef]:
def get_ray_waitables(self) -> List[Tuple[ObjectRef, bool]]:
"""
Get the ObjectRefs that will be read in the next read() call.
Get a list of tuples containing an ObjectRef and a boolean flag.
The flag indicates whether the ObjectRef should skip deserialization
in `experimental_wait_and_get_mutable_objects` and instead be
deserialized in the channel's `read()` method instead.
"""
raise NotImplementedError

Expand Down Expand Up @@ -308,11 +311,20 @@ def _read_list(self, timeout: Optional[float] = None) -> List[Any]:

def _get_all_waitables_to_num_consumers(self) -> Dict[ObjectRef, int]:
waitable_to_num_consumers = {}
skip_deserialization_waitables_to_num_consumers = {}
for c in self._input_channels:
waitables = c.get_ray_waitables()
for w in waitables:
waitable_to_num_consumers[w] = waitable_to_num_consumers.get(w, 0) + 1
return waitable_to_num_consumers
for waitable, skip_deserialization in waitables:
target_dict = (
skip_deserialization_waitables_to_num_consumers
if skip_deserialization
else waitable_to_num_consumers
)
target_dict[waitable] = target_dict.get(waitable, 0) + 1
return (
waitable_to_num_consumers,
skip_deserialization_waitables_to_num_consumers,
)

def read(self, timeout: Optional[float] = None) -> List[Any]:
"""
Expand Down Expand Up @@ -377,30 +389,48 @@ def _read_list(self, timeout: Optional[float] = None) -> List[Any]:
timeout = 1e6 if timeout is None or timeout == -1 else timeout
self._consume_non_complete_object_refs_if_needed(timeout)

waitable_to_num_consumers = self._get_all_waitables_to_num_consumers()
all_waitables = list(waitable_to_num_consumers.keys())
(
waitables_to_num_consumers,
skip_deserialization_waitables_to_num_consumers,
) = self._get_all_waitables_to_num_consumers()
normal_waitables = list(waitables_to_num_consumers.keys())
skip_deserialization_waitables = list(
skip_deserialization_waitables_to_num_consumers.keys()
)
Comment on lines +454 to +461
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

looks like these are static? should we do it at init time?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

move to ReaderInterface constructor.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

After giving it a second thought, I realized it is not static. For example, the get_ray_waitables method of BufferedSharedMemoryChannel should return the buffer that will be read in the current read operation. Therefore, the return value of get_ray_waitables is not always the same.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

    def get_ray_waitables(self) -> List[Tuple[ObjectRef, bool]]:
        self.ensure_registered_as_reader()
        return self._buffers[self._next_read_index].get_ray_waitables()


timeout_point = time.monotonic() + timeout
worker = ray._private.worker.global_worker
while len(all_waitables) > 0:
while len(normal_waitables) > 0 or len(skip_deserialization_waitables) > 0:
# Retrieve at most one object each time.
use_normal_waitables = len(normal_waitables) > 0
target_waitable_group = (
normal_waitables
if use_normal_waitables
else skip_deserialization_waitables
)
target_waitable_group_num_consumers = (
waitables_to_num_consumers
if use_normal_waitables
else skip_deserialization_waitables_to_num_consumers
)
(
values,
non_complete_object_refs_set,
) = worker.experimental_wait_and_get_mutable_objects(
all_waitables,
target_waitable_group,
num_returns=1,
timeout_ms=max(0, (timeout_point - time.monotonic()) * 1000),
return_exceptions=True,
skip_deserialization=not use_normal_waitables,
suppress_timeout_errors=True,
)
ctx = ChannelContext.get_current().serialization_context
for i, value in enumerate(values):
if all_waitables[i] in non_complete_object_refs_set:
if target_waitable_group[i] in non_complete_object_refs_set:
continue
if isinstance(value, ray.exceptions.RayTaskError):
self._non_complete_object_refs = list(non_complete_object_refs_set)
for w in all_waitables:
for w in target_waitable_group:
ctx.reset_data(w)
# If we raise an exception immediately, it will be considered
# as a system error which will cause the execution loop to
Expand All @@ -411,17 +441,21 @@ def _read_list(self, timeout: Optional[float] = None) -> List[Any]:
# get an undefined partial result.
kevin85421 marked this conversation as resolved.
Show resolved Hide resolved
return [value for _ in range(len(self._input_channels))]
ctx.set_data(
all_waitables[i],
target_waitable_group[i],
value,
waitable_to_num_consumers[all_waitables[i]],
target_waitable_group_num_consumers[target_waitable_group[i]],
)
all_waitables = list(non_complete_object_refs_set)
if time.monotonic() > timeout_point and len(all_waitables) != 0:
target_waitable_group = list(non_complete_object_refs_set)
if time.monotonic() > timeout_point and len(target_waitable_group) != 0:
# This ensures that the reader attempts to retrieve
# data once even when the `timeout` is 0.
raise ray.exceptions.RayChannelTimeoutError(
"Timed out waiting for channel data."
)
if use_normal_waitables:
normal_waitables = target_waitable_group
else:
skip_deserialization_waitables = target_waitable_group

results = []
for c in self._input_channels:
Expand Down Expand Up @@ -467,32 +501,60 @@ def start(self):
self._background_task = asyncio.ensure_future(self.run())

def _run(self):
# TODO(kevin85421): Consume non-complete object refs.
# TODO(kevin85421): Consume waitable one by one.
(
waitables_to_num_consumers,
skip_deserialization_waitables_to_num_consumers,
) = self._get_all_waitables_to_num_consumers()
normal_waitables = list(waitables_to_num_consumers.keys())
skip_deserialization_waitables = list(
skip_deserialization_waitables_to_num_consumers.keys()
)

results = []
waitable_to_num_consumers = self._get_all_waitables_to_num_consumers()
all_waitables = list(waitable_to_num_consumers.keys())

worker = ray._private.worker.global_worker
while len(all_waitables) > 0:
while len(normal_waitables) > 0 or len(skip_deserialization_waitables) > 0:
use_normal_waitables = len(normal_waitables) > 0
target_waitable_group = (
normal_waitables
if use_normal_waitables
else skip_deserialization_waitables
)
target_waitable_group_num_consumers = (
waitables_to_num_consumers
if use_normal_waitables
else skip_deserialization_waitables_to_num_consumers
)

(
values,
non_complete_object_refs_set,
) = worker.experimental_wait_and_get_mutable_objects(
all_waitables,
len(all_waitables),
target_waitable_group,
num_returns=len(target_waitable_group),
timeout_ms=1000,
return_exceptions=True,
skip_deserialization=not use_normal_waitables,
suppress_timeout_errors=True,
)

ctx = ChannelContext.get_current().serialization_context
for i, value in enumerate(values):
if all_waitables[i] in non_complete_object_refs_set:
if target_waitable_group[i] in non_complete_object_refs_set:
continue
ctx.set_data(
all_waitables[i],
target_waitable_group[i],
value,
waitable_to_num_consumers[all_waitables[i]],
target_waitable_group_num_consumers[target_waitable_group[i]],
)
all_waitables = list(non_complete_object_refs_set)

target_waitable_group = list(non_complete_object_refs_set)
if use_normal_waitables:
normal_waitables = target_waitable_group
else:
skip_deserialization_waitables = target_waitable_group
if sys.is_finalizing():
return results

Expand Down
4 changes: 2 additions & 2 deletions python/ray/experimental/channel/shared_memory_channel.py
Original file line number Diff line number Diff line change
Expand Up @@ -500,9 +500,9 @@ def read(self, timeout: Optional[float] = None) -> Any:
ret = rets[0]
return ret

def get_ray_waitables(self) -> List[ObjectRef]:
def get_ray_waitables(self) -> List[Tuple[ObjectRef, bool]]:
self.ensure_registered_as_reader()
return [self._local_reader_ref]
return [(self._local_reader_ref, False)]

def release_buffer(self, timeout: Optional[float] = None) -> None:
assert (
Expand Down
56 changes: 50 additions & 6 deletions python/ray/experimental/channel/torch_tensor_nccl_channel.py
Original file line number Diff line number Diff line change
Expand Up @@ -164,6 +164,10 @@ def ensure_registered_as_writer(self):
self._cpu_data_channel.ensure_registered_as_writer()

def ensure_registered_as_reader(self):
reader = utils.get_self_actor()
if reader == self._writer:
self._local_channel.ensure_registered_as_reader()
return
self._gpu_data_channel.ensure_registered_as_reader()
if self._cpu_data_channel is not None:
self._cpu_data_channel.ensure_registered_as_reader()
Expand Down Expand Up @@ -194,12 +198,31 @@ def _send_cpu_and_gpu_data(self, value: Any, timeout: Optional[float]):
# normally.
self.serialization_ctx.set_use_external_transport(False)

# First send the extracted tensors through a GPU-specific channel.
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

NCCL write -> NCCL read -> all mutable objects are ready -> _cpu_data_channel.write -> NCCL write

self._gpu_data_channel.write(gpu_tensors)
# Next send the non-tensor data through a CPU-specific channel. The
# The `write` operation of the shared memory channel must be called
# before the `write` operation of the GPU channel. This is because in
# `_read_list`, the channel's `read` operation waits for all underlying
# mutable objects for all input channels to be consumed.
#
# Step 1: `_cpu_data_channel.write` is called to write data into the
# mutable object.
# Step 2: `_read_list` consumes the mutable object.
# Step 3: After all underlying mutable objects of all input channels are
# consumed, `read` is called in the receiver of the NCCL channel.
#
# If we call NCCL write before the CPU channel write, then the shared
# memory channel's `write` operation will block because the NCCL write
# operation blocks forever until the NCCL read operation is called. However,
# the `read` operation of the NCCL channel will never be called because
# `_read_list` will never consume the mutable object that hasn't been
# written yet.

# First send the non-tensor data through a CPU-specific channel. The
# data contains placeholders for the extracted tensors.
self._cpu_data_channel.write(cpu_data)

# Next send the extracted tensors through a GPU-specific channel.
self._gpu_data_channel.write(gpu_tensors)

def write(self, value: Any, timeout: Optional[float] = None) -> None:
"""
Send a value that may contain torch.Tensors that should be sent via
Expand Down Expand Up @@ -275,17 +298,29 @@ def _recv_cpu_and_gpu_data(
# Next, read and deserialize the non-tensor data. The registered custom
# deserializer will replace the found tensor placeholders with
# `tensors`.
data = self._cpu_data_channel.read(
#
# We need to deserialize the CPU data channel first in `read` instead of
# `_read_list` because the deserialization of the CPU data channel relies
# on the out-of-band tensors in the serialization context. Therefore, the
# `read` method of the NCCL channel must be called first to ensure that
# the out-of-band tensors are ready.
serialized_data, metadata = self._cpu_data_channel.read(
timeout=timeout,
)
rets = self._worker.deserialize_objects(
[(serialized_data, metadata)], self._cpu_data_channel.get_ray_waitables()
)
assert len(rets) == 1
ret = rets[0]

# Check that all placeholders had a corresponding tensor.
(
_,
deserialized_tensor_placeholders,
) = self.serialization_ctx.reset_out_of_band_tensors([])
assert deserialized_tensor_placeholders == set(range(len(tensors)))

return data
return ret

def read(self, timeout: Optional[float] = None) -> Any:
"""
Expand Down Expand Up @@ -327,10 +362,19 @@ def read(self, timeout: Optional[float] = None) -> Any:

def get_ray_waitables(self) -> List[ObjectRef]:
kevin85421 marked this conversation as resolved.
Show resolved Hide resolved
self.ensure_registered_as_reader()
reader = utils.get_self_actor()
if reader == self._writer:
return self._local_channel.get_ray_waitables()
waitables = []
waitables.extend(self._gpu_data_channel.get_ray_waitables())
if self._cpu_data_channel is not None:
waitables.extend(self._cpu_data_channel.get_ray_waitables())
cpu_waitables = self._cpu_data_channel.get_ray_waitables()
assert len(cpu_waitables) == 1
# Skip deserialization of the CPU data in `_read_list` and
# handle the deserialization in the channel's `read()` method
# after the out-of-band tensors are ready in the serialization
# context instead.
waitables.append((cpu_waitables[0][0], True))
return waitables

def close(self) -> None:
Expand Down
Loading