Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[core][compiled-graphs] Support wait-and-get to round-robin the acquisition of mutable objects allowing for fast failure #49444

Closed
wants to merge 30 commits into from

Conversation

kevin85421
Copy link
Member

@kevin85421 kevin85421 commented Dec 26, 2024

Why are these changes needed?

Issue statement

https://gist.github.com/kevin85421/a7f14ea38d64420b105fbd79fd31fb8a

Without this PR, both SynchronousReader._read_list and AwaitableBackgroundReader._run call the read function of each input channel sequentially. If the first input channel involves a long-running task and the second one fails immediately, the reader still has to wait until all channels have been read before failing.

Based on #46337,

This is a problem for tensor-parallel inference, because all of the workers execute in lockstep, and if one actor throws an exception, the others may hang. Depending on the order of the actors, ray.get() may never return.

Implementation details

  • experimental_wait_and_get_mutable_objects / CoreWorker::WaitAndGetExperimentalMutableObjects

    • WaitAndGetExperimentalMutableObjects iterates through a list of mutable object references and retrieves the data when the objects are ready. The function returns when either num_objects mutable objects are acquired or the operation times out.
  • SynchronousReader._read_list / AwaitableBackgroundReader._run

    • _get_all_waitables_to_num_consumers: Iterate through self._input_channels and call get_ray_waitables to retrieve all underlying mutable object references.
    • worker.experimental_wait_and_get_mutable_objects: Attempt to retrieve a single mutable object from a list of mutable object references. If the return value is a RayTaskError, immediately return and raise an exception. Otherwise, write the return value into ChannelContext.
    • After all mutable objects have been retrieved, iterate through self._input_channels and call channel.read().
  • Channel.read (shared_memory_channel.py)

    • Because the mutable objects are retrieved in _read_list, Channel.read retrieves the data from ChannelContext instead of object store.
  • get_ray_waitables

    • Retrieve the underlying mutable object references that the "next" read operation plans to access. Therefore, get_ray_waitables may return different results for different read function calls on the same channel. For example, BufferedSharedMemoryChannel's get_ray_waitables:

      def get_ray_waitables(self) -> List[Tuple[ObjectRef, bool]]:
          self.ensure_registered_as_reader()
          return self._buffers[self._next_read_index].get_ray_waitables()
  • Special case: TorchTensorNcclChannel: See the comments in the file for more details.

    • Call CPU write before NCCL write. Since the channel's read is only called after all required mutable objects have been retrieved, NCCL read will only be invoked after the reader has already retrieved the mutable object. Therefore, CPU write must occur before NCCL write to avoid a deadlock.
    • _read_list and _run will skip deserialization for the TorchTensorNcclChannel's mutable object because TorchTensorNcclChannel relies on a custom serializer, which replaces placeholders in the CPU data with tensors read from the NCCL channel during deserialization.
      • If we deserialize the mutable object in _read_list or _run before the reader has retrieved the GPU tensors via the NCCL channel and placed the out-of-band tensors into the serialization context, issues may arise.
      • Instead, the reader will deserialize the CPU data after the out-of-band tensors are ready in the channel's read operation.

Related issue number

Closes #46337

