Skip to content

Commit

Permalink
rpc: make batch stream pool general over Conn
Browse files Browse the repository at this point in the history
This will help prototype drpc stream pooling.
  • Loading branch information
tbg authored and nvanbenschoten committed Dec 10, 2024
1 parent a8c78e5 commit 7f4ffeb
Show file tree
Hide file tree
Showing 2 changed files with 36 additions and 38 deletions.
70 changes: 34 additions & 36 deletions pkg/rpc/stream_pool.go
Original file line number Diff line number Diff line change
Expand Up @@ -24,13 +24,12 @@ import (
type streamClient[Req, Resp any] interface {
Send(Req) error
Recv() (Resp, error)
grpc.ClientStream
}

// streamConstructor creates a new gRPC stream client over the provided client
// connection, using the provided call options.
type streamConstructor[Req, Resp any] func(
context.Context, *grpc.ClientConn, ...grpc.CallOption,
type streamConstructor[Req, Resp, Conn any] func(
context.Context, Conn,
) (streamClient[Req, Resp], error)

type result[Resp any] struct {
Expand Down Expand Up @@ -67,8 +66,8 @@ const defaultPooledStreamIdleTimeout = 10 * time.Second
//
// A pooledStream must only be returned to the pool for reuse after a successful
// Send call. If the Send call fails, the pooledStream must not be reused.
type pooledStream[Req, Resp any] struct {
pool *streamPool[Req, Resp]
type pooledStream[Req, Resp any, Conn comparable] struct {
pool *streamPool[Req, Resp, Conn]
stream streamClient[Req, Resp]
streamCtx context.Context
streamCancel context.CancelFunc
Expand All @@ -77,13 +76,13 @@ type pooledStream[Req, Resp any] struct {
respC chan result[Resp]
}

func newPooledStream[Req, Resp any](
pool *streamPool[Req, Resp],
func newPooledStream[Req, Resp any, Conn comparable](
pool *streamPool[Req, Resp, Conn],
stream streamClient[Req, Resp],
streamCtx context.Context,
streamCancel context.CancelFunc,
) *pooledStream[Req, Resp] {
return &pooledStream[Req, Resp]{
) *pooledStream[Req, Resp, Conn] {
return &pooledStream[Req, Resp, Conn]{
pool: pool,
stream: stream,
streamCtx: streamCtx,
Expand All @@ -93,13 +92,13 @@ func newPooledStream[Req, Resp any](
}
}

func (s *pooledStream[Req, Resp]) run(ctx context.Context) {
func (s *pooledStream[Req, Resp, Conn]) run(ctx context.Context) {
defer s.close()
for s.runOnce(ctx) {
}
}

func (s *pooledStream[Req, Resp]) runOnce(ctx context.Context) (loop bool) {
func (s *pooledStream[Req, Resp, Conn]) runOnce(ctx context.Context) (loop bool) {
select {
case req := <-s.reqC:
err := s.stream.Send(req)
Expand Down Expand Up @@ -137,7 +136,7 @@ func (s *pooledStream[Req, Resp]) runOnce(ctx context.Context) (loop bool) {
}
}

func (s *pooledStream[Req, Resp]) close() {
func (s *pooledStream[Req, Resp, Conn]) close() {
// Make sure the stream's context is canceled to ensure that we clean up
// resources in idle timeout case.
//
Expand All @@ -156,7 +155,7 @@ func (s *pooledStream[Req, Resp]) close() {

// Send sends a request on the pooled stream and returns the response in a unary
// RPC fashion. Context cancellation is respected.
func (s *pooledStream[Req, Resp]) Send(ctx context.Context, req Req) (Resp, error) {
func (s *pooledStream[Req, Resp, Conn]) Send(ctx context.Context, req Req) (Resp, error) {
var resp result[Resp]
select {
case s.reqC <- req:
Expand Down Expand Up @@ -190,26 +189,26 @@ func (s *pooledStream[Req, Resp]) Send(ctx context.Context, req Req) (Resp, erro
// manner that mimics unary RPC invocation. Pooling these streams allows for
// reuse of gRPC resources across calls, as opposed to native unary RPCs, which
// create a new stream and throw it away for each request (see grpc.invoke).
type streamPool[Req, Resp any] struct {
type streamPool[Req, Resp any, Conn comparable] struct {
stopper *stop.Stopper
idleTimeout time.Duration
newStream streamConstructor[Req, Resp]
newStream streamConstructor[Req, Resp, Conn]

// cc and ccCtx are set on bind, when the gRPC connection is established.
cc *grpc.ClientConn
cc Conn
// Derived from rpc.Context.MasterCtx, canceled on stopper quiesce.
ccCtx context.Context

streams struct {
syncutil.Mutex
s []*pooledStream[Req, Resp]
s []*pooledStream[Req, Resp, Conn]
}
}

func makeStreamPool[Req, Resp any](
stopper *stop.Stopper, newStream streamConstructor[Req, Resp],
) streamPool[Req, Resp] {
return streamPool[Req, Resp]{
func makeStreamPool[Req, Resp any, Conn comparable](
stopper *stop.Stopper, newStream streamConstructor[Req, Resp, Conn],
) streamPool[Req, Resp, Conn] {
return streamPool[Req, Resp, Conn]{
stopper: stopper,
idleTimeout: defaultPooledStreamIdleTimeout,
newStream: newStream,
Expand All @@ -218,18 +217,18 @@ func makeStreamPool[Req, Resp any](

// Bind sets the gRPC connection and context for the streamPool. This must be
// called once before streamPool.Send.
func (p *streamPool[Req, Resp]) Bind(ctx context.Context, cc *grpc.ClientConn) {
func (p *streamPool[Req, Resp, Conn]) Bind(ctx context.Context, cc Conn) {
p.cc = cc
p.ccCtx = ctx
}

// Conn returns the gRPC connection bound to the streamPool.
func (p *streamPool[Req, Resp]) Conn() *grpc.ClientConn {
func (p *streamPool[Req, Resp, Conn]) Conn() Conn {
return p.cc
}

// Close closes all streams in the pool.
func (p *streamPool[Req, Resp]) Close() {
func (p *streamPool[Req, Resp, Conn]) Close() {
p.streams.Lock()
defer p.streams.Unlock()
for _, s := range p.streams.s {
Expand All @@ -238,7 +237,7 @@ func (p *streamPool[Req, Resp]) Close() {
p.streams.s = nil
}

func (p *streamPool[Req, Resp]) get() *pooledStream[Req, Resp] {
func (p *streamPool[Req, Resp, Conn]) get() *pooledStream[Req, Resp, Conn] {
p.streams.Lock()
defer p.streams.Unlock()
if len(p.streams.s) == 0 {
Expand All @@ -253,7 +252,7 @@ func (p *streamPool[Req, Resp]) get() *pooledStream[Req, Resp] {
return s
}

func (p *streamPool[Req, Resp]) putIfNotClosed(s *pooledStream[Req, Resp]) {
func (p *streamPool[Req, Resp, Conn]) putIfNotClosed(s *pooledStream[Req, Resp, Conn]) {
p.streams.Lock()
defer p.streams.Unlock()
if s.streamCtx.Err() != nil {
Expand All @@ -265,7 +264,7 @@ func (p *streamPool[Req, Resp]) putIfNotClosed(s *pooledStream[Req, Resp]) {
p.streams.s = append(p.streams.s, s)
}

func (p *streamPool[Req, Resp]) remove(s *pooledStream[Req, Resp]) bool {
func (p *streamPool[Req, Resp, Conn]) remove(s *pooledStream[Req, Resp, Conn]) bool {
p.streams.Lock()
defer p.streams.Unlock()
i := slices.Index(p.streams.s, s)
Expand All @@ -278,9 +277,10 @@ func (p *streamPool[Req, Resp]) remove(s *pooledStream[Req, Resp]) bool {
return true
}

func (p *streamPool[Req, Resp]) newPooledStream() (*pooledStream[Req, Resp], error) {
if p.cc == nil {
return nil, errors.AssertionFailedf("streamPool not bound to a grpc.ClientConn")
func (p *streamPool[Req, Resp, Conn]) newPooledStream() (*pooledStream[Req, Resp, Conn], error) {
var zero Conn
if p.cc == zero {
return nil, errors.AssertionFailedf("streamPool not bound to a client conn")
}

ctx, cancel := context.WithCancel(p.ccCtx)
Expand All @@ -305,7 +305,7 @@ func (p *streamPool[Req, Resp]) newPooledStream() (*pooledStream[Req, Resp], err

// Send sends a request on a pooled stream and returns the response in a unary
// RPC fashion. If no stream is available in the pool, a new stream is created.
func (p *streamPool[Req, Resp]) Send(ctx context.Context, req Req) (Resp, error) {
func (p *streamPool[Req, Resp, Conn]) Send(ctx context.Context, req Req) (Resp, error) {
s := p.get()
if s == nil {
var err error
Expand All @@ -320,16 +320,14 @@ func (p *streamPool[Req, Resp]) Send(ctx context.Context, req Req) (Resp, error)
}

// BatchStreamPool is a streamPool specialized for BatchStreamClient streams.
type BatchStreamPool = streamPool[*kvpb.BatchRequest, *kvpb.BatchResponse]
type BatchStreamPool = streamPool[*kvpb.BatchRequest, *kvpb.BatchResponse, *grpc.ClientConn]

// BatchStreamClient is a streamClient specialized for the BatchStream RPC.
//
//go:generate mockgen -destination=mocks_generated_test.go --package=. BatchStreamClient
type BatchStreamClient = streamClient[*kvpb.BatchRequest, *kvpb.BatchResponse]

// newBatchStream constructs a BatchStreamClient from a grpc.ClientConn.
func newBatchStream(
ctx context.Context, cc *grpc.ClientConn, opts ...grpc.CallOption,
) (BatchStreamClient, error) {
return kvpb.NewInternalClient(cc).BatchStream(ctx, opts...)
func newBatchStream(ctx context.Context, cc *grpc.ClientConn) (BatchStreamClient, error) {
return kvpb.NewInternalClient(cc).BatchStream(ctx)
}
4 changes: 2 additions & 2 deletions pkg/rpc/stream_pool_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,7 @@ type mockBatchStreamConstructor struct {
}

func (m *mockBatchStreamConstructor) newStream(
ctx context.Context, conn *grpc.ClientConn, option ...grpc.CallOption,
ctx context.Context, conn *grpc.ClientConn,
) (BatchStreamClient, error) {
m.streamCount++
if m.lastStreamCtx != nil {
Expand Down Expand Up @@ -153,7 +153,7 @@ func TestStreamPool_SendBeforeBind(t *testing.T) {
resp, err := p.Send(ctx, &kvpb.BatchRequest{})
require.Nil(t, resp)
require.Error(t, err)
require.Regexp(t, err, "streamPool not bound to a grpc.ClientConn")
require.Regexp(t, err, "streamPool not bound to a client conn")
require.Equal(t, 0, conn.streamCount)
require.Len(t, p.streams.s, 0)
}
Expand Down

0 comments on commit 7f4ffeb

Please sign in to comment.