From a3f5ed6931c89f28795ab4ab8e43dcce75cad3d0 Mon Sep 17 00:00:00 2001 From: Aditya Sood Date: Thu, 1 Feb 2024 03:53:27 +0530 Subject: [PATCH] interop: Replace context.Background() with passed ctx (#6827) --- interop/client/client.go | 58 +++++++------ interop/http2/negative_http2_client.go | 23 ++--- interop/observability/client/client.go | 7 +- interop/stress/client/main.go | 23 ++--- interop/test_utils.go | 114 ++++++++++++------------- interop/xds_federation/client.go | 6 +- 6 files changed, 121 insertions(+), 110 deletions(-) diff --git a/interop/client/client.go b/interop/client/client.go index 7f04be0c152e..c9245797ce39 100644 --- a/interop/client/client.go +++ b/interop/client/client.go @@ -169,6 +169,8 @@ func main() { logger.Fatalf("only one of TLS, ALTS, google default creds, or compute engine creds can be used") } + ctx := context.Background() + var credsChosen credsMode switch { case *useTLS: @@ -242,7 +244,7 @@ func main() { } opts = append(opts, grpc.WithPerRPCCredentials(jwtCreds)) } else if *testCase == "oauth2_auth_token" { - opts = append(opts, grpc.WithPerRPCCredentials(oauth.TokenSource{TokenSource: oauth2.StaticTokenSource(interop.GetToken(*serviceAccountKeyFile, *oauthScope))})) + opts = append(opts, grpc.WithPerRPCCredentials(oauth.TokenSource{TokenSource: oauth2.StaticTokenSource(interop.GetToken(ctx, *serviceAccountKeyFile, *oauthScope))})) } } if len(*serviceConfigJSON) > 0 { @@ -265,105 +267,107 @@ func main() { } defer conn.Close() tc := testgrpc.NewTestServiceClient(conn) + ctxWithDeadline, cancel := context.WithTimeout(ctx, time.Duration(*soakOverallTimeoutSeconds)*time.Second) + defer cancel() switch *testCase { case "empty_unary": - interop.DoEmptyUnaryCall(tc) + interop.DoEmptyUnaryCall(ctx, tc) logger.Infoln("EmptyUnaryCall done") case "large_unary": - interop.DoLargeUnaryCall(tc) + interop.DoLargeUnaryCall(ctx, tc) logger.Infoln("LargeUnaryCall done") case "client_streaming": - interop.DoClientStreaming(tc) + interop.DoClientStreaming(ctx, tc) logger.Infoln("ClientStreaming done") case "server_streaming": - interop.DoServerStreaming(tc) + interop.DoServerStreaming(ctx, tc) logger.Infoln("ServerStreaming done") case "ping_pong": - interop.DoPingPong(tc) + interop.DoPingPong(ctx, tc) logger.Infoln("Pingpong done") case "empty_stream": - interop.DoEmptyStream(tc) + interop.DoEmptyStream(ctx, tc) logger.Infoln("Emptystream done") case "timeout_on_sleeping_server": - interop.DoTimeoutOnSleepingServer(tc) + interop.DoTimeoutOnSleepingServer(ctx, tc) logger.Infoln("TimeoutOnSleepingServer done") case "compute_engine_creds": if credsChosen != credsTLS { logger.Fatalf("TLS credentials need to be set for compute_engine_creds test case.") } - interop.DoComputeEngineCreds(tc, *defaultServiceAccount, *oauthScope) + interop.DoComputeEngineCreds(ctx, tc, *defaultServiceAccount, *oauthScope) logger.Infoln("ComputeEngineCreds done") case "service_account_creds": if credsChosen != credsTLS { logger.Fatalf("TLS credentials need to be set for service_account_creds test case.") } - interop.DoServiceAccountCreds(tc, *serviceAccountKeyFile, *oauthScope) + interop.DoServiceAccountCreds(ctx, tc, *serviceAccountKeyFile, *oauthScope) logger.Infoln("ServiceAccountCreds done") case "jwt_token_creds": if credsChosen != credsTLS { logger.Fatalf("TLS credentials need to be set for jwt_token_creds test case.") } - interop.DoJWTTokenCreds(tc, *serviceAccountKeyFile) + interop.DoJWTTokenCreds(ctx, tc, *serviceAccountKeyFile) logger.Infoln("JWTtokenCreds done") case "per_rpc_creds": if credsChosen != credsTLS { logger.Fatalf("TLS credentials need to be set for per_rpc_creds test case.") } - interop.DoPerRPCCreds(tc, *serviceAccountKeyFile, *oauthScope) + interop.DoPerRPCCreds(ctx, tc, *serviceAccountKeyFile, *oauthScope) logger.Infoln("PerRPCCreds done") case "oauth2_auth_token": if credsChosen != credsTLS { logger.Fatalf("TLS credentials need to be set for oauth2_auth_token test case.") } - interop.DoOauth2TokenCreds(tc, *serviceAccountKeyFile, *oauthScope) + interop.DoOauth2TokenCreds(ctx, tc, *serviceAccountKeyFile, *oauthScope) logger.Infoln("Oauth2TokenCreds done") case "google_default_credentials": if credsChosen != credsGoogleDefaultCreds { logger.Fatalf("GoogleDefaultCredentials need to be set for google_default_credentials test case.") } - interop.DoGoogleDefaultCredentials(tc, *defaultServiceAccount) + interop.DoGoogleDefaultCredentials(ctx, tc, *defaultServiceAccount) logger.Infoln("GoogleDefaultCredentials done") case "compute_engine_channel_credentials": if credsChosen != credsComputeEngineCreds { logger.Fatalf("ComputeEngineCreds need to be set for compute_engine_channel_credentials test case.") } - interop.DoComputeEngineChannelCredentials(tc, *defaultServiceAccount) + interop.DoComputeEngineChannelCredentials(ctx, tc, *defaultServiceAccount) logger.Infoln("ComputeEngineChannelCredentials done") case "cancel_after_begin": - interop.DoCancelAfterBegin(tc) + interop.DoCancelAfterBegin(ctx, tc) logger.Infoln("CancelAfterBegin done") case "cancel_after_first_response": - interop.DoCancelAfterFirstResponse(tc) + interop.DoCancelAfterFirstResponse(ctx, tc) logger.Infoln("CancelAfterFirstResponse done") case "status_code_and_message": - interop.DoStatusCodeAndMessage(tc) + interop.DoStatusCodeAndMessage(ctx, tc) logger.Infoln("StatusCodeAndMessage done") case "special_status_message": - interop.DoSpecialStatusMessage(tc) + interop.DoSpecialStatusMessage(ctx, tc) logger.Infoln("SpecialStatusMessage done") case "custom_metadata": - interop.DoCustomMetadata(tc) + interop.DoCustomMetadata(ctx, tc) logger.Infoln("CustomMetadata done") case "unimplemented_method": - interop.DoUnimplementedMethod(conn) + interop.DoUnimplementedMethod(conn, ctx) logger.Infoln("UnimplementedMethod done") case "unimplemented_service": - interop.DoUnimplementedService(testgrpc.NewUnimplementedServiceClient(conn)) + interop.DoUnimplementedService(testgrpc.NewUnimplementedServiceClient(conn), ctx) logger.Infoln("UnimplementedService done") case "pick_first_unary": - interop.DoPickFirstUnary(tc) + interop.DoPickFirstUnary(ctx, tc) logger.Infoln("PickFirstUnary done") case "rpc_soak": - interop.DoSoakTest(tc, serverAddr, opts, false /* resetChannel */, *soakIterations, *soakMaxFailures, *soakRequestSize, *soakResponseSize, time.Duration(*soakPerIterationMaxAcceptableLatencyMs)*time.Millisecond, time.Duration(*soakMinTimeMsBetweenRPCs)*time.Millisecond, time.Now().Add(time.Duration(*soakOverallTimeoutSeconds)*time.Second)) + interop.DoSoakTest(ctxWithDeadline, tc, serverAddr, opts, false /* resetChannel */, *soakIterations, *soakMaxFailures, *soakRequestSize, *soakResponseSize, time.Duration(*soakPerIterationMaxAcceptableLatencyMs)*time.Millisecond, time.Duration(*soakMinTimeMsBetweenRPCs)*time.Millisecond) logger.Infoln("RpcSoak done") case "channel_soak": - interop.DoSoakTest(tc, serverAddr, opts, true /* resetChannel */, *soakIterations, *soakMaxFailures, *soakRequestSize, *soakResponseSize, time.Duration(*soakPerIterationMaxAcceptableLatencyMs)*time.Millisecond, time.Duration(*soakMinTimeMsBetweenRPCs)*time.Millisecond, time.Now().Add(time.Duration(*soakOverallTimeoutSeconds)*time.Second)) + interop.DoSoakTest(ctxWithDeadline, tc, serverAddr, opts, true /* resetChannel */, *soakIterations, *soakMaxFailures, *soakRequestSize, *soakResponseSize, time.Duration(*soakPerIterationMaxAcceptableLatencyMs)*time.Millisecond, time.Duration(*soakMinTimeMsBetweenRPCs)*time.Millisecond) logger.Infoln("ChannelSoak done") case "orca_per_rpc": - interop.DoORCAPerRPCTest(tc) + interop.DoORCAPerRPCTest(ctx, tc) logger.Infoln("ORCAPerRPC done") case "orca_oob": - interop.DoORCAOOBTest(tc) + interop.DoORCAOOBTest(ctx, tc) logger.Infoln("ORCAOOB done") default: logger.Fatal("Unsupported test case: ", *testCase) diff --git a/interop/http2/negative_http2_client.go b/interop/http2/negative_http2_client.go index b8c1d522009e..771845309c29 100644 --- a/interop/http2/negative_http2_client.go +++ b/interop/http2/negative_http2_client.go @@ -69,12 +69,12 @@ func largeSimpleRequest() *testpb.SimpleRequest { } // sends two unary calls. The server asserts that the calls use different connections. -func goaway(tc testgrpc.TestServiceClient) { - interop.DoLargeUnaryCall(tc) +func goaway(ctx context.Context, tc testgrpc.TestServiceClient) { + interop.DoLargeUnaryCall(ctx, tc) // sleep to ensure that the client has time to recv the GOAWAY. // TODO(ncteisen): make this less hacky. time.Sleep(1 * time.Second) - interop.DoLargeUnaryCall(tc) + interop.DoLargeUnaryCall(ctx, tc) } func rstAfterHeader(tc testgrpc.TestServiceClient) { @@ -110,19 +110,19 @@ func rstAfterData(tc testgrpc.TestServiceClient) { } } -func ping(tc testgrpc.TestServiceClient) { +func ping(ctx context.Context, tc testgrpc.TestServiceClient) { // The server will assert that every ping it sends was ACK-ed by the client. - interop.DoLargeUnaryCall(tc) + interop.DoLargeUnaryCall(ctx, tc) } -func maxStreams(tc testgrpc.TestServiceClient) { - interop.DoLargeUnaryCall(tc) +func maxStreams(ctx context.Context, tc testgrpc.TestServiceClient) { + interop.DoLargeUnaryCall(ctx, tc) var wg sync.WaitGroup for i := 0; i < 15; i++ { wg.Add(1) go func() { defer wg.Done() - interop.DoLargeUnaryCall(tc) + interop.DoLargeUnaryCall(ctx, tc) }() } wg.Wait() @@ -139,9 +139,10 @@ func main() { } defer conn.Close() tc := testgrpc.NewTestServiceClient(conn) + ctx := context.Background() switch *testCase { case "goaway": - goaway(tc) + goaway(ctx, tc) logger.Infoln("goaway done") case "rst_after_header": rstAfterHeader(tc) @@ -153,10 +154,10 @@ func main() { rstAfterData(tc) logger.Infoln("rst_after_data done") case "ping": - ping(tc) + ping(ctx, tc) logger.Infoln("ping done") case "max_streams": - maxStreams(tc) + maxStreams(ctx, tc) logger.Infoln("max_streams done") default: logger.Fatal("Unsupported test case: ", *testCase) diff --git a/interop/observability/client/client.go b/interop/observability/client/client.go index d8cf72fa76c9..5478efdbe035 100644 --- a/interop/observability/client/client.go +++ b/interop/observability/client/client.go @@ -58,13 +58,14 @@ func main() { } defer conn.Close() tc := testgrpc.NewTestServiceClient(conn) + ctx := context.Background() for i := 0; i < *numTimes; i++ { if *testCase == "ping_pong" { - interop.DoPingPong(tc) + interop.DoPingPong(ctx, tc) } else if *testCase == "large_unary" { - interop.DoLargeUnaryCall(tc) + interop.DoLargeUnaryCall(ctx, tc) } else if *testCase == "custom_metadata" { - interop.DoCustomMetadata(tc) + interop.DoCustomMetadata(ctx, tc) } else { log.Fatalf("Invalid test case: %s", *testCase) } diff --git a/interop/stress/client/main.go b/interop/stress/client/main.go index 0055c561c557..9467ed67285c 100644 --- a/interop/stress/client/main.go +++ b/interop/stress/client/main.go @@ -226,32 +226,33 @@ func startServer(server *server, port int) { func performRPCs(gauge *gauge, conn *grpc.ClientConn, selector *weightedRandomTestSelector, stop <-chan bool) { client := testgrpc.NewTestServiceClient(conn) var numCalls int64 + ctx := context.Background() startTime := time.Now() for { test := selector.getNextTest() switch test { case "empty_unary": - interop.DoEmptyUnaryCall(client) + interop.DoEmptyUnaryCall(ctx, client) case "large_unary": - interop.DoLargeUnaryCall(client) + interop.DoLargeUnaryCall(ctx, client) case "client_streaming": - interop.DoClientStreaming(client) + interop.DoClientStreaming(ctx, client) case "server_streaming": - interop.DoServerStreaming(client) + interop.DoServerStreaming(ctx, client) case "ping_pong": - interop.DoPingPong(client) + interop.DoPingPong(ctx, client) case "empty_stream": - interop.DoEmptyStream(client) + interop.DoEmptyStream(ctx, client) case "timeout_on_sleeping_server": - interop.DoTimeoutOnSleepingServer(client) + interop.DoTimeoutOnSleepingServer(ctx, client) case "cancel_after_begin": - interop.DoCancelAfterBegin(client) + interop.DoCancelAfterBegin(ctx, client) case "cancel_after_first_response": - interop.DoCancelAfterFirstResponse(client) + interop.DoCancelAfterFirstResponse(ctx, client) case "status_code_and_message": - interop.DoStatusCodeAndMessage(client) + interop.DoStatusCodeAndMessage(ctx, client) case "custom_metadata": - interop.DoCustomMetadata(client) + interop.DoCustomMetadata(ctx, client) } numCalls++ defer func() { atomic.AddInt64(&totalNumCalls, numCalls) }() diff --git a/interop/test_utils.go b/interop/test_utils.go index 83e656832e28..f075f0753bbd 100644 --- a/interop/test_utils.go +++ b/interop/test_utils.go @@ -79,8 +79,8 @@ func ClientNewPayload(t testpb.PayloadType, size int) *testpb.Payload { } // DoEmptyUnaryCall performs a unary RPC with empty request and response messages. -func DoEmptyUnaryCall(tc testgrpc.TestServiceClient, args ...grpc.CallOption) { - reply, err := tc.EmptyCall(context.Background(), &testpb.Empty{}, args...) +func DoEmptyUnaryCall(ctx context.Context, tc testgrpc.TestServiceClient, args ...grpc.CallOption) { + reply, err := tc.EmptyCall(ctx, &testpb.Empty{}, args...) if err != nil { logger.Fatal("/TestService/EmptyCall RPC failed: ", err) } @@ -90,14 +90,14 @@ func DoEmptyUnaryCall(tc testgrpc.TestServiceClient, args ...grpc.CallOption) { } // DoLargeUnaryCall performs a unary RPC with large payload in the request and response. -func DoLargeUnaryCall(tc testgrpc.TestServiceClient, args ...grpc.CallOption) { +func DoLargeUnaryCall(ctx context.Context, tc testgrpc.TestServiceClient, args ...grpc.CallOption) { pl := ClientNewPayload(testpb.PayloadType_COMPRESSABLE, largeReqSize) req := &testpb.SimpleRequest{ ResponseType: testpb.PayloadType_COMPRESSABLE, ResponseSize: int32(largeRespSize), Payload: pl, } - reply, err := tc.UnaryCall(context.Background(), req, args...) + reply, err := tc.UnaryCall(ctx, req, args...) if err != nil { logger.Fatal("/TestService/UnaryCall RPC failed: ", err) } @@ -109,8 +109,8 @@ func DoLargeUnaryCall(tc testgrpc.TestServiceClient, args ...grpc.CallOption) { } // DoClientStreaming performs a client streaming RPC. -func DoClientStreaming(tc testgrpc.TestServiceClient, args ...grpc.CallOption) { - stream, err := tc.StreamingInputCall(context.Background(), args...) +func DoClientStreaming(ctx context.Context, tc testgrpc.TestServiceClient, args ...grpc.CallOption) { + stream, err := tc.StreamingInputCall(ctx, args...) if err != nil { logger.Fatalf("%v.StreamingInputCall(_) = _, %v", tc, err) } @@ -135,7 +135,7 @@ func DoClientStreaming(tc testgrpc.TestServiceClient, args ...grpc.CallOption) { } // DoServerStreaming performs a server streaming RPC. -func DoServerStreaming(tc testgrpc.TestServiceClient, args ...grpc.CallOption) { +func DoServerStreaming(ctx context.Context, tc testgrpc.TestServiceClient, args ...grpc.CallOption) { respParam := make([]*testpb.ResponseParameters, len(respSizes)) for i, s := range respSizes { respParam[i] = &testpb.ResponseParameters{ @@ -146,7 +146,7 @@ func DoServerStreaming(tc testgrpc.TestServiceClient, args ...grpc.CallOption) { ResponseType: testpb.PayloadType_COMPRESSABLE, ResponseParameters: respParam, } - stream, err := tc.StreamingOutputCall(context.Background(), req, args...) + stream, err := tc.StreamingOutputCall(ctx, req, args...) if err != nil { logger.Fatalf("%v.StreamingOutputCall(_) = _, %v", tc, err) } @@ -179,8 +179,8 @@ func DoServerStreaming(tc testgrpc.TestServiceClient, args ...grpc.CallOption) { } // DoPingPong performs ping-pong style bi-directional streaming RPC. -func DoPingPong(tc testgrpc.TestServiceClient, args ...grpc.CallOption) { - stream, err := tc.FullDuplexCall(context.Background(), args...) +func DoPingPong(ctx context.Context, tc testgrpc.TestServiceClient, args ...grpc.CallOption) { + stream, err := tc.FullDuplexCall(ctx, args...) if err != nil { logger.Fatalf("%v.FullDuplexCall(_) = _, %v", tc, err) } @@ -223,8 +223,8 @@ func DoPingPong(tc testgrpc.TestServiceClient, args ...grpc.CallOption) { } // DoEmptyStream sets up a bi-directional streaming with zero message. -func DoEmptyStream(tc testgrpc.TestServiceClient, args ...grpc.CallOption) { - stream, err := tc.FullDuplexCall(context.Background(), args...) +func DoEmptyStream(ctx context.Context, tc testgrpc.TestServiceClient, args ...grpc.CallOption) { + stream, err := tc.FullDuplexCall(ctx, args...) if err != nil { logger.Fatalf("%v.FullDuplexCall(_) = _, %v", tc, err) } @@ -237,8 +237,8 @@ func DoEmptyStream(tc testgrpc.TestServiceClient, args ...grpc.CallOption) { } // DoTimeoutOnSleepingServer performs an RPC on a sleep server which causes RPC timeout. -func DoTimeoutOnSleepingServer(tc testgrpc.TestServiceClient, args ...grpc.CallOption) { - ctx, cancel := context.WithTimeout(context.Background(), 1*time.Millisecond) +func DoTimeoutOnSleepingServer(ctx context.Context, tc testgrpc.TestServiceClient, args ...grpc.CallOption) { + ctx, cancel := context.WithTimeout(ctx, 1*time.Millisecond) defer cancel() stream, err := tc.FullDuplexCall(ctx, args...) if err != nil { @@ -261,7 +261,7 @@ func DoTimeoutOnSleepingServer(tc testgrpc.TestServiceClient, args ...grpc.CallO } // DoComputeEngineCreds performs a unary RPC with compute engine auth. -func DoComputeEngineCreds(tc testgrpc.TestServiceClient, serviceAccount, oauthScope string) { +func DoComputeEngineCreds(ctx context.Context, tc testgrpc.TestServiceClient, serviceAccount, oauthScope string) { pl := ClientNewPayload(testpb.PayloadType_COMPRESSABLE, largeReqSize) req := &testpb.SimpleRequest{ ResponseType: testpb.PayloadType_COMPRESSABLE, @@ -270,7 +270,7 @@ func DoComputeEngineCreds(tc testgrpc.TestServiceClient, serviceAccount, oauthSc FillUsername: true, FillOauthScope: true, } - reply, err := tc.UnaryCall(context.Background(), req) + reply, err := tc.UnaryCall(ctx, req) if err != nil { logger.Fatal("/TestService/UnaryCall RPC failed: ", err) } @@ -293,7 +293,7 @@ func getServiceAccountJSONKey(keyFile string) []byte { } // DoServiceAccountCreds performs a unary RPC with service account auth. -func DoServiceAccountCreds(tc testgrpc.TestServiceClient, serviceAccountKeyFile, oauthScope string) { +func DoServiceAccountCreds(ctx context.Context, tc testgrpc.TestServiceClient, serviceAccountKeyFile, oauthScope string) { pl := ClientNewPayload(testpb.PayloadType_COMPRESSABLE, largeReqSize) req := &testpb.SimpleRequest{ ResponseType: testpb.PayloadType_COMPRESSABLE, @@ -302,7 +302,7 @@ func DoServiceAccountCreds(tc testgrpc.TestServiceClient, serviceAccountKeyFile, FillUsername: true, FillOauthScope: true, } - reply, err := tc.UnaryCall(context.Background(), req) + reply, err := tc.UnaryCall(ctx, req) if err != nil { logger.Fatal("/TestService/UnaryCall RPC failed: ", err) } @@ -318,7 +318,7 @@ func DoServiceAccountCreds(tc testgrpc.TestServiceClient, serviceAccountKeyFile, } // DoJWTTokenCreds performs a unary RPC with JWT token auth. -func DoJWTTokenCreds(tc testgrpc.TestServiceClient, serviceAccountKeyFile string) { +func DoJWTTokenCreds(ctx context.Context, tc testgrpc.TestServiceClient, serviceAccountKeyFile string) { pl := ClientNewPayload(testpb.PayloadType_COMPRESSABLE, largeReqSize) req := &testpb.SimpleRequest{ ResponseType: testpb.PayloadType_COMPRESSABLE, @@ -326,7 +326,7 @@ func DoJWTTokenCreds(tc testgrpc.TestServiceClient, serviceAccountKeyFile string Payload: pl, FillUsername: true, } - reply, err := tc.UnaryCall(context.Background(), req) + reply, err := tc.UnaryCall(ctx, req) if err != nil { logger.Fatal("/TestService/UnaryCall RPC failed: ", err) } @@ -338,13 +338,13 @@ func DoJWTTokenCreds(tc testgrpc.TestServiceClient, serviceAccountKeyFile string } // GetToken obtains an OAUTH token from the input. -func GetToken(serviceAccountKeyFile string, oauthScope string) *oauth2.Token { +func GetToken(ctx context.Context, serviceAccountKeyFile string, oauthScope string) *oauth2.Token { jsonKey := getServiceAccountJSONKey(serviceAccountKeyFile) config, err := google.JWTConfigFromJSON(jsonKey, oauthScope) if err != nil { logger.Fatalf("Failed to get the config: %v", err) } - token, err := config.TokenSource(context.Background()).Token() + token, err := config.TokenSource(ctx).Token() if err != nil { logger.Fatalf("Failed to get the token: %v", err) } @@ -352,7 +352,7 @@ func GetToken(serviceAccountKeyFile string, oauthScope string) *oauth2.Token { } // DoOauth2TokenCreds performs a unary RPC with OAUTH2 token auth. -func DoOauth2TokenCreds(tc testgrpc.TestServiceClient, serviceAccountKeyFile, oauthScope string) { +func DoOauth2TokenCreds(ctx context.Context, tc testgrpc.TestServiceClient, serviceAccountKeyFile, oauthScope string) { pl := ClientNewPayload(testpb.PayloadType_COMPRESSABLE, largeReqSize) req := &testpb.SimpleRequest{ ResponseType: testpb.PayloadType_COMPRESSABLE, @@ -361,7 +361,7 @@ func DoOauth2TokenCreds(tc testgrpc.TestServiceClient, serviceAccountKeyFile, oa FillUsername: true, FillOauthScope: true, } - reply, err := tc.UnaryCall(context.Background(), req) + reply, err := tc.UnaryCall(ctx, req) if err != nil { logger.Fatal("/TestService/UnaryCall RPC failed: ", err) } @@ -377,7 +377,7 @@ func DoOauth2TokenCreds(tc testgrpc.TestServiceClient, serviceAccountKeyFile, oa } // DoPerRPCCreds performs a unary RPC with per RPC OAUTH2 token. -func DoPerRPCCreds(tc testgrpc.TestServiceClient, serviceAccountKeyFile, oauthScope string) { +func DoPerRPCCreds(ctx context.Context, tc testgrpc.TestServiceClient, serviceAccountKeyFile, oauthScope string) { jsonKey := getServiceAccountJSONKey(serviceAccountKeyFile) pl := ClientNewPayload(testpb.PayloadType_COMPRESSABLE, largeReqSize) req := &testpb.SimpleRequest{ @@ -387,9 +387,9 @@ func DoPerRPCCreds(tc testgrpc.TestServiceClient, serviceAccountKeyFile, oauthSc FillUsername: true, FillOauthScope: true, } - token := GetToken(serviceAccountKeyFile, oauthScope) + token := GetToken(ctx, serviceAccountKeyFile, oauthScope) kv := map[string]string{"authorization": token.Type() + " " + token.AccessToken} - ctx := metadata.NewOutgoingContext(context.Background(), metadata.MD{"authorization": []string{kv["authorization"]}}) + ctx = metadata.NewOutgoingContext(ctx, metadata.MD{"authorization": []string{kv["authorization"]}}) reply, err := tc.UnaryCall(ctx, req) if err != nil { logger.Fatal("/TestService/UnaryCall RPC failed: ", err) @@ -405,7 +405,7 @@ func DoPerRPCCreds(tc testgrpc.TestServiceClient, serviceAccountKeyFile, oauthSc } // DoGoogleDefaultCredentials performs an unary RPC with google default credentials -func DoGoogleDefaultCredentials(tc testgrpc.TestServiceClient, defaultServiceAccount string) { +func DoGoogleDefaultCredentials(ctx context.Context, tc testgrpc.TestServiceClient, defaultServiceAccount string) { pl := ClientNewPayload(testpb.PayloadType_COMPRESSABLE, largeReqSize) req := &testpb.SimpleRequest{ ResponseType: testpb.PayloadType_COMPRESSABLE, @@ -414,7 +414,7 @@ func DoGoogleDefaultCredentials(tc testgrpc.TestServiceClient, defaultServiceAcc FillUsername: true, FillOauthScope: true, } - reply, err := tc.UnaryCall(context.Background(), req) + reply, err := tc.UnaryCall(ctx, req) if err != nil { logger.Fatal("/TestService/UnaryCall RPC failed: ", err) } @@ -424,7 +424,7 @@ func DoGoogleDefaultCredentials(tc testgrpc.TestServiceClient, defaultServiceAcc } // DoComputeEngineChannelCredentials performs an unary RPC with compute engine channel credentials -func DoComputeEngineChannelCredentials(tc testgrpc.TestServiceClient, defaultServiceAccount string) { +func DoComputeEngineChannelCredentials(ctx context.Context, tc testgrpc.TestServiceClient, defaultServiceAccount string) { pl := ClientNewPayload(testpb.PayloadType_COMPRESSABLE, largeReqSize) req := &testpb.SimpleRequest{ ResponseType: testpb.PayloadType_COMPRESSABLE, @@ -433,7 +433,7 @@ func DoComputeEngineChannelCredentials(tc testgrpc.TestServiceClient, defaultSer FillUsername: true, FillOauthScope: true, } - reply, err := tc.UnaryCall(context.Background(), req) + reply, err := tc.UnaryCall(ctx, req) if err != nil { logger.Fatal("/TestService/UnaryCall RPC failed: ", err) } @@ -448,8 +448,8 @@ var testMetadata = metadata.MD{ } // DoCancelAfterBegin cancels the RPC after metadata has been sent but before payloads are sent. -func DoCancelAfterBegin(tc testgrpc.TestServiceClient, args ...grpc.CallOption) { - ctx, cancel := context.WithCancel(metadata.NewOutgoingContext(context.Background(), testMetadata)) +func DoCancelAfterBegin(ctx context.Context, tc testgrpc.TestServiceClient, args ...grpc.CallOption) { + ctx, cancel := context.WithCancel(metadata.NewOutgoingContext(ctx, testMetadata)) stream, err := tc.StreamingInputCall(ctx, args...) if err != nil { logger.Fatalf("%v.StreamingInputCall(_) = _, %v", tc, err) @@ -462,8 +462,8 @@ func DoCancelAfterBegin(tc testgrpc.TestServiceClient, args ...grpc.CallOption) } // DoCancelAfterFirstResponse cancels the RPC after receiving the first message from the server. -func DoCancelAfterFirstResponse(tc testgrpc.TestServiceClient, args ...grpc.CallOption) { - ctx, cancel := context.WithCancel(context.Background()) +func DoCancelAfterFirstResponse(ctx context.Context, tc testgrpc.TestServiceClient, args ...grpc.CallOption) { + ctx, cancel := context.WithCancel(ctx) stream, err := tc.FullDuplexCall(ctx, args...) if err != nil { logger.Fatalf("%v.FullDuplexCall(_) = _, %v", tc, err) @@ -516,7 +516,7 @@ func validateMetadata(header, trailer metadata.MD) { } // DoCustomMetadata checks that metadata is echoed back to the client. -func DoCustomMetadata(tc testgrpc.TestServiceClient, args ...grpc.CallOption) { +func DoCustomMetadata(ctx context.Context, tc testgrpc.TestServiceClient, args ...grpc.CallOption) { // Testing with UnaryCall. pl := ClientNewPayload(testpb.PayloadType_COMPRESSABLE, 1) req := &testpb.SimpleRequest{ @@ -524,7 +524,7 @@ func DoCustomMetadata(tc testgrpc.TestServiceClient, args ...grpc.CallOption) { ResponseSize: int32(1), Payload: pl, } - ctx := metadata.NewOutgoingContext(context.Background(), customMetadata) + ctx = metadata.NewOutgoingContext(ctx, customMetadata) var header, trailer metadata.MD args = append(args, grpc.Header(&header), grpc.Trailer(&trailer)) reply, err := tc.UnaryCall( @@ -578,7 +578,7 @@ func DoCustomMetadata(tc testgrpc.TestServiceClient, args ...grpc.CallOption) { } // DoStatusCodeAndMessage checks that the status code is propagated back to the client. -func DoStatusCodeAndMessage(tc testgrpc.TestServiceClient, args ...grpc.CallOption) { +func DoStatusCodeAndMessage(ctx context.Context, tc testgrpc.TestServiceClient, args ...grpc.CallOption) { var code int32 = 2 msg := "test status message" expectedErr := status.Error(codes.Code(code), msg) @@ -590,11 +590,11 @@ func DoStatusCodeAndMessage(tc testgrpc.TestServiceClient, args ...grpc.CallOpti req := &testpb.SimpleRequest{ ResponseStatus: respStatus, } - if _, err := tc.UnaryCall(context.Background(), req, args...); err == nil || err.Error() != expectedErr.Error() { + if _, err := tc.UnaryCall(ctx, req, args...); err == nil || err.Error() != expectedErr.Error() { logger.Fatalf("%v.UnaryCall(_, %v) = _, %v, want _, %v", tc, req, err, expectedErr) } // Test FullDuplexCall. - stream, err := tc.FullDuplexCall(context.Background(), args...) + stream, err := tc.FullDuplexCall(ctx, args...) if err != nil { logger.Fatalf("%v.FullDuplexCall(_) = _, %v, want ", tc, err) } @@ -614,7 +614,7 @@ func DoStatusCodeAndMessage(tc testgrpc.TestServiceClient, args ...grpc.CallOpti // DoSpecialStatusMessage verifies Unicode and whitespace is correctly processed // in status message. -func DoSpecialStatusMessage(tc testgrpc.TestServiceClient, args ...grpc.CallOption) { +func DoSpecialStatusMessage(ctx context.Context, tc testgrpc.TestServiceClient, args ...grpc.CallOption) { const ( code int32 = 2 msg string = "\t\ntest with whitespace\r\nand Unicode BMP ☺ and non-BMP 😈\t\n" @@ -626,7 +626,7 @@ func DoSpecialStatusMessage(tc testgrpc.TestServiceClient, args ...grpc.CallOpti Message: msg, }, } - ctx, cancel := context.WithTimeout(context.Background(), 10*time.Second) + ctx, cancel := context.WithTimeout(ctx, 10*time.Second) defer cancel() if _, err := tc.UnaryCall(ctx, req, args...); err == nil || err.Error() != expectedErr.Error() { logger.Fatalf("%v.UnaryCall(_, %v) = _, %v, want _, %v", tc, req, err, expectedErr) @@ -634,24 +634,24 @@ func DoSpecialStatusMessage(tc testgrpc.TestServiceClient, args ...grpc.CallOpti } // DoUnimplementedService attempts to call a method from an unimplemented service. -func DoUnimplementedService(tc testgrpc.UnimplementedServiceClient) { - _, err := tc.UnimplementedCall(context.Background(), &testpb.Empty{}) +func DoUnimplementedService(tc testgrpc.UnimplementedServiceClient, ctx context.Context) { + _, err := tc.UnimplementedCall(ctx, &testpb.Empty{}) if status.Code(err) != codes.Unimplemented { logger.Fatalf("%v.UnimplementedCall() = _, %v, want _, %v", tc, status.Code(err), codes.Unimplemented) } } // DoUnimplementedMethod attempts to call an unimplemented method. -func DoUnimplementedMethod(cc *grpc.ClientConn) { +func DoUnimplementedMethod(cc *grpc.ClientConn, ctx context.Context) { var req, reply proto.Message - if err := cc.Invoke(context.Background(), "/grpc.testing.TestService/UnimplementedCall", req, reply); err == nil || status.Code(err) != codes.Unimplemented { + if err := cc.Invoke(ctx, "/grpc.testing.TestService/UnimplementedCall", req, reply); err == nil || status.Code(err) != codes.Unimplemented { logger.Fatalf("ClientConn.Invoke(_, _, _, _, _) = %v, want error code %s", err, codes.Unimplemented) } } // DoPickFirstUnary runs multiple RPCs (rpcCount) and checks that all requests // are sent to the same backend. -func DoPickFirstUnary(tc testgrpc.TestServiceClient) { +func DoPickFirstUnary(ctx context.Context, tc testgrpc.TestServiceClient) { const rpcCount = 100 pl := ClientNewPayload(testpb.PayloadType_COMPRESSABLE, 1) @@ -662,7 +662,7 @@ func DoPickFirstUnary(tc testgrpc.TestServiceClient) { FillServerId: true, } // TODO(mohanli): Revert the timeout back to 10s once TD migrates to xdstp. - ctx, cancel := context.WithTimeout(context.Background(), 30*time.Second) + ctx, cancel := context.WithTimeout(ctx, 30*time.Second) defer cancel() var serverID string for i := 0; i < rpcCount; i++ { @@ -724,10 +724,9 @@ func doOneSoakIteration(ctx context.Context, tc testgrpc.TestServiceClient, rese // If resetChannel is false, then each RPC will be performed on tc. Otherwise, each RPC will be performed on a new // stub that is created with the provided server address and dial options. // TODO(mohanli-ml): Create SoakTestOptions as a parameter for this method. -func DoSoakTest(tc testgrpc.TestServiceClient, serverAddr string, dopts []grpc.DialOption, resetChannel bool, soakIterations int, maxFailures int, soakRequestSize int, soakResponseSize int, perIterationMaxAcceptableLatency time.Duration, minTimeBetweenRPCs time.Duration, overallDeadline time.Time) { +func DoSoakTest(ctx context.Context, tc testgrpc.TestServiceClient, serverAddr string, dopts []grpc.DialOption, resetChannel bool, soakIterations int, maxFailures int, soakRequestSize int, soakResponseSize int, perIterationMaxAcceptableLatency time.Duration, minTimeBetweenRPCs time.Duration) { start := time.Now() - ctx, cancel := context.WithDeadline(context.Background(), overallDeadline) - defer cancel() + var elapsedTime float64 iterationsDone := 0 totalFailures := 0 hopts := stats.HistogramOptions{ @@ -738,7 +737,8 @@ func DoSoakTest(tc testgrpc.TestServiceClient, serverAddr string, dopts []grpc.D } h := stats.NewHistogram(hopts) for i := 0; i < soakIterations; i++ { - if time.Now().After(overallDeadline) { + if ctx.Err() != nil { + elapsedTime = time.Since(start).Seconds() break } earliestNextStart := time.After(minTimeBetweenRPCs) @@ -771,7 +771,7 @@ func DoSoakTest(tc testgrpc.TestServiceClient, serverAddr string, dopts []grpc.D fmt.Fprintf(os.Stderr, "(server_uri: %s) histogram of per-iteration latencies in milliseconds: %s\n", serverAddr, b.String()) fmt.Fprintf(os.Stderr, "(server_uri: %s) soak test ran: %d / %d iterations. total failures: %d. max failures threshold: %d. See breakdown above for which iterations succeeded, failed, and why for more info.\n", serverAddr, iterationsDone, soakIterations, totalFailures, maxFailures) if iterationsDone < soakIterations { - logger.Fatalf("(server_uri: %s) soak test consumed all %f seconds of time and quit early, only having ran %d out of desired %d iterations.", serverAddr, overallDeadline.Sub(start).Seconds(), iterationsDone, soakIterations) + logger.Fatalf("(server_uri: %s) soak test consumed all %f seconds of time and quit early, only having ran %d out of desired %d iterations.", serverAddr, elapsedTime, iterationsDone, soakIterations) } if totalFailures > maxFailures { logger.Fatalf("(server_uri: %s) soak test total failures: %d exceeds max failures threshold: %d.", serverAddr, totalFailures, maxFailures) @@ -989,8 +989,8 @@ func (s *testServer) HalfDuplexCall(stream testgrpc.TestService_HalfDuplexCallSe // DoORCAPerRPCTest performs a unary RPC that enables ORCA per-call reporting // and verifies the load report sent back to the LB policy's Done callback. -func DoORCAPerRPCTest(tc testgrpc.TestServiceClient) { - ctx, cancel := context.WithTimeout(context.Background(), 30*time.Second) +func DoORCAPerRPCTest(ctx context.Context, tc testgrpc.TestServiceClient) { + ctx, cancel := context.WithTimeout(ctx, 30*time.Second) defer cancel() orcaRes := &v3orcapb.OrcaLoadReport{} _, err := tc.UnaryCall(contextWithORCAResult(ctx, &orcaRes), &testpb.SimpleRequest{ @@ -1017,8 +1017,8 @@ func DoORCAPerRPCTest(tc testgrpc.TestServiceClient) { // DoORCAOOBTest performs a streaming RPC that enables ORCA OOB reporting and // verifies the load report sent to the LB policy's OOB listener. -func DoORCAOOBTest(tc testgrpc.TestServiceClient) { - ctx, cancel := context.WithTimeout(context.Background(), 30*time.Second) +func DoORCAOOBTest(ctx context.Context, tc testgrpc.TestServiceClient) { + ctx, cancel := context.WithTimeout(ctx, 30*time.Second) defer cancel() stream, err := tc.FullDuplexCall(ctx) if err != nil { diff --git a/interop/xds_federation/client.go b/interop/xds_federation/client.go index eee5ba747af3..56572e4a35c3 100644 --- a/interop/xds_federation/client.go +++ b/interop/xds_federation/client.go @@ -20,6 +20,7 @@ package main import ( + "context" "flag" "strings" "sync" @@ -115,10 +116,13 @@ func main() { // run soak tests with the different clients logger.Infof("Clients running with test case %q", *testCase) var wg sync.WaitGroup + ctx := context.Background() for i := range clients { wg.Add(1) go func(c clientConfig) { - interop.DoSoakTest(c.tc, c.uri, c.opts, resetChannel, *soakIterations, *soakMaxFailures, *soakRequestSize, *soakResponseSize, time.Duration(*soakPerIterationMaxAcceptableLatencyMs)*time.Millisecond, time.Duration(*soakMinTimeMsBetweenRPCs)*time.Millisecond, time.Now().Add(time.Duration(*soakOverallTimeoutSeconds)*time.Second)) + ctxWithDeadline, cancel := context.WithTimeout(ctx, time.Duration(*soakOverallTimeoutSeconds)*time.Second) + defer cancel() + interop.DoSoakTest(ctxWithDeadline, c.tc, c.uri, c.opts, resetChannel, *soakIterations, *soakMaxFailures, *soakRequestSize, *soakResponseSize, time.Duration(*soakPerIterationMaxAcceptableLatencyMs)*time.Millisecond, time.Duration(*soakMinTimeMsBetweenRPCs)*time.Millisecond) logger.Infof("%s test done for server: %s", *testCase, c.uri) wg.Done() }(clients[i])