Skip to content

Commit

Permalink
Refactor test oidc provider into its own package
Browse files Browse the repository at this point in the history
This makes it easier to test other OIDC code.
  • Loading branch information
Bobby Rullo committed May 19, 2016
1 parent f2135bd commit c990462
Show file tree
Hide file tree
Showing 3 changed files with 283 additions and 174 deletions.
210 changes: 36 additions & 174 deletions plugin/pkg/auth/authenticator/token/oidc/oidc_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -17,60 +17,23 @@ limitations under the License.
package oidc

import (
"bytes"
"crypto/rand"
"crypto/rsa"
"crypto/tls"
"crypto/x509"
"crypto/x509/pkix"
"encoding/json"
"encoding/pem"
"fmt"
"io/ioutil"
"math/big"
"net"
"net/http"
"net/http/httptest"
"net/url"
"os"
"path"
"path/filepath"
"reflect"
"strings"
"testing"
"time"

"github.com/coreos/go-oidc/jose"
"github.com/coreos/go-oidc/key"
"github.com/coreos/go-oidc/oidc"

"k8s.io/kubernetes/pkg/auth/user"
oidctesting "k8s.io/kubernetes/plugin/pkg/auth/authenticator/token/oidc/testing"
)

type oidcProvider struct {
mux *http.ServeMux
pcfg oidc.ProviderConfig
privKey *key.PrivateKey
}

func newOIDCProvider(t *testing.T) *oidcProvider {
privKey, err := key.GeneratePrivateKey()
if err != nil {
t.Fatalf("Cannot create OIDC Provider: %v", err)
return nil
}

op := &oidcProvider{
mux: http.NewServeMux(),
privKey: privKey,
}

op.mux.HandleFunc("/.well-known/openid-configuration", op.handleConfig)
op.mux.HandleFunc("/keys", op.handleKeys)

return op

}

