Skip to content

Commit

Permalink
test: update client state subscriber test to be not flaky and more st…
Browse files Browse the repository at this point in the history
…ressful about rapid updates (grpc#6512)
  • Loading branch information
dfawley authored Aug 10, 2023
1 parent f3e94ec commit 879faf6
Show file tree
Hide file tree
Showing 9 changed files with 156 additions and 176 deletions.
2 changes: 1 addition & 1 deletion balancer_conn_wrappers.go
Original file line number Diff line number Diff line change
Expand Up @@ -209,7 +209,7 @@ func (ccb *ccBalancerWrapper) closeBalancer(m ccbMode) {
}

ccb.mode = m
done := ccb.serializer.Done
done := ccb.serializer.Done()
b := ccb.balancer
ok := ccb.serializer.Schedule(func(_ context.Context) {
// Close the serializer to ensure that no more calls from gRPC are sent
Expand Down
16 changes: 7 additions & 9 deletions clientconn.go
Original file line number Diff line number Diff line change
Expand Up @@ -189,7 +189,7 @@ func DialContext(ctx context.Context, target string, opts ...DialOption) (conn *
// Register ClientConn with channelz.
cc.channelzRegistration(target)

cc.csMgr = newConnectivityStateManager(cc.channelzID)
cc.csMgr = newConnectivityStateManager(cc.ctx, cc.channelzID)

if err := cc.validateTransportCredentials(); err != nil {
return nil, err
Expand Down Expand Up @@ -541,10 +541,10 @@ func getChainStreamer(interceptors []StreamClientInterceptor, curr int, finalStr

// newConnectivityStateManager creates an connectivityStateManager with
// the specified id.
func newConnectivityStateManager(id *channelz.Identifier) *connectivityStateManager {
func newConnectivityStateManager(ctx context.Context, id *channelz.Identifier) *connectivityStateManager {
return &connectivityStateManager{
channelzID: id,
pubSub: grpcsync.NewPubSub(),
pubSub: grpcsync.NewPubSub(ctx),
}
}

Expand Down Expand Up @@ -600,10 +600,6 @@ func (csm *connectivityStateManager) getNotifyChan() <-chan struct{} {
return csm.notifyChan
}

func (csm *connectivityStateManager) close() {
csm.pubSub.Stop()
}

// ClientConnInterface defines the functions clients need to perform unary and
// streaming RPCs. It is implemented by *ClientConn, and is only intended to
// be referenced by generated code.
Expand Down Expand Up @@ -1234,7 +1230,10 @@ func (cc *ClientConn) ResetConnectBackoff() {

// Close tears down the ClientConn and all underlying connections.
func (cc *ClientConn) Close() error {
defer cc.cancel()
defer func() {
cc.cancel()
<-cc.csMgr.pubSub.Done()
}()

cc.mu.Lock()
if cc.conns == nil {
Expand All @@ -1249,7 +1248,6 @@ func (cc *ClientConn) Close() error {
conns := cc.conns
cc.conns = nil
cc.csMgr.updateState(connectivity.Shutdown)
cc.csMgr.close()

pWrapper := cc.blockingpicker
rWrapper := cc.resolverWrapper
Expand Down
54 changes: 30 additions & 24 deletions internal/grpcsync/callback_serializer.go
Original file line number Diff line number Diff line change
Expand Up @@ -32,10 +32,10 @@ import (
//
// This type is safe for concurrent access.
type CallbackSerializer struct {
// Done is closed once the serializer is shut down completely, i.e all
// done is closed once the serializer is shut down completely, i.e all
// scheduled callbacks are executed and the serializer has deallocated all
// its resources.
Done chan struct{}
done chan struct{}

callbacks *buffer.Unbounded
closedMu sync.Mutex
Expand All @@ -48,12 +48,12 @@ type CallbackSerializer struct {
// callbacks will be added once this context is canceled, and any pending un-run
// callbacks will be executed before the serializer is shut down.
func NewCallbackSerializer(ctx context.Context) *CallbackSerializer {
t := &CallbackSerializer{
Done: make(chan struct{}),
cs := &CallbackSerializer{
done: make(chan struct{}),
callbacks: buffer.NewUnbounded(),
}
go t.run(ctx)
return t
go cs.run(ctx)
return cs
}

// Schedule adds a callback to be scheduled after existing callbacks are run.
Expand All @@ -64,56 +64,62 @@ func NewCallbackSerializer(ctx context.Context) *CallbackSerializer {
// Return value indicates if the callback was successfully added to the list of
// callbacks to be executed by the serializer. It is not possible to add
// callbacks once the context passed to NewCallbackSerializer is cancelled.
func (t *CallbackSerializer) Schedule(f func(ctx context.Context)) bool {
t.closedMu.Lock()
defer t.closedMu.Unlock()
func (cs *CallbackSerializer) Schedule(f func(ctx context.Context)) bool {
cs.closedMu.Lock()
defer cs.closedMu.Unlock()

if t.closed {
if cs.closed {
return false
}
t.callbacks.Put(f)
cs.callbacks.Put(f)
return true
}

func (t *CallbackSerializer) run(ctx context.Context) {
func (cs *CallbackSerializer) run(ctx context.Context) {
var backlog []func(context.Context)

defer close(t.Done)
defer close(cs.done)
for ctx.Err() == nil {
select {
case <-ctx.Done():
// Do nothing here. Next iteration of the for loop will not happen,
// since ctx.Err() would be non-nil.
case callback, ok := <-t.callbacks.Get():
case callback, ok := <-cs.callbacks.Get():
if !ok {
return
}
t.callbacks.Load()
cs.callbacks.Load()
callback.(func(ctx context.Context))(ctx)
}
}

// Fetch pending callbacks if any, and execute them before returning from
// this method and closing t.Done.
t.closedMu.Lock()
t.closed = true
backlog = t.fetchPendingCallbacks()
t.callbacks.Close()
t.closedMu.Unlock()
// this method and closing cs.done.
cs.closedMu.Lock()
cs.closed = true
backlog = cs.fetchPendingCallbacks()
cs.callbacks.Close()
cs.closedMu.Unlock()
for _, b := range backlog {
b(ctx)
}
}

func (t *CallbackSerializer) fetchPendingCallbacks() []func(context.Context) {
func (cs *CallbackSerializer) fetchPendingCallbacks() []func(context.Context) {
var backlog []func(context.Context)
for {
select {
case b := <-t.callbacks.Get():
case b := <-cs.callbacks.Get():
backlog = append(backlog, b.(func(context.Context)))
t.callbacks.Load()
cs.callbacks.Load()
default:
return backlog
}
}
}

// Done returns a channel that is closed after the context passed to
// NewCallbackSerializer is canceled and all callbacks have been executed.
func (cs *CallbackSerializer) Done() <-chan struct{} {
return cs.done
}
2 changes: 1 addition & 1 deletion internal/grpcsync/callback_serializer_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -190,7 +190,7 @@ func (s) TestCallbackSerializer_Schedule_Close(t *testing.T) {
}
}
}
<-cs.Done
<-cs.Done()

done := make(chan struct{})
if cs.Schedule(func(context.Context) { close(done) }) {
Expand Down
37 changes: 11 additions & 26 deletions internal/grpcsync/pubsub.go
Original file line number Diff line number Diff line change
Expand Up @@ -40,25 +40,23 @@ type Subscriber interface {
// subscribers interested in receiving these messages register a callback
// via the Subscribe() method.
//
// Once a PubSub is stopped, no more messages can be published, and
// it is guaranteed that no more subscriber callback will be invoked.
// Once a PubSub is stopped, no more messages can be published, but any pending
// published messages will be delivered to the subscribers. Done may be used
// to determine when all published messages have been delivered.
type PubSub struct {
cs *CallbackSerializer
cancel context.CancelFunc
cs *CallbackSerializer

// Access to the below fields are guarded by this mutex.
mu sync.Mutex
msg interface{}
subscribers map[Subscriber]bool
stopped bool
}

// NewPubSub returns a new PubSub instance.
func NewPubSub() *PubSub {
ctx, cancel := context.WithCancel(context.Background())
// NewPubSub returns a new PubSub instance. Users should cancel the
// provided context to shutdown the PubSub.
func NewPubSub(ctx context.Context) *PubSub {
return &PubSub{
cs: NewCallbackSerializer(ctx),
cancel: cancel,
subscribers: map[Subscriber]bool{},
}
}
Expand All @@ -75,10 +73,6 @@ func (ps *PubSub) Subscribe(sub Subscriber) (cancel func()) {
ps.mu.Lock()
defer ps.mu.Unlock()

if ps.stopped {
return func() {}
}

ps.subscribers[sub] = true

if ps.msg != nil {
Expand Down Expand Up @@ -106,10 +100,6 @@ func (ps *PubSub) Publish(msg interface{}) {
ps.mu.Lock()
defer ps.mu.Unlock()

if ps.stopped {
return
}

ps.msg = msg
for sub := range ps.subscribers {
s := sub
Expand All @@ -124,13 +114,8 @@ func (ps *PubSub) Publish(msg interface{}) {
}
}

// Stop shuts down the PubSub and releases any resources allocated by it.
// It is guaranteed that no subscriber callbacks would be invoked once this
// method returns.
func (ps *PubSub) Stop() {
ps.mu.Lock()
defer ps.mu.Unlock()
ps.stopped = true

ps.cancel()
// Done returns a channel that is closed after the context passed to NewPubSub
// is canceled and all updates have been sent to subscribers.
func (ps *PubSub) Done() <-chan struct{} {
return ps.cs.Done()
}
18 changes: 12 additions & 6 deletions internal/grpcsync/pubsub_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@
package grpcsync

import (
"context"
"sync"
"testing"
"time"
Expand All @@ -40,8 +41,9 @@ func (ts *testSubscriber) OnMessage(msg interface{}) {
}

func (s) TestPubSub_PublishNoMsg(t *testing.T) {
pubsub := NewPubSub()
defer pubsub.Stop()
ctx, cancel := context.WithTimeout(context.Background(), defaultTestTimeout)
defer cancel()
pubsub := NewPubSub(ctx)

ts := newTestSubscriber(1)
pubsub.Subscribe(ts)
Expand All @@ -54,7 +56,9 @@ func (s) TestPubSub_PublishNoMsg(t *testing.T) {
}

func (s) TestPubSub_PublishMsgs_RegisterSubs_And_Stop(t *testing.T) {
pubsub := NewPubSub()
ctx, cancel := context.WithTimeout(context.Background(), defaultTestTimeout)
defer cancel()
pubsub := NewPubSub(ctx)

const numPublished = 10

Expand Down Expand Up @@ -148,7 +152,8 @@ func (s) TestPubSub_PublishMsgs_RegisterSubs_And_Stop(t *testing.T) {
t.FailNow()
}

pubsub.Stop()
cancel()
<-pubsub.Done()

go func() {
pubsub.Publish(99)
Expand All @@ -165,8 +170,9 @@ func (s) TestPubSub_PublishMsgs_RegisterSubs_And_Stop(t *testing.T) {
}

func (s) TestPubSub_PublishMsgs_BeforeRegisterSub(t *testing.T) {
pubsub := NewPubSub()
defer pubsub.Stop()
ctx, cancel := context.WithTimeout(context.Background(), defaultTestTimeout)
defer cancel()
pubsub := NewPubSub(ctx)

const numPublished = 3
for i := 0; i < numPublished; i++ {
Expand Down
2 changes: 1 addition & 1 deletion resolver_conn_wrapper.go
Original file line number Diff line number Diff line change
Expand Up @@ -133,7 +133,7 @@ func (ccr *ccResolverWrapper) close() {
ccr.mu.Unlock()

// Give enqueued callbacks a chance to finish.
<-ccr.serializer.Done
<-ccr.serializer.Done()

// Spawn a goroutine to close the resolver (since it may block trying to
// cleanup all allocated resources) and return early.
Expand Down
Loading

0 comments on commit 879faf6

Please sign in to comment.