Skip to content

Commit

Permalink
[Core] Fix core worker client pool leak (ray-project#41535)
Browse files Browse the repository at this point in the history
Currently core worker client pool doesn't remove clients in most cases (there are one or two places where Disconnect() might be called) and this caused memory leak. This PR adds a GC inside core worker client pool to remove IDLE clients (i.e.g clients that don't have active connections).

Signed-off-by: Jiajun Yao <jeromeyjj@gmail.com>
  • Loading branch information
jjyao authored Dec 9, 2023
1 parent 208e452 commit 1dffb4d
Show file tree
Hide file tree
Showing 12 changed files with 176 additions and 50 deletions.
13 changes: 13 additions & 0 deletions BUILD.bazel
Original file line number Diff line number Diff line change
Expand Up @@ -1536,6 +1536,19 @@ ray_cc_test(
],
)

ray_cc_test(
name = "core_worker_client_pool_test",
size = "small",
srcs = [
"src/ray/rpc/worker/test/core_worker_client_pool_test.cc",
],
tags = ["team:core"],
deps = [
":worker_rpc",
"@com_google_googletest//:gtest_main",
],
)

ray_cc_test(
name = "gcs_server_rpc_test",
size = "small",
Expand Down
1 change: 0 additions & 1 deletion src/mock/ray/rpc/worker/core_worker_client.h
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,6 @@ namespace rpc {
class MockCoreWorkerClientInterface : public ray::pubsub::MockSubscriberClientInterface,
public CoreWorkerClientInterface {
public:
MOCK_METHOD(const rpc::Address &, Addr, (), (const, override));
MOCK_METHOD(void,
PushActorTask,
(std::unique_ptr<PushTaskRequest> request,
Expand Down
2 changes: 2 additions & 0 deletions src/ray/common/grpc_util.h
Original file line number Diff line number Diff line change
Expand Up @@ -158,6 +158,8 @@ inline grpc::ChannelArguments CreateDefaultChannelArguments() {
arguments.SetInt(GRPC_ARG_KEEPALIVE_TIMEOUT_MS,
::RayConfig::instance().grpc_client_keepalive_timeout_ms());
}
arguments.SetInt(GRPC_ARG_CLIENT_IDLE_TIMEOUT_MS,
::RayConfig::instance().grpc_client_idle_timeout_ms());
return arguments;
}

Expand Down
2 changes: 2 additions & 0 deletions src/ray/common/ray_config_def.h
Original file line number Diff line number Diff line change
Expand Up @@ -781,6 +781,8 @@ RAY_CONFIG(int64_t, grpc_client_keepalive_time_ms, 300000)
/// grpc keepalive timeout for client.
RAY_CONFIG(int64_t, grpc_client_keepalive_timeout_ms, 120000)

RAY_CONFIG(int64_t, grpc_client_idle_timeout_ms, 1800000)

/// grpc streaming buffer size
/// Set it to 512kb
RAY_CONFIG(int64_t, grpc_stream_buffer_size, 512 * 1024);
Expand Down
23 changes: 5 additions & 18 deletions src/ray/core_worker/transport/direct_task_transport.cc
Original file line number Diff line number Diff line change
Expand Up @@ -221,7 +221,7 @@ void CoreWorkerDirectTaskSubmitter::OnWorkerIdle(
ReturnWorker(addr, was_error, error_detail, worker_exiting, scheduling_key);
}
} else {
auto &client = *client_cache_->GetOrConnect(addr);
auto client = client_cache_->GetOrConnect(addr);

while (!current_queue.empty() && !lease_entry.is_busy) {
auto task_spec = current_queue.front();
Expand Down Expand Up @@ -599,7 +599,7 @@ void CoreWorkerDirectTaskSubmitter::RequestNewWorkerIfNeeded(

void CoreWorkerDirectTaskSubmitter::PushNormalTask(
const rpc::Address &addr,
rpc::CoreWorkerClientInterface &client,
shared_ptr<rpc::CoreWorkerClientInterface> client,
const SchedulingKey &scheduling_key,
const TaskSpecification &task_spec,
const google::protobuf::RepeatedPtrField<rpc::ResourceMapEntry> &assigned_resources) {
Expand All @@ -620,7 +620,7 @@ void CoreWorkerDirectTaskSubmitter::PushNormalTask(
task_finisher_->MarkTaskWaitingForExecution(task_id,
NodeID::FromBinary(addr.raylet_id()),
WorkerID::FromBinary(addr.worker_id()));
client.PushNormalTask(
client->PushNormalTask(
std::move(request),
[this,
task_spec,
Expand Down Expand Up @@ -801,14 +801,7 @@ Status CoreWorkerDirectTaskSubmitter::CancelTask(TaskSpecification task_spec,
return Status::OK();
}
// Looks for an RPC handle for the worker executing the task.
auto maybe_client =
client_cache_->GetByID(WorkerID::FromBinary(rpc_client->second.worker_id()));
if (!maybe_client.has_value()) {
// If we don't have a connection to that worker, we can't cancel it.
// This case is reached for tasks that have unresolved dependencies.
return Status::OK();
}
client = maybe_client.value();
client = client_cache_->GetOrConnect(rpc_client->second);
}

RAY_CHECK(client != nullptr);
Expand Down Expand Up @@ -866,13 +859,7 @@ Status CoreWorkerDirectTaskSubmitter::CancelRemoteTask(const ObjectID &object_id
const rpc::Address &worker_addr,
bool force_kill,
bool recursive) {
auto maybe_client =
client_cache_->GetByID(WorkerID::FromBinary(worker_addr.worker_id()));

if (!maybe_client.has_value()) {
return Status::Invalid("No remote worker found");
}
auto client = maybe_client.value();
auto client = client_cache_->GetOrConnect(worker_addr);
auto request = rpc::RemoteCancelTaskRequest();
request.set_force_kill(force_kill);
request.set_recursive(recursive);
Expand Down
2 changes: 1 addition & 1 deletion src/ray/core_worker/transport/direct_task_transport.h
Original file line number Diff line number Diff line change
Expand Up @@ -220,7 +220,7 @@ class CoreWorkerDirectTaskSubmitter {

/// Push a task to a specific worker.
void PushNormalTask(const rpc::Address &addr,
rpc::CoreWorkerClientInterface &client,
std::shared_ptr<rpc::CoreWorkerClientInterface> client,
const SchedulingKey &task_queue_key,
const TaskSpecification &task_spec,
const google::protobuf::RepeatedPtrField<rpc::ResourceMapEntry>
Expand Down
6 changes: 0 additions & 6 deletions src/ray/raylet/test/local_object_manager_test.cc
Original file line number Diff line number Diff line change
Expand Up @@ -633,14 +633,12 @@ TEST_F(LocalObjectManagerTest, TestSpillObjectsOfSizeZero) {

std::vector<ObjectID> object_ids;
std::vector<std::unique_ptr<RayObject>> objects;
int64_t total_size = 0;
int64_t object_size = 1000;

for (size_t i = 0; i < 3; i++) {
ObjectID object_id = ObjectID::FromRandom();
object_ids.push_back(object_id);
auto data_buffer = std::make_shared<MockObjectBuffer>(object_size, object_id, unpins);
total_size += object_size;
auto object = std::make_unique<RayObject>(
data_buffer, nullptr, std::vector<rpc::ObjectReference>());
objects.push_back(std::move(object));
Expand Down Expand Up @@ -1417,14 +1415,12 @@ TEST_F(LocalObjectManagerFusedTest, TestMinSpillingSize) {

std::vector<ObjectID> object_ids;
std::vector<std::unique_ptr<RayObject>> objects;
int64_t total_size = 0;
int64_t object_size = 52;

for (size_t i = 0; i < 3; i++) {
ObjectID object_id = ObjectID::FromRandom();
object_ids.push_back(object_id);
auto data_buffer = std::make_shared<MockObjectBuffer>(object_size, object_id, unpins);
total_size += object_size;
auto object = std::make_unique<RayObject>(
data_buffer, nullptr, std::vector<rpc::ObjectReference>());
objects.push_back(std::move(object));
Expand Down Expand Up @@ -1479,7 +1475,6 @@ TEST_F(LocalObjectManagerFusedTest, TestMinSpillingSizeMaxFusionCount) {

std::vector<ObjectID> object_ids;
std::vector<std::unique_ptr<RayObject>> objects;
int64_t total_size = 0;
// 20 of these objects are needed to hit the min spilling size, but
// max_fused_object_count=15.
int64_t object_size = 5;
Expand All @@ -1488,7 +1483,6 @@ TEST_F(LocalObjectManagerFusedTest, TestMinSpillingSizeMaxFusionCount) {
ObjectID object_id = ObjectID::FromRandom();
object_ids.push_back(object_id);
auto data_buffer = std::make_shared<MockObjectBuffer>(object_size, object_id, unpins);
total_size += object_size;
auto object = std::make_unique<RayObject>(
data_buffer, nullptr, std::vector<rpc::ObjectReference>());
objects.push_back(std::move(object));
Expand Down
12 changes: 12 additions & 0 deletions src/ray/rpc/grpc_client.h
Original file line number Diff line number Diff line change
Expand Up @@ -156,10 +156,20 @@ class GrpcClient {
std::move(call_name),
method_timeout_ms);
RAY_CHECK(call != nullptr);
call_method_invoked_ = true;
}

std::shared_ptr<grpc::Channel> Channel() const { return channel_; }

/// A channel is IDLE when it's first created before making any RPCs
/// or after GRPC_ARG_CLIENT_IDLE_TIMEOUT_MS of no activities since the last RPC.
/// This method detects IDLE in the second case.
/// Also see https://grpc.github.io/grpc/core/md_doc_connectivity-semantics-and-api.html
/// for channel connectivity state machine.
bool IsChannelIdleAfterRPCs() const {
return (channel_->GetState(false) == GRPC_CHANNEL_IDLE) && call_method_invoked_;
}

private:
ClientCallManager &client_call_manager_;
/// The gRPC-generated stub.
Expand All @@ -168,6 +178,8 @@ class GrpcClient {
bool use_tls_;
/// The channel of the stub.
std::shared_ptr<grpc::Channel> channel_;
/// Whether CallMethod is invoked.
bool call_method_invoked_ = false;
};

} // namespace rpc
Expand Down
6 changes: 6 additions & 0 deletions src/ray/rpc/worker/core_worker_client.h
Original file line number Diff line number Diff line change
Expand Up @@ -82,6 +82,8 @@ class CoreWorkerClientInterface : public pubsub::SubscriberClientInterface {
return empty_addr_;
}

virtual bool IsChannelIdleAfterRPCs() const { return false; }

/// Push an actor task directly from worker to worker.
///
/// \param[in] request The request message.
Expand Down Expand Up @@ -214,6 +216,10 @@ class CoreWorkerClient : public std::enable_shared_from_this<CoreWorkerClient>,

const rpc::Address &Addr() const override { return addr_; }

bool IsChannelIdleAfterRPCs() const override {
return grpc_client_->IsChannelIdleAfterRPCs();
}

VOID_RPC_CLIENT_METHOD(CoreWorkerService,
DirectActorCallArgWaitComplete,
grpc_client_,
Expand Down
51 changes: 34 additions & 17 deletions src/ray/rpc/worker/core_worker_client_pool.cc
Original file line number Diff line number Diff line change
Expand Up @@ -17,31 +17,47 @@
namespace ray {
namespace rpc {

optional<shared_ptr<CoreWorkerClientInterface>> CoreWorkerClientPool::GetByID(
ray::WorkerID id) {
absl::MutexLock lock(&mu_);
auto it = client_map_.find(id);
if (it == client_map_.end()) {
return {};
}
return it->second;
}

shared_ptr<CoreWorkerClientInterface> CoreWorkerClientPool::GetOrConnect(
const Address &addr_proto) {
RAY_CHECK(addr_proto.worker_id() != "");
RAY_CHECK_NE(addr_proto.worker_id(), "");
absl::MutexLock lock(&mu_);

RemoveIdleClients();

CoreWorkerClientEntry entry;
auto id = WorkerID::FromBinary(addr_proto.worker_id());
auto it = client_map_.find(id);
if (it != client_map_.end()) {
return it->second;
entry = *it->second;
client_list_.erase(it->second);
} else {
entry = CoreWorkerClientEntry(id, client_factory_(addr_proto));
}
auto connection = client_factory_(addr_proto);
client_map_[id] = connection;
client_list_.emplace_front(entry);
client_map_[id] = client_list_.begin();

RAY_LOG(DEBUG) << "Connected to " << addr_proto.ip_address() << ":"
<< addr_proto.port();
return connection;
RAY_LOG(DEBUG) << "Connected to worker " << id << " with address "
<< addr_proto.ip_address() << ":" << addr_proto.port();
return entry.core_worker_client;
}

void CoreWorkerClientPool::RemoveIdleClients() {
while (!client_list_.empty()) {
auto id = client_list_.back().worker_id;
// The last client in the list is the least recent accessed client.
if (client_list_.back().core_worker_client->IsChannelIdleAfterRPCs()) {
client_map_.erase(id);
client_list_.pop_back();
RAY_LOG(DEBUG) << "Remove idle client to worker " << id
<< " , num of clients is now " << client_list_.size();
} else {
auto entry = client_list_.back();
client_list_.pop_back();
client_list_.emplace_front(entry);
client_map_[id] = client_list_.begin();
break;
}
}
}

void CoreWorkerClientPool::Disconnect(ray::WorkerID id) {
Expand All @@ -50,6 +66,7 @@ void CoreWorkerClientPool::Disconnect(ray::WorkerID id) {
if (it == client_map_.end()) {
return;
}
client_list_.erase(it->second);
client_map_.erase(it);
}

Expand Down
38 changes: 31 additions & 7 deletions src/ray/rpc/worker/core_worker_client_pool.h
Original file line number Diff line number Diff line change
Expand Up @@ -38,11 +38,6 @@ class CoreWorkerClientPool {
CoreWorkerClientPool(ClientFactoryFn client_factory)
: client_factory_(client_factory){};

/// Returns an existing Interface if one exists, or an empty optional
/// otherwise.
/// Any returned pointer is borrowed, and expected to be used briefly.
optional<shared_ptr<CoreWorkerClientInterface>> GetByID(ray::WorkerID id);

/// Returns an open CoreWorkerClientInterface if one exists, and connect to one
/// if it does not. The returned pointer is borrowed, and expected to be used
/// briefly.
Expand All @@ -53,6 +48,13 @@ class CoreWorkerClientPool {
/// be open until it's no longer used, at which time it will disconnect.
void Disconnect(ray::WorkerID id);

/// For testing.
size_t Size() {
absl::MutexLock lock(&mu_);
RAY_CHECK_EQ(client_list_.size(), client_map_.size());
return client_list_.size();
}

private:
/// Provides the default client factory function. Providing this function to the
/// construtor aids migration but is ultimately a thing that should be
Expand All @@ -63,17 +65,39 @@ class CoreWorkerClientPool {
};
};

/// Try to remove some idle clients to free memory.
/// It doesn't go through the entire list and remove all idle clients.
/// Instead, it tries to remove idle clients from the end of the list
/// and stops when it finds the first non-idle client.
/// However, it's guaranteed that all idle clients will eventually be
/// removed as long as the method will be called repeatedly.
void RemoveIdleClients() ABSL_EXCLUSIVE_LOCKS_REQUIRED(mu_);

/// This factory function does the connection to CoreWorkerClient, and is
/// provided by the constructor (either the default implementation, above, or a
/// provided one)
ClientFactoryFn client_factory_;

absl::Mutex mu_;

struct CoreWorkerClientEntry {
public:
CoreWorkerClientEntry() {}
CoreWorkerClientEntry(ray::WorkerID worker_id,
shared_ptr<CoreWorkerClientInterface> core_worker_client)
: worker_id(worker_id), core_worker_client(core_worker_client) {}

ray::WorkerID worker_id;
shared_ptr<CoreWorkerClientInterface> core_worker_client;
};

/// A list of open connections from the most recent accessed to the least recent
/// accessed. This is used to check and remove idle connections.
std::list<CoreWorkerClientEntry> client_list_ ABSL_GUARDED_BY(mu_);
/// A pool of open connections by WorkerID. Clients can reuse the connection
/// objects in this pool by requesting them.
absl::flat_hash_map<ray::WorkerID, shared_ptr<CoreWorkerClientInterface>> client_map_
ABSL_GUARDED_BY(mu_);
absl::flat_hash_map<ray::WorkerID, std::list<CoreWorkerClientEntry>::iterator>
client_map_ ABSL_GUARDED_BY(mu_);
};

} // namespace rpc
Expand Down
Loading

0 comments on commit 1dffb4d

Please sign in to comment.