Checks

  • I've signed off every commit(by using the -s flag, i.e., git commit -s) in this PR.
  • I've run scripts/format.sh to lint the changes in this PR.
  • I've included any doc changes needed for https://docs.ray.io/en/master/.
    • I've added any new APIs to the API Reference. For example, if I added a
      method in Tune, I've added it in doc/source/tune/api/ under the
      corresponding .rst file.
  • I've made sure the tests are passing. Note that there might be a few flaky tests, see the recent failures at https://flakey-tests.ray.io/
  • Testing Strategy
    • Unit tests
    • Release tests
    • This PR is not tested :(

Signed-off-by: kaihsun <kaihsun@anyscale.com>
Signed-off-by: kaihsun <kaihsun@anyscale.com>
Signed-off-by: kaihsun <kaihsun@anyscale.com>
Signed-off-by: kaihsun <kaihsun@anyscale.com>
Signed-off-by: kaihsun <kaihsun@anyscale.com>
Signed-off-by: kaihsun <kaihsun@anyscale.com>
Signed-off-by: kaihsun <kaihsun@anyscale.com>
@kevin85421 kevin85421 changed the title WIP - 2 [core][experimental] ray.get of accelerated DAG result may not throw exception for MultiOutputNode Dec 30, 2024
Signed-off-by: kaihsun <kaihsun@anyscale.com>
Signed-off-by: kaihsun <kaihsun@anyscale.com>
python/ray/experimental/channel/shared_memory_channel.py Outdated Show resolved Hide resolved
src/ray/core_worker/core_worker.h Outdated Show resolved Hide resolved
python/ray/_private/worker.py Show resolved Hide resolved
python/ray/_raylet.pyx Show resolved Hide resolved
python/ray/experimental/channel/common.py Outdated Show resolved Hide resolved
src/ray/core_worker/core_worker.cc Outdated Show resolved Hide resolved
python/ray/experimental/channel/common.py Outdated Show resolved Hide resolved
Signed-off-by: kaihsun <kaihsun@anyscale.com>
Signed-off-by: kaihsun <kaihsun@anyscale.com>
Signed-off-by: Kai-Hsun Chen <kaihsun@anyscale.com>
Signed-off-by: Kai-Hsun Chen <kaihsun@anyscale.com>
Signed-off-by: Kai-Hsun Chen <kaihsun@anyscale.com>
Signed-off-by: Kai-Hsun Chen <kaihsun@anyscale.com>
Signed-off-by: Kai-Hsun Chen <kaihsun@anyscale.com>
Signed-off-by: Kai-Hsun Chen <kaihsun@anyscale.com>
Signed-off-by: Kai-Hsun Chen <kaihsun@anyscale.com>
Signed-off-by: Kai-Hsun Chen <kaihsun@anyscale.com>
@kevin85421 kevin85421 changed the title [core][experimental] ray.get of accelerated DAG result may not throw exception for MultiOutputNode [core][compiled-graphs] ray.get of accelerated DAG result may not throw exception for MultiOutputNode Dec 31, 2024
@kevin85421 kevin85421 changed the title [core][compiled-graphs] ray.get of accelerated DAG result may not throw exception for MultiOutputNode [core][compiled-graphs] Support wait-and-get to round-robin the acquisition of mutable objects, allowing for fast failure Dec 31, 2024
@kevin85421 kevin85421 marked this pull request as ready for review December 31, 2024 09:23
@kevin85421 kevin85421 changed the title [core][compiled-graphs] Support wait-and-get to round-robin the acquisition of mutable objects, allowing for fast failure [core][compiled-graphs] Support wait-and-get to round-robin the acquisition of mutable objects allowing for fast failure Dec 31, 2024
Copy link
Contributor

@ruisearch42 ruisearch42 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Initial pass, partial review

python/ray/experimental/channel/common.py Show resolved Hide resolved
Comment on lines +435 to +442
(
waitables_to_num_consumers,
skip_deserialization_waitables_to_num_consumers,
) = self._get_all_waitables_to_num_consumers()
normal_waitables = list(waitables_to_num_consumers.keys())
skip_deserialization_waitables = list(
skip_deserialization_waitables_to_num_consumers.keys()
)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

looks like these are static? should we do it at init time?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

move to ReaderInterface constructor.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

After giving it a second thought, I realized it is not static. For example, the get_ray_waitables method of BufferedSharedMemoryChannel should return the buffer that will be read in the current read operation. Therefore, the return value of get_ray_waitables is not always the same.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

    def get_ray_waitables(self) -> List[Tuple[ObjectRef, bool]]:
        self.ensure_registered_as_reader()
        return self._buffers[self._next_read_index].get_ray_waitables()

python/ray/_private/worker.py Outdated Show resolved Hide resolved
python/ray/_raylet.pyx Outdated Show resolved Hide resolved
src/ray/core_worker/core_worker.cc Show resolved Hide resolved
@@ -193,12 +198,31 @@ def _send_cpu_and_gpu_data(self, value: Any, timeout: Optional[float]):
# normally.
self.serialization_ctx.set_use_external_transport(False)

# First send the extracted tensors through a GPU-specific channel.
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

NCCL write -> NCCL read -> all mutable objects are ready -> _cpu_data_channel.write -> NCCL write

Signed-off-by: Kai-Hsun Chen <kaihsun@anyscale.com>
Signed-off-by: Kai-Hsun Chen <kaihsun@anyscale.com>
Signed-off-by: Kai-Hsun Chen <kaihsun@anyscale.com>
Signed-off-by: Kai-Hsun Chen <kaihsun@anyscale.com>
Signed-off-by: Kai-Hsun Chen <kaihsun@anyscale.com>
Signed-off-by: Kai-Hsun Chen <kaihsun@anyscale.com>
Signed-off-by: Kai-Hsun Chen <kaihsun@anyscale.com>
src/ray/core_worker/core_worker.cc Outdated Show resolved Hide resolved
int64_t remaining_timeout = timeout_ms == -1 ? 1e9 : timeout_ms;
auto timeout_point = ToTimeoutPoint(remaining_timeout);
int64_t iteration_timeout =
std::min(remaining_timeout, RayConfig::instance().get_timeout_milliseconds());
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think this env variable should be renamed, kind of misleading for both core and cgraph, maybe in a separate pr, but it's should be like get_iteration_timeout_milliseconds

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Agree, the name is misleading.


// Try to acquire the object.
Status s = experimental_mutable_object_provider_->ReadAcquire(
ids[i], results[i], iteration_timeout);
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Also now that timeout is always guaranteed to be there, ReadAcquire shouldn't be taking an int that could be -1, just pass it a non-optional timeout_point from here

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Good point. I will address it in a separate PR to avoid making this PR bigger.

@stephanie-wang
Copy link
Contributor

Will this approach work? Modify SynchronousReader._read_list to try to read each channel up to a set timeout (say 100ms), then try the next one?

@kevin85421
Copy link
Member Author

Will this approach work? Modify SynchronousReader._read_list to try to read each channel up to a set timeout (say 100ms), then try the next one?

NcclCommunicator does not seem to support send/recv with a timeout, and we need to ensure that other channels in the future also support timeouts.

If a read operation in a channel has multiple points where a timeout exception can be thrown, we need to distinguish them and handle them differently to recover from any side effects, especially if we still want to reuse the DAG.

@kevin85421
Copy link
Member Author

I synced with @stephanie-wang offline. I will try using the channel read with a timeout and ensure that the channels are idempotent.

If a read operation in a channel has multiple points where a timeout exception can be thrown, we need to distinguish them and handle them differently to recover from any side effects, especially if we still want to reuse the DAG.

For this question, we just relied on the e2e timeout to handle it.

@kevin85421
Copy link
Member Author

We decided to proceed with #49711 instead of this PR.

@kevin85421 kevin85421 closed this Jan 11, 2025
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

[core][experimental] ray.get of accelerated DAG result may not throw exception for MultiOutputNode
4 participants