Skip to content

Commit

Permalink
scripts: add linter rule for using context.WithTimeout on tests (grpc…
Browse files Browse the repository at this point in the history
  • Loading branch information
hasson82 authored Jul 3, 2024
1 parent 4e9b596 commit bdd707e
Show file tree
Hide file tree
Showing 13 changed files with 105 additions and 28 deletions.
8 changes: 6 additions & 2 deletions credentials/alts/internal/handshaker/handshaker_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -309,7 +309,9 @@ func (s) TestNewClientHandshaker(t *testing.T) {
conn := testutil.NewTestConn(nil, nil)
clientConn := &grpc.ClientConn{}
opts := &ClientHandshakerOptions{}
hs, err := NewClientHandshaker(context.Background(), clientConn, conn, opts)
ctx, cancel := context.WithTimeout(context.Background(), defaultTestTimeout)
defer cancel()
hs, err := NewClientHandshaker(ctx, clientConn, conn, opts)
if err != nil {
t.Errorf("NewClientHandshaker returned unexpected error: %v", err)
}
Expand Down Expand Up @@ -341,7 +343,9 @@ func (s) TestNewServerHandshaker(t *testing.T) {
conn := testutil.NewTestConn(nil, nil)
clientConn := &grpc.ClientConn{}
opts := &ServerHandshakerOptions{}
hs, err := NewServerHandshaker(context.Background(), clientConn, conn, opts)
ctx, cancel := context.WithTimeout(context.Background(), defaultTestTimeout)
defer cancel()
hs, err := NewServerHandshaker(ctx, clientConn, conn, opts)
if err != nil {
t.Errorf("NewServerHandshaker returned unexpected error: %v", err)
}
Expand Down
24 changes: 18 additions & 6 deletions gcp/observability/observability_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -184,7 +184,9 @@ func (s) TestRefuseStartWithInvalidPatterns(t *testing.T) {
envconfig.ObservabilityConfigFile = oldObservabilityConfigFile
}()
// If there is at least one invalid pattern, which should not be silently tolerated.
if err := Start(context.Background()); err == nil {
ctx, cancel := context.WithTimeout(context.Background(), defaultTestTimeout)
defer cancel()
if err := Start(ctx); err == nil {
t.Fatalf("Invalid patterns not triggering error")
}
}
Expand Down Expand Up @@ -220,7 +222,9 @@ func (s) TestRefuseStartWithExcludeAndWildCardAll(t *testing.T) {
envconfig.ObservabilityConfigFile = oldObservabilityConfigFile
}()
// If there is at least one invalid pattern, which should not be silently tolerated.
if err := Start(context.Background()); err == nil {
ctx, cancel := context.WithTimeout(context.Background(), defaultTestTimeout)
defer cancel()
if err := Start(ctx); err == nil {
t.Fatalf("Invalid patterns not triggering error")
}
}
Expand Down Expand Up @@ -316,7 +320,9 @@ func (s) TestBothConfigEnvVarsSet(t *testing.T) {
defer func() {
envconfig.ObservabilityConfig = oldObservabilityConfig
}()
if err := Start(context.Background()); err == nil {
ctx, cancel := context.WithTimeout(context.Background(), defaultTestTimeout)
defer cancel()
if err := Start(ctx); err == nil {
t.Fatalf("Invalid patterns not triggering error")
}
}
Expand All @@ -331,7 +337,9 @@ func (s) TestErrInFileSystemEnvVar(t *testing.T) {
defer func() {
envconfig.ObservabilityConfigFile = oldObservabilityConfigFile
}()
if err := Start(context.Background()); err == nil {
ctx, cancel := context.WithTimeout(context.Background(), defaultTestTimeout)
defer cancel()
if err := Start(ctx); err == nil {
t.Fatalf("Invalid file system path not triggering error")
}
}
Expand All @@ -346,7 +354,9 @@ func (s) TestNoEnvSet(t *testing.T) {
envconfig.ObservabilityConfigFile = oldObservabilityConfigFile
}()
// If there is no observability config set at all, the Start should return an error.
if err := Start(context.Background()); err == nil {
ctx, cancel := context.WithTimeout(context.Background(), defaultTestTimeout)
defer cancel()
if err := Start(ctx); err == nil {
t.Fatalf("Invalid patterns not triggering error")
}
}
Expand Down Expand Up @@ -540,7 +550,9 @@ func (s) TestStartErrorsThenEnd(t *testing.T) {
envconfig.ObservabilityConfig = oldObservabilityConfig
envconfig.ObservabilityConfigFile = oldObservabilityConfigFile
}()
if err := Start(context.Background()); err == nil {
ctx, cancel := context.WithTimeout(context.Background(), defaultTestTimeout)
defer cancel()
if err := Start(ctx); err == nil {
t.Fatalf("Invalid patterns not triggering error")
}
End()
Expand Down
6 changes: 5 additions & 1 deletion internal/binarylog/method_logger_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,8 @@ import (
"google.golang.org/protobuf/types/known/durationpb"
)

const defaultTestTimeout = 10 * time.Second

func (s) TestLog(t *testing.T) {
idGen.reset()
ml := NewTruncatingMethodLogger(10, 10)
Expand Down Expand Up @@ -333,10 +335,12 @@ func (s) TestLog(t *testing.T) {
},
},
}
ctx, cancel := context.WithTimeout(context.Background(), defaultTestTimeout)
defer cancel()
for i, tc := range testCases {
buf.Reset()
tc.want.SequenceIdWithinCall = uint64(i + 1)
ml.Log(context.Background(), tc.config)
ml.Log(ctx, tc.config)
inSink := new(binlogpb.GrpcLogEntry)
if err := proto.Unmarshal(buf.Bytes()[4:], inSink); err != nil {
t.Errorf("failed to unmarshal bytes in sink to proto: %v", err)
Expand Down
4 changes: 3 additions & 1 deletion internal/xds/bootstrap/tlscreds/bundle_ext_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -247,7 +247,9 @@ func (s) TestMTLS(t *testing.T) {
}
defer conn.Close()
client := testgrpc.NewTestServiceClient(conn)
if _, err = client.EmptyCall(context.Background(), &testpb.Empty{}); err != nil {
ctx, cancel := context.WithTimeout(context.Background(), defaultTestTimeout)
defer cancel()
if _, err = client.EmptyCall(ctx, &testpb.Empty{}); err != nil {
t.Errorf("EmptyCall(): got error %v when expected to succeed", err)
}
}
7 changes: 6 additions & 1 deletion internal/xds/bootstrap/tlscreds/bundle_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@ import (
"fmt"
"strings"
"testing"
"time"

"google.golang.org/grpc"
"google.golang.org/grpc/credentials/tls/certprovider"
Expand All @@ -37,6 +38,8 @@ import (
testpb "google.golang.org/grpc/interop/grpc_testing"
)

const defaultTestTimeout = 5 * time.Second

type s struct {
grpctest.Tester
}
Expand Down Expand Up @@ -86,7 +89,9 @@ func (s) TestFailingProvider(t *testing.T) {
defer conn.Close()

client := testgrpc.NewTestServiceClient(conn)
_, err = client.EmptyCall(context.Background(), &testpb.Empty{})
ctx, cancel := context.WithTimeout(context.Background(), defaultTestTimeout)
defer cancel()
_, err = client.EmptyCall(ctx, &testpb.Empty{})
if wantErr := "test error"; err == nil || !strings.Contains(err.Error(), wantErr) {
t.Errorf("EmptyCall() got err: %s, want err to contain: %s", err, wantErr)
}
Expand Down
8 changes: 6 additions & 2 deletions internal/xds/rbac/rbac_engine_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,7 @@ import (
"net/url"
"reflect"
"testing"
"time"

v1xdsudpatypepb "github.com/cncf/xds/go/udpa/type/v1"
v3xdsxdstypepb "github.com/cncf/xds/go/xds/type/v3"
Expand All @@ -48,6 +49,8 @@ import (
"google.golang.org/protobuf/types/known/wrapperspb"
)

const defaultTestTimeout = 10 * time.Second

type s struct {
grpctest.Tester
}
Expand Down Expand Up @@ -1742,14 +1745,15 @@ func (s) TestChainEngine(t *testing.T) {
}
// Query the created chain of RBAC Engines with different args to see
// if the chain of RBAC Engines configured as such works as intended.
ctx, cancel := context.WithTimeout(context.Background(), defaultTestTimeout)
defer cancel()
for _, data := range test.rbacQueries {
func() {
// Construct the context with three data points that have enough
// information to represent incoming RPC's. This will be how a
// user uses this API. A user will have to put MD, PeerInfo, and
// the connection the RPC is sent on in the context.
ctx := metadata.NewIncomingContext(context.Background(), data.rpcData.md)

ctx = metadata.NewIncomingContext(ctx, data.rpcData.md)
// Make a TCP connection with a certain destination port. The
// address/port of this connection will be used to populate the
// destination ip/port in RPCData struct. This represents what
Expand Down
17 changes: 13 additions & 4 deletions metadata/metadata_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -202,9 +202,11 @@ func (s) TestFromIncomingContext(t *testing.T) {
md := Pairs(
"X-My-Header-1", "42",
)
ctx, cancel := context.WithTimeout(context.Background(), defaultTestTimeout)
defer cancel()
// Verify that we lowercase if callers directly modify md
md["X-INCORRECT-UPPERCASE"] = []string{"foo"}
ctx := NewIncomingContext(context.Background(), md)
ctx = NewIncomingContext(ctx, md)

result, found := FromIncomingContext(ctx)
if !found {
Expand Down Expand Up @@ -238,9 +240,11 @@ func (s) TestValueFromIncomingContext(t *testing.T) {
"X-My-Header-2", "43-2",
"x-my-header-3", "44",
)
ctx, cancel := context.WithTimeout(context.Background(), defaultTestTimeout)
defer cancel()
// Verify that we lowercase if callers directly modify md
md["X-INCORRECT-UPPERCASE"] = []string{"foo"}
ctx := NewIncomingContext(context.Background(), md)
ctx = NewIncomingContext(ctx, md)

for _, test := range []struct {
key string
Expand Down Expand Up @@ -376,17 +380,22 @@ func BenchmarkFromOutgoingContext(b *testing.B) {
}

func BenchmarkFromIncomingContext(b *testing.B) {
ctx, cancel := context.WithTimeout(context.Background(), defaultTestTimeout)
defer cancel()
md := Pairs("X-My-Header-1", "42")
ctx := NewIncomingContext(context.Background(), md)
ctx = NewIncomingContext(ctx, md)

b.ResetTimer()
for n := 0; n < b.N; n++ {
FromIncomingContext(ctx)
}
}

func BenchmarkValueFromIncomingContext(b *testing.B) {
ctx, cancel := context.WithTimeout(context.Background(), defaultTestTimeout)
defer cancel()
md := Pairs("X-My-Header-1", "42")
ctx := NewIncomingContext(context.Background(), md)
ctx = NewIncomingContext(ctx, md)

b.Run("key-found", func(b *testing.B) {
for n := 0; n < b.N; n++ {
Expand Down
8 changes: 7 additions & 1 deletion peer/peer_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -22,10 +22,13 @@ import (
"context"
"fmt"
"testing"
"time"

"google.golang.org/grpc/credentials"
)

const defaultTestTimeout = 10 * time.Second

// A struct that implements AuthInfo interface and implements CommonAuthInfo() method.
type testAuthInfo struct {
credentials.CommonAuthInfo
Expand Down Expand Up @@ -80,9 +83,12 @@ func TestPeerStringer(t *testing.T) {
want: "Peer<nil>",
},
}
ctx, cancel := context.WithTimeout(context.Background(), defaultTestTimeout)
defer cancel()
for _, tc := range testCases {
t.Run(tc.name, func(t *testing.T) {
ctx := NewContext(context.Background(), tc.peer)
ctx = NewContext(ctx, tc.peer)

p, ok := FromContext(ctx)
if !ok {
t.Fatalf("Unable to get peer from context")
Expand Down
16 changes: 12 additions & 4 deletions picker_wrapper_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -78,9 +78,11 @@ func (s) TestBlockingPick(t *testing.T) {
bp := newPickerWrapper(nil)
// All goroutines should block because picker is nil in bp.
var finishedCount uint64
ctx, cancel := context.WithTimeout(context.Background(), defaultTestTimeout)
defer cancel()
for i := goroutineCount; i > 0; i-- {
go func() {
if tr, _, err := bp.pick(context.Background(), true, balancer.PickInfo{}); err != nil || tr != testT {
if tr, _, err := bp.pick(ctx, true, balancer.PickInfo{}); err != nil || tr != testT {
t.Errorf("bp.pick returned non-nil error: %v", err)
}
atomic.AddUint64(&finishedCount, 1)
Expand All @@ -97,10 +99,12 @@ func (s) TestBlockingPickNoSubAvailable(t *testing.T) {
bp := newPickerWrapper(nil)
var finishedCount uint64
bp.updatePicker(&testingPicker{err: balancer.ErrNoSubConnAvailable, maxCalled: goroutineCount})
ctx, cancel := context.WithTimeout(context.Background(), defaultTestTimeout)
defer cancel()
// All goroutines should block because picker returns no subConn available.
for i := goroutineCount; i > 0; i-- {
go func() {
if tr, _, err := bp.pick(context.Background(), true, balancer.PickInfo{}); err != nil || tr != testT {
if tr, _, err := bp.pick(ctx, true, balancer.PickInfo{}); err != nil || tr != testT {
t.Errorf("bp.pick returned non-nil error: %v", err)
}
atomic.AddUint64(&finishedCount, 1)
Expand All @@ -117,11 +121,13 @@ func (s) TestBlockingPickTransientWaitforready(t *testing.T) {
bp := newPickerWrapper(nil)
bp.updatePicker(&testingPicker{err: balancer.ErrTransientFailure, maxCalled: goroutineCount})
var finishedCount uint64
ctx, cancel := context.WithTimeout(context.Background(), defaultTestTimeout)
defer cancel()
// All goroutines should block because picker returns transientFailure and
// picks are not failfast.
for i := goroutineCount; i > 0; i-- {
go func() {
if tr, _, err := bp.pick(context.Background(), false, balancer.PickInfo{}); err != nil || tr != testT {
if tr, _, err := bp.pick(ctx, false, balancer.PickInfo{}); err != nil || tr != testT {
t.Errorf("bp.pick returned non-nil error: %v", err)
}
atomic.AddUint64(&finishedCount, 1)
Expand All @@ -138,10 +144,12 @@ func (s) TestBlockingPickSCNotReady(t *testing.T) {
bp := newPickerWrapper(nil)
bp.updatePicker(&testingPicker{sc: testSCNotReady, maxCalled: goroutineCount})
var finishedCount uint64
ctx, cancel := context.WithTimeout(context.Background(), defaultTestTimeout)
defer cancel()
// All goroutines should block because subConn is not ready.
for i := goroutineCount; i > 0; i-- {
go func() {
if tr, _, err := bp.pick(context.Background(), true, balancer.PickInfo{}); err != nil || tr != testT {
if tr, _, err := bp.pick(ctx, true, balancer.PickInfo{}); err != nil || tr != testT {
t.Errorf("bp.pick returned non-nil error: %v", err)
}
atomic.AddUint64(&finishedCount, 1)
Expand Down
7 changes: 7 additions & 0 deletions scripts/vet.sh
Original file line number Diff line number Diff line change
Expand Up @@ -69,6 +69,13 @@ not git grep "\(import \|^\s*\)\"google.golang.org/grpc/interop/grpc_testing" --
# - Ensure all xds proto imports are renamed to *pb or *grpc.
git grep '"github.com/envoyproxy/go-control-plane/envoy' -- '*.go' ':(exclude)*.pb.go' | not grep -v 'pb "\|grpc "'

# - Ensure all context usages are done with timeout.
# Context tests under benchmark are excluded as they are testing the performance of context.Background() and context.TODO().
# TODO: Remove the exclusions once the tests are updated to use context.WithTimeout().
# See https://github.com/grpc/grpc-go/issues/7304
git grep -e 'context.Background()' --or -e 'context.TODO()' -- "*_test.go" | grep -v "benchmark/primitives/context_test.go" | grep -v "credential
s/google" | grep -v "internal/transport/" | grep -v "xds/internal/" | grep -v "security/advancedtls" | grep -v 'context.WithTimeout(' | not grep -v 'context.WithCancel('

misspell -error .

# - gofmt, goimports, go vet, go mod tidy.
Expand Down
15 changes: 12 additions & 3 deletions server_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -153,15 +153,22 @@ func (s) TestRetryChainedInterceptor(t *testing.T) {
handler := func(ctx context.Context, req any) (any, error) {
return nil, nil
}
ii(context.Background(), nil, nil, handler)

ctx, cancel := context.WithTimeout(context.Background(), defaultTestTimeout)
defer cancel()

ii(ctx, nil, nil, handler)
if !cmp.Equal(records, []int{1, 2, 3, 2, 3}) {
t.Fatalf("retry failed on chained interceptors: %v", records)
}
}

func (s) TestStreamContext(t *testing.T) {
expectedStream := &transport.Stream{}
ctx := NewContextWithServerTransportStream(context.Background(), expectedStream)
ctx, cancel := context.WithTimeout(context.Background(), defaultTestTimeout)
defer cancel()
ctx = NewContextWithServerTransportStream(ctx, expectedStream)

s := ServerTransportStreamFromContext(ctx)
stream, ok := s.(*transport.Stream)
if !ok || expectedStream != stream {
Expand All @@ -170,6 +177,8 @@ func (s) TestStreamContext(t *testing.T) {
}

func BenchmarkChainUnaryInterceptor(b *testing.B) {
ctx, cancel := context.WithTimeout(context.Background(), defaultTestTimeout)
defer cancel()
for _, n := range []int{1, 3, 5, 10} {
n := n
b.Run(strconv.Itoa(n), func(b *testing.B) {
Expand All @@ -186,7 +195,7 @@ func BenchmarkChainUnaryInterceptor(b *testing.B) {
b.ReportAllocs()
b.ResetTimer()
for i := 0; i < b.N; i++ {
if _, err := s.opts.unaryInt(context.Background(), nil, nil,
if _, err := s.opts.unaryInt(ctx, nil, nil,
func(ctx context.Context, req any) (any, error) {
return nil, nil
},
Expand Down
5 changes: 4 additions & 1 deletion stats/opentelemetry/csm/observability_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -602,6 +602,9 @@ func (s) TestXDSLabels(t *testing.T) {
// without error. The actual functionality of this function will be verified in
// interop tests.
func (s) TestObservability(t *testing.T) {
cleanup := EnableObservability(context.Background(), opentelemetry.Options{})
ctx, cancel := context.WithTimeout(context.Background(), defaultTestTimeout)
defer cancel()

cleanup := EnableObservability(ctx, opentelemetry.Options{})
cleanup()
}
Loading

0 comments on commit bdd707e

Please sign in to comment.