From 5ae9b189aa72b183976f88ac0f44352f3145048b Mon Sep 17 00:00:00 2001 From: Joel Speed Date: Fri, 8 May 2020 16:57:40 +0100 Subject: [PATCH 1/3] Drop SessionStateJSON wrapper --- oauthproxy_test.go | 38 +++++++++------ pkg/apis/sessions/session_state.go | 62 +++++++------------------ pkg/apis/sessions/session_state_test.go | 39 +++++++++------- pkg/sessions/cookie/session_store.go | 7 +-- pkg/sessions/redis/redis_store.go | 7 +-- pkg/sessions/session_store_test.go | 18 ++++--- providers/azure.go | 6 ++- providers/gitlab.go | 7 +-- providers/google.go | 12 +++-- providers/logingov.go | 7 ++- providers/oidc.go | 9 ++-- providers/oidc_test.go | 8 ++-- providers/provider_default.go | 5 +- providers/provider_default_test.go | 4 +- 14 files changed, 118 insertions(+), 111 deletions(-) diff --git a/oauthproxy_test.go b/oauthproxy_test.go index 845dcef1c4..3238532369 100644 --- a/oauthproxy_test.go +++ b/oauthproxy_test.go @@ -484,8 +484,7 @@ func TestBasicAuthPassword(t *testing.T) { }) rw := httptest.NewRecorder() - req, _ := http.NewRequest("GET", "/oauth2/callback?code=callback_code&state=nonce:", - strings.NewReader("")) + req, _ := http.NewRequest("GET", "/oauth2/callback?code=callback_code&state=nonce:", strings.NewReader("")) req.AddCookie(proxy.MakeCSRFCookie(req, "nonce", proxy.CookieExpire, time.Now())) proxy.ServeHTTP(rw, req) if rw.Code >= 400 { @@ -541,11 +540,12 @@ func TestBasicAuthWithEmail(t *testing.T) { expectedEmailHeader := "Basic " + base64.StdEncoding.EncodeToString([]byte(emailAddress+":"+opts.BasicAuthPassword)) expectedUserHeader := "Basic " + base64.StdEncoding.EncodeToString([]byte(userName+":"+opts.BasicAuthPassword)) + created := time.Now() session := &sessions.SessionState{ User: userName, Email: emailAddress, AccessToken: "oauth_token", - CreatedAt: time.Now(), + CreatedAt: &created, } { rw := httptest.NewRecorder() @@ -582,11 +582,12 @@ func TestPassUserHeadersWithEmail(t *testing.T) { const emailAddress = "john.doe@example.com" const userName = "9fcab5c9b889a557" + created := time.Now() session := &sessions.SessionState{ User: userName, Email: emailAddress, AccessToken: "oauth_token", - CreatedAt: time.Now(), + CreatedAt: &created, } { rw := httptest.NewRecorder() @@ -959,7 +960,8 @@ func (p *ProcessCookieTest) LoadCookiedSession() (*sessions.SessionState, error) func TestLoadCookiedSession(t *testing.T) { pcTest := NewProcessCookieTestWithDefaults() - startSession := &sessions.SessionState{Email: "john.doe@example.com", AccessToken: "my_access_token", CreatedAt: time.Now()} + created := time.Now() + startSession := &sessions.SessionState{Email: "john.doe@example.com", AccessToken: "my_access_token", CreatedAt: &created} pcTest.SaveSession(startSession) session, err := pcTest.LoadCookiedSession() @@ -985,7 +987,7 @@ func TestProcessCookieRefreshNotSet(t *testing.T) { }) reference := time.Now().Add(time.Duration(-2) * time.Hour) - startSession := &sessions.SessionState{Email: "michael.bland@gsa.gov", AccessToken: "my_access_token", CreatedAt: reference} + startSession := &sessions.SessionState{Email: "michael.bland@gsa.gov", AccessToken: "my_access_token", CreatedAt: &reference} pcTest.SaveSession(startSession) session, err := pcTest.LoadCookiedSession() @@ -1001,7 +1003,7 @@ func TestProcessCookieFailIfCookieExpired(t *testing.T) { opts.Cookie.Expire = time.Duration(24) * time.Hour }) reference := time.Now().Add(time.Duration(25) * time.Hour * -1) - startSession := &sessions.SessionState{Email: "michael.bland@gsa.gov", AccessToken: "my_access_token", CreatedAt: reference} + startSession := &sessions.SessionState{Email: "michael.bland@gsa.gov", AccessToken: "my_access_token", CreatedAt: &reference} pcTest.SaveSession(startSession) session, err := pcTest.LoadCookiedSession() @@ -1016,7 +1018,7 @@ func TestProcessCookieFailIfRefreshSetAndCookieExpired(t *testing.T) { opts.Cookie.Expire = time.Duration(24) * time.Hour }) reference := time.Now().Add(time.Duration(25) * time.Hour * -1) - startSession := &sessions.SessionState{Email: "michael.bland@gsa.gov", AccessToken: "my_access_token", CreatedAt: reference} + startSession := &sessions.SessionState{Email: "michael.bland@gsa.gov", AccessToken: "my_access_token", CreatedAt: &reference} pcTest.SaveSession(startSession) pcTest.proxy.CookieRefresh = time.Hour @@ -1062,8 +1064,9 @@ func NewAuthOnlyEndpointTest(modifiers ...OptionsModifier) *ProcessCookieTest { func TestAuthOnlyEndpointAccepted(t *testing.T) { test := NewAuthOnlyEndpointTest() + created := time.Now() startSession := &sessions.SessionState{ - Email: "michael.bland@gsa.gov", AccessToken: "my_access_token", CreatedAt: time.Now()} + Email: "michael.bland@gsa.gov", AccessToken: "my_access_token", CreatedAt: &created} test.SaveSession(startSession) test.proxy.ServeHTTP(test.rw, test.req) @@ -1087,7 +1090,7 @@ func TestAuthOnlyEndpointUnauthorizedOnExpiration(t *testing.T) { }) reference := time.Now().Add(time.Duration(25) * time.Hour * -1) startSession := &sessions.SessionState{ - Email: "michael.bland@gsa.gov", AccessToken: "my_access_token", CreatedAt: reference} + Email: "michael.bland@gsa.gov", AccessToken: "my_access_token", CreatedAt: &reference} test.SaveSession(startSession) test.proxy.ServeHTTP(test.rw, test.req) @@ -1098,8 +1101,9 @@ func TestAuthOnlyEndpointUnauthorizedOnExpiration(t *testing.T) { func TestAuthOnlyEndpointUnauthorizedOnEmailValidationFailure(t *testing.T) { test := NewAuthOnlyEndpointTest() + created := time.Now() startSession := &sessions.SessionState{ - Email: "michael.bland@gsa.gov", AccessToken: "my_access_token", CreatedAt: time.Now()} + Email: "michael.bland@gsa.gov", AccessToken: "my_access_token", CreatedAt: &created} test.SaveSession(startSession) test.validateUser = false @@ -1129,8 +1133,9 @@ func TestAuthOnlyEndpointSetXAuthRequestHeaders(t *testing.T) { pcTest.req, _ = http.NewRequest("GET", pcTest.opts.ProxyPrefix+"/auth", nil) + created := time.Now() startSession := &sessions.SessionState{ - User: "oauth_user", Email: "oauth_user@example.com", AccessToken: "oauth_token", CreatedAt: time.Now()} + User: "oauth_user", Email: "oauth_user@example.com", AccessToken: "oauth_token", CreatedAt: &created} pcTest.SaveSession(startSession) pcTest.proxy.ServeHTTP(pcTest.rw, pcTest.req) @@ -1160,8 +1165,9 @@ func TestAuthOnlyEndpointSetBasicAuthTrueRequestHeaders(t *testing.T) { pcTest.req, _ = http.NewRequest("GET", pcTest.opts.ProxyPrefix+"/auth", nil) + created := time.Now() startSession := &sessions.SessionState{ - User: "oauth_user", Email: "oauth_user@example.com", AccessToken: "oauth_token", CreatedAt: time.Now()} + User: "oauth_user", Email: "oauth_user@example.com", AccessToken: "oauth_token", CreatedAt: &created} pcTest.SaveSession(startSession) pcTest.proxy.ServeHTTP(pcTest.rw, pcTest.req) @@ -1193,8 +1199,9 @@ func TestAuthOnlyEndpointSetBasicAuthFalseRequestHeaders(t *testing.T) { pcTest.req, _ = http.NewRequest("GET", pcTest.opts.ProxyPrefix+"/auth", nil) + created := time.Now() startSession := &sessions.SessionState{ - User: "oauth_user", Email: "oauth_user@example.com", AccessToken: "oauth_token", CreatedAt: time.Now()} + User: "oauth_user", Email: "oauth_user@example.com", AccessToken: "oauth_token", CreatedAt: &created} pcTest.SaveSession(startSession) pcTest.proxy.ServeHTTP(pcTest.rw, pcTest.req) @@ -1569,10 +1576,11 @@ func TestGetJwtSession(t *testing.T) { } // Bearer + expires := time.Unix(1912151821, 0) session, _ := test.proxy.GetJwtSession(test.req) assert.Equal(t, session.User, "john@example.com") assert.Equal(t, session.Email, "john@example.com") - assert.Equal(t, session.ExpiresOn, time.Unix(1912151821, 0)) + assert.Equal(t, session.ExpiresOn, &expires) assert.Equal(t, session.IDToken, goodJwt) test.proxy.ServeHTTP(test.rw, test.req) diff --git a/pkg/apis/sessions/session_state.go b/pkg/apis/sessions/session_state.go index a09f01c132..d665303c85 100644 --- a/pkg/apis/sessions/session_state.go +++ b/pkg/apis/sessions/session_state.go @@ -2,7 +2,6 @@ package sessions import ( "encoding/json" - "errors" "fmt" "time" @@ -11,26 +10,19 @@ import ( // SessionState is used to store information about the currently authenticated user session type SessionState struct { - AccessToken string `json:",omitempty"` - IDToken string `json:",omitempty"` - CreatedAt time.Time `json:"-"` - ExpiresOn time.Time `json:"-"` - RefreshToken string `json:",omitempty"` - Email string `json:",omitempty"` - User string `json:",omitempty"` - PreferredUsername string `json:",omitempty"` -} - -// SessionStateJSON is used to encode SessionState into JSON without exposing time.Time zero value -type SessionStateJSON struct { - *SessionState - CreatedAt *time.Time `json:",omitempty"` - ExpiresOn *time.Time `json:",omitempty"` + AccessToken string `json:",omitempty"` + IDToken string `json:",omitempty"` + CreatedAt *time.Time `json:",omitempty"` + ExpiresOn *time.Time `json:",omitempty"` + RefreshToken string `json:",omitempty"` + Email string `json:",omitempty"` + User string `json:",omitempty"` + PreferredUsername string `json:",omitempty"` } // IsExpired checks whether the session has expired func (s *SessionState) IsExpired() bool { - if !s.ExpiresOn.IsZero() && s.ExpiresOn.Before(time.Now()) { + if s.ExpiresOn != nil && !s.ExpiresOn.IsZero() && s.ExpiresOn.Before(time.Now()) { return true } return false @@ -38,8 +30,8 @@ func (s *SessionState) IsExpired() bool { // Age returns the age of a session func (s *SessionState) Age() time.Duration { - if !s.CreatedAt.IsZero() { - return time.Now().Truncate(time.Second).Sub(s.CreatedAt) + if s.CreatedAt != nil && !s.CreatedAt.IsZero() { + return time.Now().Truncate(time.Second).Sub(*s.CreatedAt) } return 0 } @@ -113,42 +105,22 @@ func (s *SessionState) EncodeSessionState(c *encryption.Cipher) (string, error) } } } - // Embed SessionState and ExpiresOn pointer into SessionStateJSON - ssj := &SessionStateJSON{SessionState: &ss} - if !ss.CreatedAt.IsZero() { - ssj.CreatedAt = &ss.CreatedAt - } - if !ss.ExpiresOn.IsZero() { - ssj.ExpiresOn = &ss.ExpiresOn - } - b, err := json.Marshal(ssj) + + b, err := json.Marshal(ss) return string(b), err } // DecodeSessionState decodes the session cookie string into a SessionState func DecodeSessionState(v string, c *encryption.Cipher) (*SessionState, error) { - var ssj SessionStateJSON - var ss *SessionState - err := json.Unmarshal([]byte(v), &ssj) + var ss SessionState + err := json.Unmarshal([]byte(v), &ss) if err != nil { return nil, fmt.Errorf("error unmarshalling session: %w", err) } - if ssj.SessionState == nil { - return nil, errors.New("expected session state to not be nil") - } - - // Extract SessionState and CreatedAt,ExpiresOn value from SessionStateJSON - ss = ssj.SessionState - if ssj.CreatedAt != nil { - ss.CreatedAt = *ssj.CreatedAt - } - if ssj.ExpiresOn != nil { - ss.ExpiresOn = *ssj.ExpiresOn - } if c == nil { // Load only Email and User when cipher is unavailable - ss = &SessionState{ + ss = SessionState{ Email: ss.Email, User: ss.User, PreferredUsername: ss.PreferredUsername, @@ -193,5 +165,5 @@ func DecodeSessionState(v string, c *encryption.Cipher) (*SessionState, error) { } } } - return ss, nil + return &ss, nil } diff --git a/pkg/apis/sessions/session_state_test.go b/pkg/apis/sessions/session_state_test.go index 94c624bfa2..529656fe4b 100644 --- a/pkg/apis/sessions/session_state_test.go +++ b/pkg/apis/sessions/session_state_test.go @@ -13,6 +13,10 @@ import ( const secret = "0123456789abcdefghijklmnopqrstuv" const altSecret = "0000000000abcdefghijklmnopqrstuv" +func timePtr(t time.Time) *time.Time { + return &t +} + func TestSessionStateSerialization(t *testing.T) { c, err := encryption.NewCipher([]byte(secret)) assert.Equal(t, nil, err) @@ -23,8 +27,8 @@ func TestSessionStateSerialization(t *testing.T) { PreferredUsername: "user", AccessToken: "token1234", IDToken: "rawtoken1234", - CreatedAt: time.Now(), - ExpiresOn: time.Now().Add(time.Duration(1) * time.Hour), + CreatedAt: timePtr(time.Now()), + ExpiresOn: timePtr(time.Now().Add(time.Duration(1) * time.Hour)), RefreshToken: "refresh4321", } encoded, err := s.EncodeSessionState(c) @@ -66,8 +70,8 @@ func TestSessionStateSerializationWithUser(t *testing.T) { PreferredUsername: "ju", Email: "user@domain.com", AccessToken: "token1234", - CreatedAt: time.Now(), - ExpiresOn: time.Now().Add(time.Duration(1) * time.Hour), + CreatedAt: timePtr(time.Now()), + ExpiresOn: timePtr(time.Now().Add(time.Duration(1) * time.Hour)), RefreshToken: "refresh4321", } encoded, err := s.EncodeSessionState(c) @@ -102,8 +106,8 @@ func TestSessionStateSerializationNoCipher(t *testing.T) { Email: "user@domain.com", PreferredUsername: "user", AccessToken: "token1234", - CreatedAt: time.Now(), - ExpiresOn: time.Now().Add(time.Duration(1) * time.Hour), + CreatedAt: timePtr(time.Now()), + ExpiresOn: timePtr(time.Now().Add(time.Duration(1) * time.Hour)), RefreshToken: "refresh4321", } encoded, err := s.EncodeSessionState(nil) @@ -125,8 +129,8 @@ func TestSessionStateSerializationNoCipherWithUser(t *testing.T) { Email: "user@domain.com", PreferredUsername: "user", AccessToken: "token1234", - CreatedAt: time.Now(), - ExpiresOn: time.Now().Add(time.Duration(1) * time.Hour), + CreatedAt: timePtr(time.Now()), + ExpiresOn: timePtr(time.Now().Add(time.Duration(1) * time.Hour)), RefreshToken: "refresh4321", } encoded, err := s.EncodeSessionState(nil) @@ -143,10 +147,10 @@ func TestSessionStateSerializationNoCipherWithUser(t *testing.T) { } func TestExpired(t *testing.T) { - s := &sessions.SessionState{ExpiresOn: time.Now().Add(time.Duration(-1) * time.Minute)} + s := &sessions.SessionState{ExpiresOn: timePtr(time.Now().Add(time.Duration(-1) * time.Minute))} assert.Equal(t, true, s.IsExpired()) - s = &sessions.SessionState{ExpiresOn: time.Now().Add(time.Duration(1) * time.Minute)} + s = &sessions.SessionState{ExpiresOn: timePtr(time.Now().Add(time.Duration(1) * time.Minute))} assert.Equal(t, false, s.IsExpired()) s = &sessions.SessionState{} @@ -182,8 +186,8 @@ func TestEncodeSessionState(t *testing.T) { User: "just-user", AccessToken: "token1234", IDToken: "rawtoken1234", - CreatedAt: c, - ExpiresOn: e, + CreatedAt: &c, + ExpiresOn: &e, RefreshToken: "refresh4321", }, Encoded: `{"Email":"user@domain.com","User":"just-user"}`, @@ -249,8 +253,8 @@ func TestDecodeSessionState(t *testing.T) { User: "just-user", AccessToken: "token1234", IDToken: "rawtoken1234", - CreatedAt: created, - ExpiresOn: e, + CreatedAt: &created, + ExpiresOn: &e, RefreshToken: "refresh4321", }, Encoded: fmt.Sprintf(`{"Email":"FsKKYrTWZWrxSOAqA/fTNAUZS5QWCqOBjuAbBlbVOw==","User":"rT6JP3dxQhxUhkWrrd7yt6c1mDVyQCVVxw==","AccessToken":"I6s+ml+/MldBMgHIiC35BTKTh57skGX24w==","IDToken":"xojNdyyjB1HgYWh6XMtXY/Ph5eCVxa1cNsklJw==","RefreshToken":"qEX0x6RmASxo4dhlBG6YuRs9Syn/e9sHu/+K","CreatedAt":%s,"ExpiresOn":%s}`, createdString, eString), @@ -291,7 +295,10 @@ func TestDecodeSessionState(t *testing.T) { assert.Equal(t, tc.AccessToken, ss.AccessToken) assert.Equal(t, tc.RefreshToken, ss.RefreshToken) assert.Equal(t, tc.IDToken, ss.IDToken) - assert.Equal(t, tc.ExpiresOn.Unix(), ss.ExpiresOn.Unix()) + if tc.ExpiresOn != nil { + assert.NotEqual(t, nil, ss.ExpiresOn) + assert.Equal(t, tc.ExpiresOn.Unix(), ss.ExpiresOn.Unix()) + } } } } @@ -303,6 +310,6 @@ func TestSessionStateAge(t *testing.T) { assert.Equal(t, time.Duration(0), ss.Age()) // Set CreatedAt to 1 hour ago - ss.CreatedAt = time.Now().Add(-1 * time.Hour) + ss.CreatedAt = timePtr(time.Now().Add(-1 * time.Hour)) assert.Equal(t, time.Hour, ss.Age().Round(time.Minute)) } diff --git a/pkg/sessions/cookie/session_store.go b/pkg/sessions/cookie/session_store.go index d97125c3e2..1b88e02762 100644 --- a/pkg/sessions/cookie/session_store.go +++ b/pkg/sessions/cookie/session_store.go @@ -34,14 +34,15 @@ type SessionStore struct { // Save takes a sessions.SessionState and stores the information from it // within Cookies set on the HTTP response writer func (s *SessionStore) Save(rw http.ResponseWriter, req *http.Request, ss *sessions.SessionState) error { - if ss.CreatedAt.IsZero() { - ss.CreatedAt = time.Now() + if ss.CreatedAt == nil || ss.CreatedAt.IsZero() { + now := time.Now() + ss.CreatedAt = &now } value, err := cookieForSession(ss, s.CookieCipher) if err != nil { return err } - s.setSessionCookie(rw, req, value, ss.CreatedAt) + s.setSessionCookie(rw, req, value, *ss.CreatedAt) return nil } diff --git a/pkg/sessions/redis/redis_store.go b/pkg/sessions/redis/redis_store.go index 7737b960b4..b737b6c62c 100644 --- a/pkg/sessions/redis/redis_store.go +++ b/pkg/sessions/redis/redis_store.go @@ -111,8 +111,9 @@ func newRedisCmdable(opts options.RedisStoreOptions) (Client, error) { // Save takes a sessions.SessionState and stores the information from it // to redies, and adds a new ticket cookie on the HTTP response writer func (store *SessionStore) Save(rw http.ResponseWriter, req *http.Request, s *sessions.SessionState) error { - if s.CreatedAt.IsZero() { - s.CreatedAt = time.Now() + if s.CreatedAt == nil || s.CreatedAt.IsZero() { + now := time.Now() + s.CreatedAt = &now } // Old sessions that we are refreshing would have a request cookie @@ -132,7 +133,7 @@ func (store *SessionStore) Save(rw http.ResponseWriter, req *http.Request, s *se req, ticketString, store.CookieOptions.Expire, - s.CreatedAt, + *s.CreatedAt, ) http.SetCookie(rw, ticketCookie) diff --git a/pkg/sessions/session_store_test.go b/pkg/sessions/session_store_test.go index 68cfd125c1..60a86cef8a 100644 --- a/pkg/sessions/session_store_test.go +++ b/pkg/sessions/session_store_test.go @@ -15,6 +15,7 @@ import ( sessionsapi "github.com/oauth2-proxy/oauth2-proxy/pkg/apis/sessions" cookiesapi "github.com/oauth2-proxy/oauth2-proxy/pkg/cookies" "github.com/oauth2-proxy/oauth2-proxy/pkg/encryption" + "github.com/oauth2-proxy/oauth2-proxy/pkg/logger" "github.com/oauth2-proxy/oauth2-proxy/pkg/sessions" sessionscookie "github.com/oauth2-proxy/oauth2-proxy/pkg/sessions/cookie" "github.com/oauth2-proxy/oauth2-proxy/pkg/sessions/redis" @@ -23,6 +24,8 @@ import ( ) func TestSessionStore(t *testing.T) { + logger.SetOutput(GinkgoWriter) + RegisterFailHandler(Fail) RunSpecs(t, "SessionStore") } @@ -253,16 +256,16 @@ var _ = Describe("NewSessionStore", func() { // Can't compare time.Time using Equal() so remove ExpiresOn from sessions l := *loadedSession - l.CreatedAt = time.Time{} - l.ExpiresOn = time.Time{} + l.CreatedAt = nil + l.ExpiresOn = nil s := *session - s.CreatedAt = time.Time{} - s.ExpiresOn = time.Time{} + s.CreatedAt = nil + s.ExpiresOn = nil Expect(l).To(Equal(s)) // Compare time.Time separately - Expect(loadedSession.CreatedAt.Equal(session.CreatedAt)).To(BeTrue()) - Expect(loadedSession.ExpiresOn.Equal(session.ExpiresOn)).To(BeTrue()) + Expect(loadedSession.CreatedAt.Equal(*session.CreatedAt)).To(BeTrue()) + Expect(loadedSession.ExpiresOn.Equal(*session.ExpiresOn)).To(BeTrue()) } }) } @@ -392,10 +395,11 @@ var _ = Describe("NewSessionStore", func() { SameSite: "", } + expires := time.Now().Add(1 * time.Hour) session = &sessionsapi.SessionState{ AccessToken: "AccessToken", IDToken: "IDToken", - ExpiresOn: time.Now().Add(1 * time.Hour), + ExpiresOn: &expires, RefreshToken: "RefreshToken", Email: "john.doe@example.com", User: "john.doe", diff --git a/providers/azure.go b/providers/azure.go index 961ff908c9..aea1b0e54c 100644 --- a/providers/azure.go +++ b/providers/azure.go @@ -126,11 +126,13 @@ func (p *AzureProvider) Redeem(ctx context.Context, redirectURL, code string) (s return } + created := time.Now() + expires := time.Unix(jsonResponse.ExpiresOn, 0) s = &sessions.SessionState{ AccessToken: jsonResponse.AccessToken, IDToken: jsonResponse.IDToken, - CreatedAt: time.Now(), - ExpiresOn: time.Unix(jsonResponse.ExpiresOn, 0), + CreatedAt: &created, + ExpiresOn: &expires, RefreshToken: jsonResponse.RefreshToken, } return diff --git a/providers/gitlab.go b/providers/gitlab.go index beeb6b9810..17c5df88e8 100644 --- a/providers/gitlab.go +++ b/providers/gitlab.go @@ -67,7 +67,7 @@ func (p *GitLabProvider) Redeem(ctx context.Context, redirectURL, code string) ( // RefreshSessionIfNeeded checks if the session has expired and uses the // RefreshToken to fetch a new ID token if required func (p *GitLabProvider) RefreshSessionIfNeeded(ctx context.Context, s *sessions.SessionState) (bool, error) { - if s == nil || s.ExpiresOn.After(time.Now()) || s.RefreshToken == "" { + if s == nil || (s.ExpiresOn != nil && s.ExpiresOn.After(time.Now())) || s.RefreshToken == "" { return false, nil } @@ -209,12 +209,13 @@ func (p *GitLabProvider) createSessionState(ctx context.Context, token *oauth2.T return nil, fmt.Errorf("could not verify id_token: %v", err) } + created := time.Now() return &sessions.SessionState{ AccessToken: token.AccessToken, IDToken: rawIDToken, RefreshToken: token.RefreshToken, - CreatedAt: time.Now(), - ExpiresOn: idToken.Expiry, + CreatedAt: &created, + ExpiresOn: &idToken.Expiry, }, nil } diff --git a/providers/google.go b/providers/google.go index 1406855ba0..5aeb6e2d0a 100644 --- a/providers/google.go +++ b/providers/google.go @@ -153,11 +153,14 @@ func (p *GoogleProvider) Redeem(ctx context.Context, redirectURL, code string) ( if err != nil { return } + + created := time.Now() + expires := time.Now().Add(time.Duration(jsonResponse.ExpiresIn) * time.Second).Truncate(time.Second) s = &sessions.SessionState{ AccessToken: jsonResponse.AccessToken, IDToken: jsonResponse.IDToken, - CreatedAt: time.Now(), - ExpiresOn: time.Now().Add(time.Duration(jsonResponse.ExpiresIn) * time.Second).Truncate(time.Second), + CreatedAt: &created, + ExpiresOn: &expires, RefreshToken: jsonResponse.RefreshToken, Email: c.Email, User: c.Subject, @@ -245,7 +248,7 @@ func (p *GoogleProvider) ValidateGroup(email string) bool { // RefreshSessionIfNeeded checks if the session has expired and uses the // RefreshToken to fetch a new ID token if required func (p *GoogleProvider) RefreshSessionIfNeeded(ctx context.Context, s *sessions.SessionState) (bool, error) { - if s == nil || s.ExpiresOn.After(time.Now()) || s.RefreshToken == "" { + if s == nil || (s.ExpiresOn != nil && s.ExpiresOn.After(time.Now())) || s.RefreshToken == "" { return false, nil } @@ -260,9 +263,10 @@ func (p *GoogleProvider) RefreshSessionIfNeeded(ctx context.Context, s *sessions } origExpiration := s.ExpiresOn + expires := time.Now().Add(duration).Truncate(time.Second) s.AccessToken = newToken s.IDToken = newIDToken - s.ExpiresOn = time.Now().Add(duration).Truncate(time.Second) + s.ExpiresOn = &expires logger.Printf("refreshed access token %s (expired on %s)", s, origExpiration) return true, nil } diff --git a/providers/logingov.go b/providers/logingov.go index 6f98d0cca0..460271726b 100644 --- a/providers/logingov.go +++ b/providers/logingov.go @@ -250,12 +250,15 @@ func (p *LoginGovProvider) Redeem(ctx context.Context, redirectURL, code string) return } + created := time.Now() + expires := time.Now().Add(time.Duration(jsonResponse.ExpiresIn) * time.Second).Truncate(time.Second) + // Store the data that we found in the session state s = &sessions.SessionState{ AccessToken: jsonResponse.AccessToken, IDToken: jsonResponse.IDToken, - CreatedAt: time.Now(), - ExpiresOn: time.Now().Add(time.Duration(jsonResponse.ExpiresIn) * time.Second).Truncate(time.Second), + CreatedAt: &created, + ExpiresOn: &expires, Email: email, } return diff --git a/providers/oidc.go b/providers/oidc.go index 1b6758b9d5..f419b5d5fd 100644 --- a/providers/oidc.go +++ b/providers/oidc.go @@ -72,7 +72,7 @@ func (p *OIDCProvider) Redeem(ctx context.Context, redirectURL, code string) (s // RefreshSessionIfNeeded checks if the session has expired and uses the // RefreshToken to fetch a new Access Token (and optional ID token) if required func (p *OIDCProvider) RefreshSessionIfNeeded(ctx context.Context, s *sessions.SessionState) (bool, error) { - if s == nil || s.ExpiresOn.After(time.Now()) || s.RefreshToken == "" { + if s == nil || (s.ExpiresOn != nil && s.ExpiresOn.After(time.Now())) || s.RefreshToken == "" { return false, nil } @@ -163,10 +163,11 @@ func (p *OIDCProvider) createSessionState(ctx context.Context, token *oauth2.Tok } } + created := time.Now() newSession.AccessToken = token.AccessToken newSession.RefreshToken = token.RefreshToken - newSession.CreatedAt = time.Now() - newSession.ExpiresOn = token.Expiry + newSession.CreatedAt = &created + newSession.ExpiresOn = &token.Expiry return newSession, nil } @@ -179,7 +180,7 @@ func (p *OIDCProvider) CreateSessionStateFromBearerToken(ctx context.Context, ra newSession.AccessToken = rawIDToken newSession.IDToken = rawIDToken newSession.RefreshToken = "" - newSession.ExpiresOn = idToken.Expiry + newSession.ExpiresOn = &idToken.Expiry return newSession, nil } diff --git a/providers/oidc_test.go b/providers/oidc_test.go index 823af30c56..c5d6b5212d 100644 --- a/providers/oidc_test.go +++ b/providers/oidc_test.go @@ -204,8 +204,8 @@ func TestOIDCProviderRefreshSessionIfNeededWithoutIdToken(t *testing.T) { existingSession := &sessions.SessionState{ AccessToken: "changeit", IDToken: idToken, - CreatedAt: time.Time{}, - ExpiresOn: time.Time{}, + CreatedAt: nil, + ExpiresOn: nil, RefreshToken: refreshToken, Email: "janedoe@example.com", User: "11223344", @@ -238,8 +238,8 @@ func TestOIDCProviderRefreshSessionIfNeededWithIdToken(t *testing.T) { existingSession := &sessions.SessionState{ AccessToken: "changeit", IDToken: "changeit", - CreatedAt: time.Time{}, - ExpiresOn: time.Time{}, + CreatedAt: nil, + ExpiresOn: nil, RefreshToken: refreshToken, Email: "changeit", User: "changeit", diff --git a/providers/provider_default.go b/providers/provider_default.go index 141261a16f..74335e11c3 100644 --- a/providers/provider_default.go +++ b/providers/provider_default.go @@ -81,7 +81,8 @@ func (p *ProviderData) Redeem(ctx context.Context, redirectURL, code string) (s return } if a := v.Get("access_token"); a != "" { - s = &sessions.SessionState{AccessToken: a, CreatedAt: time.Now()} + created := time.Now() + s = &sessions.SessionState{AccessToken: a, CreatedAt: &created} } else { err = fmt.Errorf("no access token found %s", body) } @@ -168,7 +169,7 @@ func (p *ProviderData) CreateSessionStateFromBearerToken(ctx context.Context, ra newSession.AccessToken = rawIDToken newSession.IDToken = rawIDToken newSession.RefreshToken = "" - newSession.ExpiresOn = idToken.Expiry + newSession.ExpiresOn = &idToken.Expiry return newSession, nil } diff --git a/providers/provider_default_test.go b/providers/provider_default_test.go index 4d8a8306d2..658918c4fd 100644 --- a/providers/provider_default_test.go +++ b/providers/provider_default_test.go @@ -11,8 +11,10 @@ import ( func TestRefresh(t *testing.T) { p := &ProviderData{} + + expires := time.Now().Add(time.Duration(-11) * time.Minute) refreshed, err := p.RefreshSessionIfNeeded(context.Background(), &sessions.SessionState{ - ExpiresOn: time.Now().Add(time.Duration(-11) * time.Minute), + ExpiresOn: &expires, }) assert.Equal(t, false, refreshed) assert.Equal(t, nil, err) From d995817d9fdd0c4300ffea351220c668ccd3b2ca Mon Sep 17 00:00:00 2001 From: Joel Speed Date: Fri, 8 May 2020 18:11:36 +0100 Subject: [PATCH 2/3] Use EncrpytInto/DecryptInto to reduce sessionstate --- pkg/apis/sessions/session_state.go | 70 ++++++++---------------------- pkg/encryption/cipher.go | 27 ++++++++++++ pkg/encryption/cipher_test.go | 22 ++++++++++ 3 files changed, 66 insertions(+), 53 deletions(-) diff --git a/pkg/apis/sessions/session_state.go b/pkg/apis/sessions/session_state.go index d665303c85..1bb8ff035e 100644 --- a/pkg/apis/sessions/session_state.go +++ b/pkg/apis/sessions/session_state.go @@ -67,39 +67,15 @@ func (s *SessionState) EncodeSessionState(c *encryption.Cipher) (string, error) ss.PreferredUsername = s.PreferredUsername } else { ss = *s - var err error - if ss.Email != "" { - ss.Email, err = c.Encrypt(ss.Email) - if err != nil { - return "", err - } - } - if ss.User != "" { - ss.User, err = c.Encrypt(ss.User) - if err != nil { - return "", err - } - } - if ss.PreferredUsername != "" { - ss.PreferredUsername, err = c.Encrypt(ss.PreferredUsername) - if err != nil { - return "", err - } - } - if ss.AccessToken != "" { - ss.AccessToken, err = c.Encrypt(ss.AccessToken) - if err != nil { - return "", err - } - } - if ss.IDToken != "" { - ss.IDToken, err = c.Encrypt(ss.IDToken) - if err != nil { - return "", err - } - } - if ss.RefreshToken != "" { - ss.RefreshToken, err = c.Encrypt(ss.RefreshToken) + for _, s := range []*string{ + &ss.Email, + &ss.User, + &ss.PreferredUsername, + &ss.AccessToken, + &ss.IDToken, + &ss.RefreshToken, + } { + err := c.EncryptInto(s) if err != nil { return "", err } @@ -140,26 +116,14 @@ func DecodeSessionState(v string, c *encryption.Cipher) (*SessionState, error) { ss.User = decryptedUser } } - if ss.PreferredUsername != "" { - ss.PreferredUsername, err = c.Decrypt(ss.PreferredUsername) - if err != nil { - return nil, err - } - } - if ss.AccessToken != "" { - ss.AccessToken, err = c.Decrypt(ss.AccessToken) - if err != nil { - return nil, err - } - } - if ss.IDToken != "" { - ss.IDToken, err = c.Decrypt(ss.IDToken) - if err != nil { - return nil, err - } - } - if ss.RefreshToken != "" { - ss.RefreshToken, err = c.Decrypt(ss.RefreshToken) + + for _, s := range []*string{ + &ss.PreferredUsername, + &ss.AccessToken, + &ss.IDToken, + &ss.RefreshToken, + } { + err := c.DecryptInto(s) if err != nil { return nil, err } diff --git a/pkg/encryption/cipher.go b/pkg/encryption/cipher.go index 2dcbee6bf2..4eb42b0333 100644 --- a/pkg/encryption/cipher.go +++ b/pkg/encryption/cipher.go @@ -156,3 +156,30 @@ func (c *Cipher) Decrypt(s string) (string, error) { return string(encrypted), nil } + +// EncryptInto encrypts the value and stores it back in the string pointer +func (c *Cipher) EncryptInto(s *string) error { + return into(c.Encrypt, s) +} + +// DecryptInto decrypts the value and stores it back in the string pointer +func (c *Cipher) DecryptInto(s *string) error { + return into(c.Decrypt, s) +} + +// codecFunc is a function that takes a string and encodes/decodes it +type codecFunc func(string) (string, error) + +func into(f codecFunc, s *string) error { + // Do not encrypt/decrypt nil or empty strings + if s == nil || *s == "" { + return nil + } + + d, err := f(*s) + if err != nil { + return err + } + *s = d + return nil +} diff --git a/pkg/encryption/cipher_test.go b/pkg/encryption/cipher_test.go index 76bfc1bc31..aed529f354 100644 --- a/pkg/encryption/cipher_test.go +++ b/pkg/encryption/cipher_test.go @@ -133,3 +133,25 @@ func TestEncodeAndDecodeAccessTokenB64(t *testing.T) { assert.NotEqual(t, token, encoded) assert.Equal(t, token, decoded) } + +func TestEncodeIntoAndDecodeIntoAccessToken(t *testing.T) { + const secret = "0123456789abcdefghijklmnopqrstuv" + c, err := NewCipher([]byte(secret)) + assert.Equal(t, nil, err) + + token := "my access token" + originalToken := token + + assert.Equal(t, nil, c.EncryptInto(&token)) + assert.NotEqual(t, originalToken, token) + + assert.Equal(t, nil, c.DecryptInto(&token)) + assert.Equal(t, originalToken, token) + + // Check no errors with empty or nil strings + empty := "" + assert.Equal(t, nil, c.EncryptInto(&empty)) + assert.Equal(t, nil, c.DecryptInto(&empty)) + assert.Equal(t, nil, c.EncryptInto(nil)) + assert.Equal(t, nil, c.DecryptInto(nil)) +} From 1a0945ea11fc8bd911231d4ba98aed83aa0fc52c Mon Sep 17 00:00:00 2001 From: Joel Speed Date: Sun, 10 May 2020 10:37:23 +0100 Subject: [PATCH 3/3] Add Session State Improvements Changelog entry --- CHANGELOG.md | 1 + 1 file changed, 1 insertion(+) diff --git a/CHANGELOG.md b/CHANGELOG.md index 9fea049529..4905a81b39 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -51,6 +51,7 @@ ## Changes since v5.1.1 +- [#536](https://github.com/oauth2-proxy/oauth2-proxy/pull/536) Improvements to Session State code (@JoelSpeed) - [#574](https://github.com/oauth2-proxy/oauth2-proxy/pull/574) render error page on 502 proxy status (@amnay-mo) - [#559](https://github.com/oauth2-proxy/oauth2-proxy/pull/559) Rename cookie-domain config to cookie-domains (@JoelSpeed) - [#569](https://github.com/oauth2-proxy/oauth2-proxy/pull/569) Updated autocompletion for `--` long options. (@Izzette)