Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

server: expose API to set send compressor #5744

Merged
merged 13 commits into from
Jan 31, 2023
28 changes: 17 additions & 11 deletions internal/transport/handler_server_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -262,31 +262,36 @@ func (s) TestHandlerTransport_HandleStreams(t *testing.T) {
t.Errorf("stream method = %q; want %q", s.method, want)
}

err := s.SetHeader(metadata.Pairs("custom-header", "Custom header value"))
if err != nil {
if err := s.SetHeader(metadata.Pairs("custom-header", "Custom header value")); err != nil {
t.Error(err)
}
err = s.SetTrailer(metadata.Pairs("custom-trailer", "Custom trailer value"))
if err != nil {

if err := s.SetTrailer(metadata.Pairs("custom-trailer", "Custom trailer value")); err != nil {
t.Error(err)
}

if err := s.SetSendCompress("gzip"); err != nil {
t.Error(err)
}

md := metadata.Pairs("custom-header", "Another custom header value")
err = s.SendHeader(md)
delete(md, "custom-header")
if err != nil {
if err := s.SendHeader(md); err != nil {
t.Error(err)
}
delete(md, "custom-header")

err = s.SetHeader(metadata.Pairs("too-late", "Header value that should be ignored"))
if err == nil {
if err := s.SetHeader(metadata.Pairs("too-late", "Header value that should be ignored")); err == nil {
t.Error("expected SetHeader call after SendHeader to fail")
}
err = s.SendHeader(metadata.Pairs("too-late", "This header value should be ignored as well"))
if err == nil {

if err := s.SendHeader(metadata.Pairs("too-late", "This header value should be ignored as well")); err == nil {
t.Error("expected second SendHeader call to fail")
}

if err := s.SetSendCompress("snappy"); err == nil {
t.Error("expected second SetSendCompress call to fail")
}

st.bodyw.Close() // no body
st.ht.WriteStatus(s, status.New(codes.OK, ""))
}
Expand All @@ -299,6 +304,7 @@ func (s) TestHandlerTransport_HandleStreams(t *testing.T) {
"Content-Type": {"application/grpc"},
"Trailer": {"Grpc-Status", "Grpc-Message", "Grpc-Status-Details-Bin"},
"Custom-Header": {"Custom header value", "Another custom header value"},
"Grpc-Encoding": {"gzip"},
}
wantTrailer := http.Header{
"Grpc-Status": {"0"},
Expand Down
4 changes: 4 additions & 0 deletions internal/transport/http2_server.go
Original file line number Diff line number Diff line change
Expand Up @@ -403,6 +403,10 @@ func (t *http2Server) operateHeaders(frame *http2.MetaHeadersFrame, handle func(
mdata[hf.Name] = append(mdata[hf.Name], hf.Value)
s.contentSubtype = contentSubtype
isGRPC = true

case "grpc-accept-encoding":
s.clientAdvertisedCompressors = hf.Value
dfawley marked this conversation as resolved.
Show resolved Hide resolved
mdata[hf.Name] = append(mdata[hf.Name], hf.Value)
dfawley marked this conversation as resolved.
Show resolved Hide resolved
case "grpc-encoding":
s.recvCompress = hf.Value
case ":method":
Expand Down
23 changes: 21 additions & 2 deletions internal/transport/transport.go
Original file line number Diff line number Diff line change
Expand Up @@ -253,6 +253,9 @@ type Stream struct {
fc *inFlow
wq *writeQuota

// Holds compressor names passed in grpc-accept-encoding metadata from the
// client. This is empty for the client side stream.
clientAdvertisedCompressors string
// Callback to state application's intentions to read data. This
// is used to adjust flow control, if needed.
requestRead func(int)
Expand Down Expand Up @@ -341,8 +344,24 @@ func (s *Stream) RecvCompress() string {
}

// SetSendCompress sets the compression algorithm to the stream.
func (s *Stream) SetSendCompress(str string) {
s.sendCompress = str
func (s *Stream) SetSendCompress(name string) error {
if s.isHeaderSent() || s.getState() == streamDone {
return errors.New("transport: set send compressor called after headers sent or stream done")
}

s.sendCompress = name
return nil
}

// SendCompress returns the send compressor name.
func (s *Stream) SendCompress() string {
return s.sendCompress
}

// ClientAdvertisedCompressors returns the compressor names advertised by the
// client via grpc-accept-encoding header.
func (s *Stream) ClientAdvertisedCompressors() string {
return s.clientAdvertisedCompressors
}

// Done returns a channel which is closed when it receives the final status
Expand Down
99 changes: 95 additions & 4 deletions server.go
Original file line number Diff line number Diff line change
Expand Up @@ -45,6 +45,7 @@ import (
"google.golang.org/grpc/internal/channelz"
"google.golang.org/grpc/internal/grpcrand"
"google.golang.org/grpc/internal/grpcsync"
"google.golang.org/grpc/internal/grpcutil"
"google.golang.org/grpc/internal/transport"
"google.golang.org/grpc/keepalive"
"google.golang.org/grpc/metadata"
Expand Down Expand Up @@ -1267,6 +1268,7 @@ func (s *Server) processUnaryRPC(t transport.ServerTransport, stream *transport.
var comp, decomp encoding.Compressor
var cp Compressor
var dc Decompressor
var sendCompressorName string

// If dc is set and matches the stream's compression, use it. Otherwise, try
// to find a matching registered compressor for decomp.
Expand All @@ -1287,12 +1289,18 @@ func (s *Server) processUnaryRPC(t transport.ServerTransport, stream *transport.
// NOTE: this needs to be ahead of all handling, https://github.com/grpc/grpc-go/issues/686.
if s.opts.cp != nil {
cp = s.opts.cp
stream.SetSendCompress(cp.Type())
sendCompressorName = cp.Type()
} else if rc := stream.RecvCompress(); rc != "" && rc != encoding.Identity {
// Legacy compressor not specified; attempt to respond with same encoding.
comp = encoding.GetCompressor(rc)
if comp != nil {
stream.SetSendCompress(rc)
sendCompressorName = comp.Name()
}
}

if sendCompressorName != "" {
if err := stream.SetSendCompress(sendCompressorName); err != nil {
return status.Errorf(codes.Internal, "grpc: failed to set send compressor: %v", err)
}
}

Expand Down Expand Up @@ -1379,6 +1387,11 @@ func (s *Server) processUnaryRPC(t transport.ServerTransport, stream *transport.
}
opts := &transport.Options{Last: true}

// Server handler could have set new compressor by calling SetSendCompressor.
// In case it is set, we need to use it for compressing outbound message.
if stream.SendCompress() != sendCompressorName {
jronak marked this conversation as resolved.
Show resolved Hide resolved
comp = encoding.GetCompressor(stream.SendCompress())
}
if err := s.sendResponse(t, stream, reply, cp, opts, comp); err != nil {
if err == io.EOF {
// The entire stream is done (for unary RPC only).
Expand Down Expand Up @@ -1606,12 +1619,18 @@ func (s *Server) processStreamingRPC(t transport.ServerTransport, stream *transp
// NOTE: this needs to be ahead of all handling, https://github.com/grpc/grpc-go/issues/686.
if s.opts.cp != nil {
ss.cp = s.opts.cp
stream.SetSendCompress(s.opts.cp.Type())
ss.sendCompressorName = s.opts.cp.Type()
} else if rc := stream.RecvCompress(); rc != "" && rc != encoding.Identity {
// Legacy compressor not specified; attempt to respond with same encoding.
ss.comp = encoding.GetCompressor(rc)
if ss.comp != nil {
stream.SetSendCompress(rc)
ss.sendCompressorName = rc
}
}

if ss.sendCompressorName != "" {
if err := stream.SetSendCompress(ss.sendCompressorName); err != nil {
return status.Errorf(codes.Internal, "grpc: failed to set send compressor: %v", err)
}
}

Expand Down Expand Up @@ -1944,6 +1963,59 @@ func SendHeader(ctx context.Context, md metadata.MD) error {
return nil
}

// SetSendCompressor sets a compressor for outbound messages.
dfawley marked this conversation as resolved.
Show resolved Hide resolved
dfawley marked this conversation as resolved.
Show resolved Hide resolved
// It must not be called after any event that causes headers to be sent
// (see _ServerStream_.SetHeader for a complete list). Provided compressor is used when below
jronak marked this conversation as resolved.
Show resolved Hide resolved
// conditions are met:
//
// - compressor is registered via encoding.RegisterCompressor
// - compressor name exists in the client advertised compressor names sent in
// grpc-accept-encoding header. Use _ServerStream_.ClientAdvertisedCompressors
jronak marked this conversation as resolved.
Show resolved Hide resolved
// to get client advertised compressor names.
//
// The context provided must be the context passed to the server's handler.
// It must be noted that compressor name "identity" disables the outbound compression.
jronak marked this conversation as resolved.
Show resolved Hide resolved
// By default, server messages will be sent using the same compressor with which
// request messages were sent.
//
// It is not safe to call SetSendCompressor concurrently with SendHeader and
// SendMsg.
//
easwars marked this conversation as resolved.
Show resolved Hide resolved
// # Experimental
jronak marked this conversation as resolved.
Show resolved Hide resolved
//
// Notice: This _function_ is EXPERIMENTAL and may be changed or removed in a
jronak marked this conversation as resolved.
Show resolved Hide resolved
// later release.
func SetSendCompressor(ctx context.Context, name string) error {
stream, ok := ServerTransportStreamFromContext(ctx).(*transport.Stream)
easwars marked this conversation as resolved.
Show resolved Hide resolved
if !ok || stream == nil {
return fmt.Errorf("failed to fetch the stream from the given context")
}

if err := validateSendCompressor(name, stream.ClientAdvertisedCompressors()); err != nil {
return fmt.Errorf("unable to set send compressor: %w", err)
}

return stream.SetSendCompress(name)
}

// ClientAdvertisedCompressors returns compressor names advertised by the client
// via grpc-accept-encoding header.
//
// The context provided must be the context passed to the server's handler.
//
// # Experimental
//
// Notice: This _function_ is EXPERIMENTAL and may be changed or removed in a
jronak marked this conversation as resolved.
Show resolved Hide resolved
// later release.
func ClientAdvertisedCompressors(ctx context.Context) ([]string, error) {
jronak marked this conversation as resolved.
Show resolved Hide resolved
stream, ok := ServerTransportStreamFromContext(ctx).(*transport.Stream)
if !ok || stream == nil {
return nil, fmt.Errorf("failed to fetch the stream from the given context %v", ctx)
}

return strings.Split(stream.ClientAdvertisedCompressors(), ","), nil
}

// SetTrailer sets the trailer metadata that will be sent when an RPC returns.
// When called more than once, all the provided metadata will be merged.
//
Expand Down Expand Up @@ -1978,3 +2050,22 @@ type channelzServer struct {
func (c *channelzServer) ChannelzMetric() *channelz.ServerInternalMetric {
return c.s.channelzMetric()
}

// validateSendCompressor returns an error when given compressor name cannot be
// handled by the server or the client based on the advertised compressors.
func validateSendCompressor(name, clientCompressors string) error {
if name == encoding.Identity {
return nil
}

if !grpcutil.IsCompressorNameRegistered(name) {
return fmt.Errorf("compressor not registered %q", name)
}

for _, c := range strings.Split(clientCompressors, ",") {
if c == name {
return nil // found match
}
}
return fmt.Errorf("client does not support compressor %q", name)
}
9 changes: 9 additions & 0 deletions stream.go
Original file line number Diff line number Diff line change
Expand Up @@ -1481,6 +1481,8 @@ type serverStream struct {
comp encoding.Compressor
decomp encoding.Compressor

sendCompressorName string
easwars marked this conversation as resolved.
Show resolved Hide resolved

maxReceiveMessageSize int
maxSendMessageSize int
trInfo *traceInfo
Expand Down Expand Up @@ -1573,6 +1575,13 @@ func (ss *serverStream) SendMsg(m interface{}) (err error) {
}
}()

// Server handler could have set new compressor by calling SetSendCompressor.
// In case it is set, we need to use it for compressing outbound message.
if sendCompressorsName := ss.s.SendCompress(); sendCompressorsName != ss.sendCompressorName {
ss.comp = encoding.GetCompressor(sendCompressorsName)
ss.sendCompressorName = sendCompressorsName
}

// load hdr, payload, data
hdr, payload, data, err := prepareMsg(m, ss.codec, ss.cp, ss.comp)
if err != nil {
Expand Down
Loading