diff --git a/plugin/pkg/client/auth/oidc/oidc_test.go b/plugin/pkg/client/auth/oidc/oidc_test.go index 767781e63b260..30d8e4ab20826 100644 --- a/plugin/pkg/client/auth/oidc/oidc_test.go +++ b/plugin/pkg/client/auth/oidc/oidc_test.go @@ -18,14 +18,23 @@ package oidc import ( "encoding/base64" + "errors" + "fmt" "io/ioutil" + "net/http" "os" "path" + "reflect" + "strings" "testing" + "time" "github.com/coreos/go-oidc/jose" + "github.com/coreos/go-oidc/key" + "github.com/coreos/go-oidc/oauth2" "k8s.io/kubernetes/pkg/util/diff" + "k8s.io/kubernetes/pkg/util/wait" oidctesting "k8s.io/kubernetes/plugin/pkg/auth/authenticator/token/oidc/testing" ) @@ -156,6 +165,456 @@ func TestNewOIDCAuthProvider(t *testing.T) { } } +func TestWrapTranport(t *testing.T) { + oldBackoff := backoff + defer func() { + backoff = oldBackoff + }() + backoff = wait.Backoff{ + Duration: 1 * time.Nanosecond, + Steps: 3, + } + + privKey, err := key.GeneratePrivateKey() + if err != nil { + t.Fatalf("can't generate private key: %v", err) + } + + makeToken := func(s string, exp time.Time, count int) *jose.JWT { + jwt, err := jose.NewSignedJWT(jose.Claims(map[string]interface{}{ + "test": s, + "exp": exp.UTC().Unix(), + "count": count, + }), privKey.Signer()) + if err != nil { + t.Fatalf("Could not create signed JWT %v", err) + } + return jwt + } + + goodToken := makeToken("good", time.Now().Add(time.Hour), 0) + goodToken2 := makeToken("good", time.Now().Add(time.Hour), 1) + expiredToken := makeToken("good", time.Now().Add(-time.Hour), 0) + + str := func(s string) *string { + return &s + } + tests := []struct { + cfgIDToken *jose.JWT + cfgRefreshToken *string + + expectRequests []testRoundTrip + + expectRefreshes []testRefresh + + expectPersists []testPersist + + wantStatus int + wantErr bool + }{ + { + // Initial JWT is set, it is good, it is set as bearer. + cfgIDToken: goodToken, + + expectRequests: []testRoundTrip{ + { + expectBearerToken: goodToken.Encode(), + returnHTTPStatus: 200, + }, + }, + + wantStatus: 200, + }, + { + // Initial JWT is set, but it's expired, so it gets refreshed. + cfgIDToken: expiredToken, + cfgRefreshToken: str("rt1"), + + expectRefreshes: []testRefresh{ + { + expectRefreshToken: "rt1", + returnTokens: oauth2.TokenResponse{ + IDToken: goodToken.Encode(), + }, + }, + }, + + expectRequests: []testRoundTrip{ + { + expectBearerToken: goodToken.Encode(), + returnHTTPStatus: 200, + }, + }, + + expectPersists: []testPersist{ + { + cfg: map[string]string{ + cfgIDToken: goodToken.Encode(), + cfgRefreshToken: "rt1", + }, + }, + }, + + wantStatus: 200, + }, + { + // Initial JWT is set, but it's expired, so it gets refreshed - this + // time the refresh token itself is also refreshed + cfgIDToken: expiredToken, + cfgRefreshToken: str("rt1"), + + expectRefreshes: []testRefresh{ + { + expectRefreshToken: "rt1", + returnTokens: oauth2.TokenResponse{ + IDToken: goodToken.Encode(), + RefreshToken: "rt2", + }, + }, + }, + + expectRequests: []testRoundTrip{ + { + expectBearerToken: goodToken.Encode(), + returnHTTPStatus: 200, + }, + }, + + expectPersists: []testPersist{ + { + cfg: map[string]string{ + cfgIDToken: goodToken.Encode(), + cfgRefreshToken: "rt2", + }, + }, + }, + + wantStatus: 200, + }, + { + // Initial JWT is not set, so it gets refreshed. + cfgRefreshToken: str("rt1"), + + expectRefreshes: []testRefresh{ + { + expectRefreshToken: "rt1", + returnTokens: oauth2.TokenResponse{ + IDToken: goodToken.Encode(), + }, + }, + }, + + expectRequests: []testRoundTrip{ + { + expectBearerToken: goodToken.Encode(), + returnHTTPStatus: 200, + }, + }, + + expectPersists: []testPersist{ + { + cfg: map[string]string{ + cfgIDToken: goodToken.Encode(), + cfgRefreshToken: "rt1", + }, + }, + }, + + wantStatus: 200, + }, + { + // Expired token, but no refresh token. + cfgIDToken: expiredToken, + + wantErr: true, + }, + { + // Initial JWT is not set, so it gets refreshed, but the server + // rejects it when it is used, so it refreshes again, which + // succeeds. + cfgRefreshToken: str("rt1"), + + expectRefreshes: []testRefresh{ + { + expectRefreshToken: "rt1", + returnTokens: oauth2.TokenResponse{ + IDToken: goodToken.Encode(), + }, + }, + { + expectRefreshToken: "rt1", + returnTokens: oauth2.TokenResponse{ + IDToken: goodToken2.Encode(), + }, + }, + }, + + expectRequests: []testRoundTrip{ + { + expectBearerToken: goodToken.Encode(), + returnHTTPStatus: http.StatusUnauthorized, + }, + { + expectBearerToken: goodToken2.Encode(), + returnHTTPStatus: http.StatusOK, + }, + }, + + expectPersists: []testPersist{ + { + cfg: map[string]string{ + cfgIDToken: goodToken.Encode(), + cfgRefreshToken: "rt1", + }, + }, + { + cfg: map[string]string{ + cfgIDToken: goodToken2.Encode(), + cfgRefreshToken: "rt1", + }, + }, + }, + + wantStatus: 200, + }, + { + // Initial JWT is but the server rejects it when it is used, so it + // refreshes again, which succeeds. + cfgRefreshToken: str("rt1"), + cfgIDToken: goodToken, + + expectRefreshes: []testRefresh{ + { + expectRefreshToken: "rt1", + returnTokens: oauth2.TokenResponse{ + IDToken: goodToken2.Encode(), + }, + }, + }, + + expectRequests: []testRoundTrip{ + { + expectBearerToken: goodToken.Encode(), + returnHTTPStatus: http.StatusUnauthorized, + }, + { + expectBearerToken: goodToken2.Encode(), + returnHTTPStatus: http.StatusOK, + }, + }, + + expectPersists: []testPersist{ + { + cfg: map[string]string{ + cfgIDToken: goodToken2.Encode(), + cfgRefreshToken: "rt1", + }, + }, + }, + wantStatus: 200, + }, + } + + for i, tt := range tests { + client := &testOIDCClient{ + refreshes: tt.expectRefreshes, + } + + persister := &testPersister{ + tt.expectPersists, + } + + cfg := map[string]string{} + if tt.cfgIDToken != nil { + cfg[cfgIDToken] = tt.cfgIDToken.Encode() + } + + if tt.cfgRefreshToken != nil { + cfg[cfgRefreshToken] = *tt.cfgRefreshToken + } + + ap := &oidcAuthProvider{ + refresher: &idTokenRefresher{ + client: client, + cfg: cfg, + persister: persister, + }, + } + + if tt.cfgIDToken != nil { + ap.initialIDToken = *tt.cfgIDToken + } + + tstRT := &testRoundTripper{ + tt.expectRequests, + } + + rt := ap.WrapTransport(tstRT) + + req, err := http.NewRequest("GET", "http://cluster.example.com", nil) + if err != nil { + t.Errorf("case %d: unexpected error making request: %v", i, err) + } + + res, err := rt.RoundTrip(req) + if tt.wantErr { + if err == nil { + t.Errorf("case %d: Expected non-nil error", i) + } + } else if err != nil { + t.Errorf("case %d: unexpected error making round trip: %v", i, err) + + } else { + if res.StatusCode != tt.wantStatus { + t.Errorf("case %d: want=%d, got=%d", i, tt.wantStatus, res.StatusCode) + } + } + + if err = client.verify(); err != nil { + t.Errorf("case %d: %v", i, err) + } + + if err = persister.verify(); err != nil { + t.Errorf("case %d: %v", i, err) + } + + if err = tstRT.verify(); err != nil { + t.Errorf("case %d: %v", i, err) + continue + } + + } +} + +type testRoundTrip struct { + expectBearerToken string + returnHTTPStatus int +} + +type testRoundTripper struct { + trips []testRoundTrip +} + +func (t *testRoundTripper) RoundTrip(req *http.Request) (*http.Response, error) { + if len(t.trips) == 0 { + return nil, errors.New("unexpected RoundTrip call") + } + + var trip testRoundTrip + trip, t.trips = t.trips[0], t.trips[1:] + + var bt string + var parts []string + auth := strings.TrimSpace(req.Header.Get("Authorization")) + if auth == "" { + goto Compare + } + + parts = strings.Split(auth, " ") + if len(parts) < 2 || strings.ToLower(parts[0]) != "bearer" { + goto Compare + } + + bt = parts[1] + +Compare: + if trip.expectBearerToken != bt { + return nil, fmt.Errorf("want bearerToken=%v, got=%v", trip.expectBearerToken, bt) + } + return &http.Response{ + StatusCode: trip.returnHTTPStatus, + }, nil +} + +func (t *testRoundTripper) verify() error { + if l := len(t.trips); l > 0 { + return fmt.Errorf("%d uncalled round trips", l) + } + return nil +} + +type testPersist struct { + cfg map[string]string + returnErr error +} + +type testPersister struct { + persists []testPersist +} + +func (t *testPersister) Persist(cfg map[string]string) error { + if len(t.persists) == 0 { + return errors.New("unexpected persist call") + } + + var persist testPersist + persist, t.persists = t.persists[0], t.persists[1:] + + if !reflect.DeepEqual(persist.cfg, cfg) { + return fmt.Errorf("Unexpected cfg: %v", diff.ObjectDiff(persist.cfg, cfg)) + } + + return persist.returnErr +} + +func (t *testPersister) verify() error { + if l := len(t.persists); l > 0 { + return fmt.Errorf("%d uncalled persists", l) + } + return nil +} + +type testRefresh struct { + expectRefreshToken string + + returnErr error + returnTokens oauth2.TokenResponse +} + +type testOIDCClient struct { + refreshes []testRefresh +} + +func (o *testOIDCClient) refreshToken(rt string) (oauth2.TokenResponse, error) { + if len(o.refreshes) == 0 { + return oauth2.TokenResponse{}, errors.New("unexpected refresh request") + } + + var refresh testRefresh + refresh, o.refreshes = o.refreshes[0], o.refreshes[1:] + + if rt != refresh.expectRefreshToken { + return oauth2.TokenResponse{}, fmt.Errorf("want rt=%v, got=%v", + refresh.expectRefreshToken, + rt) + } + + if refresh.returnErr != nil { + return oauth2.TokenResponse{}, refresh.returnErr + } + + return refresh.returnTokens, nil +} + +func (o *testOIDCClient) verifyJWT(jwt jose.JWT) error { + claims, err := jwt.Claims() + if err != nil { + return err + } + claim, _, _ := claims.StringClaim("test") + if claim != "good" { + return errors.New("bad token") + } + return nil +} + +func (t *testOIDCClient) verify() error { + if l := len(t.refreshes); l > 0 { + return fmt.Errorf("%d uncalled refreshes", l) + } + return nil +} + func compareJWTs(a, b jose.JWT) string { if a.Encode() == b.Encode() { return "" @@ -179,5 +638,5 @@ func compareJWTs(a, b jose.JWT) string { } } - return diff.ObjectDiff(a, b) + return diff.ObjectDiff(aClaims, bClaims) }