Skip to content

Commit

Permalink
[call-v3] Convert message size filter to new API (grpc#35233)
Browse files Browse the repository at this point in the history
Closes grpc#35233

COPYBARA_INTEGRATE_REVIEW=grpc#35233 from ctiller:cg-msg-size cce51d8
PiperOrigin-RevId: 588793125
  • Loading branch information
ctiller authored and copybara-github committed Dec 7, 2023
1 parent 6c816a4 commit 5f92a67
Show file tree
Hide file tree
Showing 3 changed files with 228 additions and 98 deletions.
122 changes: 51 additions & 71 deletions src/core/ext/filters/message_size/message_size_filter.cc
Original file line number Diff line number Diff line change
Expand Up @@ -50,6 +50,13 @@

namespace grpc_core {

const NoInterceptor ClientMessageSizeFilter::Call::OnClientInitialMetadata;
const NoInterceptor ClientMessageSizeFilter::Call::OnServerInitialMetadata;
const NoInterceptor ClientMessageSizeFilter::Call::OnServerTrailingMetadata;
const NoInterceptor ServerMessageSizeFilter::Call::OnClientInitialMetadata;
const NoInterceptor ServerMessageSizeFilter::Call::OnServerInitialMetadata;
const NoInterceptor ServerMessageSizeFilter::Call::OnServerTrailingMetadata;

//
// MessageSizeParsedConfig
//
Expand Down Expand Up @@ -138,60 +145,6 @@ const grpc_channel_filter ServerMessageSizeFilter::kFilter =
kFilterExaminesOutboundMessages |
kFilterExaminesInboundMessages>("message_size");

class MessageSizeFilter::CallBuilder {
private:
auto Interceptor(uint32_t max_length, bool is_send) {
return [max_length, is_send,
err = err_](MessageHandle msg) -> absl::optional<MessageHandle> {
if (grpc_call_trace.enabled()) {
gpr_log(GPR_INFO, "%s[message_size] %s len:%" PRIdPTR " max:%d",
Activity::current()->DebugTag().c_str(),
is_send ? "send" : "recv", msg->payload()->Length(),
max_length);
}
if (msg->payload()->Length() > max_length) {
if (err->is_set()) return std::move(msg);
auto r = GetContext<Arena>()->MakePooled<ServerMetadata>(
GetContext<Arena>());
r->Set(GrpcStatusMetadata(), GRPC_STATUS_RESOURCE_EXHAUSTED);
r->Set(GrpcMessageMetadata(),
Slice::FromCopiedString(
absl::StrFormat("%s message larger than max (%u vs. %d)",
is_send ? "Sent" : "Received",
msg->payload()->Length(), max_length)));
err->Set(std::move(r));
return absl::nullopt;
}
return std::move(msg);
};
}

public:
explicit CallBuilder(const MessageSizeParsedConfig& limits)
: limits_(limits) {}

template <typename T>
void AddSend(T* pipe_end) {
if (!limits_.max_send_size().has_value()) return;
pipe_end->InterceptAndMap(Interceptor(*limits_.max_send_size(), true));
}
template <typename T>
void AddRecv(T* pipe_end) {
if (!limits_.max_recv_size().has_value()) return;
pipe_end->InterceptAndMap(Interceptor(*limits_.max_recv_size(), false));
}

ArenaPromise<ServerMetadataHandle> Run(
CallArgs call_args, NextPromiseFactory next_promise_factory) {
return Race(err_->Wait(), next_promise_factory(std::move(call_args)));
}

private:
Latch<ServerMetadataHandle>* const err_ =
GetContext<Arena>()->ManagedNew<Latch<ServerMetadataHandle>>();
MessageSizeParsedConfig limits_;
};

absl::StatusOr<ClientMessageSizeFilter> ClientMessageSizeFilter::Create(
const ChannelArgs& args, ChannelFilter::Args) {
return ClientMessageSizeFilter(args);
Expand All @@ -202,20 +155,40 @@ absl::StatusOr<ServerMessageSizeFilter> ServerMessageSizeFilter::Create(
return ServerMessageSizeFilter(args);
}

ArenaPromise<ServerMetadataHandle> ClientMessageSizeFilter::MakeCallPromise(
CallArgs call_args, NextPromiseFactory next_promise_factory) {
namespace {
ServerMetadataHandle CheckPayload(const Message& msg,
absl::optional<uint32_t> max_length,
bool is_send) {
if (!max_length.has_value()) return nullptr;
if (GRPC_TRACE_FLAG_ENABLED(grpc_call_trace)) {
gpr_log(GPR_INFO, "%s[message_size] %s len:%" PRIdPTR " max:%d",
Activity::current()->DebugTag().c_str(), is_send ? "send" : "recv",
msg.payload()->Length(), *max_length);
}
if (msg.payload()->Length() <= *max_length) return nullptr;
auto r = GetContext<Arena>()->MakePooled<ServerMetadata>(GetContext<Arena>());
r->Set(GrpcStatusMetadata(), GRPC_STATUS_RESOURCE_EXHAUSTED);
r->Set(GrpcMessageMetadata(), Slice::FromCopiedString(absl::StrFormat(
"%s message larger than max (%u vs. %d)",
is_send ? "Sent" : "Received",
msg.payload()->Length(), *max_length)));
return r;
}
} // namespace

ClientMessageSizeFilter::Call::Call(ClientMessageSizeFilter* filter)
: limits_(filter->parsed_config_) {
// Get max sizes from channel data, then merge in per-method config values.
// Note: Per-method config is only available on the client, so we
// apply the max request size to the send limit and the max response
// size to the receive limit.
MessageSizeParsedConfig limits = this->limits();
const MessageSizeParsedConfig* config_from_call_context =
MessageSizeParsedConfig::GetFromCallContext(
GetContext<grpc_call_context_element>(),
service_config_parser_index_);
filter->service_config_parser_index_);
if (config_from_call_context != nullptr) {
absl::optional<uint32_t> max_send_size = limits.max_send_size();
absl::optional<uint32_t> max_recv_size = limits.max_recv_size();
absl::optional<uint32_t> max_send_size = limits_.max_send_size();
absl::optional<uint32_t> max_recv_size = limits_.max_recv_size();
if (config_from_call_context->max_send_size().has_value() &&
(!max_send_size.has_value() ||
*config_from_call_context->max_send_size() < *max_send_size)) {
Expand All @@ -226,21 +199,28 @@ ArenaPromise<ServerMetadataHandle> ClientMessageSizeFilter::MakeCallPromise(
*config_from_call_context->max_recv_size() < *max_recv_size)) {
max_recv_size = *config_from_call_context->max_recv_size();
}
limits = MessageSizeParsedConfig(max_send_size, max_recv_size);
limits_ = MessageSizeParsedConfig(max_send_size, max_recv_size);
}
}

ServerMetadataHandle ServerMessageSizeFilter::Call::OnClientToServerMessage(
const Message& message, ServerMessageSizeFilter* filter) {
return CheckPayload(message, filter->parsed_config_.max_recv_size(), false);
}

ServerMetadataHandle ServerMessageSizeFilter::Call::OnServerToClientMessage(
const Message& message, ServerMessageSizeFilter* filter) {
return CheckPayload(message, filter->parsed_config_.max_send_size(), true);
}

CallBuilder b(limits);
b.AddSend(call_args.client_to_server_messages);
b.AddRecv(call_args.server_to_client_messages);
return b.Run(std::move(call_args), std::move(next_promise_factory));
ServerMetadataHandle ClientMessageSizeFilter::Call::OnClientToServerMessage(
const Message& message) {
return CheckPayload(message, limits_.max_send_size(), true);
}

ArenaPromise<ServerMetadataHandle> ServerMessageSizeFilter::MakeCallPromise(
CallArgs call_args, NextPromiseFactory next_promise_factory) {
CallBuilder b(limits());
b.AddSend(call_args.server_to_client_messages);
b.AddRecv(call_args.client_to_server_messages);
return b.Run(std::move(call_args), std::move(next_promise_factory));
ServerMetadataHandle ClientMessageSizeFilter::Call::OnServerToClientMessage(
const Message& message) {
return CheckPayload(message, limits_.max_recv_size(), false);
}

namespace {
Expand Down
56 changes: 33 additions & 23 deletions src/core/ext/filters/message_size/message_size_filter.h
Original file line number Diff line number Diff line change
Expand Up @@ -86,48 +86,58 @@ class MessageSizeParser : public ServiceConfigParser::Parser {
absl::optional<uint32_t> GetMaxRecvSizeFromChannelArgs(const ChannelArgs& args);
absl::optional<uint32_t> GetMaxSendSizeFromChannelArgs(const ChannelArgs& args);

class MessageSizeFilter : public ChannelFilter {
protected:
explicit MessageSizeFilter(const ChannelArgs& args)
: limits_(MessageSizeParsedConfig::GetFromChannelArgs(args)) {}

class CallBuilder;

const MessageSizeParsedConfig& limits() const { return limits_; }

private:
MessageSizeParsedConfig limits_;
};

class ServerMessageSizeFilter final : public MessageSizeFilter {
class ServerMessageSizeFilter final
: public ImplementChannelFilter<ServerMessageSizeFilter> {
public:
static const grpc_channel_filter kFilter;

static absl::StatusOr<ServerMessageSizeFilter> Create(
const ChannelArgs& args, ChannelFilter::Args filter_args);

// Construct a promise for one call.
ArenaPromise<ServerMetadataHandle> MakeCallPromise(
CallArgs call_args, NextPromiseFactory next_promise_factory) override;
class Call {
public:
static const NoInterceptor OnClientInitialMetadata;
static const NoInterceptor OnServerInitialMetadata;
static const NoInterceptor OnServerTrailingMetadata;
ServerMetadataHandle OnClientToServerMessage(
const Message& message, ServerMessageSizeFilter* filter);
ServerMetadataHandle OnServerToClientMessage(
const Message& message, ServerMessageSizeFilter* filter);
};

private:
using MessageSizeFilter::MessageSizeFilter;
explicit ServerMessageSizeFilter(const ChannelArgs& args)
: parsed_config_(MessageSizeParsedConfig::GetFromChannelArgs(args)) {}
const MessageSizeParsedConfig parsed_config_;
};

class ClientMessageSizeFilter final : public MessageSizeFilter {
class ClientMessageSizeFilter final
: public ImplementChannelFilter<ClientMessageSizeFilter> {
public:
static const grpc_channel_filter kFilter;

static absl::StatusOr<ClientMessageSizeFilter> Create(
const ChannelArgs& args, ChannelFilter::Args filter_args);

// Construct a promise for one call.
ArenaPromise<ServerMetadataHandle> MakeCallPromise(
CallArgs call_args, NextPromiseFactory next_promise_factory) override;
class Call {
public:
explicit Call(ClientMessageSizeFilter* filter);

static const NoInterceptor OnClientInitialMetadata;
static const NoInterceptor OnServerInitialMetadata;
static const NoInterceptor OnServerTrailingMetadata;
ServerMetadataHandle OnClientToServerMessage(const Message& message);
ServerMetadataHandle OnServerToClientMessage(const Message& message);

private:
MessageSizeParsedConfig limits_;
};

private:
explicit ClientMessageSizeFilter(const ChannelArgs& args)
: parsed_config_(MessageSizeParsedConfig::GetFromChannelArgs(args)) {}
const size_t service_config_parser_index_{MessageSizeParser::ParserIndex()};
using MessageSizeFilter::MessageSizeFilter;
const MessageSizeParsedConfig parsed_config_;
};

} // namespace grpc_core
Expand Down
Loading

0 comments on commit 5f92a67

Please sign in to comment.