Skip to content

Commit

Permalink
Merge pull request kubernetes#96061 from tkashem/context-wiring
Browse files Browse the repository at this point in the history
plumb context with request deadline
  • Loading branch information
k8s-ci-robot authored Nov 16, 2020
2 parents d20e324 + 83f869e commit 59ac565
Show file tree
Hide file tree
Showing 16 changed files with 392 additions and 67 deletions.
9 changes: 8 additions & 1 deletion pkg/kubeapiserver/server/insecure_handler.go
Original file line number Diff line number Diff line change
Expand Up @@ -39,7 +39,14 @@ func BuildInsecureHandlerChain(apiHandler http.Handler, c *server.Config) http.H
handler = genericapifilters.WithAudit(handler, c.AuditBackend, c.AuditPolicyChecker, c.LongRunningFunc)
handler = genericapifilters.WithAuthentication(handler, server.InsecureSuperuser{}, nil, nil)
handler = genericfilters.WithCORS(handler, c.CorsAllowedOriginList, nil, nil, nil, "true")
handler = genericfilters.WithTimeoutForNonLongRunningRequests(handler, c.LongRunningFunc, c.RequestTimeout)

// WithTimeoutForNonLongRunningRequests will call the rest of the request handling in a go-routine with the
// context with deadline. The go-routine can keep running, while the timeout logic will return a timeout to the client.
handler = genericfilters.WithTimeoutForNonLongRunningRequests(handler, c.LongRunningFunc)

// WithRequestDeadline sets a deadline for the request context appropriately
handler = genericapifilters.WithRequestDeadline(handler, c.LongRunningFunc, c.RequestTimeout)

handler = genericfilters.WithWaitGroup(handler, c.LongRunningFunc, c.HandlerChainWaitGroup)
handler = genericapifilters.WithRequestInfo(handler, requestInfoResolver)
handler = genericapifilters.WithWarningRecorder(handler)
Expand Down
1 change: 1 addition & 0 deletions staging/src/k8s.io/apiserver/pkg/endpoints/BUILD
Original file line number Diff line number Diff line change
Expand Up @@ -53,6 +53,7 @@ go_test(
"//staging/src/k8s.io/apiserver/pkg/endpoints/testing:go_default_library",
"//staging/src/k8s.io/apiserver/pkg/features:go_default_library",
"//staging/src/k8s.io/apiserver/pkg/registry/rest:go_default_library",
"//staging/src/k8s.io/apiserver/pkg/server/filters:go_default_library",
"//staging/src/k8s.io/apiserver/pkg/util/feature:go_default_library",
"//staging/src/k8s.io/client-go/dynamic:go_default_library",
"//staging/src/k8s.io/client-go/rest:go_default_library",
Expand Down
7 changes: 7 additions & 0 deletions staging/src/k8s.io/apiserver/pkg/endpoints/apiserver_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -73,6 +73,7 @@ import (
genericapitesting "k8s.io/apiserver/pkg/endpoints/testing"
"k8s.io/apiserver/pkg/features"
"k8s.io/apiserver/pkg/registry/rest"
"k8s.io/apiserver/pkg/server/filters"
utilfeature "k8s.io/apiserver/pkg/util/feature"
featuregatetesting "k8s.io/component-base/featuregate/testing"
)
Expand Down Expand Up @@ -286,6 +287,7 @@ func handleInternal(storage map[string]rest.Storage, admissionControl admission.
// simplified long-running check
return requestInfo.Verb == "watch" || requestInfo.Verb == "proxy"
})
handler = genericapifilters.WithRequestDeadline(handler, testLongRunningCheck, 60*time.Second)
handler = genericapifilters.WithRequestInfo(handler, testRequestInfoResolver())

return &defaultAPIServer{handler, container}
Expand All @@ -298,6 +300,11 @@ func testRequestInfoResolver() *request.RequestInfoFactory {
}
}

var testLongRunningCheck = filters.BasicLongRunningRequestCheck(
sets.NewString("watch", "proxy"),
sets.NewString("attach", "exec", "proxy", "log", "portforward"),
)

