Skip to content

Commit

Permalink
Upgrade websocket failure add extra error info
Browse files Browse the repository at this point in the history
  • Loading branch information
seans3 committed Oct 4, 2024
1 parent fccbbf3 commit 491611a
Show file tree
Hide file tree
Showing 2 changed files with 125 additions and 28 deletions.
44 changes: 41 additions & 3 deletions staging/src/k8s.io/client-go/transport/websocket/roundtripper.go
Original file line number Diff line number Diff line change
Expand Up @@ -20,11 +20,17 @@ import (
"crypto/tls"
"errors"
"fmt"
"io"
"net/http"
"net/url"
"strings"

gwebsocket "github.com/gorilla/websocket"

apierrors "k8s.io/apimachinery/pkg/api/errors"
metav1 "k8s.io/apimachinery/pkg/apis/meta/v1"
"k8s.io/apimachinery/pkg/runtime"
"k8s.io/apimachinery/pkg/runtime/serializer"
"k8s.io/apimachinery/pkg/util/httpstream"
"k8s.io/apimachinery/pkg/util/httpstream/wsstream"
utilnet "k8s.io/apimachinery/pkg/util/net"
Expand All @@ -37,6 +43,17 @@ var (
_ http.RoundTripper = &RoundTripper{}
)

var (
statusScheme = runtime.NewScheme()
statusCodecs = serializer.NewCodecFactory(statusScheme)
)

func init() {
statusScheme.AddUnversionedTypes(metav1.SchemeGroupVersion,
&metav1.Status{},
)
}

// ConnectionHolder defines functions for structure providing
// access to the websocket connection.
type ConnectionHolder interface {
Expand Down Expand Up @@ -110,12 +127,33 @@ func (rt *RoundTripper) RoundTrip(request *http.Request) (retResp *http.Response
}
wsConn, resp, err := dialer.DialContext(request.Context(), request.URL.String(), request.Header)
if err != nil {
// BadHandshake error becomes an "UpgradeFailureError" (used for streaming fallback).
if errors.Is(err, gwebsocket.ErrBadHandshake) {
// Enhance the error message with the response status if possible.
cause := err
// Enhance the error message with the error response if possible.
if resp != nil && len(resp.Status) > 0 {
err = fmt.Errorf("%w (%s)", err, resp.Status)
defer resp.Body.Close() //nolint:errcheck
cause = fmt.Errorf("%w (%s)", err, resp.Status) // Always add the response status
responseError := ""
responseErrorBytes, readErr := io.ReadAll(io.LimitReader(resp.Body, 64*1024))
if readErr != nil {
cause = fmt.Errorf("%w (unable to read error from server response)", cause)
} else {
// If returned error can be decoded as "metav1.Status", return a "StatusError".
responseError = strings.TrimSpace(string(responseErrorBytes))
if len(responseError) > 0 {
if obj, _, decodeErr := statusCodecs.UniversalDecoder().Decode(responseErrorBytes, nil, &metav1.Status{}); decodeErr == nil {
if status, ok := obj.(*metav1.Status); ok {
cause = &apierrors.StatusError{ErrStatus: *status}
}
} else {
// Otherwise, append the responseError string.
cause = fmt.Errorf("%w: %s", cause, responseError)
}
}
}
}
return nil, &httpstream.UpgradeFailureError{Cause: err}
return nil, &httpstream.UpgradeFailureError{Cause: cause}
}
return nil, err
}
Expand Down
109 changes: 84 additions & 25 deletions staging/src/k8s.io/client-go/transport/websocket/roundtripper_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@ package websocket

import (
"context"
"errors"
"io"
"net/http"
"net/http/httptest"
Expand All @@ -28,6 +29,9 @@ import (
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"

apierrors "k8s.io/apimachinery/pkg/api/errors"
metav1 "k8s.io/apimachinery/pkg/apis/meta/v1"
"k8s.io/apimachinery/pkg/runtime"
"k8s.io/apimachinery/pkg/util/httpstream"
"k8s.io/apimachinery/pkg/util/httpstream/wsstream"
"k8s.io/apimachinery/pkg/util/remotecommand"
Expand Down Expand Up @@ -64,31 +68,86 @@ func TestWebSocketRoundTripper_RoundTripperSucceeds(t *testing.T) {
}

func TestWebSocketRoundTripper_RoundTripperFails(t *testing.T) {
// Create fake WebSocket server.
websocketServer := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, req *http.Request) {
// Bad handshake means websocket server will not completely initialize.
_, err := webSocketServerStreams(req, w)
require.Error(t, err)
assert.ErrorContains(t, err, "websocket server finished before becoming ready")
}))
defer websocketServer.Close()

// Create the wrapped roundtripper and websocket upgrade roundtripper and call "RoundTrip()".
websocketLocation, err := url.Parse(websocketServer.URL)
require.NoError(t, err)
req, err := http.NewRequestWithContext(context.Background(), "GET", websocketServer.URL, nil)
require.NoError(t, err)
rt, _, err := RoundTripperFor(&restclient.Config{Host: websocketLocation.Host})
require.NoError(t, err)
// Requested subprotocol version 1 is not supported by test websocket server.
requestedProtocol := remotecommand.StreamProtocolV1Name
req.Header[wsstream.WebSocketProtocolHeader] = []string{requestedProtocol}
_, err = rt.RoundTrip(req)
// Ensure a "bad handshake" error is returned, since requested protocol is not supported.
require.Error(t, err)
assert.ErrorContains(t, err, "websocket: bad handshake")
assert.ErrorContains(t, err, "403 Forbidden")
assert.True(t, httpstream.IsUpgradeFailure(err))
testCases := map[string]struct {
statusCode int
body string
status *metav1.Status
expectedError string
}{
"Empty response status still returns basic websocket error": {
statusCode: -1,
body: "",
expectedError: "websocket: bad handshake",
},
"Empty response body still returns status": {
statusCode: http.StatusForbidden,
body: "",
expectedError: "(403 Forbidden)",
},
"Error response body returned as string when can not be cast as metav1.Status": {
statusCode: http.StatusForbidden,
body: "RBAC violated",
expectedError: "(403 Forbidden): RBAC violated",
},
"Error returned as metav1.Status within response body": {
statusCode: http.StatusBadRequest,
body: "",
status: &metav1.Status{
TypeMeta: metav1.TypeMeta{
APIVersion: "meta.k8s.io/v1",
Kind: "Status",
},
Status: "Failure",
Reason: "Unable to negotiate sub-protocol",
Code: http.StatusBadRequest,
},
},
}
encoder := statusCodecs.LegacyCodec(metav1.SchemeGroupVersion)
for testName, testCase := range testCases {
t.Run(testName, func(t *testing.T) {
// Create fake WebSocket server.
websocketServer := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, req *http.Request) {
if testCase.statusCode > 0 {
w.WriteHeader(testCase.statusCode)
}
if testCase.status != nil {
statusBytes, err := runtime.Encode(encoder, testCase.status)
require.NoError(t, err)
_, err = w.Write(statusBytes)
require.NoError(t, err)
} else if len(testCase.body) > 0 {
_, err := w.Write([]byte(testCase.body))
require.NoError(t, err)
}
}))
defer websocketServer.Close()

// Create the wrapped roundtripper and websocket upgrade roundtripper and call "RoundTrip()".
websocketLocation, err := url.Parse(websocketServer.URL)
require.NoError(t, err)
req, err := http.NewRequestWithContext(context.Background(), "GET", websocketServer.URL, nil)
require.NoError(t, err)
rt, _, err := RoundTripperFor(&restclient.Config{Host: websocketLocation.Host})
require.NoError(t, err)
_, err = rt.RoundTrip(req)
require.Error(t, err)
assert.True(t, httpstream.IsUpgradeFailure(err))
if testCase.status != nil {
upgradeErr := &httpstream.UpgradeFailureError{}
validErr := errors.As(err, &upgradeErr)
assert.True(t, validErr, "could not cast error as httpstream.UpgradeFailureError")
statusErr := upgradeErr.Cause
apiErr := &apierrors.StatusError{}
validErr = errors.As(statusErr, &apiErr)
assert.True(t, validErr, "could not cast error as apierrors.StatusError")
assert.Equal(t, *testCase.status, apiErr.ErrStatus)
} else {
assert.Contains(t, err.Error(), testCase.expectedError,
"expected (%s), got (%s)", testCase.expectedError, err.Error())
}
})
}
}

func TestWebSocketRoundTripper_NegotiateCreatesConnection(t *testing.T) {
Expand Down

0 comments on commit 491611a

Please sign in to comment.