Skip to content

Commit

Permalink
[core] Support mutable plasma objects (ray-project#41515)
Browse files Browse the repository at this point in the history
This is a first pass on introducing an experimental "channel" concept that can be used for direct worker-worker communication, bypassing the usual Ray Core components like the driver and raylet.

Channels are implemented as mutable plasma objects. The object can be written multiple times by a client. The writer must specify the number of reads that can be made before the written object value is no longer valid. Reads block until the specified version or a later one is available. Writes block until all readers are available. Synchronization between a single writer and multiple readers is performed through a new header for plasma objects that is stored in shared memory.

API:

    channel: Channel = ray.experimental.channel.Channel(buf_size): Client uses the normal ray.put path to create a mutable plasma object. Once created and sealed for the first time, the plasma store synchronously reads and releases the object. At this point, the object may be written by the original client and read by others.
    channel.write(val): Use the handle returned by the above to send a value through the channel. The caller waits until all readers of the previous version have released the object, then writes a new version.
    val = channel.begin_read(): Blocks until a value is available. Equivalent to ray.get. This is the beginning of the client's read.
    channel.end_read(): End the client's read, marking the channel as available to write again.

---------

Signed-off-by: Stephanie Wang <swang@cs.berkeley.edu>
  • Loading branch information
stephanie-wang authored Dec 9, 2023
1 parent bfa35fd commit cb5bb4e
Show file tree
Hide file tree
Showing 30 changed files with 1,445 additions and 118 deletions.
2 changes: 1 addition & 1 deletion .buildkite/core.rayci.yml
Original file line number Diff line number Diff line change
Expand Up @@ -222,7 +222,7 @@ steps:

- label: ":ray: core: cpp ubsan tests"
tags: core_cpp
instance_type: medium
instance_type: large
commands:
- bazel run //ci/ray_ci:test_in_docker -- //:all //src/... core --build-type ubsan
--except-tags no_ubsan
Expand Down
1 change: 1 addition & 0 deletions BUILD.bazel
Original file line number Diff line number Diff line change
Expand Up @@ -321,6 +321,7 @@ PLASMA_LINKOPTS = [] + select({
ray_cc_library(
name = "plasma_client",
srcs = [
"src/ray/object_manager/common.cc",
"src/ray/object_manager/plasma/client.cc",
"src/ray/object_manager/plasma/connection.cc",
"src/ray/object_manager/plasma/malloc.cc",
Expand Down
3 changes: 2 additions & 1 deletion cpp/src/ray/runtime/object/native_object_store.cc
Original file line number Diff line number Diff line change
Expand Up @@ -91,7 +91,8 @@ std::vector<std::shared_ptr<msgpack::sbuffer>> NativeObjectStore::GetRaw(
const std::vector<ObjectID> &ids, int timeout_ms) {
auto &core_worker = CoreWorkerProcess::GetCoreWorker();
std::vector<std::shared_ptr<::ray::RayObject>> results;
::ray::Status status = core_worker.Get(ids, timeout_ms, &results);
::ray::Status status = core_worker.Get(
ids, timeout_ms, /*is_experimental_mutable_object=*/false, &results);
if (!status.ok()) {
if (status.IsTimedOut()) {
throw RayTimeoutException("Get object error:" + status.message());
Expand Down
87 changes: 87 additions & 0 deletions python/ray/_private/ray_perf.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,8 @@
import multiprocessing
import ray

import ray.experimental.channel as ray_channel

logger = logging.getLogger(__name__)


Expand Down Expand Up @@ -288,6 +290,91 @@ def async_actor_multi():
results += timeit("n:n async-actor calls async", async_actor_multi, m * n)
ray.shutdown()

#################################################
# Perf tests for channels, used in compiled DAGs.
#################################################

ray.init()

def put_channel_small(chans, do_get=False, do_release=False):
for chan in chans:
chan.write(b"0")
if do_get:
chan.begin_read()
if do_release:
chan.end_read()

@ray.remote
class ChannelReader:
def ready(self):
return

def read(self, chans):
while True:
for chan in chans:
chan.begin_read()
chan.end_read()

chans = [ray_channel.Channel(1000)]
results += timeit(
"local put, single channel calls",
lambda: put_channel_small(chans, do_release=True),
)
results += timeit(
"local put:local get, single channel calls",
lambda: put_channel_small(chans, do_get=True, do_release=True),
)

chans = [ray_channel.Channel(1000)]
reader = ChannelReader.remote()
ray.get(reader.ready.remote())
reader.read.remote(chans)
results += timeit(
"local put:1 remote get, single channel calls", lambda: put_channel_small(chans)
)
ray.kill(reader)

n_cpu = multiprocessing.cpu_count() // 2
print(f"Testing multiple readers/channels, n={n_cpu}")

chans = [ray_channel.Channel(1000, num_readers=n_cpu)]
readers = [ChannelReader.remote() for _ in range(n_cpu)]
ray.get([reader.ready.remote() for reader in readers])
for reader in readers:
reader.read.remote(chans)
results += timeit(
"local put:n remote get, single channel calls",
lambda: put_channel_small(chans),
)
for reader in readers:
ray.kill(reader)

chans = [ray_channel.Channel(1000) for _ in range(n_cpu)]
reader = ChannelReader.remote()
ray.get(reader.ready.remote())
reader.read.remote(chans)
results += timeit(
"local put:1 remote get, n channels calls", lambda: put_channel_small(chans)
)
ray.kill(reader)

chans = [ray_channel.Channel(1000) for _ in range(n_cpu)]
readers = [ChannelReader.remote() for _ in range(n_cpu)]
ray.get([reader.ready.remote() for reader in readers])
for chan, reader in zip(chans, readers):
reader.read.remote([chan])
results += timeit(
"local put:n remote get, n channels calls", lambda: put_channel_small(chans)
)
for reader in readers:
ray.kill(reader)

ray.shutdown()

############################
# End of channel perf tests.
############################

NUM_PGS = 100
NUM_BUNDLES = 1
ray.init(resources={"custom": 100})
Expand Down
43 changes: 38 additions & 5 deletions python/ray/_private/worker.py
Original file line number Diff line number Diff line change
Expand Up @@ -712,7 +712,13 @@ def set_mode(self, mode):
def set_load_code_from_local(self, load_code_from_local):
self._load_code_from_local = load_code_from_local

def put_object(self, value, object_ref=None, owner_address=None):
def put_object(
self,
value: Any,
object_ref: Optional["ray.ObjectRef"] = None,
owner_address: Optional[str] = None,
_is_experimental_mutable_object: bool = False,
):
"""Put value in the local object store with object reference `object_ref`.
This assumes that the value for `object_ref` has not yet been placed in
Expand All @@ -727,6 +733,10 @@ def put_object(self, value, object_ref=None, owner_address=None):
object_ref: The object ref of the value to be
put. If None, one will be generated.
owner_address: The serialized address of object's owner.
_is_experimental_mutable_object: An experimental flag for mutable
objects. If True, then the returned object will not have a
valid value. The object must be written to using the
ray.experimental.channel API before readers can read.
Returns:
ObjectRef: The object ref the object was put under.
Expand Down Expand Up @@ -760,6 +770,11 @@ def put_object(self, value, object_ref=None, owner_address=None):
f"{sio.getvalue()}"
)
raise TypeError(msg) from e

# If the object is mutable, then the raylet should never read the
# object. Instead, clients will keep the object pinned.
pin_object = not _is_experimental_mutable_object

# This *must* be the first place that we construct this python
# ObjectRef because an entry with 0 local references is created when
# the object is Put() in the core worker, expecting that this python
Expand All @@ -768,7 +783,11 @@ def put_object(self, value, object_ref=None, owner_address=None):
# reference counter.
return ray.ObjectRef(
self.core_worker.put_serialized_object_and_increment_local_ref(
serialized_value, object_ref=object_ref, owner_address=owner_address
serialized_value,
object_ref=object_ref,
pin_object=pin_object,
owner_address=owner_address,
_is_experimental_mutable_object=_is_experimental_mutable_object,
),
# The initial local reference is already acquired internally.
skip_adding_local_ref=True,
Expand All @@ -790,7 +809,12 @@ def deserialize_objects(self, data_metadata_pairs, object_refs):
context = self.get_serialization_context()
return context.deserialize_objects(data_metadata_pairs, object_refs)

def get_objects(self, object_refs: list, timeout: Optional[float] = None):
def get_objects(
self,
object_refs: list,
timeout: Optional[float] = None,
_is_experimental_mutable_object: bool = False,
):
"""Get the values in the object store associated with the IDs.
Return the values from the local object store for object_refs. This
Expand All @@ -806,6 +830,10 @@ def get_objects(self, object_refs: list, timeout: Optional[float] = None):
list: List of deserialized objects
bytes: UUID of the debugger breakpoint we should drop
into or b"" if there is no breakpoint.
_is_experimental_mutable_object: An experimental flag for mutable
objects. If True, then wait until there is a value available to
read. The object must also already be local, or else the get
call will hang.
"""
# Make sure that the values are object refs.
for object_ref in object_refs:
Expand All @@ -817,7 +845,10 @@ def get_objects(self, object_refs: list, timeout: Optional[float] = None):

timeout_ms = int(timeout * 1000) if timeout is not None else -1
data_metadata_pairs = self.core_worker.get_objects(
object_refs, self.current_task_id, timeout_ms
object_refs,
self.current_task_id,
timeout_ms,
_is_experimental_mutable_object,
)
debugger_breakpoint = b""
for data, metadata in data_metadata_pairs:
Expand Down Expand Up @@ -2648,7 +2679,9 @@ def get(
@PublicAPI
@client_mode_hook
def put(
value: Any, *, _owner: Optional["ray.actor.ActorHandle"] = None
value: Any,
*,
_owner: Optional["ray.actor.ActorHandle"] = None,
) -> "ray.ObjectRef":
"""Store an object in the object store.
Expand Down
3 changes: 2 additions & 1 deletion python/ray/_raylet.pxd
Original file line number Diff line number Diff line change
Expand Up @@ -134,7 +134,8 @@ cdef class CoreWorker:
CObjectID *c_object_id, shared_ptr[CBuffer] *data,
c_bool created_by_worker,
owner_address=*,
c_bool inline_small_object=*)
c_bool inline_small_object=*,
c_bool is_experimental_mutable_object=*)
cdef unique_ptr[CAddress] _convert_python_address(self, address=*)
cdef store_task_output(
self, serialized_object,
Expand Down
71 changes: 61 additions & 10 deletions python/ray/_raylet.pyx
Original file line number Diff line number Diff line change
Expand Up @@ -3370,14 +3370,15 @@ cdef class CoreWorker:
return self.plasma_event_handler

def get_objects(self, object_refs, TaskID current_task_id,
int64_t timeout_ms=-1):
int64_t timeout_ms=-1,
c_bool _is_experimental_mutable_object=False):
cdef:
c_vector[shared_ptr[CRayObject]] results
CTaskID c_task_id = current_task_id.native()
c_vector[CObjectID] c_object_ids = ObjectRefsToVector(object_refs)
with nogil:
op_status = CCoreWorkerProcess.GetCoreWorker().Get(
c_object_ids, timeout_ms, &results)
c_object_ids, timeout_ms, _is_experimental_mutable_object, &results)
check_status(op_status)

return RayObjectsToDataMetadataPairs(results)
Expand Down Expand Up @@ -3412,7 +3413,9 @@ cdef class CoreWorker:
CObjectID *c_object_id, shared_ptr[CBuffer] *data,
c_bool created_by_worker,
owner_address=None,
c_bool inline_small_object=True):
c_bool inline_small_object=True,
c_bool is_experimental_mutable_object=False,
):
cdef:
unique_ptr[CAddress] c_owner_address

Expand All @@ -3422,7 +3425,8 @@ cdef class CoreWorker:
with nogil:
check_status(CCoreWorkerProcess.GetCoreWorker()
.CreateOwnedAndIncrementLocalRef(
metadata, data_size, contained_ids,
is_experimental_mutable_object, metadata,
data_size, contained_ids,
c_object_id, data, created_by_worker,
move(c_owner_address),
inline_small_object))
Expand Down Expand Up @@ -3511,11 +3515,57 @@ cdef class CoreWorker:
generator_id=CObjectID.Nil(),
owner_address=c_owner_address))

def put_serialized_object_and_increment_local_ref(self, serialized_object,
ObjectRef object_ref=None,
c_bool pin_object=True,
owner_address=None,
c_bool inline_small_object=True):
def experimental_mutable_object_put_serialized(self, serialized_object,
ObjectRef object_ref,
num_readers,
):
cdef:
CObjectID c_object_id = object_ref.native()
shared_ptr[CBuffer] data
unique_ptr[CAddress] null_owner_address

metadata = string_to_buffer(serialized_object.metadata)
data_size = serialized_object.total_bytes
check_status(CCoreWorkerProcess.GetCoreWorker()
.ExperimentalMutableObjectWriteAcquire(
c_object_id,
metadata,
data_size,
num_readers,
&data,
))
if data_size > 0:
(<SerializedObject>serialized_object).write_to(
Buffer.make(data))
check_status(CCoreWorkerProcess.GetCoreWorker()
.ExperimentalMutableObjectWriteRelease(
c_object_id,
))

def experimental_mutable_object_read_release(self, object_refs):
"""
For experimental.channel.Channel.
Signal to the writer that the channel is ready to write again. The read
began when the caller calls ray.get and a written value is available. If
ray.get is not called first, then this call will block until a value is
written, then drop the value.
"""
cdef:
c_vector[CObjectID] c_object_ids = ObjectRefsToVector(object_refs)
with nogil:
op_status = (CCoreWorkerProcess.GetCoreWorker()
.ExperimentalMutableObjectReadRelease(c_object_ids))
check_status(op_status)

def put_serialized_object_and_increment_local_ref(
self, serialized_object,
ObjectRef object_ref=None,
c_bool pin_object=True,
owner_address=None,
c_bool inline_small_object=True,
c_bool _is_experimental_mutable_object=False,
):
cdef:
CObjectID c_object_id
shared_ptr[CBuffer] data
Expand All @@ -3531,7 +3581,8 @@ cdef class CoreWorker:
object_already_exists = self._create_put_buffer(
metadata, total_bytes, object_ref,
contained_object_ids,
&c_object_id, &data, True, owner_address, inline_small_object)
&c_object_id, &data, True, owner_address, inline_small_object,
_is_experimental_mutable_object)

logger.debug(
f"Serialized object size of {c_object_id.Hex()} is {total_bytes} bytes")
Expand Down
Loading

0 comments on commit cb5bb4e

Please sign in to comment.