func mustParseURL(t *testing.T, s string) *url.URL {
u, err := url.Parse(s)
if err != nil {
Expand All @@ -79,37 +42,8 @@ func mustParseURL(t *testing.T, s string) *url.URL {
return u
}

func (op *oidcProvider) handleConfig(w http.ResponseWriter, req *http.Request) {
b, err := json.Marshal(&op.pcfg)
if err != nil {
http.Error(w, err.Error(), http.StatusInternalServerError)
return
}
w.Header().Set("Content-Type", "application/json")
w.Write(b)
}

func (op *oidcProvider) handleKeys(w http.ResponseWriter, req *http.Request) {
keys := struct {
Keys []jose.JWK `json:"keys"`
}{
Keys: []jose.JWK{op.privKey.JWK()},
}

b, err := json.Marshal(keys)
if err != nil {
http.Error(w, err.Error(), http.StatusInternalServerError)
return
}

w.Header().Set("Cache-Control", fmt.Sprintf("public, max-age=%d", int(time.Hour.Seconds())))
w.Header().Set("Expires", time.Now().Add(time.Hour).Format(time.RFC1123))
w.Header().Set("Content-Type", "application/json")
w.Write(b)
}

func (op *oidcProvider) generateToken(t *testing.T, iss, sub, aud string, usernameClaim, value, groupsClaim string, groups []string, iat, exp time.Time) string {
signer := op.privKey.Signer()
func generateToken(t *testing.T, op *oidctesting.OIDCProvider, iss, sub, aud string, usernameClaim, value, groupsClaim string, groups []string, iat, exp time.Time) string {
signer := op.PrivKey.Signer()
claims := oidc.NewClaims(iss, sub, aud, iat, exp)
claims.Add(usernameClaim, value)
if groups != nil && groupsClaim != "" {
Expand All @@ -124,79 +58,16 @@ func (op *oidcProvider) generateToken(t *testing.T, iss, sub, aud string, userna
return jwt.Encode()
}

func (op *oidcProvider) generateGoodToken(t *testing.T, iss, sub, aud string, usernameClaim, value, groupsClaim string, groups []string) string {
return op.generateToken(t, iss, sub, aud, usernameClaim, value, groupsClaim, groups, time.Now(), time.Now().Add(time.Hour))
func generateGoodToken(t *testing.T, op *oidctesting.OIDCProvider, iss, sub, aud string, usernameClaim, value, groupsClaim string, groups []string) string {
return generateToken(t, op, iss, sub, aud, usernameClaim, value, groupsClaim, groups, time.Now(), time.Now().Add(time.Hour))
}

func (op *oidcProvider) generateMalformedToken(t *testing.T, iss, sub, aud string, usernameClaim, value, groupsClaim string, groups []string) string {
return op.generateToken(t, iss, sub, aud, usernameClaim, value, groupsClaim, groups, time.Now(), time.Now().Add(time.Hour)) + "randombits"
func generateMalformedToken(t *testing.T, op *oidctesting.OIDCProvider, iss, sub, aud string, usernameClaim, value, groupsClaim string, groups []string) string {
return generateToken(t, op, iss, sub, aud, usernameClaim, value, groupsClaim, groups, time.Now(), time.Now().Add(time.Hour)) + "randombits"
}

func (op *oidcProvider) generateExpiredToken(t *testing.T, iss, sub, aud string, usernameClaim, value, groupsClaim string, groups []string) string {
return op.generateToken(t, iss, sub, aud, usernameClaim, value, groupsClaim, groups, time.Now().Add(-2*time.Hour), time.Now().Add(-1*time.Hour))
}

// generateSelfSignedCert generates a self-signed cert/key pairs and writes to the certPath/keyPath.
// This method is mostly identical to crypto.GenerateSelfSignedCert except for the 'IsCA' and 'KeyUsage'
// in the certificate template. (Maybe we can merge these two methods).
func generateSelfSignedCert(t *testing.T, host, certPath, keyPath string) {
priv, err := rsa.GenerateKey(rand.Reader, 2048)
if err != nil {
t.Fatal(err)
}

template := x509.Certificate{
SerialNumber: big.NewInt(1),
Subject: pkix.Name{
CommonName: fmt.Sprintf("%s@%d", host, time.Now().Unix()),
},
NotBefore: time.Now(),
NotAfter: time.Now().Add(time.Hour * 24 * 365),

KeyUsage: x509.KeyUsageKeyEncipherment | x509.KeyUsageDigitalSignature | x509.KeyUsageCertSign,
ExtKeyUsage: []x509.ExtKeyUsage{x509.ExtKeyUsageServerAuth},
BasicConstraintsValid: true,
IsCA: true,
}

if ip := net.ParseIP(host); ip != nil {
template.IPAddresses = append(template.IPAddresses, ip)
} else {
template.DNSNames = append(template.DNSNames, host)
}

derBytes, err := x509.CreateCertificate(rand.Reader, &template, &template, &priv.PublicKey, priv)
if err != nil {
t.Fatal(err)
}

// Generate cert
certBuffer := bytes.Buffer{}
if err := pem.Encode(&certBuffer, &pem.Block{Type: "CERTIFICATE", Bytes: derBytes}); err != nil {
t.Fatal(err)
}

// Generate key
keyBuffer := bytes.Buffer{}
if err := pem.Encode(&keyBuffer, &pem.Block{Type: "RSA PRIVATE KEY", Bytes: x509.MarshalPKCS1PrivateKey(priv)}); err != nil {
t.Fatal(err)
}

// Write cert
if err := os.MkdirAll(filepath.Dir(certPath), os.FileMode(0755)); err != nil {
t.Fatal(err)
}
if err := ioutil.WriteFile(certPath, certBuffer.Bytes(), os.FileMode(0644)); err != nil {
t.Fatal(err)
}

// Write key
if err := os.MkdirAll(filepath.Dir(keyPath), os.FileMode(0755)); err != nil {
t.Fatal(err)
}
if err := ioutil.WriteFile(keyPath, keyBuffer.Bytes(), os.FileMode(0600)); err != nil {
t.Fatal(err)
}
func generateExpiredToken(t *testing.T, op *oidctesting.OIDCProvider, iss, sub, aud string, usernameClaim, value, groupsClaim string, groups []string) string {
return generateToken(t, op, iss, sub, aud, usernameClaim, value, groupsClaim, groups, time.Now().Add(-2*time.Hour), time.Now().Add(-1*time.Hour))
}

func TestOIDCDiscoveryTimeout(t *testing.T) {
Expand All @@ -217,19 +88,16 @@ func TestOIDCDiscoveryNoKeyEndpoint(t *testing.T) {
defer os.Remove(cert)
defer os.Remove(key)

generateSelfSignedCert(t, "127.0.0.1", cert, key)
oidctesting.GenerateSelfSignedCert(t, "127.0.0.1", cert, key)

op := newOIDCProvider(t)
srv := httptest.NewUnstartedServer(op.mux)
srv.TLS = &tls.Config{Certificates: make([]tls.Certificate, 1)}
srv.TLS.Certificates[0], err = tls.LoadX509KeyPair(cert, key)
op := oidctesting.NewOIDCProvider(t)
srv, err := op.ServeTLSWithKeyPair(cert, key)
if err != nil {
t.Fatalf("Cannot load cert/key pair: %v", err)
t.Fatalf("Cannot start server %v", err)
}
srv.StartTLS()
defer srv.Close()

op.pcfg = oidc.ProviderConfig{
op.PCFG = oidc.ProviderConfig{
Issuer: mustParseURL(t, srv.URL), // An invalid ProviderConfig. Keys endpoint is required.
}

Expand All @@ -241,11 +109,11 @@ func TestOIDCDiscoveryNoKeyEndpoint(t *testing.T) {

func TestOIDCDiscoverySecureConnection(t *testing.T) {
// Verify that plain HTTP issuer URL is forbidden.
op := newOIDCProvider(t)
srv := httptest.NewServer(op.mux)
op := oidctesting.NewOIDCProvider(t)
srv := httptest.NewServer(op.Mux)
defer srv.Close()

op.pcfg = oidc.ProviderConfig{
op.PCFG = oidc.ProviderConfig{
Issuer: mustParseURL(t, srv.URL),
KeysEndpoint: mustParseURL(t, srv.URL+"/keys"),
}
Expand All @@ -268,20 +136,17 @@ func TestOIDCDiscoverySecureConnection(t *testing.T) {
defer os.Remove(cert2)
defer os.Remove(key2)

generateSelfSignedCert(t, "127.0.0.1", cert1, key1)
generateSelfSignedCert(t, "127.0.0.1", cert2, key2)
oidctesting.GenerateSelfSignedCert(t, "127.0.0.1", cert1, key1)
oidctesting.GenerateSelfSignedCert(t, "127.0.0.1", cert2, key2)

// Create a TLS server using cert/key pair 1.
tlsSrv := httptest.NewUnstartedServer(op.mux)
tlsSrv.TLS = &tls.Config{Certificates: make([]tls.Certificate, 1)}
tlsSrv.TLS.Certificates[0], err = tls.LoadX509KeyPair(cert1, key1)
tlsSrv, err := op.ServeTLSWithKeyPair(cert1, key1)
if err != nil {
t.Fatalf("Cannot load cert/key pair: %v", err)
t.Fatalf("Cannot start server: %v", err)
}
tlsSrv.StartTLS()
defer tlsSrv.Close()

op.pcfg = oidc.ProviderConfig{
op.PCFG = oidc.ProviderConfig{
Issuer: mustParseURL(t, tlsSrv.URL),
KeysEndpoint: mustParseURL(t, tlsSrv.URL+"/keys"),
}
Expand All @@ -303,21 +168,18 @@ func TestOIDCAuthentication(t *testing.T) {
defer os.Remove(cert)
defer os.Remove(key)

generateSelfSignedCert(t, "127.0.0.1", cert, key)
oidctesting.GenerateSelfSignedCert(t, "127.0.0.1", cert, key)

// Create a TLS server and a client.
op := newOIDCProvider(t)
srv := httptest.NewUnstartedServer(op.mux)
srv.TLS = &tls.Config{Certificates: make([]tls.Certificate, 1)}
srv.TLS.Certificates[0], err = tls.LoadX509KeyPair(cert, key)
op := oidctesting.NewOIDCProvider(t)
srv, err := op.ServeTLSWithKeyPair(cert, key)
if err != nil {
t.Fatalf("Cannot load cert/key pair: %v", err)
t.Fatalf("Cannot start server: %v", err)
}
srv.StartTLS()
defer srv.Close()

// A provider config with all required fields.
op.pcfg = oidc.ProviderConfig{
op.PCFG = oidc.ProviderConfig{
Issuer: mustParseURL(t, srv.URL),
AuthEndpoint: mustParseURL(t, srv.URL+"/auth"),
TokenEndpoint: mustParseURL(t, srv.URL+"/token"),
Expand All @@ -338,7 +200,7 @@ func TestOIDCAuthentication(t *testing.T) {
{
"sub",
"",
op.generateGoodToken(t, srv.URL, "client-foo", "client-foo", "sub", "user-foo", "", nil),
generateGoodToken(t, op, srv.URL, "client-foo", "client-foo", "sub", "user-foo", "", nil),
&user.DefaultInfo{Name: fmt.Sprintf("%s#%s", srv.URL, "user-foo")},
true,
"",
Expand All @@ -347,7 +209,7 @@ func TestOIDCAuthentication(t *testing.T) {
// Use user defined claim (email here).
"email",
"",
op.generateGoodToken(t, srv.URL, "client-foo", "client-foo", "email", "foo@example.com", "", nil),
generateGoodToken(t, op, srv.URL, "client-foo", "client-foo", "email", "foo@example.com", "", nil),
&user.DefaultInfo{Name: "foo@example.com"},
true,
"",
Expand All @@ -356,7 +218,7 @@ func TestOIDCAuthentication(t *testing.T) {
// Use user defined claim (email here).
"email",
"",
op.generateGoodToken(t, srv.URL, "client-foo", "client-foo", "email", "foo@example.com", "groups", []string{"group1", "group2"}),
generateGoodToken(t, op, srv.URL, "client-foo", "client-foo", "email", "foo@example.com", "groups", []string{"group1", "group2"}),
&user.DefaultInfo{Name: "foo@example.com"},
true,
"",
Expand All @@ -365,15 +227,15 @@ func TestOIDCAuthentication(t *testing.T) {
// Use user defined claim (email here).
"email",
"groups",
op.generateGoodToken(t, srv.URL, "client-foo", "client-foo", "email", "foo@example.com", "groups", []string{"group1", "group2"}),
generateGoodToken(t, op, srv.URL, "client-foo", "client-foo", "email", "foo@example.com", "groups", []string{"group1", "group2"}),
&user.DefaultInfo{Name: "foo@example.com", Groups: []string{"group1", "group2"}},
true,
"",
},
{
"sub",
"",
op.generateMalformedToken(t, srv.URL, "client-foo", "client-foo", "sub", "user-foo", "", nil),
generateMalformedToken(t, op, srv.URL, "client-foo", "client-foo", "sub", "user-foo", "", nil),
nil,
false,
"oidc: unable to verify JWT signature: no matching keys",
Expand All @@ -382,7 +244,7 @@ func TestOIDCAuthentication(t *testing.T) {
// Invalid 'aud'.
"sub",
"",
op.generateGoodToken(t, srv.URL, "client-foo", "client-bar", "sub", "user-foo", "", nil),
generateGoodToken(t, op, srv.URL, "client-foo", "client-bar", "sub", "user-foo", "", nil),
nil,
false,
"oidc: JWT claims invalid: invalid claims, 'aud' claim and 'client_id' do not match",
Expand All @@ -391,15 +253,15 @@ func TestOIDCAuthentication(t *testing.T) {
// Invalid issuer.
"sub",
"",
op.generateGoodToken(t, "http://foo-bar.com", "client-foo", "client-foo", "sub", "user-foo", "", nil),
generateGoodToken(t, op, "http://foo-bar.com", "client-foo", "client-foo", "sub", "user-foo", "", nil),
nil,
false,
"oidc: JWT claims invalid: invalid claim value: 'iss'.",
},
{
"sub",
"",
op.generateExpiredToken(t, srv.URL, "client-foo", "client-foo", "sub", "user-foo", "", nil),
generateExpiredToken(t, op, srv.URL, "client-foo", "client-foo", "sub", "user-foo", "", nil),
nil,
false,
"oidc: JWT claims invalid: token is expired",
Expand Down
Loading

0 comments on commit c990462

Please sign in to comment.