diff --git a/CMakeLists.txt b/CMakeLists.txt index 379674c8f..683ee98ef 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -147,6 +147,10 @@ include(cmake/setup_iwyu.cmake) # Keep all source files sorted!!! add_library(libsrf + src/internal/data_plane/callbacks.cpp + src/internal/data_plane/client.cpp + src/internal/data_plane/resources.cpp + src/internal/data_plane/request.cpp src/internal/data_plane/server.cpp src/internal/executor/executor.cpp src/internal/executor/iexecutor.cpp diff --git a/include/srf/codable/encoded_object.hpp b/include/srf/codable/encoded_object.hpp index 12cee1011..abd91af77 100644 --- a/include/srf/codable/encoded_object.hpp +++ b/include/srf/codable/encoded_object.hpp @@ -34,6 +34,7 @@ #include #include #include + namespace srf::codable { /** @@ -60,7 +61,7 @@ class EncodedObject const protos::EncodedObject& proto() const; /** - * @brief Access const memory::buffer_view of the RemoteDescriptor at the required index + * @brief Access const memory::buffer_view of the RemoteMemoryDescriptor at the required index * @return memory::const_buffer_view */ memory::const_buffer_view memory_block(std::size_t idx) const; @@ -119,20 +120,20 @@ class EncodedObject memory::buffer_view mutable_memory_block(std::size_t idx) const; /** - * @brief Converts a memory block to a RemoteDescriptor proto + * @brief Converts a memory block to a RemoteMemoryDescriptor proto * * @param view - * @return protos::RemoteDescriptor + * @return protos::RemoteMemoryDescriptor */ - static protos::RemoteDescriptor encode_descriptor(memory::const_buffer_view view); + static protos::RemoteMemoryDescriptor encode_descriptor(memory::const_buffer_view view); /** - * @brief Converts a RemoteDescriptor proto to a mutable memory block + * @brief Converts a RemoteMemoryDescriptor proto to a mutable memory block * * @param desc * @return memory::buffer_view */ - static memory::buffer_view decode_descriptor(const protos::RemoteDescriptor& desc); + static memory::buffer_view decode_descriptor(const protos::RemoteMemoryDescriptor& desc); /** * @brief Add a custom protobuf meta data to the descriptor list @@ -266,13 +267,4 @@ MetaDataT EncodedObject::meta_data(std::size_t idx) const return meta_data; } -std::size_t EncodedObject::add_buffer(std::shared_ptr mr, std::size_t bytes) -{ - CHECK(m_context_acquired); - memory::buffer buff(bytes, mr); - auto index = add_memory_block(buff); - m_buffers[index] = std::move(buff); - return index; -} - } // namespace srf::codable diff --git a/protos/srf/protos/codable.proto b/protos/srf/protos/codable.proto index ca078c952..406846e86 100644 --- a/protos/srf/protos/codable.proto +++ b/protos/srf/protos/codable.proto @@ -30,7 +30,7 @@ enum MemoryKind None = 99; } -message RemoteDescriptor +message RemoteMemoryDescriptor { uint32 instance_id = 1; uint32 object_id = 2; @@ -63,7 +63,7 @@ message Descriptor { oneof desc { - RemoteDescriptor remote_desc = 1; + RemoteMemoryDescriptor remote_desc = 1; PackedDescriptor packed_desc = 2; EagerDescriptor eager_desc = 3; MetaDataDescriptor meta_data_desc = 4; diff --git a/src/internal/data_plane/WIP.md b/src/internal/data_plane/WIP.md new file mode 100644 index 000000000..97fa68e31 --- /dev/null +++ b/src/internal/data_plane/WIP.md @@ -0,0 +1,5 @@ +# Work In Progress + +Files in this directory should be considered a WIP up until this file is removed from the directory. + +See: https://github.com/nv-morpheus/SRF/issues/144 diff --git a/src/internal/data_plane/callbacks.cpp b/src/internal/data_plane/callbacks.cpp new file mode 100644 index 000000000..d1bbafaf4 --- /dev/null +++ b/src/internal/data_plane/callbacks.cpp @@ -0,0 +1,87 @@ +/** + * SPDX-FileCopyrightText: Copyright (c) 2021-2022, NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: Apache-2.0 + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "internal/data_plane/callbacks.hpp" + +#include "internal/data_plane/request.hpp" + +#include +#include + +namespace srf::internal::data_plane { + +void Callbacks::send(void* request, ucs_status_t status, void* user_data) +{ + DVLOG(10) << "send callback start for request " << request; + + DCHECK(user_data); + auto* user_req = static_cast(user_data); + DCHECK(user_req->m_state == Request::State::Running); + + if (user_req->m_rkey != nullptr) + { + ucp_rkey_destroy(reinterpret_cast(user_req->m_rkey)); + } + + if (status == UCS_OK) + { + ucp_request_free(request); + user_req->m_request = nullptr; + user_req->m_state = Request::State::OK; + } + + else if (status == UCS_ERR_CANCELED) + { + ucp_request_free(request); + user_req->m_request = nullptr; + user_req->m_state = Request::State::Cancelled; + } + else + { + // todo(ryan) - set the promise exception ptr + LOG(FATAL) << "data_plane: pre_posted_recv_callback failed with status: " << ucs_status_string(status); + user_req->m_state = Request::State::Error; + } +} + +void Callbacks::recv(void* request, ucs_status_t status, const ucp_tag_recv_info_t* msg_info, void* user_data) +{ + DCHECK(user_data); + auto* user_req = static_cast(user_data); + DCHECK(user_req->m_state == Request::State::Running); + + if (status == UCS_OK) // cpp20 [[likely]] + { + ucp_request_free(request); + user_req->m_request = nullptr; + user_req->m_state = Request::State::OK; + } + else if (status == UCS_ERR_CANCELED) + { + ucp_request_free(request); + user_req->m_request = nullptr; + user_req->m_state = Request::State::Cancelled; + } + else + { + // todo(ryan) - set the promise exception ptr + LOG(FATAL) << "data_plane: pre_posted_recv_callback failed with status: " << ucs_status_string(status); + user_req->m_state = Request::State::Error; + } +} + +} // namespace srf::internal::data_plane diff --git a/src/internal/data_plane/client_worker.cpp b/src/internal/data_plane/callbacks.hpp similarity index 65% rename from src/internal/data_plane/client_worker.cpp rename to src/internal/data_plane/callbacks.hpp index 870fc7ccb..775fd8992 100644 --- a/src/internal/data_plane/client_worker.cpp +++ b/src/internal/data_plane/callbacks.hpp @@ -15,26 +15,19 @@ * limitations under the License. */ -#include "internal/data_plane/client_worker.hpp" +#pragma once -#include "internal/ucx/worker.hpp" - -#include -#include +#include +#include +#include namespace srf::internal::data_plane { -void DataPlaneClientWorker::on_data(void*&& data) +struct Callbacks final { - while (ucp_request_is_completed(data) == 0) - { - if (m_worker->progress() != 0U) - { - continue; - } - boost::this_fiber::yield(); - } - ucp_request_release(data); -} + // internal point-to-point + static void send(void* request, ucs_status_t status, void* user_data); + static void recv(void* request, ucs_status_t status, const ucp_tag_recv_info_t* msg_info, void* user_data); +}; } // namespace srf::internal::data_plane diff --git a/src/internal/data_plane/client.cpp b/src/internal/data_plane/client.cpp index b72c98661..1bd4a5e08 100644 --- a/src/internal/data_plane/client.cpp +++ b/src/internal/data_plane/client.cpp @@ -17,11 +17,12 @@ #include "internal/data_plane/client.hpp" -#include "internal/data_plane/client_worker.hpp" +#include "internal/data_plane/callbacks.hpp" #include "internal/data_plane/tags.hpp" #include "internal/ucx/common.hpp" #include "internal/ucx/context.hpp" #include "internal/ucx/endpoint.hpp" +#include "internal/ucx/resources.hpp" #include "internal/ucx/worker.hpp" #include "internal/utils/contains.hpp" @@ -46,6 +47,7 @@ #include #include #include +#include #include #include @@ -60,63 +62,12 @@ namespace srf::internal::data_plane { -static void send_completion_handler_with_future(void* request, ucs_status_t status, void* user_data) -{ - auto* promise = static_cast*>(user_data); - - if (status == UCS_OK) - { - promise->set_value(); - } - else - { - promise->set_exception(std::make_exception_ptr(std::runtime_error(ucs_status_string(status)))); - } - - // the request will be released by the progress engine - // we could optimize this a bit more -} - -Client::Client(resources::PartitionResourceBase& provider, std::shared_ptr worker) : - resources::PartitionResourceBase(provider), - m_worker(std::move(worker)) +Client::Client(resources::PartitionResourceBase& base, ucx::Resources& ucx) : + resources::PartitionResourceBase(base), + m_ucx(ucx) {} -Client::~Client() -{ - Service::call_in_destructor(); -} - -void Client::do_service_start() -{ - m_ucx_request_channel = std::make_unique>(); - auto sink = std::make_unique(m_worker); - sink->update_channel(std::make_unique>(256)); - node::make_edge(*m_ucx_request_channel, *sink); - LOG(FATAL) << "get launch control from partition resources"; - m_progress_engine = runnable().launch_control().prepare_launcher(std::move(sink))->ignition(); -} - -void Client::do_service_await_live() -{ - m_progress_engine->await_live(); -} - -void Client::do_service_stop() -{ - m_ucx_request_channel.reset(); -} - -void Client::do_service_kill() -{ - m_ucx_request_channel.reset(); - m_progress_engine->kill(); -} - -void Client::do_service_await_join() -{ - m_progress_engine->await_join(); -} +Client::~Client() = default; void Client::register_instance(InstanceID instance_id, ucx::WorkerAddress worker_address) { @@ -142,8 +93,7 @@ const ucx::Endpoint& Client::endpoint(InstanceID id) const } // lazy instantiation of the endpoint DVLOG(10) << "creating endpoint to instance_id: " << id; - auto endpoint = std::make_shared(m_worker, search_workers->second); - m_worker->progress(); + auto endpoint = m_ucx.make_ep(search_workers->second); m_endpoints[id] = endpoint; return *endpoint; } @@ -151,67 +101,151 @@ const ucx::Endpoint& Client::endpoint(InstanceID id) const return *search_endpoints->second; } -void Client::push_request(void* request) +std::size_t Client::connections() const { - DCHECK(m_ucx_request_channel); - m_ucx_request_channel->await_write(std::move(request)); + return m_endpoints.size(); } -bool Client::is_connected_to(InstanceID instance_id) const +void Client::async_recv(void* addr, std::size_t bytes, std::uint64_t tag, Request& request) { - return contains(m_workers, instance_id); -} + static constexpr std::uint64_t mask = P2P_TAG & TAG_USER_MASK; // NOLINT -void Client::decrement_remote_descriptor(InstanceID id, ObjectID obj_id) -{ - ucp_tag_t tag = obj_id | DESCRIPTOR_TAG; - issue_network_event(id, tag); + CHECK_EQ(request.m_request, nullptr); + CHECK(request.m_state == Request::State::Init); + request.m_state = Request::State::Running; + + ucp_request_param_t params; + params.op_attr_mask = UCP_OP_ATTR_FIELD_CALLBACK | UCP_OP_ATTR_FIELD_USER_DATA | UCP_OP_ATTR_FLAG_NO_IMM_CMPL; + params.cb.recv = Callbacks::recv; + params.user_data = &request; + + // build tag + CHECK_LE(tag, TAG_USER_MASK); + tag |= P2P_TAG; + + request.m_request = ucp_tag_recv_nbx(m_ucx.worker().handle(), addr, bytes, tag, mask, ¶ms); + CHECK(request.m_request); + CHECK(!UCS_PTR_IS_ERR(request.m_request)); } -void Client::issue_network_event(InstanceID id, ucp_tag_t tag) +void Client::async_send(void* addr, std::size_t bytes, std::uint64_t tag, InstanceID instance_id, Request& request) { - ucp_request_param_t params; - std::memset(¶ms, 0, sizeof(params)); + CHECK_EQ(request.m_request, nullptr); + CHECK(request.m_state == Request::State::Init); + request.m_state = Request::State::Running; - auto* request = ucp_tag_send_nbx(endpoint(id).handle(), nullptr, 0, tag, ¶ms); + CHECK_LE(tag, TAG_USER_MASK); + tag |= P2P_TAG; - if (request == nullptr /* UCS_OK */) - { - // send completed successfully - return; - } - if (UCS_PTR_IS_ERR(request)) - { - LOG(ERROR) << "send failed"; - throw std::runtime_error("send failed"); - } + ucp_request_param_t send_params; + send_params.op_attr_mask = UCP_OP_ATTR_FIELD_CALLBACK | UCP_OP_ATTR_FIELD_USER_DATA | UCP_OP_ATTR_FLAG_NO_IMM_CMPL; + send_params.cb.send = Callbacks::send; + send_params.user_data = &request; - // send operation was scheduled by the ucx runtime - // adding requests to the channel will ensure the progress engine - // will work to make forward progress on queued network requests - push_request(std::move(request)); + request.m_request = ucp_tag_send_nbx(endpoint(instance_id).handle(), addr, bytes, tag, &send_params); + CHECK(request.m_request); + CHECK(!UCS_PTR_IS_ERR(request.m_request)); } -struct GetUserData +void Client::async_get(void* addr, + std::size_t bytes, + InstanceID instance_id, + void* remote_addr, + const std::string& packed_remote_key, + Request& request) { - Promise promise; - ucp_rkey_h rkey; -}; + CHECK_EQ(request.m_request, nullptr); + CHECK(request.m_state == Request::State::Init); + request.m_state = Request::State::Running; + + const auto& ep = endpoint(instance_id); + + ucp_request_param_t params; + params.op_attr_mask = UCP_OP_ATTR_FIELD_CALLBACK | UCP_OP_ATTR_FIELD_USER_DATA | UCP_OP_ATTR_FLAG_NO_IMM_CMPL; + params.cb.send = Callbacks::send; + params.user_data = &request; -static void rdma_get_callback(void* request, ucs_status_t status, void* user_data) -{ - DVLOG(1) << "rdma get callback start for request " << request; - auto* data = static_cast(user_data); - if (status != UCS_OK) { - LOG(FATAL) << "rdma get failure occurred"; - // data->promise.set_exception(); + auto rc = + ucp_ep_rkey_unpack(ep.handle(), packed_remote_key.data(), reinterpret_cast(&request.m_rkey)); + if (rc != UCS_OK) + { + LOG(ERROR) << "ucp_ep_rkey_unpack failed - " << ucs_status_string(rc); + throw std::runtime_error("ucp_ep_rkey_unpack failed"); + } } - data->promise.set_value(); - ucp_request_free(request); - ucp_rkey_destroy(data->rkey); + + request.m_request = ucp_get_nbx(ep.handle(), + addr, + bytes, + reinterpret_cast(remote_addr), + reinterpret_cast(request.m_rkey), + ¶ms); + CHECK(request.m_request); + CHECK(!UCS_PTR_IS_ERR(request.m_request)); } +// void Client::push_request(void* request) +// { +// DCHECK(m_ucx_request_channel); +// m_ucx_request_channel->await_write(std::move(request)); +// } + +// bool Client::is_connected_to(InstanceID instance_id) const +// { +// return contains(m_workers, instance_id); +// } + +// void Client::decrement_remote_descriptor(InstanceID id, ObjectID obj_id) +// { +// ucp_tag_t tag = obj_id | DESCRIPTOR_TAG; +// issue_network_event(id, tag); +// } + +// void Client::issue_network_event(InstanceID id, ucp_tag_t tag) +// { +// ucp_request_param_t params; +// std::memset(¶ms, 0, sizeof(params)); + +// auto* request = ucp_tag_send_nbx(endpoint(id).handle(), nullptr, 0, tag, ¶ms); + +// if (request == nullptr /* UCS_OK */) +// { +// // send completed successfully +// return; +// } +// if (UCS_PTR_IS_ERR(request)) +// { +// LOG(ERROR) << "send failed"; +// throw std::runtime_error("send failed"); +// } + +// // send operation was scheduled by the ucx runtime +// // adding requests to the channel will ensure the progress engine +// // will work to make forward progress on queued network requests +// push_request(std::move(request)); +// } + +// struct GetUserData +// { +// Promise promise; +// ucp_rkey_h rkey; +// }; + +// static void rdma_get_callback(void* request, ucs_status_t status, void* user_data) +// { +// DVLOG(1) << "rdma get callback start for request " << request; +// auto* data = static_cast(user_data); +// if (status != UCS_OK) +// { +// LOG(FATAL) << "rdma get failure occurred"; +// // data->promise.set_exception(); +// } +// data->promise.set_value(); +// ucp_request_free(request); +// ucp_rkey_destroy(data->rkey); +// } + /* void Client::get(const protos::RemoteDescriptor& remote_md, Descriptor& buffer) { @@ -274,64 +308,59 @@ void Client::get(const protos::RemoteDescriptor& remote_md, Descriptor& buffer) } */ -void Client::await_send(const InstanceID& instance_id, - const PortAddress& port_address, - const codable::EncodedObject& encoded_object) -{ - Promise promise; - auto future = promise.get_future(); - - ucp_tag_t tag = port_address | INGRESS_TAG; - ucp_request_param_t params; - - params.op_attr_mask = UCP_OP_ATTR_FIELD_CALLBACK | UCP_OP_ATTR_FIELD_USER_DATA; - params.cb.send = send_completion_handler_with_future; - params.user_data = &promise; - - // serialize the proto of the encoded object into it's own encoded object - // dogfooding at its best - codable::EncodedObject msg; - codable::encode(encoded_object.proto(), msg); - - // sanity check - // 1) there should be only 1 descriptor, and - // 2) the size of the memory block should be the size of the protos requested - DCHECK_EQ(msg.descriptor_count(), 1); - auto block = msg.memory_block(0); - DCHECK_EQ(block.bytes(), encoded_object.proto().ByteSizeLong()); - - // all encoded_objects are serialized to host memory - // these are small packed remote descriptors, not the actual payload data - params.op_attr_mask |= UCP_OP_ATTR_FIELD_MEMORY_TYPE; - params.memory_type = UCS_MEMORY_TYPE_HOST; - - // issue send - ucs_status_ptr_t request = - ucp_tag_send_nbx(endpoint(instance_id).handle(), block.data(), block.bytes(), tag, ¶ms); - - if (request == nullptr /* UCS_OK */) - { - return; - } - if (UCS_PTR_IS_ERR(request)) - { - LOG(ERROR) << "send failed - "; - throw std::runtime_error("send failed"); - } - - // if we didn't complete immediate or throw an error, then the message - // is in flight. push the request to the progress engine which will - // wake up a progress fiber to complete the send - push_request(std::move(request)); - - // the caller of this await_send method will block and yield the fiber here - // the caller is calling an "await" method so blocking and yielding is implied - future.get(); -} - -std::size_t Client::connections() const -{ - return m_endpoints.size(); -} +// void Client::await_send(const InstanceID& instance_id, +// const PortAddress& port_address, +// const codable::EncodedObject& encoded_object) +// { +// Promise promise; +// auto future = promise.get_future(); + +// ucp_tag_t tag = port_address | INGRESS_TAG; +// ucp_request_param_t params; + +// params.op_attr_mask = UCP_OP_ATTR_FIELD_CALLBACK | UCP_OP_ATTR_FIELD_USER_DATA; +// params.cb.send = send_completion_handler_with_future; +// params.user_data = &promise; + +// // serialize the proto of the encoded object into it's own encoded object +// // dogfooding at its best +// codable::EncodedObject msg; +// codable::encode(encoded_object.proto(), msg); + +// // sanity check +// // 1) there should be only 1 descriptor, and +// // 2) the size of the memory block should be the size of the protos requested +// DCHECK_EQ(msg.descriptor_count(), 1); +// auto block = msg.memory_block(0); +// DCHECK_EQ(block.bytes(), encoded_object.proto().ByteSizeLong()); + +// // all encoded_objects are serialized to host memory +// // these are small packed remote descriptors, not the actual payload data +// params.op_attr_mask |= UCP_OP_ATTR_FIELD_MEMORY_TYPE; +// params.memory_type = UCS_MEMORY_TYPE_HOST; + +// // issue send +// ucs_status_ptr_t request = +// ucp_tag_send_nbx(endpoint(instance_id).handle(), block.data(), block.bytes(), tag, ¶ms); + +// if (request == nullptr /* UCS_OK */) +// { +// return; +// } +// if (UCS_PTR_IS_ERR(request)) +// { +// LOG(ERROR) << "send failed - "; +// throw std::runtime_error("send failed"); +// } + +// // if we didn't complete immediate or throw an error, then the message +// // is in flight. push the request to the progress engine which will +// // wake up a progress fiber to complete the send +// push_request(std::move(request)); + +// // the caller of this await_send method will block and yield the fiber here +// // the caller is calling an "await" method so blocking and yielding is implied +// future.get(); +// } } // namespace srf::internal::data_plane diff --git a/src/internal/data_plane/client.hpp b/src/internal/data_plane/client.hpp index 5dc0ee6ef..341dc2aa0 100644 --- a/src/internal/data_plane/client.hpp +++ b/src/internal/data_plane/client.hpp @@ -17,6 +17,7 @@ #pragma once +#include "internal/data_plane/request.hpp" #include "internal/resources/partition_resources.hpp" #include "internal/resources/partition_resources_base.hpp" #include "internal/service.hpp" @@ -42,10 +43,10 @@ namespace srf::internal::data_plane { -class Client final : public Service, public resources::PartitionResourceBase +class Client final : public resources::PartitionResourceBase { public: - Client(resources::PartitionResourceBase& provider, std::shared_ptr worker); + Client(resources::PartitionResourceBase& base, ucx::Resources& ucx); ~Client() final; /** @@ -65,42 +66,46 @@ class Client final : public Service, public resources::PartitionResourceBase * @param port_address * @param encoded_object */ - void await_send(const InstanceID& instance_id, - const PortAddress& port_address, - const codable::EncodedObject& encoded_object); + // void await_send(const InstanceID& instance_id, + // const PortAddress& port_address, + // const codable::EncodedObject& encoded_object); // number of established remote instances std::size_t connections() const; - // determine if connected to a given remote instance - bool is_connected_to(InstanceID) const; + void async_recv(void* addr, std::size_t bytes, std::uint64_t tag, Request& request); + void async_send(void* addr, std::size_t bytes, std::uint64_t tag, InstanceID instance_id, Request& request); - void decrement_remote_descriptor(InstanceID, ObjectID); + /** + * @brief Perform an asynchronous one-side GET from a contiguous block of memory starting at remote_addr on remote + * instance_id. + */ + void async_get(void* addr, + std::size_t bytes, + InstanceID instance_id, + void* remote_addr, + const std::string& packed_remote_key, + Request& request); + + // // determine if connected to a given remote instance + // bool is_connected_to(InstanceID) const; + + // void decrement_remote_descriptor(InstanceID, ObjectID); // void get(const protos::RemoteDescriptor&, void*, size_t); // void get(const protos::RemoteDescriptor&, Descriptor&); protected: - // issue tag only send - no payload data - void issue_network_event(InstanceID, ucp_tag_t); + // // issue tag only send - no payload data + // void issue_network_event(InstanceID, ucp_tag_t); // get endpoint for instance id const ucx::Endpoint& endpoint(InstanceID) const; - void push_request(void* request); + // void push_request(void* request); private: - void do_service_start() final; - void do_service_await_live() final; - void do_service_stop() final; - void do_service_kill() final; - void do_service_await_join() final; - - std::shared_ptr m_worker; - std::shared_ptr m_resources; - std::unique_ptr> m_ucx_request_channel; - std::unique_ptr m_progress_engine; - + ucx::Resources& m_ucx; std::map m_workers; mutable std::map> m_endpoints; }; diff --git a/src/internal/data_plane/instance.cpp b/src/internal/data_plane/instance.cpp deleted file mode 100644 index b8f5ec1a3..000000000 --- a/src/internal/data_plane/instance.cpp +++ /dev/null @@ -1,103 +0,0 @@ -/** - * SPDX-FileCopyrightText: Copyright (c) 2021-2022, NVIDIA CORPORATION & AFFILIATES. All rights reserved. - * SPDX-License-Identifier: Apache-2.0 - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#include "internal/data_plane/instance.hpp" - -#include "internal/data_plane/client.hpp" -#include "internal/data_plane/server.hpp" -#include "internal/ucx/context.hpp" - -#include "srf/cuda/common.hpp" -#include "srf/runnable/launch_control.hpp" - -#include - -#include - -namespace srf::internal::data_plane { - -Instance::Instance(std::shared_ptr resources) : m_resources(std::move(resources)) {} - -Instance::~Instance() -{ - call_in_destructor(); -} - -Client& Instance::client() const -{ - CHECK(m_client); - return *m_client; -} - -Server& Instance::server() const -{ - CHECK(m_server); - return *m_server; -} - -void Instance::do_service_start() -{ - m_resources->host() - .main() - .enqueue([this] { - // if the PartitionResource has a GPU, ensure the CUDA context on the main thread is active - // before the ucx context is constructed - if (m_resources->device()) - { - auto device = m_resources->device()->get(); - device.activate(); - - void* addr = nullptr; - SRF_CHECK_CUDA(cudaMalloc(&addr, 1024)); - SRF_CHECK_CUDA(cudaFree(addr)); - } - - m_context = std::make_shared(); - m_server = std::make_unique(m_context, m_resources); - m_client = std::make_unique(m_context, m_resources); - - m_server->service_start(); - m_client->service_start(); - }) - .get(); -} - -void Instance::do_service_await_live() -{ - client().service_await_live(); - server().service_await_live(); -} - -void Instance::do_service_stop() -{ - client().service_stop(); - server().service_stop(); -} - -void Instance::do_service_kill() -{ - client().service_kill(); - server().service_kill(); -} - -void Instance::do_service_await_join() -{ - client().service_await_join(); - server().service_await_join(); -} - -} // namespace srf::internal::data_plane diff --git a/src/internal/data_plane/client_worker.hpp b/src/internal/data_plane/request.cpp similarity index 53% rename from src/internal/data_plane/client_worker.hpp rename to src/internal/data_plane/request.cpp index f39996a09..ac8f9c4e4 100644 --- a/src/internal/data_plane/client_worker.hpp +++ b/src/internal/data_plane/request.cpp @@ -15,31 +15,47 @@ * limitations under the License. */ -#pragma once +#include "internal/data_plane/request.hpp" -#include "internal/ucx/worker.hpp" - -#include "srf/channel/status.hpp" -#include "srf/node/generic_sink.hpp" -#include "srf/runnable/context.hpp" - -#include -#include - -#include -#include +#include +#include namespace srf::internal::data_plane { -class DataPlaneClientWorker : public node::GenericSink +Request::Request() = default; + +Request::~Request() { - public: - DataPlaneClientWorker(std::shared_ptr worker) : m_worker(std::move(worker)) {} + CHECK(m_state == State::Init) << "A Request that is in use is being destroyed"; +} - private: - void on_data(void*&& data) final; +void Request::reset() +{ + m_state = State::Init; + m_request = nullptr; +} - std::shared_ptr m_worker; -}; +bool Request::await_complete() +{ + CHECK(m_state > State::Init); + while (m_state == State::Running) + { + boost::this_fiber::yield(); + } + + if (m_state == State::OK) + { + reset(); + return true; + } + + if (m_state == State::Cancelled) + { + reset(); + return false; + } + + LOG(FATAL) << "error in ucx callback"; +} } // namespace srf::internal::data_plane diff --git a/src/internal/data_plane/request.hpp b/src/internal/data_plane/request.hpp new file mode 100644 index 000000000..1dfccbc83 --- /dev/null +++ b/src/internal/data_plane/request.hpp @@ -0,0 +1,69 @@ +/** + * SPDX-FileCopyrightText: Copyright (c) 2021-2022, NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: Apache-2.0 + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#pragma once + +#include "srf/utils/macros.hpp" + +#include +#include +#include +#include + +#include + +namespace srf::internal::data_plane { + +class Callbacks; +class Client; + +class Request final +{ + public: + Request(); + ~Request(); + + DELETE_COPYABILITY(Request); + DELETE_MOVEABILITY(Request); + + // std::optional is_complete(); + bool await_complete(); + + // attempts to cancel the request + // the request will either be cancelled or completed + // void try_cancel(); + + private: + void reset(); + + enum class State + { + Init, + Running, + OK, + Cancelled, + Error + }; + std::atomic m_state{State::Init}; + void* m_request{nullptr}; + void* m_rkey{nullptr}; + + friend Client; + friend Callbacks; +}; + +} // namespace srf::internal::data_plane diff --git a/src/internal/data_plane/resources.cpp b/src/internal/data_plane/resources.cpp new file mode 100644 index 000000000..fcbaeb0e3 --- /dev/null +++ b/src/internal/data_plane/resources.cpp @@ -0,0 +1,95 @@ +/** + * SPDX-FileCopyrightText: Copyright (c) 2021-2022, NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: Apache-2.0 + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "internal/data_plane/resources.hpp" + +#include "internal/data_plane/client.hpp" +#include "internal/data_plane/server.hpp" +#include "internal/ucx/resources.hpp" + +#include "srf/cuda/common.hpp" +#include "srf/runnable/launch_control.hpp" + +#include + +#include + +namespace srf::internal::data_plane { + +Resources::Resources(resources::PartitionResourceBase& base, ucx::Resources& ucx, memory::HostResources& host) : + resources::PartitionResourceBase(base), + m_ucx(ucx), + m_host(host), + m_server(base, ucx, host), + m_client(base, ucx) +{ + // ensure the data plane progress engine is up and running + m_server.service_start(); + m_server.service_await_live(); +} + +Resources::~Resources() +{ + call_in_destructor(); +} + +Client& Resources::client() +{ + return m_client; +} + +// Server& Resources::server() +// { +// return m_server; +// } + +std::string Resources::ucx_address() const +{ + return m_ucx.worker().address(); +} + +const ucx::RegistrationCache& Resources::registration_cache() const +{ + return m_ucx.registration_cache(); +} + +void Resources::do_service_start() +{ + m_server.service_start(); +} + +void Resources::do_service_await_live() +{ + m_server.service_await_live(); +} + +void Resources::do_service_stop() +{ + m_server.service_stop(); +} + +void Resources::do_service_kill() +{ + m_server.service_kill(); +} + +void Resources::do_service_await_join() +{ + m_server.service_await_join(); +} + +} // namespace srf::internal::data_plane diff --git a/src/internal/data_plane/instance.hpp b/src/internal/data_plane/resources.hpp similarity index 67% rename from src/internal/data_plane/instance.hpp rename to src/internal/data_plane/resources.hpp index 3ecebf04b..5947a4168 100644 --- a/src/internal/data_plane/instance.hpp +++ b/src/internal/data_plane/resources.hpp @@ -19,10 +19,11 @@ #include "internal/data_plane/client.hpp" #include "internal/data_plane/server.hpp" +#include "internal/resources/forward.hpp" +#include "internal/resources/partition_resources_base.hpp" #include "internal/service.hpp" -#include "internal/ucx/context.hpp" -#include "srf/runnable/launch_control.hpp" +#include "srf/protos/codable.pb.h" #include @@ -32,29 +33,31 @@ namespace srf::internal::data_plane { * @brief ArchitectResources hold and is responsible for constructing any object that depending the UCX data plane * */ -class Instance final : public Service +class Resources final : private Service, private resources::PartitionResourceBase { public: - Instance(std::shared_ptr resources); - ~Instance() final; + Resources(resources::PartitionResourceBase& base, ucx::Resources& ucx, memory::HostResources& host); + ~Resources() final; - Client& client() const; - Server& server() const; + Client& client(); - private: - Service& client_service(); - Service& server_service(); + std::string ucx_address() const; + const ucx::RegistrationCache& registration_cache() const; + private: void do_service_start() final; void do_service_await_live() final; void do_service_stop() final; void do_service_kill() final; void do_service_await_join() final; - std::shared_ptr m_resources; - std::shared_ptr m_context; - std::unique_ptr m_client; - std::unique_ptr m_server; + ucx::Resources& m_ucx; + memory::HostResources& m_host; + + Server m_server; + Client m_client; + + friend network::Resources; }; } // namespace srf::internal::data_plane diff --git a/src/internal/data_plane/server.cpp b/src/internal/data_plane/server.cpp index 614df338c..0b05d5cc1 100644 --- a/src/internal/data_plane/server.cpp +++ b/src/internal/data_plane/server.cpp @@ -20,6 +20,7 @@ #include "internal/data_plane/tags.hpp" #include "internal/ucx/common.hpp" #include "internal/ucx/context.hpp" +#include "internal/ucx/resources.hpp" #include "internal/ucx/worker.hpp" #include "srf/channel/status.hpp" @@ -42,6 +43,7 @@ #include #include // IWYU pragma: keep #include +#include #include #include @@ -53,8 +55,6 @@ namespace srf::internal::data_plane { -static thread_local rxcpp::subscriber* static_subscriber = nullptr; - namespace { void zero_bytes_completion_handler(void* request, @@ -69,18 +69,45 @@ void zero_bytes_completion_handler(void* request, ucp_request_free(request); } -void recv_completion_handler(void* request, ucs_status_t status, const ucp_tag_recv_info_t* msg_info, void* user_data) +static void pre_post_recv_issue(detail::PrePostedRecvInfo* info); + +void pre_posted_recv_callback(void* request, ucs_status_t status, const ucp_tag_recv_info_t* msg_info, void* user_data) { - if (status != UCS_OK) + DCHECK(user_data); + auto* info = static_cast(user_data); + if (status == UCS_OK) // cpp20 [[likely]] + { + // grab tag and free request - not sure if there will be a race condition on msg_info + auto tag = msg_info->sender_tag; + ucp_request_free(request); + + // repost recv + pre_post_recv_issue(info); + + // write tag to channel + info->channel->await_write(std::move(tag)); + } + else if (status == UCS_ERR_CANCELED) { - LOG(FATAL) << "recv_completion_handler observed " << ucs_status_string(status); + ucp_request_free(info->request); + info->request = nullptr; // this ensures than cancel will not be called again if a kill is issued after stop } - auto port_address = tag_decode_user_tag(msg_info->sender_tag); - DCHECK(static_subscriber && static_subscriber->is_subscribed()); - auto msg = std::make_pair(port_address, - srf::memory::buffer_view(user_data, msg_info->length, srf::memory::memory_kind::host)); - static_subscriber->on_next(std::move(msg)); - ucp_request_free(request); + else + { + LOG(FATAL) << "data_plane: pre_posted_recv_callback failed with status: " << ucs_status_string(status); + } +} + +void pre_post_recv_issue(detail::PrePostedRecvInfo* info) +{ + ucp_request_param_t params; + params.op_attr_mask = UCP_OP_ATTR_FIELD_CALLBACK | UCP_OP_ATTR_FIELD_USER_DATA | UCP_OP_ATTR_FLAG_NO_IMM_CMPL; + params.cb.recv = pre_posted_recv_callback; + params.user_data = info; + + info->request = ucp_tag_recv_nbx(info->worker, nullptr, 0, 0, 0, ¶ms); + CHECK(info->request); + CHECK(!UCS_PTR_IS_ERR(info->request)); } } // namespace @@ -88,7 +115,7 @@ void recv_completion_handler(void* request, ucs_status_t status, const ucp_tag_r class DataPlaneServerWorker final : public node::GenericSource { public: - DataPlaneServerWorker(Handle worker); + DataPlaneServerWorker(ucx::Worker& worker); private: void data_source(rxcpp::subscriber& s) final; @@ -97,7 +124,7 @@ class DataPlaneServerWorker final : public node::GenericSource ucp_tag_message_h msg, const ucp_tag_recv_info_t& msg_info); - Handle m_worker; + ucx::Worker& m_worker; // modify these to adjust the tag matching // 0/0 is the equivalent of match all tags @@ -105,9 +132,10 @@ class DataPlaneServerWorker final : public node::GenericSource ucp_tag_t m_tag_mask{0}; }; -Server::Server(resources::PartitionResourceBase& provider, std::shared_ptr worker) : +Server::Server(resources::PartitionResourceBase& provider, ucx::Resources& ucx, memory::HostResources& host) : resources::PartitionResourceBase(provider), - m_worker(std::move(worker)) + m_ucx(ucx), + m_host(host) {} Server::~Server() @@ -117,18 +145,39 @@ Server::~Server() void Server::do_service_start() { - m_deserialize_source = std::make_shared>(); - m_rd_source = std::make_unique>(); - - auto progress_engine = std::make_unique(m_worker); - node::make_edge(*progress_engine, *m_deserialize_source); - - // all network runnables use the `srf_network` engine factory - DVLOG(10) << "launch network event mananger progress engine"; - m_progress_engine = runnable() - .launch_control() - .prepare_launcher(srf::runnable::LaunchOptions("srf_network"), std::move(progress_engine)) - ->ignition(); + m_ucx.network_task_queue() + .enqueue([this] { + // source channel ucx tag recvs masked with the RemoteDescriptor tag + // this recv has no recv payload, we simply write the tag to the channel + m_rd_source = std::make_unique>(); + + // pre-post recv for remote descriptors and remote promise/future + // m_pre_posted_recv_info.resize(m_pre_posted_recv_count); + // for (auto& info : m_pre_posted_recv_info) + // { + // info.worker = m_ucx.worker().handle(); + // info.channel = m_rd_source.get(); + // pre_post_recv_issue(&info); + // } + + // source for ucx tag recvs with data + auto progress_engine = std::make_unique(m_ucx.worker()); + + // router for ucx tag recvs with data + m_deserialize_source = std::make_shared>(); + + // for edge between source and router - on channel operator driven by the source thread + node::make_edge(*progress_engine, *m_deserialize_source); + + // all network runnables use the `srf_network` engine factory + DVLOG(10) << "launch network event mananger progress engine"; + m_progress_engine = + runnable() + .launch_control() + .prepare_launcher(srf::runnable::LaunchOptions("srf_network"), std::move(progress_engine)) + ->ignition(); + }) + .get(); } void Server::do_service_await_live() @@ -138,6 +187,30 @@ void Server::do_service_await_live() void Server::do_service_stop() { + DVLOG(10) << "data_plane server: stop issued"; + + m_ucx.network_task_queue() + .enqueue([this] { + // we need to cancel all preposted recvs before shutting down the progress engine + DVLOG(10) << "data_plane server: cancelling all outstanding pre-posted recvs"; + for (auto& info : m_pre_posted_recv_info) + { + if (info.request != nullptr) + { + ucp_request_cancel(m_ucx.worker().handle(), info.request); + } + + // we are on the network task queue thread, so we can pump the progress engine until + // the cancelled request is complete + while (info.request != nullptr) + { + m_ucx.worker().progress(); + } + } + }) + .get(); + + DVLOG(10) << "data_plane server: issuing stop to progress engine runnable"; m_progress_engine->stop(); } @@ -157,7 +230,7 @@ void Server::do_service_await_join() ucx::WorkerAddress Server::worker_address() const { - return m_worker->address(); + return m_ucx.worker().address(); } node::Router& Server::deserialize_source() @@ -168,7 +241,7 @@ node::Router& Server::deserialize_source( // NetworkEventProgressEngine -DataPlaneServerWorker::DataPlaneServerWorker(Handle worker) : m_worker(std::move(worker)) {} +DataPlaneServerWorker::DataPlaneServerWorker(ucx::Worker& worker) : m_worker(worker) {} void DataPlaneServerWorker::data_source(rxcpp::subscriber& s) { @@ -176,47 +249,36 @@ void DataPlaneServerWorker::data_source(rxcpp::subscriber& s) ucp_tag_recv_info_t msg_info; std::uint32_t backoff = 1; - // set static variable for callbacks - static_subscriber = &s; + DVLOG(10) << "starting data plane server progress engine loop"; + + // the progress loop has tag_probe_nb disabled + // this should be re-enabled to accept tagged messages that have payloads + // larger than the pre-posted recv buffers while (true) { for (;;) { - msg = ucp_tag_probe_nb(m_worker->handle(), m_tag, m_tag_mask, 1, &msg_info); + // msg = ucp_tag_probe_nb(m_worker->handle(), m_tag, m_tag_mask, 1, &msg_info); if (!s.is_subscribed()) { + DVLOG(10) << "exiting data plane server progress engine loop"; return; } - if (msg != nullptr) - { - break; - } - while (m_worker->progress() != 0U) + // if (msg != nullptr) + // { + // break; + // } + while (m_worker.progress() != 0U) { backoff = 1; } boost::this_fiber::yield(); - - /* - if (backoff < 1048576) - { - backoff = backoff << 1; - } - if (backoff < 32768) - { - boost::this_fiber::yield(); - } - else - { - boost::this_fiber::sleep_for(std::chrono::nanoseconds(backoff)); - } - */ } - on_tagged_msg(s, msg, msg_info); - backoff = 1; + // on_tagged_msg(s, msg, msg_info); + // backoff = 1; } } @@ -254,7 +316,7 @@ void DataPlaneServerWorker::on_tagged_msg(rxcpp::subscriber& su recv_bytes = msg_info.length; recv_addr = std::malloc(recv_bytes); params.user_data = recv_addr; - params.cb.recv = recv_completion_handler; + // params.cb.recv = recv_completion_handler; break; } case DESCRIPTOR_TAG: @@ -271,7 +333,7 @@ void DataPlaneServerWorker::on_tagged_msg(rxcpp::subscriber& su LOG(FATAL) << "unknown network event received: " << msg_info.sender_tag; }; - void* status = ucp_tag_msg_recv_nbx(m_worker->handle(), recv_addr, recv_bytes, msg, ¶ms); + void* status = ucp_tag_msg_recv_nbx(m_worker.handle(), recv_addr, recv_bytes, msg, ¶ms); if (UCS_PTR_IS_ERR(status)) { LOG(FATAL) << "ucp_tag_msg_recv_nbx for 0-byte event failed"; diff --git a/src/internal/data_plane/server.hpp b/src/internal/data_plane/server.hpp index 81690fd02..a248e5a9c 100644 --- a/src/internal/data_plane/server.hpp +++ b/src/internal/data_plane/server.hpp @@ -17,12 +17,14 @@ #pragma once +#include "internal/resources/forward.hpp" #include "internal/resources/partition_resources.hpp" #include "internal/resources/partition_resources_base.hpp" #include "internal/runnable/resources.hpp" #include "internal/service.hpp" #include "internal/ucx/common.hpp" #include "internal/ucx/context.hpp" +#include "internal/ucx/registration_cache.hpp" #include "internal/ucx/worker.hpp" #include "srf/channel/status.hpp" @@ -65,12 +67,22 @@ namespace srf::internal::data_plane { +namespace detail { +struct PrePostedRecvInfo +{ + ucp_worker_h worker; + node::SourceChannelWriteable* channel; + void* request; + // std::array buffer; +}; +} // namespace detail + using network_event_t = std::pair; class Server final : public Service, public resources::PartitionResourceBase { public: - Server(resources::PartitionResourceBase& provider, std::shared_ptr worker); + Server(resources::PartitionResourceBase& provider, ucx::Resources& ucx, memory::HostResources& host); ~Server() final; ucx::WorkerAddress worker_address() const; @@ -84,8 +96,11 @@ class Server final : public Service, public resources::PartitionResourceBase void do_service_kill() final; void do_service_await_join() final; - // host resources + const std::size_t m_pre_posted_recv_count{16}; + // ucx resources + ucx::Resources& m_ucx; + memory::HostResources& m_host; // deserialization nodes will connect to this source wtih their port id // the source for this router is the private GenericSoruce of this object @@ -95,13 +110,11 @@ class Server final : public Service, public resources::PartitionResourceBase // data will be emitted on this source as a conditional branch of data source std::unique_ptr> m_rd_source; - // ucx worker - Handle m_worker; + // pre-posted recv state + std::vector m_pre_posted_recv_info; // runner for the ucx progress engine event source std::unique_ptr m_progress_engine; - - // host resources - probably should be }; } // namespace srf::internal::data_plane diff --git a/src/internal/data_plane/tags.hpp b/src/internal/data_plane/tags.hpp index 82765797b..c6062bb29 100644 --- a/src/internal/data_plane/tags.hpp +++ b/src/internal/data_plane/tags.hpp @@ -33,7 +33,10 @@ static constexpr ucp_tag_t MSG_TYPE_MASK = 0xF000000000000000; // leading 4 bi static constexpr ucp_tag_t INGRESS_TAG = 0x8000000000000000; // leading 4 bits are 1000 // NOLINT static constexpr ucp_tag_t DESCRIPTOR_TAG = 0x4000000000000000; // leading 4 bits are 0100 // NOLINT static constexpr ucp_tag_t FUTURE_TAG = 0x2000000000000000; // leading 4 bits are 0010 // NOLINT +static constexpr ucp_tag_t P2P_TAG = 0x1000000000000000; // leading 4 bits are 0010 // NOLINT +static constexpr ucp_tag_t TAG_CTRL_MASK = 0xFFFF000000000000; // 48-bits // NOLINT +static constexpr ucp_tag_t TAG_USER_MASK = 0x0000FFFFFFFFFFFF; // 48-bits // NOLINT static constexpr ucp_tag_t USR_TYPE_MASK = 0x0000FFFFFFFFFFFF; // 48-bits // NOLINT static ucp_tag_t tag_decode_msg_type(const ucp_tag_t& tag) diff --git a/src/internal/memory/device_resources.cpp b/src/internal/memory/device_resources.cpp index ceeec9f06..36e842c62 100644 --- a/src/internal/memory/device_resources.cpp +++ b/src/internal/memory/device_resources.cpp @@ -47,14 +47,13 @@ namespace srf::internal::memory { -DeviceResources::DeviceResources(runnable::Resources& runnable, - std::size_t partition_id, - std::optional& ucx) : - resources::PartitionResourceBase(runnable, partition_id) +DeviceResources::DeviceResources(resources::PartitionResourceBase& base, std::optional& ucx) : + resources::PartitionResourceBase(base) { CHECK(partition().has_device()); - runnable.main() + runnable() + .main() .enqueue([this, &ucx] { std::stringstream device_prefix; device_prefix << "cuda_malloc:" << cuda_device_id(); diff --git a/src/internal/memory/device_resources.hpp b/src/internal/memory/device_resources.hpp index 034d34013..6a21e0b45 100644 --- a/src/internal/memory/device_resources.hpp +++ b/src/internal/memory/device_resources.hpp @@ -39,7 +39,7 @@ namespace srf::internal::memory { class DeviceResources : private resources::PartitionResourceBase { public: - DeviceResources(runnable::Resources& runnable, std::size_t partition_id, std::optional& ucx); + DeviceResources(resources::PartitionResourceBase& base, std::optional& ucx); int cuda_device_id() const; diff --git a/src/internal/network/resources.cpp b/src/internal/network/resources.cpp index 289fb7fa4..804919a13 100644 --- a/src/internal/network/resources.cpp +++ b/src/internal/network/resources.cpp @@ -17,37 +17,40 @@ #include "internal/network/resources.hpp" +#include "internal/data_plane/resources.hpp" #include "internal/data_plane/server.hpp" #include "internal/memory/host_resources.hpp" +#include "internal/resources/forward.hpp" #include "internal/resources/partition_resources_base.hpp" #include "internal/ucx/registration_cache.hpp" #include "internal/ucx/resources.hpp" namespace srf::internal::network { -Resources::Resources(runnable::Resources& _runnable_resources, - std::size_t _partition_id, - ucx::Resources& ucx, - memory::HostResources& host) : - resources::PartitionResourceBase(_runnable_resources, _partition_id), - m_ucx(ucx), - m_host(host) +Resources::Resources(resources::PartitionResourceBase& base, ucx::Resources& ucx, memory::HostResources& host) : + resources::PartitionResourceBase(base) { // construct resources on the srf_network task queue thread - m_ucx.network_task_queue() - .enqueue([this] { + ucx.network_task_queue() + .enqueue([this, &base, &ucx, &host] { // initialize data plane services - server / client - m_server = std::make_unique(static_cast(*this), - m_ucx.m_worker_server); + m_data_plane = std::make_unique(base, ucx, host); }) .get(); } -Resources::~Resources() = default; - -const ucx::RegistrationCache& Resources::registration_cache() const +Resources::~Resources() { - return m_ucx.registration_cache(); + if (m_data_plane) + { + m_data_plane->service_stop(); + m_data_plane->service_await_join(); + } } +data_plane::Resources& Resources::data_plane() +{ + CHECK(m_data_plane); + return *m_data_plane; +} } // namespace srf::internal::network diff --git a/src/internal/network/resources.hpp b/src/internal/network/resources.hpp index 075d07c3c..5f454e832 100644 --- a/src/internal/network/resources.hpp +++ b/src/internal/network/resources.hpp @@ -17,44 +17,30 @@ #pragma once -#include "internal/memory/device_resources.hpp" +#include "internal/resources/forward.hpp" #include "internal/resources/partition_resources_base.hpp" #include "internal/runnable/resources.hpp" #include "internal/ucx/registration_cache.hpp" -#include +#include "srf/utils/macros.hpp" -namespace srf::internal { -namespace ucx { -class Resources; -} // namespace ucx -namespace memory { -class HostResources; -} // namespace memory -namespace data_plane { -class Client; -class Server; -} // namespace data_plane -} // namespace srf::internal +#include namespace srf::internal::network { class Resources final : private resources::PartitionResourceBase { public: - Resources(runnable::Resources& _runnable_resources, - std::size_t _partition_id, - ucx::Resources& ucx, - memory::HostResources& host); + Resources(resources::PartitionResourceBase& base, ucx::Resources& ucx, memory::HostResources& host); ~Resources() final; - const ucx::RegistrationCache& registration_cache() const; + DELETE_COPYABILITY(Resources); + DEFAULT_MOVEABILITY(Resources); + + data_plane::Resources& data_plane(); private: - ucx::Resources& m_ucx; - memory::HostResources& m_host; - std::shared_ptr m_server; - std::shared_ptr m_client; + std::unique_ptr m_data_plane; }; } // namespace srf::internal::network diff --git a/src/internal/resources/forward.hpp b/src/internal/resources/forward.hpp index fcc752732..f097bce79 100644 --- a/src/internal/resources/forward.hpp +++ b/src/internal/resources/forward.hpp @@ -17,10 +17,33 @@ #pragma once -namespace srf::internal::resources { +namespace srf::internal { +namespace resources { class Manager; +class PartitionResourceBase; +} // namespace resources + +namespace runnable { +class Resources; +} // namespace runnable + +namespace memory { class HostResources; class DeviceResources; +} // namespace memory + +// control plane and data plane +namespace network { +class Resources; +} // namespace network + +namespace ucx { +class Resources; +} // namespace ucx + +namespace data_plane { +class Resources; +} // namespace data_plane -} // namespace srf::internal::resources +} // namespace srf::internal diff --git a/src/internal/resources/manager.cpp b/src/internal/resources/manager.cpp index 1b6e295b1..fb4ba59eb 100644 --- a/src/internal/resources/manager.cpp +++ b/src/internal/resources/manager.cpp @@ -17,6 +17,9 @@ #include "internal/resources/manager.hpp" +#include "internal/data_plane/resources.hpp" +#include "internal/resources/forward.hpp" +#include "internal/resources/partition_resources_base.hpp" #include "internal/system/partition.hpp" #include "internal/system/partitions.hpp" #include "internal/system/system.hpp" @@ -27,6 +30,7 @@ #include #include +#include #include #include #include @@ -51,19 +55,25 @@ Manager::Manager(std::unique_ptr resources) : m_runnable.emplace_back(*m_system, i); } - // construct ucx resources on each flattened partition - // this provides a ucx context, 2x workers and registration cache per partition - for (std::size_t i = 0; i < partitions.size(); ++i) + std::vector base_partition_resources; + for (int i = 0; i < partitions.size(); i++) { auto host_partition_id = partitions.at(i).host_partition_id(); + base_partition_resources.emplace_back(m_runnable.at(host_partition_id), i); + } + + // construct ucx resources on each flattened partition + // this provides a ucx context, ucx worker and registration cache per partition + for (auto& base : base_partition_resources) + { if (network_enabled) { - VLOG(1) << "building ucx resources for partition " << i; + VLOG(1) << "building ucx resources for partition " << base.partition_id(); auto network_task_queue_cpuset = - host_partitions.at(host_partition_id).engine_factory_cpu_sets().fiber_cpu_sets.at("srf_network"); + base.partition().host().engine_factory_cpu_sets().fiber_cpu_sets.at("srf_network"); auto& network_fiber_queue = m_system->get_task_queue(network_task_queue_cpuset.first()); std::optional ucx; - ucx.emplace(m_runnable.at(host_partition_id), i, network_fiber_queue); + ucx.emplace(base, network_fiber_queue); m_ucx.push_back(std::move(ucx)); } else @@ -91,15 +101,14 @@ Manager::Manager(std::unique_ptr resources) : } // devices - for (std::size_t i = 0; i < partition_count(); ++i) + for (auto& base : base_partition_resources) { - VLOG(1) << "building device resources for partition: " << i; - auto host_partition_id = partitions.at(i).host_partition_id(); - - if (i < device_count()) + VLOG(1) << "building device resources for partition: " << base.partition_id(); + if (base.partition().has_device()) { + DCHECK_LT(base.partition_id(), device_count()); std::optional device; - device.emplace(m_runnable.at(host_partition_id), i, m_ucx.at(i)); + device.emplace(base, m_ucx.at(base.partition_id())); m_device.emplace_back(std::move(device)); } else @@ -109,16 +118,15 @@ Manager::Manager(std::unique_ptr resources) : } // network resources - for (std::size_t i = 0; i < partition_count(); ++i) + for (auto& base : base_partition_resources) { if (network_enabled) { - VLOG(1) << "building network resources for partition: " << i; - CHECK(m_ucx.at(i)); - auto host_partition_id = partitions.at(i).host_partition_id(); + VLOG(1) << "building network resources for partition: " << base.partition_id(); + CHECK(m_ucx.at(base.partition_id())); std::optional network; - network.emplace(m_runnable.at(host_partition_id), i, *m_ucx.at(i), m_host.at(host_partition_id)); - m_network.emplace_back(std::move(network)); + network.emplace(base, *m_ucx.at(base.partition_id()), m_host.at(base.partition().host_partition_id())); + m_network.push_back(std::move(network)); } else { diff --git a/src/internal/resources/partition_resources_base.cpp b/src/internal/resources/partition_resources_base.cpp index 5b35aa7f7..3eb51258d 100644 --- a/src/internal/resources/partition_resources_base.cpp +++ b/src/internal/resources/partition_resources_base.cpp @@ -27,7 +27,7 @@ PartitionResourceBase::PartitionResourceBase(runnable::Resources& runnable, std: system::PartitionProvider(runnable, partition_id), m_runnable(runnable) { - CHECK_EQ(m_runnable.host_partition_id(), partition().host_partition_id()); + CHECK_EQ(runnable.host_partition_id(), partition().host_partition_id()); } runnable::Resources& PartitionResourceBase::runnable() { diff --git a/src/internal/resources/partition_resources_base.hpp b/src/internal/resources/partition_resources_base.hpp index 35bed9988..f1177efbb 100644 --- a/src/internal/resources/partition_resources_base.hpp +++ b/src/internal/resources/partition_resources_base.hpp @@ -20,7 +20,10 @@ #include "internal/runnable/resources.hpp" #include "internal/system/partition_provider.hpp" +#include "srf/utils/macros.hpp" + #include +#include namespace srf::internal::resources { @@ -34,12 +37,11 @@ class PartitionResourceBase : public system::PartitionProvider { public: PartitionResourceBase(runnable::Resources& runnable, std::size_t partition_id); - PartitionResourceBase(const PartitionResourceBase& other) = default; runnable::Resources& runnable(); private: - runnable::Resources& m_runnable; + std::reference_wrapper m_runnable; }; } // namespace srf::internal::resources diff --git a/src/internal/system/partition_provider.hpp b/src/internal/system/partition_provider.hpp index e415cb5d0..a25ddc4c7 100644 --- a/src/internal/system/partition_provider.hpp +++ b/src/internal/system/partition_provider.hpp @@ -41,7 +41,7 @@ class PartitionProvider : public SystemProvider const Partition& partition() const; private: - const std::size_t m_partition_id; + std::size_t m_partition_id; }; } // namespace srf::internal::system diff --git a/src/internal/ucx/memory_block.cpp b/src/internal/ucx/memory_block.cpp index 66c142031..21a09000a 100644 --- a/src/internal/ucx/memory_block.cpp +++ b/src/internal/ucx/memory_block.cpp @@ -17,9 +17,11 @@ #include "internal/ucx/memory_block.hpp" -namespace srf::internal::memory {} -srf::internal::ucx::MemoryBlock::MemoryBlock(void* data, std::size_t bytes) : memory::MemoryBlock(data, bytes) {} -srf::internal::ucx::MemoryBlock::MemoryBlock( +namespace srf::internal::ucx { + +MemoryBlock::MemoryBlock(void* data, std::size_t bytes) : memory::MemoryBlock(data, bytes) {} + +MemoryBlock::MemoryBlock( void* data, std::size_t bytes, ucp_mem_h local_handle, void* remote_handle, std::size_t remote_handle_size) : memory::MemoryBlock(data, bytes), m_local_handle(local_handle), @@ -32,10 +34,11 @@ srf::internal::ucx::MemoryBlock::MemoryBlock( CHECK(m_remote_handle && m_remote_handle_size); } } -srf::internal::ucx::MemoryBlock::MemoryBlock(const MemoryBlock& block, - ucp_mem_h local_handle, - void* remote_handle, - std::size_t remote_handle_size) : + +MemoryBlock::MemoryBlock(const MemoryBlock& block, + ucp_mem_h local_handle, + void* remote_handle, + std::size_t remote_handle_size) : memory::MemoryBlock(block), m_local_handle(local_handle), m_remote_handle(remote_handle), @@ -46,15 +49,28 @@ srf::internal::ucx::MemoryBlock::MemoryBlock(const MemoryBlock& block, CHECK(m_remote_handle && m_remote_handle_size); } } -ucp_mem_h srf::internal::ucx::MemoryBlock::local_handle() const + +ucp_mem_h MemoryBlock::local_handle() const { return m_local_handle; } -void* srf::internal::ucx::MemoryBlock::remote_handle() const + +void* MemoryBlock::remote_handle() const { return m_remote_handle; } -std::size_t srf::internal::ucx::MemoryBlock::remote_handle_size() const + +std::size_t MemoryBlock::remote_handle_size() const { return m_remote_handle_size; } + +std::string MemoryBlock::packed_remote_keys() const +{ + std::string keys; + keys.resize(m_remote_handle_size); + std::memcpy(keys.data(), m_remote_handle, m_remote_handle_size); + return keys; +} + +} // namespace srf::internal::ucx diff --git a/src/internal/ucx/memory_block.hpp b/src/internal/ucx/memory_block.hpp index ac24ae861..e5b95659c 100644 --- a/src/internal/ucx/memory_block.hpp +++ b/src/internal/ucx/memory_block.hpp @@ -55,6 +55,8 @@ struct MemoryBlock : public memory::MemoryBlock */ std::size_t remote_handle_size() const; + std::string packed_remote_keys() const; + private: ucp_mem_h m_local_handle{nullptr}; void* m_remote_handle{nullptr}; diff --git a/src/internal/ucx/registration_resource.hpp b/src/internal/ucx/registration_resource.hpp index 2f9686c70..b64d55998 100644 --- a/src/internal/ucx/registration_resource.hpp +++ b/src/internal/ucx/registration_resource.hpp @@ -28,7 +28,10 @@ namespace srf::internal::ucx { /** - * @brief Memory Resource Adaptor that registers allocated memory with the UCX via the ucx::RegistrationCache + * @brief Memory Resource adaptor to provide UCX registration to allocated blocks. + * + * This is an internal class and used only for constructing device memory resources. A more general implementation might + * separate our the CUDA DeviceID requirement. * * @tparam PointerT */ @@ -60,6 +63,7 @@ class RegistrationResource : public srf::memory::adaptor void do_deallocate(void* ptr, std::size_t bytes) final { + DeviceGuard guard(m_cuda_device_id); auto size = m_registration_cache->drop_block(ptr, bytes); this->resource().deallocate(ptr, size); } diff --git a/src/internal/ucx/resources.cpp b/src/internal/ucx/resources.cpp index 0ca130c6f..4c057da08 100644 --- a/src/internal/ucx/resources.cpp +++ b/src/internal/ucx/resources.cpp @@ -17,9 +17,11 @@ #include "internal/ucx/resources.hpp" +#include "internal/resources/partition_resources_base.hpp" #include "internal/system/device_partition.hpp" #include "internal/system/fiber_task_queue.hpp" #include "internal/system/partition.hpp" +#include "internal/ucx/endpoint.hpp" #include "internal/ucx/worker.hpp" #include "srf/core/task_queue.hpp" @@ -33,10 +35,8 @@ namespace srf::internal::ucx { -Resources::Resources(runnable::Resources& _runnable_resources, - std::size_t _partition_id, - system::FiberTaskQueue& network_task_queue) : - resources::PartitionResourceBase(_runnable_resources, _partition_id), +Resources::Resources(resources::PartitionResourceBase& base, system::FiberTaskQueue& network_task_queue) : + resources::PartitionResourceBase(base), m_network_task_queue(network_task_queue) { VLOG(1) << "constructing network resources for partition: " << partition_id() << " on partitions main task queue"; @@ -57,18 +57,14 @@ Resources::Resources(runnable::Resources& _runnable_resources, DVLOG(10) << "initializing ucx context"; m_ucx_context = std::make_shared(); - DVLOG(10) << "initialize a ucx data_plane worker for server"; - m_worker_server = std::make_shared(m_ucx_context); - - DVLOG(10) << "initialize a ucx data_plane worker for client"; - m_worker_client = std::make_shared(m_ucx_context); + DVLOG(10) << "initialize a ucx data_plane worker"; + m_worker = std::make_shared(m_ucx_context); DVLOG(10) << "initialize the registration cache for this context"; m_registration_cache = std::make_shared(m_ucx_context); // flush any work that needs to be done by the workers - while (m_worker_server->progress() != 0) {} - while (m_worker_client->progress() != 0) {} + while (m_worker->progress() != 0) {} }) .get(); } @@ -87,4 +83,16 @@ const RegistrationCache& Resources::registration_cache() const CHECK(m_registration_cache); return *m_registration_cache; } + +Worker& Resources::worker() +{ + CHECK(m_worker); + return *m_worker; +} + +std::shared_ptr Resources::make_ep(const std::string& worker_address) const +{ + return std::make_shared(m_worker, worker_address); +} + } // namespace srf::internal::ucx diff --git a/src/internal/ucx/resources.hpp b/src/internal/ucx/resources.hpp index c485ec5b7..2fd6b60ed 100644 --- a/src/internal/ucx/resources.hpp +++ b/src/internal/ucx/resources.hpp @@ -21,6 +21,7 @@ #include "internal/runnable/resources.hpp" #include "internal/system/fiber_task_queue.hpp" #include "internal/ucx/context.hpp" +#include "internal/ucx/endpoint.hpp" #include "internal/ucx/registation_callback_builder.hpp" #include "internal/ucx/registration_cache.hpp" #include "internal/ucx/registration_resource.hpp" @@ -39,20 +40,29 @@ class Resources; namespace srf::internal::ucx { +/** + * @brief UCX Resources - if networking is enabled, there should be 1 UCX Resource per "flattened" partition + */ class Resources final : private resources::PartitionResourceBase { public: - Resources(runnable::Resources& _runnable_resources, - std::size_t _partition_id, - system::FiberTaskQueue& network_task_queue); + Resources(resources::PartitionResourceBase& base, system::FiberTaskQueue& network_task_queue); using resources::PartitionResourceBase::partition; + // ucx worker associated with this partitions ucx context + Worker& worker(); + + // task queue used to run the data plane's progress engine srf::core::FiberTaskQueue& network_task_queue(); + + // registration cache to look up local/remote keys for registered blocks of memory const RegistrationCache& registration_cache() const; + // used to build a callback adaptor memory resource for host memory resources void add_registration_cache_to_builder(RegistrationCallbackBuilder& builder); + // used to build device memory resources that are registered with the ucx context template auto adapt_to_registered_resource(UpstreamT upstream, int cuda_device_id) { @@ -60,11 +70,12 @@ class Resources final : private resources::PartitionResourceBase std::move(upstream), m_registration_cache, cuda_device_id); } + std::shared_ptr make_ep(const std::string& worker_address) const; + private: system::FiberTaskQueue& m_network_task_queue; std::shared_ptr m_ucx_context; - std::shared_ptr m_worker_server; - std::shared_ptr m_worker_client; + std::shared_ptr m_worker; std::shared_ptr m_registration_cache; // enable direct access to context and workers diff --git a/src/internal/ucx/worker.cpp b/src/internal/ucx/worker.cpp index 0b1bd3db6..119f581fc 100644 --- a/src/internal/ucx/worker.cpp +++ b/src/internal/ucx/worker.cpp @@ -108,4 +108,10 @@ void Worker::release_address() } } +Context& Worker::context() +{ + CHECK(m_context); + return *m_context; +} + } // namespace srf::internal::ucx diff --git a/src/internal/ucx/worker.hpp b/src/internal/ucx/worker.hpp index c82100c51..048419336 100644 --- a/src/internal/ucx/worker.hpp +++ b/src/internal/ucx/worker.hpp @@ -44,6 +44,8 @@ class Worker : public Primitive Handle create_endpoint(WorkerAddress); + Context& context(); + private: Handle m_context; std::string m_address; diff --git a/src/public/codable/encoded_object.cpp b/src/public/codable/encoded_object.cpp index 694ba4771..83ab7ffeb 100644 --- a/src/public/codable/encoded_object.cpp +++ b/src/public/codable/encoded_object.cpp @@ -67,19 +67,19 @@ static protos::MemoryKind encode_memory_type(memory::memory_kind mem_kind) return protos::MemoryKind::None; } -memory::buffer_view EncodedObject::decode_descriptor(const protos::RemoteDescriptor& desc) +memory::buffer_view EncodedObject::decode_descriptor(const protos::RemoteMemoryDescriptor& desc) { return memory::buffer_view( reinterpret_cast(desc.remote_address()), desc.remote_bytes(), decode_memory_type(desc.memory_kind())); } -protos::RemoteDescriptor EncodedObject::encode_descriptor(memory::const_buffer_view view) +protos::RemoteMemoryDescriptor EncodedObject::encode_descriptor(memory::const_buffer_view view) { - protos::RemoteDescriptor desc; + protos::RemoteMemoryDescriptor desc; desc.set_remote_address(reinterpret_cast(view.data())); desc.set_remote_bytes(view.bytes()); desc.set_memory_kind(encode_memory_type(view.kind())); - + // get ucx registration if applicable return desc; } @@ -197,4 +197,13 @@ void EncodedObject::add_type_index(std::type_index type_index) obj->set_desc_id(descriptor_count()); } +std::size_t EncodedObject::add_buffer(std::shared_ptr mr, std::size_t bytes) +{ + CHECK(m_context_acquired); + memory::buffer buff(bytes, mr); + auto index = add_memory_block(buff); + m_buffers[index] = std::move(buff); + return index; +} + } // namespace srf::codable diff --git a/src/public/runnable/launcher.cpp b/src/public/runnable/launcher.cpp index a1378e9bf..adea98e9d 100644 --- a/src/public/runnable/launcher.cpp +++ b/src/public/runnable/launcher.cpp @@ -36,10 +36,7 @@ Launcher::Launcher(std::unique_ptr runner, m_engines(std::move(engines)) {} -Launcher::~Launcher() -{ - LOG_IF(WARNING, m_runner) << "destroying unused launcher"; -} +Launcher::~Launcher() = default; std::unique_ptr Launcher::ignition() { diff --git a/src/tests/test_network.cpp b/src/tests/test_network.cpp index 58eb0e345..5cebf73f2 100644 --- a/src/tests/test_network.cpp +++ b/src/tests/test_network.cpp @@ -15,6 +15,8 @@ * limitations under the License. */ +#include "internal/data_plane/client.hpp" +#include "internal/data_plane/resources.hpp" #include "internal/memory/device_resources.hpp" #include "internal/memory/host_resources.hpp" #include "internal/network/resources.hpp" @@ -110,8 +112,8 @@ TEST_F(TestNetwork, ResourceManager) auto h_buffer_0 = resources->partition(0).host().make_buffer(1_MiB); auto d_buffer_0 = resources->partition(0).device()->make_buffer(1_MiB); - auto h_ucx_block = resources->partition(0).network()->registration_cache().lookup(h_buffer_0.data()); - auto d_ucx_block = resources->partition(0).network()->registration_cache().lookup(d_buffer_0.data()); + auto h_ucx_block = resources->partition(0).network()->data_plane().registration_cache().lookup(h_buffer_0.data()); + auto d_ucx_block = resources->partition(0).network()->data_plane().registration_cache().lookup(d_buffer_0.data()); EXPECT_EQ(h_ucx_block.bytes(), 32_MiB); EXPECT_EQ(d_ucx_block.bytes(), 64_MiB); @@ -135,6 +137,120 @@ TEST_F(TestNetwork, ResourceManager) d_buffer_0.release(); } +TEST_F(TestNetwork, CommsSendRecv) +{ + // using options.placement().resources_strategy(PlacementResources::Shared) + // will test if cudaSetDevice is being properly called by the network services + // since all network services for potentially multiple devices are colocated on a single thread + auto resources = std::make_unique( + internal::system::SystemProvider(make_system([](Options& options) { + options.architect_url("localhost:13337"); + options.placement().resources_strategy(PlacementResources::Dedicated); + options.resources().enable_device_memory_pool(true); + options.resources().enable_host_memory_pool(true); + options.resources().host_memory_pool().block_size(32_MiB); + options.resources().host_memory_pool().max_aggregate_bytes(128_MiB); + options.resources().device_memory_pool().block_size(64_MiB); + options.resources().device_memory_pool().max_aggregate_bytes(128_MiB); + }))); + + if (resources->partition_count() < 2 && resources->device_count() < 2) + { + GTEST_SKIP() << "this test only works with 2 device partitions"; + } + + EXPECT_TRUE(resources->partition(0).network()); + EXPECT_TRUE(resources->partition(1).network()); + + auto& r0 = resources->partition(0).network()->data_plane(); + auto& r1 = resources->partition(1).network()->data_plane(); + + // here we are exchanging internal ucx worker addresses without the need of the control plane + r0.client().register_instance(1, r1.ucx_address()); // register r1 as instance_id 1 + r1.client().register_instance(0, r0.ucx_address()); // register r0 as instance_id 0 + + int src = 42; + int dst = -1; + + internal::data_plane::Request send_req; + internal::data_plane::Request recv_req; + + r1.client().async_recv(&dst, sizeof(int), 0, recv_req); + r0.client().async_send(&src, sizeof(int), 0, 1, send_req); + + LOG(INFO) << "await recv"; + recv_req.await_complete(); + LOG(INFO) << "await send"; + send_req.await_complete(); + + EXPECT_EQ(src, dst); + + // expect that the buffers are allowed to survive pass the resource manager + resources.reset(); +} + +TEST_F(TestNetwork, CommsGet) +{ + // using options.placement().resources_strategy(PlacementResources::Shared) + // will test if cudaSetDevice is being properly called by the network services + // since all network services for potentially multiple devices are colocated on a single thread + auto resources = std::make_unique( + internal::system::SystemProvider(make_system([](Options& options) { + options.architect_url("localhost:13337"); + options.placement().resources_strategy(PlacementResources::Dedicated); + options.resources().enable_device_memory_pool(true); + options.resources().enable_host_memory_pool(true); + options.resources().host_memory_pool().block_size(32_MiB); + options.resources().host_memory_pool().max_aggregate_bytes(128_MiB); + options.resources().device_memory_pool().block_size(64_MiB); + options.resources().device_memory_pool().max_aggregate_bytes(128_MiB); + }))); + + if (resources->partition_count() < 2 && resources->device_count() < 2) + { + GTEST_SKIP() << "this test only works with 2 device partitions"; + } + + EXPECT_TRUE(resources->partition(0).network()); + EXPECT_TRUE(resources->partition(1).network()); + + auto src = resources->partition(0).host().make_buffer(1_MiB); + auto dst = resources->partition(1).host().make_buffer(1_MiB); + + auto src_keys = + resources->partition(0).network()->data_plane().registration_cache().lookup(src.data()).packed_remote_keys(); + + auto* src_data = static_cast(src.data()); + std::size_t count = 1_MiB / sizeof(std::size_t); + for (std::size_t i = 0; i < count; ++i) + { + src_data[i] = 42; + } + + auto& r0 = resources->partition(0).network()->data_plane(); + auto& r1 = resources->partition(1).network()->data_plane(); + + // here we are exchanging internal ucx worker addresses without the need of the control plane + r0.client().register_instance(1, r1.ucx_address()); // register r1 as instance_id 1 + r1.client().register_instance(0, r0.ucx_address()); // register r0 as instance_id 0 + + internal::data_plane::Request get_req; + + r1.client().async_get(dst.data(), 1_MiB, 0, src.data(), src_keys, get_req); + + LOG(INFO) << "await get"; + get_req.await_complete(); + + auto* dst_data = static_cast(dst.data()); + for (std::size_t i = 0; i < count; ++i) + { + EXPECT_EQ(dst_data[i], 42); + } + + // expect that the buffers are allowed to survive pass the resource manager + resources.reset(); +} + // TEST_F(TestNetwork, NetworkEventsManagerLifeCycle) // { // auto launcher = m_launch_control->prepare_launcher(std::move(m_mutable_nem)); diff --git a/tests/test_codable.cpp b/tests/test_codable.cpp index 481d70070..484d3f50d 100644 --- a/tests/test_codable.cpp +++ b/tests/test_codable.cpp @@ -40,12 +40,12 @@ class CodableObject CodableObject() = default; ~CodableObject() = default; - static CodableObject deserialize(const EncodedObject& buffer, std::size_t) + static CodableObject deserialize(const EncodedObject& buffer, std::size_t /*unused*/) { return CodableObject(); } - void serialize(Encoded&) {} + void serialize(Encoded& /*unused*/) {} }; class CodableObjectWithOptions @@ -54,12 +54,12 @@ class CodableObjectWithOptions CodableObjectWithOptions() = default; ~CodableObjectWithOptions() = default; - static CodableObjectWithOptions deserialize(const EncodedObject& encoding, std::size_t) + static CodableObjectWithOptions deserialize(const EncodedObject& encoding, std::size_t /*unused*/) { return CodableObjectWithOptions(); } - void serialize(Encoded&, const EncodingOptions& opts) {} + void serialize(Encoded& /*unused*/, const EncodingOptions& opts) {} }; class CodableViaExternalStruct @@ -70,7 +70,7 @@ namespace srf::codable { template <> struct codable_protocol { - void serialize(const CodableViaExternalStruct&, Encoded&) {} + void serialize(const CodableViaExternalStruct& /*unused*/, Encoded& /*unused*/) {} }; }; // namespace srf::codable