Skip to content

Commit

Permalink
client: support a 1:1 mapping with acbws and addrConns
Browse files Browse the repository at this point in the history
  • Loading branch information
dfawley committed May 19, 2023
1 parent 098b2d0 commit 689d5a9
Show file tree
Hide file tree
Showing 3 changed files with 73 additions and 131 deletions.
70 changes: 8 additions & 62 deletions balancer_conn_wrappers.go
Original file line number Diff line number Diff line change
Expand Up @@ -107,19 +107,6 @@ func (ccb *ccBalancerWrapper) updateClientConnState(ccs *balancer.ClientConnStat
// updateSubConnState is invoked by grpc to push a subConn state update to the
// underlying balancer.
func (ccb *ccBalancerWrapper) updateSubConnState(sc balancer.SubConn, s connectivity.State, err error) {
// When updating addresses for a SubConn, if the address in use is not in
// the new addresses, the old ac will be tearDown() and a new ac will be
// created. tearDown() generates a state change with Shutdown state, we
// don't want the balancer to receive this state change. So before
// tearDown() on the old ac, ac.acbw (acWrapper) will be set to nil, and
// this function will be called with (nil, Shutdown). We don't need to call
// balancer method in this case.
//
// TODO: Suppress the above mentioned state change to Shutdown, so we don't
// have to handle it here.
if sc == nil {
return
}
ccb.serializer.Schedule(func(_ context.Context) {
ccb.balancer.UpdateSubConnState(sc, balancer.SubConnState{ConnectivityState: s, ConnectionError: err})
})
Expand Down Expand Up @@ -193,9 +180,7 @@ func (ccb *ccBalancerWrapper) NewSubConn(addrs []resolver.Address, opts balancer
return nil, err
}
acbw := &acBalancerWrapper{ac: ac, producers: make(map[balancer.ProducerBuilder]*refCountedProducer)}
acbw.ac.mu.Lock()
ac.acbw = acbw
acbw.ac.mu.Unlock()
return acbw, nil
}

Expand All @@ -204,7 +189,7 @@ func (ccb *ccBalancerWrapper) RemoveSubConn(sc balancer.SubConn) {
if !ok {
return
}
ccb.cc.removeAddrConn(acbw.getAddrConn(), errConnDrain)
ccb.cc.removeAddrConn(acbw.ac, errConnDrain)
}

func (ccb *ccBalancerWrapper) UpdateAddresses(sc balancer.SubConn, addrs []resolver.Address) {
Expand Down Expand Up @@ -236,63 +221,24 @@ func (ccb *ccBalancerWrapper) Target() string {
// acBalancerWrapper is a wrapper on top of ac for balancers.
// It implements balancer.SubConn interface.
type acBalancerWrapper struct {
ac *addrConn // read-only

mu sync.Mutex
ac *addrConn
producers map[balancer.ProducerBuilder]*refCountedProducer
}

func (acbw *acBalancerWrapper) UpdateAddresses(addrs []resolver.Address) {
acbw.mu.Lock()
defer acbw.mu.Unlock()
if len(addrs) <= 0 {
acbw.ac.cc.removeAddrConn(acbw.ac, errConnDrain)
return
}
if !acbw.ac.tryUpdateAddrs(addrs) {
cc := acbw.ac.cc
opts := acbw.ac.scopts
acbw.ac.mu.Lock()
// Set old ac.acbw to nil so the Shutdown state update will be ignored
// by balancer.
//
// TODO(bar) the state transition could be wrong when tearDown() old ac
// and creating new ac, fix the transition.
acbw.ac.acbw = nil
acbw.ac.mu.Unlock()
acState := acbw.ac.getState()
acbw.ac.cc.removeAddrConn(acbw.ac, errConnDrain)

if acState == connectivity.Shutdown {
return
}
func (acbw *acBalancerWrapper) String() string {
return fmt.Sprintf("SubConn(id:%d)", acbw.ac.channelzID.Int())
}

newAC, err := cc.newAddrConn(addrs, opts)
if err != nil {
channelz.Warningf(logger, acbw.ac.channelzID, "acBalancerWrapper: UpdateAddresses: failed to newAddrConn: %v", err)
return
}
acbw.ac = newAC
newAC.mu.Lock()
newAC.acbw = acbw
newAC.mu.Unlock()
if acState != connectivity.Idle {
go newAC.connect()
}
}
func (acbw *acBalancerWrapper) UpdateAddresses(addrs []resolver.Address) {
acbw.ac.updateAddrs(addrs)
}

func (acbw *acBalancerWrapper) Connect() {
acbw.mu.Lock()
defer acbw.mu.Unlock()
go acbw.ac.connect()
}

func (acbw *acBalancerWrapper) getAddrConn() *addrConn {
acbw.mu.Lock()
defer acbw.mu.Unlock()
return acbw.ac
}

// NewStream begins a streaming RPC on the addrConn. If the addrConn is not
// ready, blocks until it is or ctx expires. Returns an error when the context
// expires or the addrConn is shut down.
Expand Down
122 changes: 60 additions & 62 deletions clientconn.go
Original file line number Diff line number Diff line change
Expand Up @@ -836,9 +836,6 @@ func (ac *addrConn) connect() error {
ac.mu.Unlock()
return nil
}
// Update connectivity state within the lock to prevent subsequent or
// concurrent calls from resetting the transport more than once.
ac.updateConnectivityState(connectivity.Connecting, nil)
ac.mu.Unlock()

ac.resetTransport()
Expand All @@ -857,58 +854,53 @@ func equalAddresses(a, b []resolver.Address) bool {
return true
}

// tryUpdateAddrs tries to update ac.addrs with the new addresses list.
//
// If ac is TransientFailure, it updates ac.addrs and returns true. The updated
// addresses will be picked up by retry in the next iteration after backoff.
//
// If ac is Shutdown or Idle, it updates ac.addrs and returns true.
//
// If the addresses is the same as the old list, it does nothing and returns
// true.
//
// If ac is Connecting, it returns false. The caller should tear down the ac and
// create a new one. Note that the backoff will be reset when this happens.
//
// If ac is Ready, it checks whether current connected address of ac is in the
// new addrs list.
// - If true, it updates ac.addrs and returns true. The ac will keep using
// the existing connection.
// - If false, it does nothing and returns false.
func (ac *addrConn) tryUpdateAddrs(addrs []resolver.Address) bool {
// updateAddrs updates ac.addrs with the new addresses list and handles active
// connections or connection attempts.
func (ac *addrConn) updateAddrs(addrs []resolver.Address) {
ac.mu.Lock()
defer ac.mu.Unlock()
channelz.Infof(logger, ac.channelzID, "addrConn: tryUpdateAddrs curAddr: %v, addrs: %v", ac.curAddr, addrs)
if ac.state == connectivity.Shutdown ||
ac.state == connectivity.TransientFailure ||
ac.state == connectivity.Idle {
ac.addrs = addrs
return true
}
channelz.Infof(logger, ac.channelzID, "addrConn: updateAddrs curAddr: %v, addrs: %v", ac.curAddr, addrs)

if equalAddresses(ac.addrs, addrs) {
return true
ac.mu.Unlock()
return
}

if ac.state == connectivity.Connecting {
return false
ac.addrs = addrs

if ac.state == connectivity.Shutdown ||
ac.state == connectivity.TransientFailure ||
ac.state == connectivity.Idle {
// We were not connecting, so do nothing but update the addresses.
ac.mu.Unlock()
return
}

// ac.state is Ready, try to find the connected address.
var curAddrFound bool
for _, a := range addrs {
a.ServerName = ac.cc.getServerName(a)
if reflect.DeepEqual(ac.curAddr, a) {
curAddrFound = true
break
if ac.state == connectivity.Ready {
// try to find the connected address.
for _, a := range addrs {
a.ServerName = ac.cc.getServerName(a)
if reflect.DeepEqual(ac.curAddr, a) {
// We are connected to a valid address, so do nothing bu update
// the addresses.
ac.mu.Unlock()
return
}
}
}
channelz.Infof(logger, ac.channelzID, "addrConn: tryUpdateAddrs curAddrFound: %v", curAddrFound)
if curAddrFound {
ac.addrs = addrs
}

return curAddrFound
// We are either connected to the wrong address or currently connecting.
// Stop the current iteration and restart.

ac.cancel()
ac.ctx, ac.cancel = context.WithCancel(ac.cc.ctx)

curTr := ac.transport
ac.transport = nil
ac.mu.Unlock()
curTr.GracefulClose()
// Since we were connecting/connected, we should start a new connection
// attempt.
go ac.resetTransport()
}

// getServerName determines the serverName to be used in the connection
Expand Down Expand Up @@ -1166,7 +1158,7 @@ func (ac *addrConn) adjustParams(r transport.GoAwayReason) {

func (ac *addrConn) resetTransport() {
ac.mu.Lock()
if ac.state == connectivity.Shutdown {
if ac.ctx.Err() != nil {
ac.mu.Unlock()
return
}
Expand All @@ -1192,17 +1184,17 @@ func (ac *addrConn) resetTransport() {
connectDeadline := time.Now().Add(dialDuration)

ac.updateConnectivityState(connectivity.Connecting, nil)
acCtx := ac.ctx
ac.mu.Unlock()

if err := ac.tryAllAddrs(addrs, connectDeadline); err != nil {
if err := ac.tryAllAddrs(acCtx, addrs, connectDeadline); err != nil {
ac.cc.resolveNow(resolver.ResolveNowOptions{})
// After exhausting all addresses, the addrConn enters
// TRANSIENT_FAILURE.
ac.mu.Lock()
if ac.state == connectivity.Shutdown {
ac.mu.Unlock()
if acCtx.Err() != nil {
return
}
ac.mu.Lock()
ac.updateConnectivityState(connectivity.TransientFailure, err)

// Backoff.
Expand All @@ -1217,13 +1209,13 @@ func (ac *addrConn) resetTransport() {
ac.mu.Unlock()
case <-b:
timer.Stop()
case <-ac.ctx.Done():
case <-acCtx.Done():
timer.Stop()
return
}

ac.mu.Lock()
if ac.state != connectivity.Shutdown {
if acCtx.Err() == nil {
ac.updateConnectivityState(connectivity.Idle, err)
}
ac.mu.Unlock()
Expand All @@ -1238,11 +1230,11 @@ func (ac *addrConn) resetTransport() {
// tryAllAddrs tries to creates a connection to the addresses, and stop when at
// the first successful one. It returns an error if no address was successfully
// connected, or updates ac appropriately with the new transport.
func (ac *addrConn) tryAllAddrs(addrs []resolver.Address, connectDeadline time.Time) error {
func (ac *addrConn) tryAllAddrs(ctx context.Context, addrs []resolver.Address, connectDeadline time.Time) error {
var firstConnErr error
for _, addr := range addrs {
ac.mu.Lock()
if ac.state == connectivity.Shutdown {
if ac.ctx.Err() != nil {
ac.mu.Unlock()
return errConnClosing
}
Expand All @@ -1259,7 +1251,7 @@ func (ac *addrConn) tryAllAddrs(addrs []resolver.Address, connectDeadline time.T

channelz.Infof(logger, ac.channelzID, "Subchannel picks a new address %q to connect", addr.Addr)

err := ac.createTransport(addr, copts, connectDeadline)
err := ac.createTransport(ctx, addr, copts, connectDeadline)
if err == nil {
return nil
}
Expand All @@ -1276,19 +1268,20 @@ func (ac *addrConn) tryAllAddrs(addrs []resolver.Address, connectDeadline time.T
// createTransport creates a connection to addr. It returns an error if the
// address was not successfully connected, or updates ac appropriately with the
// new transport.
func (ac *addrConn) createTransport(addr resolver.Address, copts transport.ConnectOptions, connectDeadline time.Time) error {
func (ac *addrConn) createTransport(ctx context.Context, addr resolver.Address, copts transport.ConnectOptions, connectDeadline time.Time) error {
addr.ServerName = ac.cc.getServerName(addr)
hctx, hcancel := context.WithCancel(ac.ctx)
hctx, hcancel := context.WithCancel(ctx)

onClose := func(r transport.GoAwayReason) {
ac.mu.Lock()
defer ac.mu.Unlock()
// adjust params based on GoAwayReason
ac.adjustParams(r)
if ac.state == connectivity.Shutdown {
// Already shut down. tearDown() already cleared the transport and
// canceled hctx via ac.ctx, and we expected this connection to be
// closed, so do nothing here.
if ctx.Err() != nil {
// Already shut down or connection attempt canceled. tearDown() or
// updateAddrs() already cleared the transport and canceled hctx
// via ac.ctx, and we expected this connection to be closed, so do
// nothing here.
return
}
hcancel()
Expand All @@ -1307,7 +1300,7 @@ func (ac *addrConn) createTransport(addr resolver.Address, copts transport.Conne
ac.updateConnectivityState(connectivity.Idle, nil)
}

connectCtx, cancel := context.WithDeadline(ac.ctx, connectDeadline)
connectCtx, cancel := context.WithDeadline(ctx, connectDeadline)
defer cancel()
copts.ChannelzParentID = ac.channelzID

Expand Down Expand Up @@ -1346,6 +1339,11 @@ func (ac *addrConn) createTransport(addr resolver.Address, copts transport.Conne
ac.updateConnectivityState(connectivity.Idle, nil)
return nil
}
if ctx.Err() != nil {
// updateAddrs stopped this connection attempt just after it completed.
// Pretend it didn't happen.
return nil
}
ac.curAddr = addr
ac.transport = newTr
ac.startHealthCheck(hctx) // Will set state to READY if appropriate.
Expand Down
12 changes: 5 additions & 7 deletions picker_wrapper.go
Original file line number Diff line number Diff line change
Expand Up @@ -63,10 +63,8 @@ func (pw *pickerWrapper) updatePicker(p balancer.Picker) {
// - wraps the done function in the passed in result to increment the calls
// failed or calls succeeded channelz counter before invoking the actual
// done function.
func doneChannelzWrapper(acw *acBalancerWrapper, result *balancer.PickResult) {
acw.mu.Lock()
ac := acw.ac
acw.mu.Unlock()
func doneChannelzWrapper(acbw *acBalancerWrapper, result *balancer.PickResult) {
ac := acbw.ac
ac.incrCallsStarted()
done := result.Done
result.Done = func(b balancer.DoneInfo) {
Expand Down Expand Up @@ -152,14 +150,14 @@ func (pw *pickerWrapper) pick(ctx context.Context, failfast bool, info balancer.
return nil, balancer.PickResult{}, status.Error(codes.Unavailable, err.Error())
}

acw, ok := pickResult.SubConn.(*acBalancerWrapper)
acbw, ok := pickResult.SubConn.(*acBalancerWrapper)
if !ok {
logger.Errorf("subconn returned from pick is type %T, not *acBalancerWrapper", pickResult.SubConn)
continue
}
if t := acw.getAddrConn().getReadyTransport(); t != nil {
if t := acbw.ac.getReadyTransport(); t != nil {
if channelz.IsOn() {
doneChannelzWrapper(acw, &pickResult)
doneChannelzWrapper(acbw, &pickResult)
return t, pickResult, nil
}
return t, pickResult, nil
Expand Down

0 comments on commit 689d5a9

Please sign in to comment.