Skip to content

Commit

Permalink
rpc,kv: use drpc for BatchRequest:
Browse files Browse the repository at this point in the history
I cherry-picked the stream reuse PR cockroachdb#136648 so that both are
more likely to evolve in unison, and because I anticipated
"piggybacking" on top of it. This did not materialize, but
the important integration point is now side by side and should
lend itself well to a future change that adds stream reuse for
drpc as well.

This commit roughly encompasses the following:

1.  we change `*rpc.Connection` to also maintain a `drpcpool.Conn`.

It took me a while to grok how the `drpcpool.Pool` is architected.
Essentially, its `Get()` method gives you a handle to a `drpcpool.Conn`
that reflects a specific use case of the pool. When a client is created
on it, it gets assigned an actual `drpc.Conn` from the pool (dialing if
necessary), and when the client is closed, this conn is returned to the
pool.

So in a sense, `drpcpool.Conn` is an actual pool for some keyed use
case; `drpcpool.Pool` is simply the bucket in which the actual
connections get pooled, but they're never pulled from or returned to it
directly. We don't currently use the key since we create a pool per
remote node, and if we're not sharing TCP conns they all look the same
to us anyway (i.e. there's no point in DefaultClass vs SystemClass).

If we squint, a `drpcoool.Conn` parallels `*grpc.ClientConn` in the sense
that you can "just" make multiple clients on top of it.  Of course
internally a `*grpc.ClientConn` multiplexes multiple clients over one
HTTP2 connection, whereas a `drpcpool.Conn` represents multiple
independent TCP connections.

The lifecycle of these pools is currently completely broken, and I was
honestly surprised TestTestServerDRPC passes! Future commits will need
to clean this up to the point where we can at least run a representative
benchmark.
  • Loading branch information
tbg committed Dec 5, 2024
1 parent c4c8bda commit fd84492
Show file tree
Hide file tree
Showing 6 changed files with 107 additions and 30 deletions.
3 changes: 2 additions & 1 deletion pkg/kv/kvclient/kvtenant/connector.go
Original file line number Diff line number Diff line change
Expand Up @@ -1032,7 +1032,8 @@ func (c *connector) dialAddrs(ctx context.Context) (*client, error) {

func (c *connector) dialAddr(ctx context.Context, addr string) (conn *grpc.ClientConn, err error) {
if c.rpcDialTimeout == 0 {
return c.rpcContext.GRPCUnvalidatedDial(addr, roachpb.Locality{}).Connect(ctx)
cc, err := c.rpcContext.GRPCUnvalidatedDial(addr, roachpb.Locality{}).Connect(ctx)
return cc, err
}
err = timeutil.RunWithTimeout(ctx, "dial addr", c.rpcDialTimeout, func(ctx context.Context) error {
conn, err = c.rpcContext.GRPCUnvalidatedDial(addr, roachpb.Locality{}).Connect(ctx)
Expand Down
4 changes: 2 additions & 2 deletions pkg/kv/kvserver/loqrecovery/server.go
Original file line number Diff line number Diff line change
Expand Up @@ -750,7 +750,7 @@ func visitNodeWithRetry(
// Note that we use ConnectNoBreaker here to avoid any race with probe
// running on current node and target node restarting. Errors from circuit
// breaker probes could confuse us and present node as unavailable.
conn, err = rpcCtx.GRPCDialNode(addr.String(), node.NodeID, node.Locality, rpc.DefaultClass).ConnectNoBreaker(ctx)
conn, _, err = rpcCtx.GRPCDialNode(addr.String(), node.NodeID, node.Locality, rpc.DefaultClass).ConnectNoBreaker(ctx)
// Nodes would contain dead nodes that we don't need to visit. We can skip
// them and let caller handle incomplete info.
if err != nil {
Expand Down Expand Up @@ -803,7 +803,7 @@ func makeVisitNode(g *gossip.Gossip, loc roachpb.Locality, rpcCtx *rpc.Context)
// Note that we use ConnectNoBreaker here to avoid any race with probe
// running on current node and target node restarting. Errors from circuit
// breaker probes could confuse us and present node as unavailable.
conn, err = rpcCtx.GRPCDialNode(addr.String(), node.NodeID, node.Locality, rpc.DefaultClass).ConnectNoBreaker(ctx)
conn, _, err = rpcCtx.GRPCDialNode(addr.String(), node.NodeID, node.Locality, rpc.DefaultClass).ConnectNoBreaker(ctx)
if err != nil {
if grpcutil.IsClosedConnection(err) {
log.Infof(ctx, "can't dial node n%d because connection is permanently closed: %s",
Expand Down
1 change: 1 addition & 0 deletions pkg/rpc/BUILD.bazel
Original file line number Diff line number Diff line change
Expand Up @@ -73,6 +73,7 @@ go_library(
"@com_github_vividcortex_ewma//:ewma",
"@io_opentelemetry_go_otel//attribute",
"@io_storj_drpc//drpcmux",
"@io_storj_drpc//drpcpool",
"@io_storj_drpc//drpcserver",
"@org_golang_google_grpc//:grpc",
"@org_golang_google_grpc//backoff",
Expand Down
41 changes: 29 additions & 12 deletions pkg/rpc/connection.go
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@ import (
"github.com/cockroachdb/cockroach/pkg/util/syncutil"
"github.com/cockroachdb/errors"
"google.golang.org/grpc"
"storj.io/drpc/drpcpool"
)

// Connection is a wrapper around grpc.ClientConn. It prevents the underlying
Expand Down Expand Up @@ -61,14 +62,14 @@ func newConnectionToNodeID(k peerKey, breakerSignal func() circuit.Signal) *Conn
// block but fall back to defErr in this case.
func (c *Connection) waitOrDefault(
ctx context.Context, defErr error, sig circuit.Signal,
) (*grpc.ClientConn, error) {
) (*grpc.ClientConn, drpcpool.Conn, error) {
// Check the circuit breaker first. If it is already tripped now, we
// want it to take precedence over connFuture below (which is closed in
// the common case of a connection going bad after having been healthy
// for a while).
select {
case <-sig.C():
return nil, sig.Err()
return nil, nil, sig.Err()
default:
}

Expand All @@ -79,26 +80,26 @@ func (c *Connection) waitOrDefault(
select {
case <-c.connFuture.C():
case <-sig.C():
return nil, sig.Err()
return nil, nil, sig.Err()
case <-ctx.Done():
return nil, errors.Wrapf(ctx.Err(), "while connecting to n%d at %s", c.k.NodeID, c.k.TargetAddr)
return nil, nil, errors.Wrapf(ctx.Err(), "while connecting to n%d at %s", c.k.NodeID, c.k.TargetAddr)
}
} else {
select {
case <-c.connFuture.C():
case <-sig.C():
return nil, sig.Err()
return nil, nil, sig.Err()
case <-ctx.Done():
return nil, errors.Wrapf(ctx.Err(), "while connecting to n%d at %s", c.k.NodeID, c.k.TargetAddr)
return nil, nil, errors.Wrapf(ctx.Err(), "while connecting to n%d at %s", c.k.NodeID, c.k.TargetAddr)
default:
return nil, defErr
return nil, nil, defErr
}
}

// Done waiting, c.connFuture has resolved, return the result. Note that this
// conn could be unhealthy (or there may not even be a conn, i.e. Err() !=
// nil), if that's what the caller wanted (ConnectNoBreaker).
return c.connFuture.Conn(), c.connFuture.Err()
return c.connFuture.Conn(), c.connFuture.DRPCConn(), c.connFuture.Err()
}

// Connect returns the underlying grpc.ClientConn after it has been validated,
Expand All @@ -108,6 +109,11 @@ func (c *Connection) waitOrDefault(
// an error. In rare cases, this behavior is undesired and ConnectNoBreaker may
// be used instead.
func (c *Connection) Connect(ctx context.Context) (*grpc.ClientConn, error) {
cc, _, err := c.waitOrDefault(ctx, nil /* defErr */, c.breakerSignalFn())
return cc, err
}

func (c *Connection) Connect2(ctx context.Context) (*grpc.ClientConn, drpcpool.Conn, error) {
return c.waitOrDefault(ctx, nil /* defErr */, c.breakerSignalFn())
}

Expand All @@ -129,7 +135,9 @@ func (s *neverTripSignal) IsTripped() bool {
// that it will latch onto (or start) an existing connection attempt even if
// previous attempts have not succeeded. This may be preferable to Connect
// if the caller is already certain that a peer is available.
func (c *Connection) ConnectNoBreaker(ctx context.Context) (*grpc.ClientConn, error) {
func (c *Connection) ConnectNoBreaker(
ctx context.Context,
) (*grpc.ClientConn, drpcpool.Conn, error) {
// For ConnectNoBreaker we don't use the default Signal but pass a dummy one
// that never trips. (The probe tears down the Conn on quiesce so we don't rely
// on the Signal for that).
Expand All @@ -153,7 +161,7 @@ func (c *Connection) ConnectNoBreaker(ctx context.Context) (*grpc.ClientConn, er
// latest heartbeat. Returns ErrNotHeartbeated if the peer was just contacted for
// the first time and the first heartbeat has not occurred yet.
func (c *Connection) Health() error {
_, err := c.waitOrDefault(context.Background(), ErrNotHeartbeated, c.breakerSignalFn())
_, _, err := c.waitOrDefault(context.Background(), ErrNotHeartbeated, c.breakerSignalFn())
return err
}

Expand All @@ -164,6 +172,7 @@ func (c *Connection) Signal() circuit.Signal {
type connFuture struct {
ready chan struct{}
cc *grpc.ClientConn
dc drpcpool.Conn
err error
}

Expand All @@ -190,6 +199,14 @@ func (s *connFuture) Conn() *grpc.ClientConn {
return s.cc
}

// DRPCConn must only be called after C() has been closed.
func (s *connFuture) DRPCConn() drpcpool.Conn {
if s.err != nil {
return nil
}
return s.dc
}

func (s *connFuture) Resolved() bool {
select {
case <-s.ready:
Expand All @@ -201,12 +218,12 @@ func (s *connFuture) Resolved() bool {

// Resolve is idempotent. Only the first call has any effect.
// Not thread safe.
func (s *connFuture) Resolve(cc *grpc.ClientConn, err error) {
func (s *connFuture) Resolve(cc *grpc.ClientConn, dc drpcpool.Conn, err error) {
select {
case <-s.ready:
// Already resolved, noop.
default:
s.cc, s.err = cc, err
s.cc, s.dc, s.err = cc, dc, err
close(s.ready)
}
}
Expand Down
43 changes: 33 additions & 10 deletions pkg/rpc/nodedialer/nodedialer.go
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@ import (
"github.com/cockroachdb/cockroach/pkg/util/tracing"
"github.com/cockroachdb/errors"
"google.golang.org/grpc"
"storj.io/drpc/drpcpool"
)

// An AddressResolver translates NodeIDs into addresses.
Expand Down Expand Up @@ -96,7 +97,7 @@ func (n *Dialer) Dial(
err = errors.Wrapf(err, "failed to resolve n%d", nodeID)
return nil, err
}
conn, _, err := n.dial(ctx, nodeID, addr, locality, true, class)
conn, _, _, err := n.dial(ctx, nodeID, addr, locality, true, class)
return conn, err
}

Expand All @@ -113,7 +114,7 @@ func (n *Dialer) DialNoBreaker(
if err != nil {
return nil, err
}
conn, _, err := n.dial(ctx, nodeID, addr, locality, false, class)
conn, _, _, err := n.dial(ctx, nodeID, addr, locality, false, class)
return conn, err
}

Expand Down Expand Up @@ -143,16 +144,38 @@ func (n *Dialer) DialInternalClient(
return nil, errors.Wrap(err, "resolver error")
}
log.VEventf(ctx, 2, "sending request to %s", addr)
conn, pool, err := n.dial(ctx, nodeID, addr, locality, true, class)
conn, pool, dconn, err := n.dial(ctx, nodeID, addr, locality, true, class)
if err != nil {
return nil, err
}

const useDRPC = true

if useDRPC {
grpcClient := kvpb.NewInternalClient(conn) // used for RangeFeed only
client := &unaryDRPCBatchServiceToInternalAdapter{
InternalClient: grpcClient,
drpcClient: kvpb.NewDRPCDRPCBatchServiceClient(dconn),
}
return client, nil
}
client := kvpb.NewInternalClient(conn)
client = &BatchStreamerClient{InternalClient: client, pool: pool}
client = &TracingInternalClient{InternalClient: client}
return client, nil
}

type unaryDRPCBatchServiceToInternalAdapter struct {
kvpb.InternalClient
drpcClient kvpb.DRPCDRPCBatchServiceClient
}

func (a *unaryDRPCBatchServiceToInternalAdapter) Batch(
ctx context.Context, in *kvpb.BatchRequest, opts ...grpc.CallOption,
) (*kvpb.BatchResponse, error) {
return a.drpcClient.Batch(ctx, in)
}

// dial performs the dialing of the remote connection. If checkBreaker
// is set (which it usually is), circuit breakers for the peer will be
// checked.
Expand All @@ -163,28 +186,28 @@ func (n *Dialer) dial(
locality roachpb.Locality,
checkBreaker bool,
class rpc.ConnectionClass,
) (_ *grpc.ClientConn, _ *rpc.BatchStreamPool, err error) {
) (_ *grpc.ClientConn, _ *rpc.BatchStreamPool, _ drpcpool.Conn, err error) {
const ctxWrapMsg = "dial"
// Don't trip the breaker if we're already canceled.
if ctxErr := ctx.Err(); ctxErr != nil {
return nil, nil, errors.Wrap(ctxErr, ctxWrapMsg)
return nil, nil, nil, errors.Wrap(ctxErr, ctxWrapMsg)
}
rpcConn := n.rpcContext.GRPCDialNode(addr.String(), nodeID, locality, class)
connect := rpcConn.Connect
connect := rpcConn.Connect2
if !checkBreaker {
connect = rpcConn.ConnectNoBreaker
}
conn, err := connect(ctx)
conn, dconn, err := connect(ctx)
if err != nil {
// If we were canceled during the dial, don't trip the breaker.
if ctxErr := ctx.Err(); ctxErr != nil {
return nil, nil, errors.Wrap(ctxErr, ctxWrapMsg)
return nil, nil, nil, errors.Wrap(ctxErr, ctxWrapMsg)
}
err = errors.Wrapf(err, "failed to connect to n%d at %v", nodeID, addr)
return nil, nil, err
return nil, nil, nil, err
}

return conn, &rpcConn.BatchStreamPool, nil
return conn, &rpcConn.BatchStreamPool, dconn, nil
}

// ConnHealth returns nil if we have an open connection of the request
Expand Down
45 changes: 40 additions & 5 deletions pkg/rpc/peer.go
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@ package rpc

import (
"context"
"crypto/tls"
"fmt"
"runtime/pprof"
"time"
Expand All @@ -27,6 +28,9 @@ import (
"google.golang.org/grpc"
"google.golang.org/grpc/connectivity"
"google.golang.org/grpc/status"
"storj.io/drpc/drpcconn"
"storj.io/drpc/drpcmigrate"
"storj.io/drpc/drpcpool"
)

type peerStatus int
Expand Down Expand Up @@ -125,6 +129,7 @@ type peer struct {
heartbeatInterval time.Duration
heartbeatTimeout time.Duration
dial func(ctx context.Context, target string, class ConnectionClass) (*grpc.ClientConn, error)
dialDRPC func(ctx context.Context, target string) (drpcpool.Conn, error)
// b maintains connection health. This breaker's async probe is always
// active - it is the heartbeat loop and manages `mu.c.` (including
// recreating it after the connection fails and has to be redialed).
Expand Down Expand Up @@ -245,6 +250,35 @@ func (rpcCtx *Context) newPeer(k peerKey, locality roachpb.Locality) *peer {
additionalDialOpts = append(additionalDialOpts, rpcCtx.testingDialOpts...)
return rpcCtx.grpcDialRaw(ctx, target, class, additionalDialOpts...)
},
dialDRPC: func(ctx context.Context, target string) (drpcpool.Conn, error) {
// TODO(server): could use connection class instead of empty key here.
pool := drpcpool.New[struct{}, drpcpool.Conn](drpcpool.Options{})
pooledConn := pool.Get(context.Background() /* unused */, struct{}{}, func(ctx context.Context,
_ struct{}) (drpcpool.Conn, error) {
rawconn, err := drpcmigrate.DialWithHeader(ctx, "tcp", target, drpcmigrate.DRPCHeader)
if err != nil {
return nil, err
}

var conn *drpcconn.Conn
if rpcCtx.ContextOptions.Insecure {
conn = drpcconn.New(rawconn)
} else {
tlsConfig, err := rpcCtx.GetClientTLSConfig()
if err != nil {
return nil, err
}
tlsConfig.InsecureSkipVerify = true // HACK
tlsConn := tls.Client(rawconn, tlsConfig)
conn = drpcconn.New(tlsConn)
}
// TODO(tbg): if we remove gRPC, this is where we'd do an initial ping
// to ascertain that the peer is compatible with us before returning
// the conn for general use.
return conn, err
})
return pooledConn, nil
},
heartbeatInterval: rpcCtx.RPCHeartbeatInterval,
heartbeatTimeout: rpcCtx.RPCHeartbeatTimeout,
}
Expand Down Expand Up @@ -378,6 +412,7 @@ func (p *peer) runOnce(ctx context.Context, report func(error)) error {
if err != nil {
return err
}
dc, err := p.dialDRPC(ctx, p.k.TargetAddr)
defer func() {
_ = cc.Close() // nolint:grpcconnclose
}()
Expand All @@ -399,7 +434,7 @@ func (p *peer) runOnce(ctx context.Context, report func(error)) error {
return err
}

p.onInitialHeartbeatSucceeded(ctx, p.opts.Clock.Now(), cc, report)
p.onInitialHeartbeatSucceeded(ctx, p.opts.Clock.Now(), cc, dc, report)

return p.runHeartbeatUntilFailure(ctx, connFailedCh)
}
Expand Down Expand Up @@ -563,7 +598,7 @@ func logOnHealthy(ctx context.Context, disconnected, now time.Time) {
}

func (p *peer) onInitialHeartbeatSucceeded(
ctx context.Context, now time.Time, cc *grpc.ClientConn, report func(err error),
ctx context.Context, now time.Time, cc *grpc.ClientConn, dc drpcpool.Conn, report func(err error),
) {
// First heartbeat succeeded. By convention we update the breaker
// before updating the peer. The other way is fine too, just the
Expand All @@ -584,7 +619,7 @@ func (p *peer) onInitialHeartbeatSucceeded(

// Close the channel last which is helpful for unit tests that
// first waitOrDefault for a healthy conn to then check metrics.
p.mu.c.connFuture.Resolve(cc, nil /* err */)
p.mu.c.connFuture.Resolve(cc, dc, nil /* err */)

logOnHealthy(ctx, p.mu.disconnected, now)
}
Expand Down Expand Up @@ -701,7 +736,7 @@ func (p *peer) onHeartbeatFailed(
// someone might be waiting on it in ConnectNoBreaker who is not paying
// attention to the circuit breaker.
err = &netutil.InitialHeartbeatFailedError{WrappedErr: err}
ls.c.connFuture.Resolve(nil /* cc */, err)
ls.c.connFuture.Resolve(nil, nil /* cc */, err)
}
// By convention, we stick to updating breaker before updating peer
// to make it easier to write non-flaky tests.
Expand Down Expand Up @@ -737,7 +772,7 @@ func (p *peer) onQuiesce(report func(error)) {
// NB: it's important that connFuture is resolved, or a caller sitting on
// `c.ConnectNoBreaker` would never be unblocked; after all, the probe won't
// start again in the future.
p.snap().c.connFuture.Resolve(nil, errQuiescing)
p.snap().c.connFuture.Resolve(nil, nil, errQuiescing)
}

func (p PeerSnap) deletable(now time.Time) bool {
Expand Down

0 comments on commit fd84492

Please sign in to comment.