Skip to content

Commit

Permalink
testing: Avoid using context.Background (#3949)
Browse files Browse the repository at this point in the history
  • Loading branch information
gauravgahlot authored Nov 5, 2020
1 parent c6fa121 commit d7a7a30
Show file tree
Hide file tree
Showing 24 changed files with 539 additions and 236 deletions.
82 changes: 47 additions & 35 deletions balancer/grpclb/grpclb_test.go

Large diffs are not rendered by default.

13 changes: 9 additions & 4 deletions balancer_switching_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -83,8 +83,10 @@ func checkPickFirst(cc *ClientConn, servers []*server) error {
err error
)
connected := false
ctx, cancel := context.WithTimeout(context.Background(), defaultTestTimeout)
defer cancel()
for i := 0; i < 5000; i++ {
if err = cc.Invoke(context.Background(), "/foo/bar", &req, &reply); errorDesc(err) == servers[0].port {
if err = cc.Invoke(ctx, "/foo/bar", &req, &reply); errorDesc(err) == servers[0].port {
if connected {
// connected is set to false if peer is not server[0]. So if
// connected is true here, this is the second time we saw
Expand All @@ -100,9 +102,10 @@ func checkPickFirst(cc *ClientConn, servers []*server) error {
if !connected {
return fmt.Errorf("pickfirst is not in effect after 5 second, EmptyCall() = _, %v, want _, %v", err, servers[0].port)
}

// The following RPCs should all succeed with the first server.
for i := 0; i < 3; i++ {
err = cc.Invoke(context.Background(), "/foo/bar", &req, &reply)
err = cc.Invoke(ctx, "/foo/bar", &req, &reply)
if errorDesc(err) != servers[0].port {
return fmt.Errorf("index %d: want peer %v, got peer %v", i, servers[0].port, err)
}
Expand All @@ -117,14 +120,16 @@ func checkRoundRobin(cc *ClientConn, servers []*server) error {
err error
)

ctx, cancel := context.WithTimeout(context.Background(), defaultTestTimeout)
defer cancel()
// Make sure connections to all servers are up.
for i := 0; i < 2; i++ {
// Do this check twice, otherwise the first RPC's transport may still be
// picked by the closing pickfirst balancer, and the test becomes flaky.
for _, s := range servers {
var up bool
for i := 0; i < 5000; i++ {
if err = cc.Invoke(context.Background(), "/foo/bar", &req, &reply); errorDesc(err) == s.port {
if err = cc.Invoke(ctx, "/foo/bar", &req, &reply); errorDesc(err) == s.port {
up = true
break
}
Expand All @@ -138,7 +143,7 @@ func checkRoundRobin(cc *ClientConn, servers []*server) error {

serverCount := len(servers)
for i := 0; i < 3*serverCount; i++ {
err = cc.Invoke(context.Background(), "/foo/bar", &req, &reply)
err = cc.Invoke(ctx, "/foo/bar", &req, &reply)
if errorDesc(err) != servers[i%serverCount].port {
return fmt.Errorf("index %d: want peer %v, got peer %v", i, servers[i%serverCount].port, err)
}
Expand Down
6 changes: 4 additions & 2 deletions benchmark/primitives/context_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,8 @@ import (
"time"
)

const defaultTestTimeout = 10 * time.Second

func BenchmarkCancelContextErrNoErr(b *testing.B) {
ctx, cancel := context.WithCancel(context.Background())
for i := 0; i < b.N; i++ {
Expand Down Expand Up @@ -72,7 +74,7 @@ func BenchmarkCancelContextChannelGotErr(b *testing.B) {
}

func BenchmarkTimerContextErrNoErr(b *testing.B) {
ctx, cancel := context.WithTimeout(context.Background(), 24*time.Hour)
ctx, cancel := context.WithTimeout(context.Background(), defaultTestTimeout)
for i := 0; i < b.N; i++ {
if err := ctx.Err(); err != nil {
b.Fatal("error")
Expand All @@ -92,7 +94,7 @@ func BenchmarkTimerContextErrGotErr(b *testing.B) {
}

func BenchmarkTimerContextChannelNoErr(b *testing.B) {
ctx, cancel := context.WithTimeout(context.Background(), 24*time.Hour)
ctx, cancel := context.WithTimeout(context.Background(), defaultTestTimeout)
for i := 0; i < b.N; i++ {
select {
case <-ctx.Done():
Expand Down
26 changes: 19 additions & 7 deletions call_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -43,6 +43,8 @@ var (
canceled = 0
)

const defaultTestTimeout = 10 * time.Second

type testCodec struct {
}

Expand Down Expand Up @@ -237,7 +239,8 @@ func (s) TestUnaryClientInterceptor(t *testing.T) {
}()

var reply string
ctx := context.Background()
ctx, cancel := context.WithTimeout(context.Background(), defaultTestTimeout)
defer cancel()
parentCtx := context.WithValue(ctx, ctxKey("parentKey"), 0)
if err := cc.Invoke(parentCtx, "/foo/bar", &expectedRequest, &reply); err != nil || reply != expectedResponse {
t.Fatalf("grpc.Invoke(_, _, _, _, _) = %v, want <nil>", err)
Expand Down Expand Up @@ -305,7 +308,8 @@ func (s) TestChainUnaryClientInterceptor(t *testing.T) {
}()

var reply string
ctx := context.Background()
ctx, cancel := context.WithTimeout(context.Background(), defaultTestTimeout)
defer cancel()
parentCtx := context.WithValue(ctx, ctxKey("parentKey"), 0)
if err := cc.Invoke(parentCtx, "/foo/bar", &expectedRequest, &reply); err != nil || reply != expectedResponse+"321" {
t.Fatalf("grpc.Invoke(_, _, _, _, _) = %v, want <nil>", err)
Expand Down Expand Up @@ -346,7 +350,8 @@ func (s) TestChainOnBaseUnaryClientInterceptor(t *testing.T) {
}()

var reply string
ctx := context.Background()
ctx, cancel := context.WithTimeout(context.Background(), defaultTestTimeout)
defer cancel()
parentCtx := context.WithValue(ctx, ctxKey("parentKey"), 0)
if err := cc.Invoke(parentCtx, "/foo/bar", &expectedRequest, &reply); err != nil || reply != expectedResponse {
t.Fatalf("grpc.Invoke(_, _, _, _, _) = %v, want <nil>", err)
Expand Down Expand Up @@ -407,7 +412,8 @@ func (s) TestChainStreamClientInterceptor(t *testing.T) {
server.stop()
}()

ctx := context.Background()
ctx, cancel := context.WithTimeout(context.Background(), defaultTestTimeout)
defer cancel()
parentCtx := context.WithValue(ctx, ctxKey("parentKey"), 0)
_, err := cc.NewStream(parentCtx, &StreamDesc{}, "/foo/bar")
if err != nil {
Expand All @@ -418,7 +424,9 @@ func (s) TestChainStreamClientInterceptor(t *testing.T) {
func (s) TestInvoke(t *testing.T) {
server, cc := setUp(t, 0, math.MaxUint32)
var reply string
if err := cc.Invoke(context.Background(), "/foo/bar", &expectedRequest, &reply); err != nil || reply != expectedResponse {
ctx, cancel := context.WithTimeout(context.Background(), defaultTestTimeout)
defer cancel()
if err := cc.Invoke(ctx, "/foo/bar", &expectedRequest, &reply); err != nil || reply != expectedResponse {
t.Fatalf("grpc.Invoke(_, _, _, _, _) = %v, want <nil>", err)
}
cc.Close()
Expand All @@ -429,7 +437,9 @@ func (s) TestInvokeLargeErr(t *testing.T) {
server, cc := setUp(t, 0, math.MaxUint32)
var reply string
req := "hello"
err := cc.Invoke(context.Background(), "/foo/bar", &req, &reply)
ctx, cancel := context.WithTimeout(context.Background(), defaultTestTimeout)
defer cancel()
err := cc.Invoke(ctx, "/foo/bar", &req, &reply)
if _, ok := status.FromError(err); !ok {
t.Fatalf("grpc.Invoke(_, _, _, _, _) receives non rpc error.")
}
Expand All @@ -445,7 +455,9 @@ func (s) TestInvokeErrorSpecialChars(t *testing.T) {
server, cc := setUp(t, 0, math.MaxUint32)
var reply string
req := "weird error"
err := cc.Invoke(context.Background(), "/foo/bar", &req, &reply)
ctx, cancel := context.WithTimeout(context.Background(), defaultTestTimeout)
defer cancel()
err := cc.Invoke(ctx, "/foo/bar", &req, &reply)
if _, ok := status.FromError(err); !ok {
t.Fatalf("grpc.Invoke(_, _, _, _, _) receives non rpc error.")
}
Expand Down
4 changes: 3 additions & 1 deletion channelz/service/service_sktopt_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -145,8 +145,10 @@ func (s) TestGetSocketOptions(t *testing.T) {
ids[i] = channelz.RegisterNormalSocket(s, svrID, strconv.Itoa(i))
defer channelz.RemoveEntry(ids[i])
}
ctx, cancel := context.WithTimeout(context.Background(), defaultTestTimeout)
defer cancel()
for i, s := range ss {
resp, _ := svr.GetSocket(context.Background(), &channelzpb.GetSocketRequest{SocketId: ids[i]})
resp, _ := svr.GetSocket(ctx, &channelzpb.GetSocketRequest{SocketId: ids[i]})
metrics := resp.GetSocket()
if !reflect.DeepEqual(metrics.GetRef(), &channelzpb.SocketRef{SocketId: ids[i], Name: strconv.Itoa(i)}) || !reflect.DeepEqual(socketProtoToStruct(metrics), s) {
t.Fatalf("resp.GetSocket() want: metrics.GetRef() = %#v and %#v, got: metrics.GetRef() = %#v and %#v", &channelzpb.SocketRef{SocketId: ids[i], Name: strconv.Itoa(i)}, s, metrics.GetRef(), socketProtoToStruct(metrics))
Expand Down
38 changes: 27 additions & 11 deletions channelz/service/service_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -69,6 +69,8 @@ var protoToSocketOpt protoToSocketOptFunc
// TODO: Go1.7 is no longer supported - does this need a change?
var emptyTime time.Time

const defaultTestTimeout = 10 * time.Second

type dummyChannel struct {
state connectivity.State
target string
Expand Down Expand Up @@ -327,7 +329,9 @@ func (s) TestGetTopChannels(t *testing.T) {
defer channelz.RemoveEntry(id)
}
s := newCZServer()
resp, _ := s.GetTopChannels(context.Background(), &channelzpb.GetTopChannelsRequest{StartChannelId: 0})
ctx, cancel := context.WithTimeout(context.Background(), defaultTestTimeout)
defer cancel()
resp, _ := s.GetTopChannels(ctx, &channelzpb.GetTopChannelsRequest{StartChannelId: 0})
if !resp.GetEnd() {
t.Fatalf("resp.GetEnd() want true, got %v", resp.GetEnd())
}
Expand All @@ -340,7 +344,7 @@ func (s) TestGetTopChannels(t *testing.T) {
id := channelz.RegisterChannel(tcs[0], 0, "")
defer channelz.RemoveEntry(id)
}
resp, _ = s.GetTopChannels(context.Background(), &channelzpb.GetTopChannelsRequest{StartChannelId: 0})
resp, _ = s.GetTopChannels(ctx, &channelzpb.GetTopChannelsRequest{StartChannelId: 0})
if resp.GetEnd() {
t.Fatalf("resp.GetEnd() want false, got %v", resp.GetEnd())
}
Expand Down Expand Up @@ -374,7 +378,9 @@ func (s) TestGetServers(t *testing.T) {
defer channelz.RemoveEntry(id)
}
svr := newCZServer()
resp, _ := svr.GetServers(context.Background(), &channelzpb.GetServersRequest{StartServerId: 0})
ctx, cancel := context.WithTimeout(context.Background(), defaultTestTimeout)
defer cancel()
resp, _ := svr.GetServers(ctx, &channelzpb.GetServersRequest{StartServerId: 0})
if !resp.GetEnd() {
t.Fatalf("resp.GetEnd() want true, got %v", resp.GetEnd())
}
Expand All @@ -387,7 +393,7 @@ func (s) TestGetServers(t *testing.T) {
id := channelz.RegisterServer(ss[0], "")
defer channelz.RemoveEntry(id)
}
resp, _ = svr.GetServers(context.Background(), &channelzpb.GetServersRequest{StartServerId: 0})
resp, _ = svr.GetServers(ctx, &channelzpb.GetServersRequest{StartServerId: 0})
if resp.GetEnd() {
t.Fatalf("resp.GetEnd() want false, got %v", resp.GetEnd())
}
Expand All @@ -407,7 +413,9 @@ func (s) TestGetServerSockets(t *testing.T) {
defer channelz.RemoveEntry(id)
}
svr := newCZServer()
resp, _ := svr.GetServerSockets(context.Background(), &channelzpb.GetServerSocketsRequest{ServerId: svrID, StartSocketId: 0})
ctx, cancel := context.WithTimeout(context.Background(), defaultTestTimeout)
defer cancel()
resp, _ := svr.GetServerSockets(ctx, &channelzpb.GetServerSocketsRequest{ServerId: svrID, StartSocketId: 0})
if !resp.GetEnd() {
t.Fatalf("resp.GetEnd() want: true, got: %v", resp.GetEnd())
}
Expand All @@ -424,7 +432,7 @@ func (s) TestGetServerSockets(t *testing.T) {
id := channelz.RegisterNormalSocket(&dummySocket{}, svrID, "")
defer channelz.RemoveEntry(id)
}
resp, _ = svr.GetServerSockets(context.Background(), &channelzpb.GetServerSocketsRequest{ServerId: svrID, StartSocketId: 0})
resp, _ = svr.GetServerSockets(ctx, &channelzpb.GetServerSocketsRequest{ServerId: svrID, StartSocketId: 0})
if resp.GetEnd() {
t.Fatalf("resp.GetEnd() want false, got %v", resp.GetEnd())
}
Expand All @@ -446,9 +454,11 @@ func (s) TestGetServerSocketsNonZeroStartID(t *testing.T) {
defer channelz.RemoveEntry(id)
}
svr := newCZServer()
ctx, cancel := context.WithTimeout(context.Background(), defaultTestTimeout)
defer cancel()
// Make GetServerSockets with startID = ids[1]+1, so socket-1 won't be
// included in the response.
resp, _ := svr.GetServerSockets(context.Background(), &channelzpb.GetServerSocketsRequest{ServerId: svrID, StartSocketId: ids[1] + 1})
resp, _ := svr.GetServerSockets(ctx, &channelzpb.GetServerSocketsRequest{ServerId: svrID, StartSocketId: ids[1] + 1})
if !resp.GetEnd() {
t.Fatalf("resp.GetEnd() want: true, got: %v", resp.GetEnd())
}
Expand Down Expand Up @@ -512,7 +522,9 @@ func (s) TestGetChannel(t *testing.T) {
defer channelz.RemoveEntry(id)
}
svr := newCZServer()
resp, _ := svr.GetChannel(context.Background(), &channelzpb.GetChannelRequest{ChannelId: ids[0]})
ctx, cancel := context.WithTimeout(context.Background(), defaultTestTimeout)
defer cancel()
resp, _ := svr.GetChannel(ctx, &channelzpb.GetChannelRequest{ChannelId: ids[0]})
metrics := resp.GetChannel()
subChans := metrics.GetSubchannelRef()
if len(subChans) != 1 || subChans[0].GetName() != refNames[2] || subChans[0].GetSubchannelId() != ids[2] {
Expand Down Expand Up @@ -552,7 +564,7 @@ func (s) TestGetChannel(t *testing.T) {
}
}
}
resp, _ = svr.GetChannel(context.Background(), &channelzpb.GetChannelRequest{ChannelId: ids[1]})
resp, _ = svr.GetChannel(ctx, &channelzpb.GetChannelRequest{ChannelId: ids[1]})
metrics = resp.GetChannel()
nestedChans = metrics.GetChannelRef()
if len(nestedChans) != 1 || nestedChans[0].GetName() != refNames[3] || nestedChans[0].GetChannelId() != ids[3] {
Expand Down Expand Up @@ -598,7 +610,9 @@ func (s) TestGetSubChannel(t *testing.T) {
defer channelz.RemoveEntry(id)
}
svr := newCZServer()
resp, _ := svr.GetSubchannel(context.Background(), &channelzpb.GetSubchannelRequest{SubchannelId: ids[1]})
ctx, cancel := context.WithTimeout(context.Background(), defaultTestTimeout)
defer cancel()
resp, _ := svr.GetSubchannel(ctx, &channelzpb.GetSubchannelRequest{SubchannelId: ids[1]})
metrics := resp.GetSubchannel()
want := map[int64]string{
ids[2]: refNames[2],
Expand Down Expand Up @@ -719,8 +733,10 @@ func (s) TestGetSocket(t *testing.T) {
ids[i] = channelz.RegisterNormalSocket(s, svrID, strconv.Itoa(i))
defer channelz.RemoveEntry(ids[i])
}
ctx, cancel := context.WithTimeout(context.Background(), defaultTestTimeout)
defer cancel()
for i, s := range ss {
resp, _ := svr.GetSocket(context.Background(), &channelzpb.GetSocketRequest{SocketId: ids[i]})
resp, _ := svr.GetSocket(ctx, &channelzpb.GetSocketRequest{SocketId: ids[i]})
metrics := resp.GetSocket()
if !reflect.DeepEqual(metrics.GetRef(), &channelzpb.SocketRef{SocketId: ids[i], Name: strconv.Itoa(i)}) || !reflect.DeepEqual(socketProtoToStruct(metrics), s) {
t.Fatalf("resp.GetSocket() want: metrics.GetRef() = %#v and %#v, got: metrics.GetRef() = %#v and %#v", &channelzpb.SocketRef{SocketId: ids[i], Name: strconv.Itoa(i)}, s, metrics.GetRef(), socketProtoToStruct(metrics))
Expand Down
19 changes: 16 additions & 3 deletions credentials/alts/internal/handshaker/handshaker_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -56,6 +56,8 @@ var (
}
)

const defaultTestTimeout = 10 * time.Second

// testRPCStream mimics a altspb.HandshakerService_DoHandshakeClient object.
type testRPCStream struct {
grpc.ClientStream
Expand Down Expand Up @@ -133,6 +135,10 @@ func (s) TestClientHandshake(t *testing.T) {
} {
errc := make(chan error)
stat.Reset()

ctx, cancel := context.WithTimeout(context.Background(), defaultTestTimeout)
defer cancel()

for i := 0; i < testCase.numberOfHandshakes; i++ {
stream := &testRPCStream{
t: t,
Expand All @@ -155,7 +161,7 @@ func (s) TestClientHandshake(t *testing.T) {
side: core.ClientSide,
}
go func() {
_, context, err := chs.ClientHandshake(context.Background())
_, context, err := chs.ClientHandshake(ctx)
if err == nil && context == nil {
panic("expected non-nil ALTS context")
}
Expand Down Expand Up @@ -188,6 +194,10 @@ func (s) TestServerHandshake(t *testing.T) {
} {
errc := make(chan error)
stat.Reset()

ctx, cancel := context.WithTimeout(context.Background(), defaultTestTimeout)
defer cancel()

for i := 0; i < testCase.numberOfHandshakes; i++ {
stream := &testRPCStream{
t: t,
Expand All @@ -207,7 +217,7 @@ func (s) TestServerHandshake(t *testing.T) {
side: core.ServerSide,
}
go func() {
_, context, err := shs.ServerHandshake(context.Background())
_, context, err := shs.ServerHandshake(ctx)
if err == nil && context == nil {
panic("expected non-nil ALTS context")
}
Expand Down Expand Up @@ -258,7 +268,10 @@ func (s) TestPeerNotResponding(t *testing.T) {
},
side: core.ClientSide,
}
_, context, err := chs.ClientHandshake(context.Background())

ctx, cancel := context.WithTimeout(context.Background(), defaultTestTimeout)
defer cancel()
_, context, err := chs.ClientHandshake(ctx)
chs.Close()
if context != nil {
t.Error("expected non-nil ALTS context")
Expand Down
9 changes: 7 additions & 2 deletions credentials/alts/utils_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@ import (
"os"
"strings"
"testing"
"time"

"google.golang.org/grpc/codes"
altspb "google.golang.org/grpc/credentials/alts/internal/proto/grpc_gcp"
Expand All @@ -37,6 +38,8 @@ const (
testServiceAccount1 = "service_account1"
testServiceAccount2 = "service_account2"
testServiceAccount3 = "service_account3"

defaultTestTimeout = 10 * time.Second
)

func setupManufacturerReader(testOS string, reader func() (io.Reader, error)) func() {
Expand Down Expand Up @@ -101,7 +104,8 @@ func (s) TestIsRunningOnGCPNoProductNameFile(t *testing.T) {
}

func (s) TestAuthInfoFromContext(t *testing.T) {
ctx := context.Background()
ctx, cancel := context.WithTimeout(context.Background(), defaultTestTimeout)
defer cancel()
altsAuthInfo := &fakeALTSAuthInfo{}
p := &peer.Peer{
AuthInfo: altsAuthInfo,
Expand Down Expand Up @@ -158,7 +162,8 @@ func (s) TestAuthInfoFromPeer(t *testing.T) {
}

func (s) TestClientAuthorizationCheck(t *testing.T) {
ctx := context.Background()
ctx, cancel := context.WithTimeout(context.Background(), defaultTestTimeout)
defer cancel()
altsAuthInfo := &fakeALTSAuthInfo{testServiceAccount1}
p := &peer.Peer{
AuthInfo: altsAuthInfo,
Expand Down
Loading

0 comments on commit d7a7a30

Please sign in to comment.