diff --git a/benchmark/benchmain/main.go b/benchmark/benchmain/main.go index 78ca59363841..1366c18c972b 100644 --- a/benchmark/benchmain/main.go +++ b/benchmark/benchmain/main.go @@ -112,6 +112,7 @@ var ( serverWriteBufferSize = flags.IntSlice("serverWriteBufferSize", []int{-1}, "Configures the server write buffer size in bytes. If negative, use the default - may be a a comma-separated list") sleepBetweenRPCs = flags.DurationSlice("sleepBetweenRPCs", []time.Duration{0}, "Configures the maximum amount of time the client should sleep between consecutive RPCs - may be a a comma-separated list") connections = flag.Int("connections", 1, "The number of connections. Each connection will handle maxConcurrentCalls RPC streams") + recvBufferPool = flags.StringWithAllowedValues("recvBufferPool", recvBufferPoolNil, "Configures the shared receive buffer pool. One of: nil, simple, all", allRecvBufferPools) logger = grpclog.Component("benchmark") ) @@ -136,6 +137,10 @@ const ( networkModeLAN = "LAN" networkModeWAN = "WAN" networkLongHaul = "Longhaul" + // Shared recv buffer pool + recvBufferPoolNil = "nil" + recvBufferPoolSimple = "simple" + recvBufferPoolAll = "all" numStatsBuckets = 10 warmupCallCount = 10 @@ -147,6 +152,7 @@ var ( allCompModes = []string{compModeOff, compModeGzip, compModeNop, compModeAll} allToggleModes = []string{toggleModeOff, toggleModeOn, toggleModeBoth} allNetworkModes = []string{networkModeNone, networkModeLocal, networkModeLAN, networkModeWAN, networkLongHaul} + allRecvBufferPools = []string{recvBufferPoolNil, recvBufferPoolSimple, recvBufferPoolAll} defaultReadLatency = []time.Duration{0, 40 * time.Millisecond} // if non-positive, no delay. defaultReadKbps = []int{0, 10240} // if non-positive, infinite defaultReadMTU = []int{0} // if non-positive, infinite @@ -330,6 +336,15 @@ func makeClients(bf stats.Features) ([]testgrpc.BenchmarkServiceClient, func()) if bf.ServerWriteBufferSize >= 0 { sopts = append(sopts, grpc.WriteBufferSize(bf.ServerWriteBufferSize)) } + switch bf.RecvBufferPool { + case recvBufferPoolNil: + // Do nothing. + case recvBufferPoolSimple: + opts = append(opts, grpc.WithRecvBufferPool(grpc.NewSharedBufferPool())) + sopts = append(sopts, grpc.RecvBufferPool(grpc.NewSharedBufferPool())) + default: + logger.Fatalf("Unknown shared recv buffer pool type: %v", bf.RecvBufferPool) + } sopts = append(sopts, grpc.MaxConcurrentStreams(uint32(bf.MaxConcurrentCalls+1))) opts = append(opts, grpc.WithTransportCredentials(insecure.NewCredentials())) @@ -573,6 +588,7 @@ type featureOpts struct { serverReadBufferSize []int serverWriteBufferSize []int sleepBetweenRPCs []time.Duration + recvBufferPools []string } // makeFeaturesNum returns a slice of ints of size 'maxFeatureIndex' where each @@ -619,6 +635,8 @@ func makeFeaturesNum(b *benchOpts) []int { featuresNum[i] = len(b.features.serverWriteBufferSize) case stats.SleepBetweenRPCs: featuresNum[i] = len(b.features.sleepBetweenRPCs) + case stats.RecvBufferPool: + featuresNum[i] = len(b.features.recvBufferPools) default: log.Fatalf("Unknown feature index %v in generateFeatures. maxFeatureIndex is %v", i, stats.MaxFeatureIndex) } @@ -687,6 +705,7 @@ func (b *benchOpts) generateFeatures(featuresNum []int) []stats.Features { ServerReadBufferSize: b.features.serverReadBufferSize[curPos[stats.ServerReadBufferSize]], ServerWriteBufferSize: b.features.serverWriteBufferSize[curPos[stats.ServerWriteBufferSize]], SleepBetweenRPCs: b.features.sleepBetweenRPCs[curPos[stats.SleepBetweenRPCs]], + RecvBufferPool: b.features.recvBufferPools[curPos[stats.RecvBufferPool]], } if len(b.features.reqPayloadCurves) == 0 { f.ReqSizeBytes = b.features.reqSizeBytes[curPos[stats.ReqSizeBytesIndex]] @@ -759,6 +778,7 @@ func processFlags() *benchOpts { serverReadBufferSize: append([]int(nil), *serverReadBufferSize...), serverWriteBufferSize: append([]int(nil), *serverWriteBufferSize...), sleepBetweenRPCs: append([]time.Duration(nil), *sleepBetweenRPCs...), + recvBufferPools: setRecvBufferPool(*recvBufferPool), }, } @@ -834,6 +854,19 @@ func setCompressorMode(val string) []string { } } +func setRecvBufferPool(val string) []string { + switch val { + case recvBufferPoolNil, recvBufferPoolSimple: + return []string{val} + case recvBufferPoolAll: + return []string{recvBufferPoolNil, recvBufferPoolSimple} + default: + // This should never happen because a wrong value passed to this flag would + // be caught during flag.Parse(). + return []string{} + } +} + func main() { opts := processFlags() before(opts) diff --git a/benchmark/stats/stats.go b/benchmark/stats/stats.go index 74070fd76c07..3989e25dbf4b 100644 --- a/benchmark/stats/stats.go +++ b/benchmark/stats/stats.go @@ -57,6 +57,7 @@ const ( ServerReadBufferSize ServerWriteBufferSize SleepBetweenRPCs + RecvBufferPool // MaxFeatureIndex is a place holder to indicate the total number of feature // indices we have. Any new feature indices should be added above this. @@ -126,6 +127,8 @@ type Features struct { ServerWriteBufferSize int // SleepBetweenRPCs configures optional delay between RPCs. SleepBetweenRPCs time.Duration + // RecvBufferPool represents the shared recv buffer pool used. + RecvBufferPool string } // String returns all the feature values as a string. @@ -145,12 +148,13 @@ func (f Features) String() string { "trace_%v-latency_%v-kbps_%v-MTU_%v-maxConcurrentCalls_%v-%s-%s-"+ "compressor_%v-channelz_%v-preloader_%v-clientReadBufferSize_%v-"+ "clientWriteBufferSize_%v-serverReadBufferSize_%v-serverWriteBufferSize_%v-"+ - "sleepBetweenRPCs_%v-connections_%v-", + "sleepBetweenRPCs_%v-connections_%v-recvBufferPool_%v-", f.NetworkMode, f.UseBufConn, f.EnableKeepalive, f.BenchTime, f.EnableTrace, f.Latency, f.Kbps, f.MTU, f.MaxConcurrentCalls, reqPayloadString, respPayloadString, f.ModeCompressor, f.EnableChannelz, f.EnablePreloader, f.ClientReadBufferSize, f.ClientWriteBufferSize, f.ServerReadBufferSize, - f.ServerWriteBufferSize, f.SleepBetweenRPCs, f.Connections) + f.ServerWriteBufferSize, f.SleepBetweenRPCs, f.Connections, + f.RecvBufferPool) } // SharedFeatures returns the shared features as a pretty printable string. @@ -224,6 +228,8 @@ func (f Features) partialString(b *bytes.Buffer, wantFeatures []bool, sep, delim b.WriteString(fmt.Sprintf("ServerWriteBufferSize%v%v%v", sep, f.ServerWriteBufferSize, delim)) case SleepBetweenRPCs: b.WriteString(fmt.Sprintf("SleepBetweenRPCs%v%v%v", sep, f.SleepBetweenRPCs, delim)) + case RecvBufferPool: + b.WriteString(fmt.Sprintf("RecvBufferPool%v%v%v", sep, f.RecvBufferPool, delim)) default: log.Fatalf("Unknown feature index %v. maxFeatureIndex is %v", i, MaxFeatureIndex) } diff --git a/dialoptions.go b/dialoptions.go index 15a3d5102a9a..23ea95237ea0 100644 --- a/dialoptions.go +++ b/dialoptions.go @@ -78,6 +78,7 @@ type dialOptions struct { defaultServiceConfigRawJSON *string resolvers []resolver.Builder idleTimeout time.Duration + recvBufferPool SharedBufferPool } // DialOption configures how we set up the connection. @@ -628,6 +629,7 @@ func defaultDialOptions() dialOptions { ReadBufferSize: defaultReadBufSize, UseProxy: true, }, + recvBufferPool: nopBufferPool{}, } } @@ -676,3 +678,24 @@ func WithIdleTimeout(d time.Duration) DialOption { o.idleTimeout = d }) } + +// WithRecvBufferPool returns a DialOption that configures the ClientConn +// to use the provided shared buffer pool for parsing incoming messages. Depending +// on the application's workload, this could result in reduced memory allocation. +// +// If you are unsure about how to implement a memory pool but want to utilize one, +// begin with grpc.NewSharedBufferPool. +// +// Note: The shared buffer pool feature will not be active if any of the following +// options are used: WithStatsHandler, EnableTracing, or binary logging. In such +// cases, the shared buffer pool will be ignored. +// +// # Experimental +// +// Notice: This API is EXPERIMENTAL and may be changed or removed in a +// later release. +func WithRecvBufferPool(bufferPool SharedBufferPool) DialOption { + return newFuncDialOption(func(o *dialOptions) { + o.recvBufferPool = bufferPool + }) +} diff --git a/rpc_util.go b/rpc_util.go index 2030736a306b..a844d28f49d0 100644 --- a/rpc_util.go +++ b/rpc_util.go @@ -577,6 +577,9 @@ type parser struct { // The header of a gRPC message. Find more detail at // https://github.com/grpc/grpc/blob/master/doc/PROTOCOL-HTTP2.md header [5]byte + + // recvBufferPool is the pool of shared receive buffers. + recvBufferPool SharedBufferPool } // recvMsg reads a complete gRPC message from the stream. @@ -610,9 +613,7 @@ func (p *parser) recvMsg(maxReceiveMessageSize int) (pf payloadFormat, msg []byt if int(length) > maxReceiveMessageSize { return 0, nil, status.Errorf(codes.ResourceExhausted, "grpc: received message larger than max (%d vs. %d)", length, maxReceiveMessageSize) } - // TODO(bradfitz,zhaoq): garbage. reuse buffer after proto decoding instead - // of making it for each message: - msg = make([]byte, int(length)) + msg = p.recvBufferPool.Get(int(length)) if _, err := p.r.Read(msg); err != nil { if err == io.EOF { err = io.ErrUnexpectedEOF @@ -726,12 +727,12 @@ type payloadInfo struct { } func recvAndDecompress(p *parser, s *transport.Stream, dc Decompressor, maxReceiveMessageSize int, payInfo *payloadInfo, compressor encoding.Compressor) ([]byte, error) { - pf, d, err := p.recvMsg(maxReceiveMessageSize) + pf, buf, err := p.recvMsg(maxReceiveMessageSize) if err != nil { return nil, err } if payInfo != nil { - payInfo.compressedLength = len(d) + payInfo.compressedLength = len(buf) } if st := checkRecvPayload(pf, s.RecvCompress(), compressor != nil || dc != nil); st != nil { @@ -743,10 +744,10 @@ func recvAndDecompress(p *parser, s *transport.Stream, dc Decompressor, maxRecei // To match legacy behavior, if the decompressor is set by WithDecompressor or RPCDecompressor, // use this decompressor as the default. if dc != nil { - d, err = dc.Do(bytes.NewReader(d)) - size = len(d) + buf, err = dc.Do(bytes.NewReader(buf)) + size = len(buf) } else { - d, size, err = decompress(compressor, d, maxReceiveMessageSize) + buf, size, err = decompress(compressor, buf, maxReceiveMessageSize) } if err != nil { return nil, status.Errorf(codes.Internal, "grpc: failed to decompress the received message: %v", err) @@ -757,7 +758,7 @@ func recvAndDecompress(p *parser, s *transport.Stream, dc Decompressor, maxRecei return nil, status.Errorf(codes.ResourceExhausted, "grpc: received message after decompression larger than max (%d vs. %d)", size, maxReceiveMessageSize) } } - return d, nil + return buf, nil } // Using compressor, decompress d, returning data and size. @@ -792,15 +793,17 @@ func decompress(compressor encoding.Compressor, d []byte, maxReceiveMessageSize // dc takes precedence over compressor. // TODO(dfawley): wrap the old compressor/decompressor using the new API? func recv(p *parser, c baseCodec, s *transport.Stream, dc Decompressor, m interface{}, maxReceiveMessageSize int, payInfo *payloadInfo, compressor encoding.Compressor) error { - d, err := recvAndDecompress(p, s, dc, maxReceiveMessageSize, payInfo, compressor) + buf, err := recvAndDecompress(p, s, dc, maxReceiveMessageSize, payInfo, compressor) if err != nil { return err } - if err := c.Unmarshal(d, m); err != nil { + if err := c.Unmarshal(buf, m); err != nil { return status.Errorf(codes.Internal, "grpc: failed to unmarshal the received message: %v", err) } if payInfo != nil { - payInfo.uncompressedBytes = d + payInfo.uncompressedBytes = buf + } else { + p.recvBufferPool.Put(&buf) } return nil } diff --git a/rpc_util_test.go b/rpc_util_test.go index 90912d52a226..84f2348655b9 100644 --- a/rpc_util_test.go +++ b/rpc_util_test.go @@ -65,7 +65,7 @@ func (s) TestSimpleParsing(t *testing.T) { {append([]byte{0, 1, 0, 0, 0}, bigMsg...), nil, bigMsg, compressionNone}, } { buf := fullReader{bytes.NewReader(test.p)} - parser := &parser{r: buf} + parser := &parser{r: buf, recvBufferPool: nopBufferPool{}} pt, b, err := parser.recvMsg(math.MaxInt32) if err != test.err || !bytes.Equal(b, test.b) || pt != test.pt { t.Fatalf("parser{%v}.recvMsg(_) = %v, %v, %v\nwant %v, %v, %v", test.p, pt, b, err, test.pt, test.b, test.err) @@ -77,7 +77,7 @@ func (s) TestMultipleParsing(t *testing.T) { // Set a byte stream consists of 3 messages with their headers. p := []byte{0, 0, 0, 0, 1, 'a', 0, 0, 0, 0, 2, 'b', 'c', 0, 0, 0, 0, 1, 'd'} b := fullReader{bytes.NewReader(p)} - parser := &parser{r: b} + parser := &parser{r: b, recvBufferPool: nopBufferPool{}} wantRecvs := []struct { pt payloadFormat diff --git a/server.go b/server.go index 81969e7c15a9..e076ec7143bb 100644 --- a/server.go +++ b/server.go @@ -174,6 +174,7 @@ type serverOptions struct { maxHeaderListSize *uint32 headerTableSize *uint32 numServerWorkers uint32 + recvBufferPool SharedBufferPool } var defaultServerOptions = serverOptions{ @@ -182,6 +183,7 @@ var defaultServerOptions = serverOptions{ connectionTimeout: 120 * time.Second, writeBufferSize: defaultWriteBufSize, readBufferSize: defaultReadBufSize, + recvBufferPool: nopBufferPool{}, } var globalServerOptions []ServerOption @@ -552,6 +554,27 @@ func NumStreamWorkers(numServerWorkers uint32) ServerOption { }) } +// RecvBufferPool returns a ServerOption that configures the server +// to use the provided shared buffer pool for parsing incoming messages. Depending +// on the application's workload, this could result in reduced memory allocation. +// +// If you are unsure about how to implement a memory pool but want to utilize one, +// begin with grpc.NewSharedBufferPool. +// +// Note: The shared buffer pool feature will not be active if any of the following +// options are used: StatsHandler, EnableTracing, or binary logging. In such +// cases, the shared buffer pool will be ignored. +// +// # Experimental +// +// Notice: This API is EXPERIMENTAL and may be changed or removed in a +// later release. +func RecvBufferPool(bufferPool SharedBufferPool) ServerOption { + return newFuncServerOption(func(o *serverOptions) { + o.recvBufferPool = bufferPool + }) +} + // serverWorkerResetThreshold defines how often the stack must be reset. Every // N requests, by spawning a new goroutine in its place, a worker can reset its // stack so that large stacks don't live in memory forever. 2^16 should allow @@ -1296,7 +1319,7 @@ func (s *Server) processUnaryRPC(t transport.ServerTransport, stream *transport. if len(shs) != 0 || len(binlogs) != 0 { payInfo = &payloadInfo{} } - d, err := recvAndDecompress(&parser{r: stream}, stream, dc, s.opts.maxReceiveMessageSize, payInfo, decomp) + d, err := recvAndDecompress(&parser{r: stream, recvBufferPool: s.opts.recvBufferPool}, stream, dc, s.opts.maxReceiveMessageSize, payInfo, decomp) if err != nil { if e := t.WriteStatus(stream, status.Convert(err)); e != nil { channelz.Warningf(logger, s.channelzID, "grpc: Server.processUnaryRPC failed to write status: %v", e) @@ -1506,7 +1529,7 @@ func (s *Server) processStreamingRPC(t transport.ServerTransport, stream *transp ctx: ctx, t: t, s: stream, - p: &parser{r: stream}, + p: &parser{r: stream, recvBufferPool: s.opts.recvBufferPool}, codec: s.getCodec(stream.ContentSubtype()), maxReceiveMessageSize: s.opts.maxReceiveMessageSize, maxSendMessageSize: s.opts.maxSendMessageSize, diff --git a/shared_buffer_pool.go b/shared_buffer_pool.go new file mode 100644 index 000000000000..c3a5a9ac1f19 --- /dev/null +++ b/shared_buffer_pool.go @@ -0,0 +1,154 @@ +/* + * + * Copyright 2023 gRPC authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * + */ + +package grpc + +import "sync" + +// SharedBufferPool is a pool of buffers that can be shared, resulting in +// decreased memory allocation. Currently, in gRPC-go, it is only utilized +// for parsing incoming messages. +// +// # Experimental +// +// Notice: This API is EXPERIMENTAL and may be changed or removed in a +// later release. +type SharedBufferPool interface { + // Get returns a buffer with specified length from the pool. + // + // The returned byte slice may be not zero initialized. + Get(length int) []byte + + // Put returns a buffer to the pool. + Put(*[]byte) +} + +// NewSharedBufferPool creates a simple SharedBufferPool with buckets +// of different sizes to optimize memory usage. This prevents the pool from +// wasting large amounts of memory, even when handling messages of varying sizes. +// +// # Experimental +// +// Notice: This API is EXPERIMENTAL and may be changed or removed in a +// later release. +func NewSharedBufferPool() SharedBufferPool { + return &simpleSharedBufferPool{ + pools: [poolArraySize]simpleSharedBufferChildPool{ + newBytesPool(level0PoolMaxSize), + newBytesPool(level1PoolMaxSize), + newBytesPool(level2PoolMaxSize), + newBytesPool(level3PoolMaxSize), + newBytesPool(level4PoolMaxSize), + newBytesPool(0), + }, + } +} + +// simpleSharedBufferPool is a simple implementation of SharedBufferPool. +type simpleSharedBufferPool struct { + pools [poolArraySize]simpleSharedBufferChildPool +} + +func (p *simpleSharedBufferPool) Get(size int) []byte { + return p.pools[p.poolIdx(size)].Get(size) +} + +func (p *simpleSharedBufferPool) Put(bs *[]byte) { + p.pools[p.poolIdx(cap(*bs))].Put(bs) +} + +func (p *simpleSharedBufferPool) poolIdx(size int) int { + switch { + case size <= level0PoolMaxSize: + return level0PoolIdx + case size <= level1PoolMaxSize: + return level1PoolIdx + case size <= level2PoolMaxSize: + return level2PoolIdx + case size <= level3PoolMaxSize: + return level3PoolIdx + case size <= level4PoolMaxSize: + return level4PoolIdx + default: + return levelMaxPoolIdx + } +} + +const ( + level0PoolMaxSize = 16 // 16 B + level1PoolMaxSize = level0PoolMaxSize * 16 // 256 B + level2PoolMaxSize = level1PoolMaxSize * 16 // 4 KB + level3PoolMaxSize = level2PoolMaxSize * 16 // 64 KB + level4PoolMaxSize = level3PoolMaxSize * 16 // 1 MB +) + +const ( + level0PoolIdx = iota + level1PoolIdx + level2PoolIdx + level3PoolIdx + level4PoolIdx + levelMaxPoolIdx + poolArraySize +) + +type simpleSharedBufferChildPool interface { + Get(size int) []byte + Put(interface{}) +} + +type bufferPool struct { + sync.Pool + + defaultSize int +} + +func (p *bufferPool) Get(size int) []byte { + bs := p.Pool.Get().(*[]byte) + + if cap(*bs) < size { + p.Pool.Put(bs) + + return make([]byte, size) + } + + return (*bs)[:size] +} + +func newBytesPool(size int) simpleSharedBufferChildPool { + return &bufferPool{ + Pool: sync.Pool{ + New: func() interface{} { + bs := make([]byte, size) + return &bs + }, + }, + defaultSize: size, + } +} + +// nopBufferPool is a buffer pool just makes new buffer without pooling. +type nopBufferPool struct { +} + +func (nopBufferPool) Get(length int) []byte { + return make([]byte, length) +} + +func (nopBufferPool) Put(*[]byte) { +} diff --git a/shared_buffer_pool_test.go b/shared_buffer_pool_test.go new file mode 100644 index 000000000000..f5ed7c8314f1 --- /dev/null +++ b/shared_buffer_pool_test.go @@ -0,0 +1,48 @@ +/* + * + * Copyright 2023 gRPC authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * + */ + +package grpc + +import "testing" + +func (s) TestSharedBufferPool(t *testing.T) { + pools := []SharedBufferPool{ + nopBufferPool{}, + NewSharedBufferPool(), + } + + lengths := []int{ + level4PoolMaxSize + 1, + level4PoolMaxSize, + level3PoolMaxSize, + level2PoolMaxSize, + level1PoolMaxSize, + level0PoolMaxSize, + } + + for _, p := range pools { + for _, l := range lengths { + bs := p.Get(l) + if len(bs) != l { + t.Fatalf("Expected buffer of length %d, got %d", l, len(bs)) + } + + p.Put(&bs) + } + } +} diff --git a/stream.go b/stream.go index 10092685b228..de32a7597145 100644 --- a/stream.go +++ b/stream.go @@ -507,7 +507,7 @@ func (a *csAttempt) newStream() error { return toRPCErr(nse.Err) } a.s = s - a.p = &parser{r: s} + a.p = &parser{r: s, recvBufferPool: a.cs.cc.dopts.recvBufferPool} return nil } @@ -1270,7 +1270,7 @@ func newNonRetryClientStream(ctx context.Context, desc *StreamDesc, method strin return nil, err } as.s = s - as.p = &parser{r: s} + as.p = &parser{r: s, recvBufferPool: ac.dopts.recvBufferPool} ac.incrCallsStarted() if desc != unaryStreamDesc { // Listen on stream context to cleanup when the stream context is diff --git a/test/recv_buffer_pool_test.go b/test/recv_buffer_pool_test.go new file mode 100644 index 000000000000..8bb6db4a77af --- /dev/null +++ b/test/recv_buffer_pool_test.go @@ -0,0 +1,90 @@ +/* + * + * Copyright 2023 gRPC authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * + */ + +package test + +import ( + "bytes" + "context" + "io" + "testing" + "time" + + "google.golang.org/grpc" + "google.golang.org/grpc/internal/stubserver" + testgrpc "google.golang.org/grpc/interop/grpc_testing" + testpb "google.golang.org/grpc/interop/grpc_testing" +) + +func (s) TestRecvBufferPool(t *testing.T) { + ss := &stubserver.StubServer{ + FullDuplexCallF: func(stream testgrpc.TestService_FullDuplexCallServer) error { + for i := 0; i < 10; i++ { + preparedMsg := &grpc.PreparedMsg{} + err := preparedMsg.Encode(stream, &testpb.StreamingOutputCallResponse{ + Payload: &testpb.Payload{ + Body: []byte{'0' + uint8(i)}, + }, + }) + if err != nil { + return err + } + stream.SendMsg(preparedMsg) + } + return nil + }, + } + if err := ss.Start( + []grpc.ServerOption{grpc.RecvBufferPool(grpc.NewSharedBufferPool())}, + grpc.WithRecvBufferPool(grpc.NewSharedBufferPool()), + ); err != nil { + t.Fatalf("Error starting endpoint server: %v", err) + } + defer ss.Stop() + + ctx, cancel := context.WithTimeout(context.Background(), 2*time.Second) + defer cancel() + + stream, err := ss.Client.FullDuplexCall(ctx) + if err != nil { + t.Fatalf("ss.Client.FullDuplexCall failed: %f", err) + } + + var ngot int + var buf bytes.Buffer + for { + reply, err := stream.Recv() + if err == io.EOF { + break + } + if err != nil { + t.Fatal(err) + } + ngot++ + if buf.Len() > 0 { + buf.WriteByte(',') + } + buf.Write(reply.GetPayload().GetBody()) + } + if want := 10; ngot != want { + t.Errorf("Got %d replies, want %d", ngot, want) + } + if got, want := buf.String(), "0,1,2,3,4,5,6,7,8,9"; got != want { + t.Errorf("Got replies %q; want %q", got, want) + } +}