diff --git a/internal/transport/handler_server_test.go b/internal/transport/handler_server_test.go index fbd8058b79fb..a6eb20285787 100644 --- a/internal/transport/handler_server_test.go +++ b/internal/transport/handler_server_test.go @@ -280,31 +280,36 @@ func (s) TestHandlerTransport_HandleStreams(t *testing.T) { t.Errorf("stream method = %q; want %q", s.method, want) } - err := s.SetHeader(metadata.Pairs("custom-header", "Custom header value")) - if err != nil { + if err := s.SetHeader(metadata.Pairs("custom-header", "Custom header value")); err != nil { t.Error(err) } - err = s.SetTrailer(metadata.Pairs("custom-trailer", "Custom trailer value")) - if err != nil { + + if err := s.SetTrailer(metadata.Pairs("custom-trailer", "Custom trailer value")); err != nil { + t.Error(err) + } + + if err := s.SetSendCompress("gzip"); err != nil { t.Error(err) } md := metadata.Pairs("custom-header", "Another custom header value") - err = s.SendHeader(md) - delete(md, "custom-header") - if err != nil { + if err := s.SendHeader(md); err != nil { t.Error(err) } + delete(md, "custom-header") - err = s.SetHeader(metadata.Pairs("too-late", "Header value that should be ignored")) - if err == nil { + if err := s.SetHeader(metadata.Pairs("too-late", "Header value that should be ignored")); err == nil { t.Error("expected SetHeader call after SendHeader to fail") } - err = s.SendHeader(metadata.Pairs("too-late", "This header value should be ignored as well")) - if err == nil { + + if err := s.SendHeader(metadata.Pairs("too-late", "This header value should be ignored as well")); err == nil { t.Error("expected second SendHeader call to fail") } + if err := s.SetSendCompress("snappy"); err == nil { + t.Error("expected second SetSendCompress call to fail") + } + st.bodyw.Close() // no body st.ht.WriteStatus(s, status.New(codes.OK, "")) } @@ -317,6 +322,7 @@ func (s) TestHandlerTransport_HandleStreams(t *testing.T) { "Content-Type": {"application/grpc"}, "Trailer": {"Grpc-Status", "Grpc-Message", "Grpc-Status-Details-Bin"}, "Custom-Header": {"Custom header value", "Another custom header value"}, + "Grpc-Encoding": {"gzip"}, } wantTrailer := http.Header{ "Grpc-Status": {"0"}, diff --git a/internal/transport/http2_server.go b/internal/transport/http2_server.go index bc3da706726d..7dee882bf663 100644 --- a/internal/transport/http2_server.go +++ b/internal/transport/http2_server.go @@ -404,6 +404,17 @@ func (t *http2Server) operateHeaders(frame *http2.MetaHeadersFrame, handle func( mdata[hf.Name] = append(mdata[hf.Name], hf.Value) s.contentSubtype = contentSubtype isGRPC = true + + case "grpc-accept-encoding": + mdata[hf.Name] = append(mdata[hf.Name], hf.Value) + if hf.Value == "" { + continue + } + compressors := hf.Value + if s.clientAdvertisedCompressors != "" { + compressors = s.clientAdvertisedCompressors + "," + compressors + } + s.clientAdvertisedCompressors = compressors case "grpc-encoding": s.recvCompress = hf.Value case ":method": diff --git a/internal/transport/transport.go b/internal/transport/transport.go index 0ac77ea4f8c7..1b7d7fabc512 100644 --- a/internal/transport/transport.go +++ b/internal/transport/transport.go @@ -257,6 +257,9 @@ type Stream struct { fc *inFlow wq *writeQuota + // Holds compressor names passed in grpc-accept-encoding metadata from the + // client. This is empty for the client side stream. + clientAdvertisedCompressors string // Callback to state application's intentions to read data. This // is used to adjust flow control, if needed. requestRead func(int) @@ -345,8 +348,24 @@ func (s *Stream) RecvCompress() string { } // SetSendCompress sets the compression algorithm to the stream. -func (s *Stream) SetSendCompress(str string) { - s.sendCompress = str +func (s *Stream) SetSendCompress(name string) error { + if s.isHeaderSent() || s.getState() == streamDone { + return errors.New("transport: set send compressor called after headers sent or stream done") + } + + s.sendCompress = name + return nil +} + +// SendCompress returns the send compressor name. +func (s *Stream) SendCompress() string { + return s.sendCompress +} + +// ClientAdvertisedCompressors returns the compressor names advertised by the +// client via grpc-accept-encoding header. +func (s *Stream) ClientAdvertisedCompressors() string { + return s.clientAdvertisedCompressors } // Done returns a channel which is closed when it receives the final status diff --git a/server.go b/server.go index d5a6e78be44d..c225f044037a 100644 --- a/server.go +++ b/server.go @@ -45,6 +45,7 @@ import ( "google.golang.org/grpc/internal/channelz" "google.golang.org/grpc/internal/grpcrand" "google.golang.org/grpc/internal/grpcsync" + "google.golang.org/grpc/internal/grpcutil" "google.golang.org/grpc/internal/transport" "google.golang.org/grpc/keepalive" "google.golang.org/grpc/metadata" @@ -1263,6 +1264,7 @@ func (s *Server) processUnaryRPC(t transport.ServerTransport, stream *transport. var comp, decomp encoding.Compressor var cp Compressor var dc Decompressor + var sendCompressorName string // If dc is set and matches the stream's compression, use it. Otherwise, try // to find a matching registered compressor for decomp. @@ -1283,12 +1285,18 @@ func (s *Server) processUnaryRPC(t transport.ServerTransport, stream *transport. // NOTE: this needs to be ahead of all handling, https://github.com/grpc/grpc-go/issues/686. if s.opts.cp != nil { cp = s.opts.cp - stream.SetSendCompress(cp.Type()) + sendCompressorName = cp.Type() } else if rc := stream.RecvCompress(); rc != "" && rc != encoding.Identity { // Legacy compressor not specified; attempt to respond with same encoding. comp = encoding.GetCompressor(rc) if comp != nil { - stream.SetSendCompress(rc) + sendCompressorName = comp.Name() + } + } + + if sendCompressorName != "" { + if err := stream.SetSendCompress(sendCompressorName); err != nil { + return status.Errorf(codes.Internal, "grpc: failed to set send compressor: %v", err) } } @@ -1375,6 +1383,11 @@ func (s *Server) processUnaryRPC(t transport.ServerTransport, stream *transport. } opts := &transport.Options{Last: true} + // Server handler could have set new compressor by calling SetSendCompressor. + // In case it is set, we need to use it for compressing outbound message. + if stream.SendCompress() != sendCompressorName { + comp = encoding.GetCompressor(stream.SendCompress()) + } if err := s.sendResponse(t, stream, reply, cp, opts, comp); err != nil { if err == io.EOF { // The entire stream is done (for unary RPC only). @@ -1597,12 +1610,18 @@ func (s *Server) processStreamingRPC(t transport.ServerTransport, stream *transp // NOTE: this needs to be ahead of all handling, https://github.com/grpc/grpc-go/issues/686. if s.opts.cp != nil { ss.cp = s.opts.cp - stream.SetSendCompress(s.opts.cp.Type()) + ss.sendCompressorName = s.opts.cp.Type() } else if rc := stream.RecvCompress(); rc != "" && rc != encoding.Identity { // Legacy compressor not specified; attempt to respond with same encoding. ss.comp = encoding.GetCompressor(rc) if ss.comp != nil { - stream.SetSendCompress(rc) + ss.sendCompressorName = rc + } + } + + if ss.sendCompressorName != "" { + if err := stream.SetSendCompress(ss.sendCompressorName); err != nil { + return status.Errorf(codes.Internal, "grpc: failed to set send compressor: %v", err) } } @@ -1935,6 +1954,60 @@ func SendHeader(ctx context.Context, md metadata.MD) error { return nil } +// SetSendCompressor sets a compressor for outbound messages from the server. +// It must not be called after any event that causes headers to be sent +// (see ServerStream.SetHeader for the complete list). Provided compressor is +// used when below conditions are met: +// +// - compressor is registered via encoding.RegisterCompressor +// - compressor name must exist in the client advertised compressor names +// sent in grpc-accept-encoding header. Use ClientSupportedCompressors to +// get client supported compressor names. +// +// The context provided must be the context passed to the server's handler. +// It must be noted that compressor name encoding.Identity disables the +// outbound compression. +// By default, server messages will be sent using the same compressor with +// which request messages were sent. +// +// It is not safe to call SetSendCompressor concurrently with SendHeader and +// SendMsg. +// +// # Experimental +// +// Notice: This function is EXPERIMENTAL and may be changed or removed in a +// later release. +func SetSendCompressor(ctx context.Context, name string) error { + stream, ok := ServerTransportStreamFromContext(ctx).(*transport.Stream) + if !ok || stream == nil { + return fmt.Errorf("failed to fetch the stream from the given context") + } + + if err := validateSendCompressor(name, stream.ClientAdvertisedCompressors()); err != nil { + return fmt.Errorf("unable to set send compressor: %w", err) + } + + return stream.SetSendCompress(name) +} + +// ClientSupportedCompressors returns compressor names advertised by the client +// via grpc-accept-encoding header. +// +// The context provided must be the context passed to the server's handler. +// +// # Experimental +// +// Notice: This function is EXPERIMENTAL and may be changed or removed in a +// later release. +func ClientSupportedCompressors(ctx context.Context) ([]string, error) { + stream, ok := ServerTransportStreamFromContext(ctx).(*transport.Stream) + if !ok || stream == nil { + return nil, fmt.Errorf("failed to fetch the stream from the given context %v", ctx) + } + + return strings.Split(stream.ClientAdvertisedCompressors(), ","), nil +} + // SetTrailer sets the trailer metadata that will be sent when an RPC returns. // When called more than once, all the provided metadata will be merged. // @@ -1969,3 +2042,22 @@ type channelzServer struct { func (c *channelzServer) ChannelzMetric() *channelz.ServerInternalMetric { return c.s.channelzMetric() } + +// validateSendCompressor returns an error when given compressor name cannot be +// handled by the server or the client based on the advertised compressors. +func validateSendCompressor(name, clientCompressors string) error { + if name == encoding.Identity { + return nil + } + + if !grpcutil.IsCompressorNameRegistered(name) { + return fmt.Errorf("compressor not registered %q", name) + } + + for _, c := range strings.Split(clientCompressors, ",") { + if c == name { + return nil // found match + } + } + return fmt.Errorf("client does not support compressor %q", name) +} diff --git a/stream.go b/stream.go index 93231af2ac56..89936a4f1665 100644 --- a/stream.go +++ b/stream.go @@ -1511,6 +1511,8 @@ type serverStream struct { comp encoding.Compressor decomp encoding.Compressor + sendCompressorName string + maxReceiveMessageSize int maxSendMessageSize int trInfo *traceInfo @@ -1603,6 +1605,13 @@ func (ss *serverStream) SendMsg(m interface{}) (err error) { } }() + // Server handler could have set new compressor by calling SetSendCompressor. + // In case it is set, we need to use it for compressing outbound message. + if sendCompressorsName := ss.s.SendCompress(); sendCompressorsName != ss.sendCompressorName { + ss.comp = encoding.GetCompressor(sendCompressorsName) + ss.sendCompressorName = sendCompressorsName + } + // load hdr, payload, data hdr, payload, data, err := prepareMsg(m, ss.codec, ss.cp, ss.comp) if err != nil { diff --git a/test/end2end_test.go b/test/end2end_test.go index 0f5cbc345774..d3c339ccb8d9 100644 --- a/test/end2end_test.go +++ b/test/end2end_test.go @@ -59,6 +59,7 @@ import ( "google.golang.org/grpc/internal" "google.golang.org/grpc/internal/binarylog" "google.golang.org/grpc/internal/channelz" + "google.golang.org/grpc/internal/envconfig" "google.golang.org/grpc/internal/grpcsync" "google.golang.org/grpc/internal/grpctest" "google.golang.org/grpc/internal/stubserver" @@ -5080,6 +5081,340 @@ func (s) TestClientForwardsGrpcAcceptEncodingHeader(t *testing.T) { } } +// wrapCompressor is a wrapper of encoding.Compressor which maintains count of +// Compressor method invokes. +type wrapCompressor struct { + encoding.Compressor + compressInvokes int32 +} + +func (wc *wrapCompressor) Compress(w io.Writer) (io.WriteCloser, error) { + atomic.AddInt32(&wc.compressInvokes, 1) + return wc.Compressor.Compress(w) +} + +func setupGzipWrapCompressor(t *testing.T) *wrapCompressor { + oldC := encoding.GetCompressor("gzip") + c := &wrapCompressor{Compressor: oldC} + encoding.RegisterCompressor(c) + t.Cleanup(func() { + encoding.RegisterCompressor(oldC) + }) + return c +} + +func (s) TestSetSendCompressorSuccess(t *testing.T) { + for _, tt := range []struct { + name string + desc string + dialOpts []grpc.DialOption + resCompressor string + wantCompressInvokes int32 + }{ + { + name: "identity_request_and_gzip_response", + desc: "request is uncompressed and response is gzip compressed", + resCompressor: "gzip", + wantCompressInvokes: 1, + }, + { + name: "gzip_request_and_identity_response", + desc: "request is gzip compressed and response is uncompressed with identity", + resCompressor: "identity", + dialOpts: []grpc.DialOption{ + // Use WithCompressor instead of UseCompressor to avoid counting + // the client's compressor usage. + grpc.WithCompressor(grpc.NewGZIPCompressor()), + }, + wantCompressInvokes: 0, + }, + } { + t.Run(tt.name, func(t *testing.T) { + t.Run("unary", func(t *testing.T) { + testUnarySetSendCompressorSuccess(t, tt.resCompressor, tt.wantCompressInvokes, tt.dialOpts) + }) + + t.Run("stream", func(t *testing.T) { + testStreamSetSendCompressorSuccess(t, tt.resCompressor, tt.wantCompressInvokes, tt.dialOpts) + }) + }) + } +} + +func testUnarySetSendCompressorSuccess(t *testing.T, resCompressor string, wantCompressInvokes int32, dialOpts []grpc.DialOption) { + wc := setupGzipWrapCompressor(t) + ss := &stubserver.StubServer{ + EmptyCallF: func(ctx context.Context, in *testpb.Empty) (*testpb.Empty, error) { + if err := grpc.SetSendCompressor(ctx, resCompressor); err != nil { + return nil, err + } + return &testpb.Empty{}, nil + }, + } + if err := ss.Start(nil, dialOpts...); err != nil { + t.Fatalf("Error starting endpoint server: %v", err) + } + defer ss.Stop() + + ctx, cancel := context.WithTimeout(context.Background(), defaultTestTimeout) + defer cancel() + + if _, err := ss.Client.EmptyCall(ctx, &testpb.Empty{}); err != nil { + t.Fatalf("Unexpected unary call error, got: %v, want: nil", err) + } + + compressInvokes := atomic.LoadInt32(&wc.compressInvokes) + if compressInvokes != wantCompressInvokes { + t.Fatalf("Unexpected compress invokes, got:%d, want: %d", compressInvokes, wantCompressInvokes) + } +} + +func testStreamSetSendCompressorSuccess(t *testing.T, resCompressor string, wantCompressInvokes int32, dialOpts []grpc.DialOption) { + wc := setupGzipWrapCompressor(t) + ss := &stubserver.StubServer{ + FullDuplexCallF: func(stream testpb.TestService_FullDuplexCallServer) error { + if _, err := stream.Recv(); err != nil { + return err + } + + if err := grpc.SetSendCompressor(stream.Context(), resCompressor); err != nil { + return err + } + + return stream.Send(&testpb.StreamingOutputCallResponse{}) + }, + } + if err := ss.Start(nil, dialOpts...); err != nil { + t.Fatalf("Error starting endpoint server: %v", err) + } + defer ss.Stop() + + ctx, cancel := context.WithTimeout(context.Background(), defaultTestTimeout) + defer cancel() + + s, err := ss.Client.FullDuplexCall(ctx) + if err != nil { + t.Fatalf("Unexpected full duplex call error, got: %v, want: nil", err) + } + + if err := s.Send(&testpb.StreamingOutputCallRequest{}); err != nil { + t.Fatalf("Unexpected full duplex call send error, got: %v, want: nil", err) + } + + if _, err := s.Recv(); err != nil { + t.Fatalf("Unexpected full duplex recv error, got: %v, want: nil", err) + } + + compressInvokes := atomic.LoadInt32(&wc.compressInvokes) + if compressInvokes != wantCompressInvokes { + t.Fatalf("Unexpected compress invokes, got:%d, want: %d", compressInvokes, wantCompressInvokes) + } +} + +func (s) TestUnregisteredSetSendCompressorFailure(t *testing.T) { + resCompressor := "snappy2" + wantErr := status.Error(codes.Unknown, "unable to set send compressor: compressor not registered \"snappy2\"") + + t.Run("unary", func(t *testing.T) { + testUnarySetSendCompressorFailure(t, resCompressor, wantErr) + }) + + t.Run("stream", func(t *testing.T) { + testStreamSetSendCompressorFailure(t, resCompressor, wantErr) + }) +} + +func (s) TestUnadvertisedSetSendCompressorFailure(t *testing.T) { + // Disable client compressor advertisement. + defer func(b bool) { envconfig.AdvertiseCompressors = b }(envconfig.AdvertiseCompressors) + envconfig.AdvertiseCompressors = false + + resCompressor := "gzip" + wantErr := status.Error(codes.Unknown, "unable to set send compressor: client does not support compressor \"gzip\"") + + t.Run("unary", func(t *testing.T) { + testUnarySetSendCompressorFailure(t, resCompressor, wantErr) + }) + + t.Run("stream", func(t *testing.T) { + testStreamSetSendCompressorFailure(t, resCompressor, wantErr) + }) +} + +func testUnarySetSendCompressorFailure(t *testing.T, resCompressor string, wantErr error) { + ss := &stubserver.StubServer{ + EmptyCallF: func(ctx context.Context, in *testpb.Empty) (*testpb.Empty, error) { + if err := grpc.SetSendCompressor(ctx, resCompressor); err != nil { + return nil, err + } + return &testpb.Empty{}, nil + }, + } + if err := ss.Start(nil); err != nil { + t.Fatalf("Error starting endpoint server: %v", err) + } + defer ss.Stop() + + ctx, cancel := context.WithTimeout(context.Background(), defaultTestTimeout) + defer cancel() + + if _, err := ss.Client.EmptyCall(ctx, &testpb.Empty{}); !equalError(err, wantErr) { + t.Fatalf("Unexpected unary call error, got: %v, want: %v", err, wantErr) + } +} + +func testStreamSetSendCompressorFailure(t *testing.T, resCompressor string, wantErr error) { + ss := &stubserver.StubServer{ + FullDuplexCallF: func(stream testpb.TestService_FullDuplexCallServer) error { + if _, err := stream.Recv(); err != nil { + return err + } + + if err := grpc.SetSendCompressor(stream.Context(), resCompressor); err != nil { + return err + } + + return stream.Send(&testpb.StreamingOutputCallResponse{}) + }, + } + if err := ss.Start(nil); err != nil { + t.Fatalf("Error starting endpoint server: %v, want: nil", err) + } + defer ss.Stop() + + ctx, cancel := context.WithTimeout(context.Background(), defaultTestTimeout) + defer cancel() + + s, err := ss.Client.FullDuplexCall(ctx) + if err != nil { + t.Fatalf("Unexpected full duplex call error, got: %v, want: nil", err) + } + + if err := s.Send(&testpb.StreamingOutputCallRequest{}); err != nil { + t.Fatalf("Unexpected full duplex call send error, got: %v, want: nil", err) + } + + if _, err := s.Recv(); !equalError(err, wantErr) { + t.Fatalf("Unexpected full duplex recv error, got: %v, want: nil", err) + } +} + +func (s) TestUnarySetSendCompressorAfterHeaderSendFailure(t *testing.T) { + ss := &stubserver.StubServer{ + EmptyCallF: func(ctx context.Context, in *testpb.Empty) (*testpb.Empty, error) { + // Send headers early and then set send compressor. + grpc.SendHeader(ctx, metadata.MD{}) + err := grpc.SetSendCompressor(ctx, "gzip") + if err == nil { + t.Error("Wanted set send compressor error") + return &testpb.Empty{}, nil + } + return nil, err + }, + } + if err := ss.Start(nil); err != nil { + t.Fatalf("Error starting endpoint server: %v", err) + } + defer ss.Stop() + + ctx, cancel := context.WithTimeout(context.Background(), defaultTestTimeout) + defer cancel() + + wantErr := status.Error(codes.Unknown, "transport: set send compressor called after headers sent or stream done") + if _, err := ss.Client.EmptyCall(ctx, &testpb.Empty{}); !equalError(err, wantErr) { + t.Fatalf("Unexpected unary call error, got: %v, want: %v", err, wantErr) + } +} + +func (s) TestStreamSetSendCompressorAfterHeaderSendFailure(t *testing.T) { + ss := &stubserver.StubServer{ + FullDuplexCallF: func(stream testpb.TestService_FullDuplexCallServer) error { + // Send headers early and then set send compressor. + grpc.SendHeader(stream.Context(), metadata.MD{}) + err := grpc.SetSendCompressor(stream.Context(), "gzip") + if err == nil { + t.Error("Wanted set send compressor error") + } + return err + }, + } + if err := ss.Start(nil); err != nil { + t.Fatalf("Error starting endpoint server: %v", err) + } + defer ss.Stop() + + ctx, cancel := context.WithTimeout(context.Background(), defaultTestTimeout) + defer cancel() + + wantErr := status.Error(codes.Unknown, "transport: set send compressor called after headers sent or stream done") + s, err := ss.Client.FullDuplexCall(ctx) + if err != nil { + t.Fatalf("Unexpected full duplex call error, got: %v, want: nil", err) + } + + if _, err := s.Recv(); !equalError(err, wantErr) { + t.Fatalf("Unexpected full duplex recv error, got: %v, want: %v", err, wantErr) + } +} + +func (s) TestClientSupportedCompressors(t *testing.T) { + for _, tt := range []struct { + desc string + ctx context.Context + want []string + }{ + { + desc: "No additional grpc-accept-encoding header", + ctx: context.Background(), + want: []string{"gzip"}, + }, + { + desc: "With additional grpc-accept-encoding header", + ctx: metadata.AppendToOutgoingContext(context.Background(), + "grpc-accept-encoding", "test-compressor-1", + "grpc-accept-encoding", "test-compressor-2", + ), + want: []string{"gzip", "test-compressor-1", "test-compressor-2"}, + }, + { + desc: "With additional empty grpc-accept-encoding header", + ctx: metadata.AppendToOutgoingContext(context.Background(), + "grpc-accept-encoding", "", + ), + want: []string{"gzip"}, + }, + } { + t.Run(tt.desc, func(t *testing.T) { + ss := &stubserver.StubServer{ + EmptyCallF: func(ctx context.Context, in *testpb.Empty) (*testpb.Empty, error) { + got, err := grpc.ClientSupportedCompressors(ctx) + if err != nil { + return nil, err + } + + if !reflect.DeepEqual(got, tt.want) { + t.Errorf("unexpected client compressors got: %v, want: %v", got, tt.want) + } + + return &testpb.Empty{}, nil + }, + } + if err := ss.Start(nil); err != nil { + t.Fatalf("Error starting endpoint server: %v, want: nil", err) + } + defer ss.Stop() + + ctx, cancel := context.WithTimeout(tt.ctx, defaultTestTimeout) + defer cancel() + + _, err := ss.Client.EmptyCall(ctx, &testpb.Empty{}) + if err != nil { + t.Fatalf("Unexpected unary call error, got: %v, want: nil", err) + } + }) + } +} + func (s) TestUnaryProxyDoesNotForwardMetadata(t *testing.T) { const mdkey = "somedata"