func TestSimpleSetupRight(t *testing.T) {
s := &genericapitesting.Simple{ObjectMeta: metav1.ObjectMeta{Name: "aName"}}
wire, err := runtime.Encode(codec, s)
Expand Down
2 changes: 2 additions & 0 deletions staging/src/k8s.io/apiserver/pkg/endpoints/filters/BUILD
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@ go_test(
"cachecontrol_test.go",
"impersonation_test.go",
"metrics_test.go",
"request_deadline_test.go",
"request_received_time_test.go",
"requestinfo_test.go",
"warning_test.go",
Expand Down Expand Up @@ -56,6 +57,7 @@ go_library(
"doc.go",
"impersonation.go",
"metrics.go",
"request_deadline.go",
"request_received_time.go",
"requestinfo.go",
"storageversion.go",
Expand Down
105 changes: 105 additions & 0 deletions staging/src/k8s.io/apiserver/pkg/endpoints/filters/request_deadline.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,105 @@
/*
Copyright 2020 The Kubernetes Authors.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
*/

package filters

import (
"context"
"errors"
"fmt"
"k8s.io/klog/v2"
"net/http"
"time"

"k8s.io/apiserver/pkg/endpoints/request"
)

var (
// The 'timeout' query parameter in the request URL has an invalid timeout specifier
errInvalidTimeoutInURL = errors.New("invalid timeout specified in the request URL")

// The timeout specified in the request URL exceeds the global maximum timeout allowed by the apiserver.
errTimeoutExceedsMaximumAllowed = errors.New("timeout specified in the request URL exceeds the maximum timeout allowed by the server")
)

// WithRequestDeadline determines the deadline of the given request and sets a new context with the appropriate timeout.
// requestTimeoutMaximum specifies the default request timeout value
func WithRequestDeadline(handler http.Handler, longRunning request.LongRunningRequestCheck, requestTimeoutMaximum time.Duration) http.Handler {
return http.HandlerFunc(func(w http.ResponseWriter, req *http.Request) {
ctx := req.Context()

requestInfo, ok := request.RequestInfoFrom(ctx)
if !ok {
handleError(w, req, http.StatusInternalServerError, fmt.Errorf("no RequestInfo found in context, handler chain must be wrong"))
return
}
if longRunning(req, requestInfo) {
handler.ServeHTTP(w, req)
return
}

userSpecifiedTimeout, ok, err := parseTimeout(req)
if err != nil {
statusCode := http.StatusInternalServerError
if err == errInvalidTimeoutInURL {
statusCode = http.StatusBadRequest
}

handleError(w, req, statusCode, err)
return
}

timeout := requestTimeoutMaximum
if ok {
if userSpecifiedTimeout > requestTimeoutMaximum {
handleError(w, req, http.StatusBadRequest, errTimeoutExceedsMaximumAllowed)
return
}

timeout = userSpecifiedTimeout
}

ctx, cancel := context.WithTimeout(ctx, timeout)
defer cancel()

req = req.WithContext(ctx)
handler.ServeHTTP(w, req)
})
}

// parseTimeout parses the given HTTP request URL and extracts the timeout query parameter
// value if specified by the user.
// If a timeout is not specified the function returns false and err is set to nil
// If the value specified is malformed then the function returns false and err is set
func parseTimeout(req *http.Request) (time.Duration, bool, error) {
value := req.URL.Query().Get("timeout")
if value == "" {
return 0, false, nil
}

timeout, err := time.ParseDuration(value)
if err != nil {
return 0, false, errInvalidTimeoutInURL
}

return timeout, true, nil
}

func handleError(w http.ResponseWriter, r *http.Request, code int, err error) {
errorMsg := fmt.Sprintf("Error - %s: %#v", err.Error(), r.RequestURI)
http.Error(w, errorMsg, code)
klog.Errorf(err.Error())
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,192 @@
/*
Copyright 2020 The Kubernetes Authors.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
*/

package filters

import (
"fmt"
"net/http"
"net/http/httptest"
"testing"
"time"

"k8s.io/apiserver/pkg/endpoints/request"
)

func TestParseTimeout(t *testing.T) {
tests := []struct {
name string
url string
expected bool
timeoutExpected time.Duration
errExpected error
}{
{
name: "the user does not specify a timeout",
url: "/api/v1/namespaces",
},
{
name: "the user specifies a valid timeout",
url: "/api/v1/namespaces?timeout=10s",
expected: true,
timeoutExpected: 10 * time.Second,
},
{
name: "the use specifies an invalid timeout",
url: "/api/v1/namespaces?timeout=foo",
errExpected: errInvalidTimeoutInURL,
},
}

for _, test := range tests {
t.Run(test.name, func(t *testing.T) {
request, err := http.NewRequest(http.MethodGet, test.url, nil)
if err != nil {
t.Fatalf("failed to create new http request - %v", err)
}

timeoutGot, ok, err := parseTimeout(request)

if test.expected != ok {
t.Errorf("expected: %t, but got: %t", test.expected, ok)
}
if test.errExpected != err {
t.Errorf("expected err: %v, but got: %v", test.errExpected, err)
}
if test.timeoutExpected != timeoutGot {
t.Errorf("expected timeout: %s, but got: %s", test.timeoutExpected, timeoutGot)
}
})
}
}

func TestWithRequestDeadline(t *testing.T) {
const requestTimeoutMaximum = 60 * time.Second

tests := []struct {
name string
requestURL string
longRunning bool
hasDeadlineExpected bool
deadlineExpected time.Duration
handlerCallCountExpected int
statusCodeExpected int
}{
{
name: "the user specifies a valid request timeout",
requestURL: "/api/v1/namespaces?timeout=15s",
longRunning: false,
handlerCallCountExpected: 1,
hasDeadlineExpected: true,
deadlineExpected: 14 * time.Second, // to account for the delay in verification
statusCodeExpected: http.StatusOK,
},
{
name: "the user does not specify any request timeout, default deadline is expected to be set",
requestURL: "/api/v1/namespaces?timeout=",
longRunning: false,
handlerCallCountExpected: 1,
hasDeadlineExpected: true,
deadlineExpected: requestTimeoutMaximum - time.Second, // to account for the delay in verification
statusCodeExpected: http.StatusOK,
},
{
name: "the request is long running, no deadline is expected to be set",
requestURL: "/api/v1/namespaces?timeout=10s",
longRunning: true,
hasDeadlineExpected: false,
handlerCallCountExpected: 1,
statusCodeExpected: http.StatusOK,
},
{
name: "the timeout specified is malformed, the request is aborted with HTTP 400",
requestURL: "/api/v1/namespaces?timeout=foo",
longRunning: false,
statusCodeExpected: http.StatusBadRequest,
},
{
name: "the timeout specified exceeds the maximum deadline allowed, the request is aborted with HTTP 400",
requestURL: fmt.Sprintf("/api/v1/namespaces?timeout=%s", requestTimeoutMaximum+time.Second),
longRunning: false,
statusCodeExpected: http.StatusBadRequest,
},
}

for _, test := range tests {
t.Run(test.name, func(t *testing.T) {
var (
callCount int
hasDeadlineGot bool
deadlineGot time.Duration
)
handler := http.HandlerFunc(func(_ http.ResponseWriter, req *http.Request) {
callCount++
deadlineGot, hasDeadlineGot = deadline(req)
})

withDeadline := WithRequestDeadline(
handler, func(_ *http.Request, _ *request.RequestInfo) bool { return test.longRunning }, requestTimeoutMaximum)
withDeadline = WithRequestInfo(withDeadline, &fakeRequestResolver{})

testRequest, err := http.NewRequest(http.MethodGet, test.requestURL, nil)
if err != nil {
t.Fatalf("failed to create new http request - %v", err)
}

// make sure a default request does not have any deadline set
remaning, ok := deadline(testRequest)
if ok {
t.Fatalf("test setup failed, expected the new HTTP request context to have no deadline but got: %s", remaning)
}

w := httptest.NewRecorder()
withDeadline.ServeHTTP(w, testRequest)

if test.handlerCallCountExpected != callCount {
t.Errorf("expected the request handler to be invoked %d times, but was actually invoked %d times", test.handlerCallCountExpected, callCount)
}

if test.hasDeadlineExpected != hasDeadlineGot {
t.Errorf("expected the request context to have deadline set: %t but got: %t", test.hasDeadlineExpected, hasDeadlineGot)
}

deadlineGot = deadlineGot.Truncate(time.Second)
if test.deadlineExpected != deadlineGot {
t.Errorf("expected a request context with a deadline of %s but got: %s", test.deadlineExpected, deadlineGot)
}

statusCodeGot := w.Result().StatusCode
if test.statusCodeExpected != statusCodeGot {
t.Errorf("expected status code %d but got: %d", test.statusCodeExpected, statusCodeGot)
}
})
}
}

type fakeRequestResolver struct{}

func (r fakeRequestResolver) NewRequestInfo(req *http.Request) (*request.RequestInfo, error) {
return &request.RequestInfo{}, nil
}

func deadline(r *http.Request) (time.Duration, bool) {
if deadline, ok := r.Context().Deadline(); ok {
remaining := time.Until(deadline)
return remaining, ok
}

return 0, false
}
7 changes: 2 additions & 5 deletions staging/src/k8s.io/apiserver/pkg/endpoints/handlers/create.go
Original file line number Diff line number Diff line change
Expand Up @@ -57,9 +57,6 @@ func createHandler(r rest.NamedCreater, scope *RequestScope, admit admission.Int
return
}

// TODO: we either want to remove timeout or document it (if we document, move timeout out of this function and declare it in api_installer)
timeout := parseTimeout(req.URL.Query().Get("timeout"))

namespace, name, err := scope.Namer.Name(req)
if err != nil {
if includeName {
Expand All @@ -76,7 +73,7 @@ func createHandler(r rest.NamedCreater, scope *RequestScope, admit admission.Int
}
}

ctx, cancel := context.WithTimeout(req.Context(), timeout)
ctx, cancel := context.WithTimeout(req.Context(), requestTimeout)
defer cancel()
outputMediaType, _, err := negotiation.NegotiateOutputMediaType(req, scope.Serializer, scope)
if err != nil {
Expand Down Expand Up @@ -155,7 +152,7 @@ func createHandler(r rest.NamedCreater, scope *RequestScope, admit admission.Int
options,
)
}
result, err := finishRequest(timeout, func() (runtime.Object, error) {
result, err := finishRequest(ctx, func() (runtime.Object, error) {
if scope.FieldManager != nil {
liveObj, err := scope.Creater.New(scope.Kind)
if err != nil {
Expand Down
Loading

0 comments on commit 59ac565

Please sign in to comment.