From 5ae9b189aa72b183976f88ac0f44352f3145048b Mon Sep 17 00:00:00 2001
From: Joel Speed
Date: Fri, 8 May 2020 16:57:40 +0100
Subject: [PATCH 1/3] Drop SessionStateJSON wrapper
---
oauthproxy_test.go | 38 +++++++++------
pkg/apis/sessions/session_state.go | 62 +++++++------------------
pkg/apis/sessions/session_state_test.go | 39 +++++++++-------
pkg/sessions/cookie/session_store.go | 7 +--
pkg/sessions/redis/redis_store.go | 7 +--
pkg/sessions/session_store_test.go | 18 ++++---
providers/azure.go | 6 ++-
providers/gitlab.go | 7 +--
providers/google.go | 12 +++--
providers/logingov.go | 7 ++-
providers/oidc.go | 9 ++--
providers/oidc_test.go | 8 ++--
providers/provider_default.go | 5 +-
providers/provider_default_test.go | 4 +-
14 files changed, 118 insertions(+), 111 deletions(-)
diff --git a/oauthproxy_test.go b/oauthproxy_test.go
index 845dcef1c4..3238532369 100644
--- a/oauthproxy_test.go
+++ b/oauthproxy_test.go
@@ -484,8 +484,7 @@ func TestBasicAuthPassword(t *testing.T) {
})
rw := httptest.NewRecorder()
- req, _ := http.NewRequest("GET", "/oauth2/callback?code=callback_code&state=nonce:",
- strings.NewReader(""))
+ req, _ := http.NewRequest("GET", "/oauth2/callback?code=callback_code&state=nonce:", strings.NewReader(""))
req.AddCookie(proxy.MakeCSRFCookie(req, "nonce", proxy.CookieExpire, time.Now()))
proxy.ServeHTTP(rw, req)
if rw.Code >= 400 {
@@ -541,11 +540,12 @@ func TestBasicAuthWithEmail(t *testing.T) {
expectedEmailHeader := "Basic " + base64.StdEncoding.EncodeToString([]byte(emailAddress+":"+opts.BasicAuthPassword))
expectedUserHeader := "Basic " + base64.StdEncoding.EncodeToString([]byte(userName+":"+opts.BasicAuthPassword))
+ created := time.Now()
session := &sessions.SessionState{
User: userName,
Email: emailAddress,
AccessToken: "oauth_token",
- CreatedAt: time.Now(),
+ CreatedAt: &created,
}
{
rw := httptest.NewRecorder()
@@ -582,11 +582,12 @@ func TestPassUserHeadersWithEmail(t *testing.T) {
const emailAddress = "john.doe@example.com"
const userName = "9fcab5c9b889a557"
+ created := time.Now()
session := &sessions.SessionState{
User: userName,
Email: emailAddress,
AccessToken: "oauth_token",
- CreatedAt: time.Now(),
+ CreatedAt: &created,
}
{
rw := httptest.NewRecorder()
@@ -959,7 +960,8 @@ func (p *ProcessCookieTest) LoadCookiedSession() (*sessions.SessionState, error)
func TestLoadCookiedSession(t *testing.T) {
pcTest := NewProcessCookieTestWithDefaults()
- startSession := &sessions.SessionState{Email: "john.doe@example.com", AccessToken: "my_access_token", CreatedAt: time.Now()}
+ created := time.Now()
+ startSession := &sessions.SessionState{Email: "john.doe@example.com", AccessToken: "my_access_token", CreatedAt: &created}
pcTest.SaveSession(startSession)
session, err := pcTest.LoadCookiedSession()
@@ -985,7 +987,7 @@ func TestProcessCookieRefreshNotSet(t *testing.T) {
})
reference := time.Now().Add(time.Duration(-2) * time.Hour)
- startSession := &sessions.SessionState{Email: "michael.bland@gsa.gov", AccessToken: "my_access_token", CreatedAt: reference}
+ startSession := &sessions.SessionState{Email: "michael.bland@gsa.gov", AccessToken: "my_access_token", CreatedAt: &reference}
pcTest.SaveSession(startSession)
session, err := pcTest.LoadCookiedSession()
@@ -1001,7 +1003,7 @@ func TestProcessCookieFailIfCookieExpired(t *testing.T) {
opts.Cookie.Expire = time.Duration(24) * time.Hour
})
reference := time.Now().Add(time.Duration(25) * time.Hour * -1)
- startSession := &sessions.SessionState{Email: "michael.bland@gsa.gov", AccessToken: "my_access_token", CreatedAt: reference}
+ startSession := &sessions.SessionState{Email: "michael.bland@gsa.gov", AccessToken: "my_access_token", CreatedAt: &reference}
pcTest.SaveSession(startSession)
session, err := pcTest.LoadCookiedSession()
@@ -1016,7 +1018,7 @@ func TestProcessCookieFailIfRefreshSetAndCookieExpired(t *testing.T) {
opts.Cookie.Expire = time.Duration(24) * time.Hour
})
reference := time.Now().Add(time.Duration(25) * time.Hour * -1)
- startSession := &sessions.SessionState{Email: "michael.bland@gsa.gov", AccessToken: "my_access_token", CreatedAt: reference}
+ startSession := &sessions.SessionState{Email: "michael.bland@gsa.gov", AccessToken: "my_access_token", CreatedAt: &reference}
pcTest.SaveSession(startSession)
pcTest.proxy.CookieRefresh = time.Hour
@@ -1062,8 +1064,9 @@ func NewAuthOnlyEndpointTest(modifiers ...OptionsModifier) *ProcessCookieTest {
func TestAuthOnlyEndpointAccepted(t *testing.T) {
test := NewAuthOnlyEndpointTest()
+ created := time.Now()
startSession := &sessions.SessionState{
- Email: "michael.bland@gsa.gov", AccessToken: "my_access_token", CreatedAt: time.Now()}
+ Email: "michael.bland@gsa.gov", AccessToken: "my_access_token", CreatedAt: &created}
test.SaveSession(startSession)
test.proxy.ServeHTTP(test.rw, test.req)
@@ -1087,7 +1090,7 @@ func TestAuthOnlyEndpointUnauthorizedOnExpiration(t *testing.T) {
})
reference := time.Now().Add(time.Duration(25) * time.Hour * -1)
startSession := &sessions.SessionState{
- Email: "michael.bland@gsa.gov", AccessToken: "my_access_token", CreatedAt: reference}
+ Email: "michael.bland@gsa.gov", AccessToken: "my_access_token", CreatedAt: &reference}
test.SaveSession(startSession)
test.proxy.ServeHTTP(test.rw, test.req)
@@ -1098,8 +1101,9 @@ func TestAuthOnlyEndpointUnauthorizedOnExpiration(t *testing.T) {
func TestAuthOnlyEndpointUnauthorizedOnEmailValidationFailure(t *testing.T) {
test := NewAuthOnlyEndpointTest()
+ created := time.Now()
startSession := &sessions.SessionState{
- Email: "michael.bland@gsa.gov", AccessToken: "my_access_token", CreatedAt: time.Now()}
+ Email: "michael.bland@gsa.gov", AccessToken: "my_access_token", CreatedAt: &created}
test.SaveSession(startSession)
test.validateUser = false
@@ -1129,8 +1133,9 @@ func TestAuthOnlyEndpointSetXAuthRequestHeaders(t *testing.T) {
pcTest.req, _ = http.NewRequest("GET",
pcTest.opts.ProxyPrefix+"/auth", nil)
+ created := time.Now()
startSession := &sessions.SessionState{
- User: "oauth_user", Email: "oauth_user@example.com", AccessToken: "oauth_token", CreatedAt: time.Now()}
+ User: "oauth_user", Email: "oauth_user@example.com", AccessToken: "oauth_token", CreatedAt: &created}
pcTest.SaveSession(startSession)
pcTest.proxy.ServeHTTP(pcTest.rw, pcTest.req)
@@ -1160,8 +1165,9 @@ func TestAuthOnlyEndpointSetBasicAuthTrueRequestHeaders(t *testing.T) {
pcTest.req, _ = http.NewRequest("GET",
pcTest.opts.ProxyPrefix+"/auth", nil)
+ created := time.Now()
startSession := &sessions.SessionState{
- User: "oauth_user", Email: "oauth_user@example.com", AccessToken: "oauth_token", CreatedAt: time.Now()}
+ User: "oauth_user", Email: "oauth_user@example.com", AccessToken: "oauth_token", CreatedAt: &created}
pcTest.SaveSession(startSession)
pcTest.proxy.ServeHTTP(pcTest.rw, pcTest.req)
@@ -1193,8 +1199,9 @@ func TestAuthOnlyEndpointSetBasicAuthFalseRequestHeaders(t *testing.T) {
pcTest.req, _ = http.NewRequest("GET",
pcTest.opts.ProxyPrefix+"/auth", nil)
+ created := time.Now()
startSession := &sessions.SessionState{
- User: "oauth_user", Email: "oauth_user@example.com", AccessToken: "oauth_token", CreatedAt: time.Now()}
+ User: "oauth_user", Email: "oauth_user@example.com", AccessToken: "oauth_token", CreatedAt: &created}
pcTest.SaveSession(startSession)
pcTest.proxy.ServeHTTP(pcTest.rw, pcTest.req)
@@ -1569,10 +1576,11 @@ func TestGetJwtSession(t *testing.T) {
}
// Bearer
+ expires := time.Unix(1912151821, 0)
session, _ := test.proxy.GetJwtSession(test.req)
assert.Equal(t, session.User, "john@example.com")
assert.Equal(t, session.Email, "john@example.com")
- assert.Equal(t, session.ExpiresOn, time.Unix(1912151821, 0))
+ assert.Equal(t, session.ExpiresOn, &expires)
assert.Equal(t, session.IDToken, goodJwt)
test.proxy.ServeHTTP(test.rw, test.req)
diff --git a/pkg/apis/sessions/session_state.go b/pkg/apis/sessions/session_state.go
index a09f01c132..d665303c85 100644
--- a/pkg/apis/sessions/session_state.go
+++ b/pkg/apis/sessions/session_state.go
@@ -2,7 +2,6 @@ package sessions
import (
"encoding/json"
- "errors"
"fmt"
"time"
@@ -11,26 +10,19 @@ import (
// SessionState is used to store information about the currently authenticated user session
type SessionState struct {
- AccessToken string `json:",omitempty"`
- IDToken string `json:",omitempty"`
- CreatedAt time.Time `json:"-"`
- ExpiresOn time.Time `json:"-"`
- RefreshToken string `json:",omitempty"`
- Email string `json:",omitempty"`
- User string `json:",omitempty"`
- PreferredUsername string `json:",omitempty"`
-}
-
-// SessionStateJSON is used to encode SessionState into JSON without exposing time.Time zero value
-type SessionStateJSON struct {
- *SessionState
- CreatedAt *time.Time `json:",omitempty"`
- ExpiresOn *time.Time `json:",omitempty"`
+ AccessToken string `json:",omitempty"`
+ IDToken string `json:",omitempty"`
+ CreatedAt *time.Time `json:",omitempty"`
+ ExpiresOn *time.Time `json:",omitempty"`
+ RefreshToken string `json:",omitempty"`
+ Email string `json:",omitempty"`
+ User string `json:",omitempty"`
+ PreferredUsername string `json:",omitempty"`
}
// IsExpired checks whether the session has expired
func (s *SessionState) IsExpired() bool {
- if !s.ExpiresOn.IsZero() && s.ExpiresOn.Before(time.Now()) {
+ if s.ExpiresOn != nil && !s.ExpiresOn.IsZero() && s.ExpiresOn.Before(time.Now()) {
return true
}
return false
@@ -38,8 +30,8 @@ func (s *SessionState) IsExpired() bool {
// Age returns the age of a session
func (s *SessionState) Age() time.Duration {
- if !s.CreatedAt.IsZero() {
- return time.Now().Truncate(time.Second).Sub(s.CreatedAt)
+ if s.CreatedAt != nil && !s.CreatedAt.IsZero() {
+ return time.Now().Truncate(time.Second).Sub(*s.CreatedAt)
}
return 0
}
@@ -113,42 +105,22 @@ func (s *SessionState) EncodeSessionState(c *encryption.Cipher) (string, error)
}
}
}
- // Embed SessionState and ExpiresOn pointer into SessionStateJSON
- ssj := &SessionStateJSON{SessionState: &ss}
- if !ss.CreatedAt.IsZero() {
- ssj.CreatedAt = &ss.CreatedAt
- }
- if !ss.ExpiresOn.IsZero() {
- ssj.ExpiresOn = &ss.ExpiresOn
- }
- b, err := json.Marshal(ssj)
+
+ b, err := json.Marshal(ss)
return string(b), err
}
// DecodeSessionState decodes the session cookie string into a SessionState
func DecodeSessionState(v string, c *encryption.Cipher) (*SessionState, error) {
- var ssj SessionStateJSON
- var ss *SessionState
- err := json.Unmarshal([]byte(v), &ssj)
+ var ss SessionState
+ err := json.Unmarshal([]byte(v), &ss)
if err != nil {
return nil, fmt.Errorf("error unmarshalling session: %w", err)
}
- if ssj.SessionState == nil {
- return nil, errors.New("expected session state to not be nil")
- }
-
- // Extract SessionState and CreatedAt,ExpiresOn value from SessionStateJSON
- ss = ssj.SessionState
- if ssj.CreatedAt != nil {
- ss.CreatedAt = *ssj.CreatedAt
- }
- if ssj.ExpiresOn != nil {
- ss.ExpiresOn = *ssj.ExpiresOn
- }
if c == nil {
// Load only Email and User when cipher is unavailable
- ss = &SessionState{
+ ss = SessionState{
Email: ss.Email,
User: ss.User,
PreferredUsername: ss.PreferredUsername,
@@ -193,5 +165,5 @@ func DecodeSessionState(v string, c *encryption.Cipher) (*SessionState, error) {
}
}
}
- return ss, nil
+ return &ss, nil
}
diff --git a/pkg/apis/sessions/session_state_test.go b/pkg/apis/sessions/session_state_test.go
index 94c624bfa2..529656fe4b 100644
--- a/pkg/apis/sessions/session_state_test.go
+++ b/pkg/apis/sessions/session_state_test.go
@@ -13,6 +13,10 @@ import (
const secret = "0123456789abcdefghijklmnopqrstuv"
const altSecret = "0000000000abcdefghijklmnopqrstuv"
+func timePtr(t time.Time) *time.Time {
+ return &t
+}
+
func TestSessionStateSerialization(t *testing.T) {
c, err := encryption.NewCipher([]byte(secret))
assert.Equal(t, nil, err)
@@ -23,8 +27,8 @@ func TestSessionStateSerialization(t *testing.T) {
PreferredUsername: "user",
AccessToken: "token1234",
IDToken: "rawtoken1234",
- CreatedAt: time.Now(),
- ExpiresOn: time.Now().Add(time.Duration(1) * time.Hour),
+ CreatedAt: timePtr(time.Now()),
+ ExpiresOn: timePtr(time.Now().Add(time.Duration(1) * time.Hour)),
RefreshToken: "refresh4321",
}
encoded, err := s.EncodeSessionState(c)
@@ -66,8 +70,8 @@ func TestSessionStateSerializationWithUser(t *testing.T) {
PreferredUsername: "ju",
Email: "user@domain.com",
AccessToken: "token1234",
- CreatedAt: time.Now(),
- ExpiresOn: time.Now().Add(time.Duration(1) * time.Hour),
+ CreatedAt: timePtr(time.Now()),
+ ExpiresOn: timePtr(time.Now().Add(time.Duration(1) * time.Hour)),
RefreshToken: "refresh4321",
}
encoded, err := s.EncodeSessionState(c)
@@ -102,8 +106,8 @@ func TestSessionStateSerializationNoCipher(t *testing.T) {
Email: "user@domain.com",
PreferredUsername: "user",
AccessToken: "token1234",
- CreatedAt: time.Now(),
- ExpiresOn: time.Now().Add(time.Duration(1) * time.Hour),
+ CreatedAt: timePtr(time.Now()),
+ ExpiresOn: timePtr(time.Now().Add(time.Duration(1) * time.Hour)),
RefreshToken: "refresh4321",
}
encoded, err := s.EncodeSessionState(nil)
@@ -125,8 +129,8 @@ func TestSessionStateSerializationNoCipherWithUser(t *testing.T) {
Email: "user@domain.com",
PreferredUsername: "user",
AccessToken: "token1234",
- CreatedAt: time.Now(),
- ExpiresOn: time.Now().Add(time.Duration(1) * time.Hour),
+ CreatedAt: timePtr(time.Now()),
+ ExpiresOn: timePtr(time.Now().Add(time.Duration(1) * time.Hour)),
RefreshToken: "refresh4321",
}
encoded, err := s.EncodeSessionState(nil)
@@ -143,10 +147,10 @@ func TestSessionStateSerializationNoCipherWithUser(t *testing.T) {
}
func TestExpired(t *testing.T) {
- s := &sessions.SessionState{ExpiresOn: time.Now().Add(time.Duration(-1) * time.Minute)}
+ s := &sessions.SessionState{ExpiresOn: timePtr(time.Now().Add(time.Duration(-1) * time.Minute))}
assert.Equal(t, true, s.IsExpired())
- s = &sessions.SessionState{ExpiresOn: time.Now().Add(time.Duration(1) * time.Minute)}
+ s = &sessions.SessionState{ExpiresOn: timePtr(time.Now().Add(time.Duration(1) * time.Minute))}
assert.Equal(t, false, s.IsExpired())
s = &sessions.SessionState{}
@@ -182,8 +186,8 @@ func TestEncodeSessionState(t *testing.T) {
User: "just-user",
AccessToken: "token1234",
IDToken: "rawtoken1234",
- CreatedAt: c,
- ExpiresOn: e,
+ CreatedAt: &c,
+ ExpiresOn: &e,
RefreshToken: "refresh4321",
},
Encoded: `{"Email":"user@domain.com","User":"just-user"}`,
@@ -249,8 +253,8 @@ func TestDecodeSessionState(t *testing.T) {
User: "just-user",
AccessToken: "token1234",
IDToken: "rawtoken1234",
- CreatedAt: created,
- ExpiresOn: e,
+ CreatedAt: &created,
+ ExpiresOn: &e,
RefreshToken: "refresh4321",
},
Encoded: fmt.Sprintf(`{"Email":"FsKKYrTWZWrxSOAqA/fTNAUZS5QWCqOBjuAbBlbVOw==","User":"rT6JP3dxQhxUhkWrrd7yt6c1mDVyQCVVxw==","AccessToken":"I6s+ml+/MldBMgHIiC35BTKTh57skGX24w==","IDToken":"xojNdyyjB1HgYWh6XMtXY/Ph5eCVxa1cNsklJw==","RefreshToken":"qEX0x6RmASxo4dhlBG6YuRs9Syn/e9sHu/+K","CreatedAt":%s,"ExpiresOn":%s}`, createdString, eString),
@@ -291,7 +295,10 @@ func TestDecodeSessionState(t *testing.T) {
assert.Equal(t, tc.AccessToken, ss.AccessToken)
assert.Equal(t, tc.RefreshToken, ss.RefreshToken)
assert.Equal(t, tc.IDToken, ss.IDToken)
- assert.Equal(t, tc.ExpiresOn.Unix(), ss.ExpiresOn.Unix())
+ if tc.ExpiresOn != nil {
+ assert.NotEqual(t, nil, ss.ExpiresOn)
+ assert.Equal(t, tc.ExpiresOn.Unix(), ss.ExpiresOn.Unix())
+ }
}
}
}
@@ -303,6 +310,6 @@ func TestSessionStateAge(t *testing.T) {
assert.Equal(t, time.Duration(0), ss.Age())
// Set CreatedAt to 1 hour ago
- ss.CreatedAt = time.Now().Add(-1 * time.Hour)
+ ss.CreatedAt = timePtr(time.Now().Add(-1 * time.Hour))
assert.Equal(t, time.Hour, ss.Age().Round(time.Minute))
}
diff --git a/pkg/sessions/cookie/session_store.go b/pkg/sessions/cookie/session_store.go
index d97125c3e2..1b88e02762 100644
--- a/pkg/sessions/cookie/session_store.go
+++ b/pkg/sessions/cookie/session_store.go
@@ -34,14 +34,15 @@ type SessionStore struct {
// Save takes a sessions.SessionState and stores the information from it
// within Cookies set on the HTTP response writer
func (s *SessionStore) Save(rw http.ResponseWriter, req *http.Request, ss *sessions.SessionState) error {
- if ss.CreatedAt.IsZero() {
- ss.CreatedAt = time.Now()
+ if ss.CreatedAt == nil || ss.CreatedAt.IsZero() {
+ now := time.Now()
+ ss.CreatedAt = &now
}
value, err := cookieForSession(ss, s.CookieCipher)
if err != nil {
return err
}
- s.setSessionCookie(rw, req, value, ss.CreatedAt)
+ s.setSessionCookie(rw, req, value, *ss.CreatedAt)
return nil
}
diff --git a/pkg/sessions/redis/redis_store.go b/pkg/sessions/redis/redis_store.go
index 7737b960b4..b737b6c62c 100644
--- a/pkg/sessions/redis/redis_store.go
+++ b/pkg/sessions/redis/redis_store.go
@@ -111,8 +111,9 @@ func newRedisCmdable(opts options.RedisStoreOptions) (Client, error) {
// Save takes a sessions.SessionState and stores the information from it
// to redies, and adds a new ticket cookie on the HTTP response writer
func (store *SessionStore) Save(rw http.ResponseWriter, req *http.Request, s *sessions.SessionState) error {
- if s.CreatedAt.IsZero() {
- s.CreatedAt = time.Now()
+ if s.CreatedAt == nil || s.CreatedAt.IsZero() {
+ now := time.Now()
+ s.CreatedAt = &now
}
// Old sessions that we are refreshing would have a request cookie
@@ -132,7 +133,7 @@ func (store *SessionStore) Save(rw http.ResponseWriter, req *http.Request, s *se
req,
ticketString,
store.CookieOptions.Expire,
- s.CreatedAt,
+ *s.CreatedAt,
)
http.SetCookie(rw, ticketCookie)
diff --git a/pkg/sessions/session_store_test.go b/pkg/sessions/session_store_test.go
index 68cfd125c1..60a86cef8a 100644
--- a/pkg/sessions/session_store_test.go
+++ b/pkg/sessions/session_store_test.go
@@ -15,6 +15,7 @@ import (
sessionsapi "github.com/oauth2-proxy/oauth2-proxy/pkg/apis/sessions"
cookiesapi "github.com/oauth2-proxy/oauth2-proxy/pkg/cookies"
"github.com/oauth2-proxy/oauth2-proxy/pkg/encryption"
+ "github.com/oauth2-proxy/oauth2-proxy/pkg/logger"
"github.com/oauth2-proxy/oauth2-proxy/pkg/sessions"
sessionscookie "github.com/oauth2-proxy/oauth2-proxy/pkg/sessions/cookie"
"github.com/oauth2-proxy/oauth2-proxy/pkg/sessions/redis"
@@ -23,6 +24,8 @@ import (
)
func TestSessionStore(t *testing.T) {
+ logger.SetOutput(GinkgoWriter)
+
RegisterFailHandler(Fail)
RunSpecs(t, "SessionStore")
}
@@ -253,16 +256,16 @@ var _ = Describe("NewSessionStore", func() {
// Can't compare time.Time using Equal() so remove ExpiresOn from sessions
l := *loadedSession
- l.CreatedAt = time.Time{}
- l.ExpiresOn = time.Time{}
+ l.CreatedAt = nil
+ l.ExpiresOn = nil
s := *session
- s.CreatedAt = time.Time{}
- s.ExpiresOn = time.Time{}
+ s.CreatedAt = nil
+ s.ExpiresOn = nil
Expect(l).To(Equal(s))
// Compare time.Time separately
- Expect(loadedSession.CreatedAt.Equal(session.CreatedAt)).To(BeTrue())
- Expect(loadedSession.ExpiresOn.Equal(session.ExpiresOn)).To(BeTrue())
+ Expect(loadedSession.CreatedAt.Equal(*session.CreatedAt)).To(BeTrue())
+ Expect(loadedSession.ExpiresOn.Equal(*session.ExpiresOn)).To(BeTrue())
}
})
}
@@ -392,10 +395,11 @@ var _ = Describe("NewSessionStore", func() {
SameSite: "",
}
+ expires := time.Now().Add(1 * time.Hour)
session = &sessionsapi.SessionState{
AccessToken: "AccessToken",
IDToken: "IDToken",
- ExpiresOn: time.Now().Add(1 * time.Hour),
+ ExpiresOn: &expires,
RefreshToken: "RefreshToken",
Email: "john.doe@example.com",
User: "john.doe",
diff --git a/providers/azure.go b/providers/azure.go
index 961ff908c9..aea1b0e54c 100644
--- a/providers/azure.go
+++ b/providers/azure.go
@@ -126,11 +126,13 @@ func (p *AzureProvider) Redeem(ctx context.Context, redirectURL, code string) (s
return
}
+ created := time.Now()
+ expires := time.Unix(jsonResponse.ExpiresOn, 0)
s = &sessions.SessionState{
AccessToken: jsonResponse.AccessToken,
IDToken: jsonResponse.IDToken,
- CreatedAt: time.Now(),
- ExpiresOn: time.Unix(jsonResponse.ExpiresOn, 0),
+ CreatedAt: &created,
+ ExpiresOn: &expires,
RefreshToken: jsonResponse.RefreshToken,
}
return
diff --git a/providers/gitlab.go b/providers/gitlab.go
index beeb6b9810..17c5df88e8 100644
--- a/providers/gitlab.go
+++ b/providers/gitlab.go
@@ -67,7 +67,7 @@ func (p *GitLabProvider) Redeem(ctx context.Context, redirectURL, code string) (
// RefreshSessionIfNeeded checks if the session has expired and uses the
// RefreshToken to fetch a new ID token if required
func (p *GitLabProvider) RefreshSessionIfNeeded(ctx context.Context, s *sessions.SessionState) (bool, error) {
- if s == nil || s.ExpiresOn.After(time.Now()) || s.RefreshToken == "" {
+ if s == nil || (s.ExpiresOn != nil && s.ExpiresOn.After(time.Now())) || s.RefreshToken == "" {
return false, nil
}
@@ -209,12 +209,13 @@ func (p *GitLabProvider) createSessionState(ctx context.Context, token *oauth2.T
return nil, fmt.Errorf("could not verify id_token: %v", err)
}
+ created := time.Now()
return &sessions.SessionState{
AccessToken: token.AccessToken,
IDToken: rawIDToken,
RefreshToken: token.RefreshToken,
- CreatedAt: time.Now(),
- ExpiresOn: idToken.Expiry,
+ CreatedAt: &created,
+ ExpiresOn: &idToken.Expiry,
}, nil
}
diff --git a/providers/google.go b/providers/google.go
index 1406855ba0..5aeb6e2d0a 100644
--- a/providers/google.go
+++ b/providers/google.go
@@ -153,11 +153,14 @@ func (p *GoogleProvider) Redeem(ctx context.Context, redirectURL, code string) (
if err != nil {
return
}
+
+ created := time.Now()
+ expires := time.Now().Add(time.Duration(jsonResponse.ExpiresIn) * time.Second).Truncate(time.Second)
s = &sessions.SessionState{
AccessToken: jsonResponse.AccessToken,
IDToken: jsonResponse.IDToken,
- CreatedAt: time.Now(),
- ExpiresOn: time.Now().Add(time.Duration(jsonResponse.ExpiresIn) * time.Second).Truncate(time.Second),
+ CreatedAt: &created,
+ ExpiresOn: &expires,
RefreshToken: jsonResponse.RefreshToken,
Email: c.Email,
User: c.Subject,
@@ -245,7 +248,7 @@ func (p *GoogleProvider) ValidateGroup(email string) bool {
// RefreshSessionIfNeeded checks if the session has expired and uses the
// RefreshToken to fetch a new ID token if required
func (p *GoogleProvider) RefreshSessionIfNeeded(ctx context.Context, s *sessions.SessionState) (bool, error) {
- if s == nil || s.ExpiresOn.After(time.Now()) || s.RefreshToken == "" {
+ if s == nil || (s.ExpiresOn != nil && s.ExpiresOn.After(time.Now())) || s.RefreshToken == "" {
return false, nil
}
@@ -260,9 +263,10 @@ func (p *GoogleProvider) RefreshSessionIfNeeded(ctx context.Context, s *sessions
}
origExpiration := s.ExpiresOn
+ expires := time.Now().Add(duration).Truncate(time.Second)
s.AccessToken = newToken
s.IDToken = newIDToken
- s.ExpiresOn = time.Now().Add(duration).Truncate(time.Second)
+ s.ExpiresOn = &expires
logger.Printf("refreshed access token %s (expired on %s)", s, origExpiration)
return true, nil
}
diff --git a/providers/logingov.go b/providers/logingov.go
index 6f98d0cca0..460271726b 100644
--- a/providers/logingov.go
+++ b/providers/logingov.go
@@ -250,12 +250,15 @@ func (p *LoginGovProvider) Redeem(ctx context.Context, redirectURL, code string)
return
}
+ created := time.Now()
+ expires := time.Now().Add(time.Duration(jsonResponse.ExpiresIn) * time.Second).Truncate(time.Second)
+
// Store the data that we found in the session state
s = &sessions.SessionState{
AccessToken: jsonResponse.AccessToken,
IDToken: jsonResponse.IDToken,
- CreatedAt: time.Now(),
- ExpiresOn: time.Now().Add(time.Duration(jsonResponse.ExpiresIn) * time.Second).Truncate(time.Second),
+ CreatedAt: &created,
+ ExpiresOn: &expires,
Email: email,
}
return
diff --git a/providers/oidc.go b/providers/oidc.go
index 1b6758b9d5..f419b5d5fd 100644
--- a/providers/oidc.go
+++ b/providers/oidc.go
@@ -72,7 +72,7 @@ func (p *OIDCProvider) Redeem(ctx context.Context, redirectURL, code string) (s
// RefreshSessionIfNeeded checks if the session has expired and uses the
// RefreshToken to fetch a new Access Token (and optional ID token) if required
func (p *OIDCProvider) RefreshSessionIfNeeded(ctx context.Context, s *sessions.SessionState) (bool, error) {
- if s == nil || s.ExpiresOn.After(time.Now()) || s.RefreshToken == "" {
+ if s == nil || (s.ExpiresOn != nil && s.ExpiresOn.After(time.Now())) || s.RefreshToken == "" {
return false, nil
}
@@ -163,10 +163,11 @@ func (p *OIDCProvider) createSessionState(ctx context.Context, token *oauth2.Tok
}
}
+ created := time.Now()
newSession.AccessToken = token.AccessToken
newSession.RefreshToken = token.RefreshToken
- newSession.CreatedAt = time.Now()
- newSession.ExpiresOn = token.Expiry
+ newSession.CreatedAt = &created
+ newSession.ExpiresOn = &token.Expiry
return newSession, nil
}
@@ -179,7 +180,7 @@ func (p *OIDCProvider) CreateSessionStateFromBearerToken(ctx context.Context, ra
newSession.AccessToken = rawIDToken
newSession.IDToken = rawIDToken
newSession.RefreshToken = ""
- newSession.ExpiresOn = idToken.Expiry
+ newSession.ExpiresOn = &idToken.Expiry
return newSession, nil
}
diff --git a/providers/oidc_test.go b/providers/oidc_test.go
index 823af30c56..c5d6b5212d 100644
--- a/providers/oidc_test.go
+++ b/providers/oidc_test.go
@@ -204,8 +204,8 @@ func TestOIDCProviderRefreshSessionIfNeededWithoutIdToken(t *testing.T) {
existingSession := &sessions.SessionState{
AccessToken: "changeit",
IDToken: idToken,
- CreatedAt: time.Time{},
- ExpiresOn: time.Time{},
+ CreatedAt: nil,
+ ExpiresOn: nil,
RefreshToken: refreshToken,
Email: "janedoe@example.com",
User: "11223344",
@@ -238,8 +238,8 @@ func TestOIDCProviderRefreshSessionIfNeededWithIdToken(t *testing.T) {
existingSession := &sessions.SessionState{
AccessToken: "changeit",
IDToken: "changeit",
- CreatedAt: time.Time{},
- ExpiresOn: time.Time{},
+ CreatedAt: nil,
+ ExpiresOn: nil,
RefreshToken: refreshToken,
Email: "changeit",
User: "changeit",
diff --git a/providers/provider_default.go b/providers/provider_default.go
index 141261a16f..74335e11c3 100644
--- a/providers/provider_default.go
+++ b/providers/provider_default.go
@@ -81,7 +81,8 @@ func (p *ProviderData) Redeem(ctx context.Context, redirectURL, code string) (s
return
}
if a := v.Get("access_token"); a != "" {
- s = &sessions.SessionState{AccessToken: a, CreatedAt: time.Now()}
+ created := time.Now()
+ s = &sessions.SessionState{AccessToken: a, CreatedAt: &created}
} else {
err = fmt.Errorf("no access token found %s", body)
}
@@ -168,7 +169,7 @@ func (p *ProviderData) CreateSessionStateFromBearerToken(ctx context.Context, ra
newSession.AccessToken = rawIDToken
newSession.IDToken = rawIDToken
newSession.RefreshToken = ""
- newSession.ExpiresOn = idToken.Expiry
+ newSession.ExpiresOn = &idToken.Expiry
return newSession, nil
}
diff --git a/providers/provider_default_test.go b/providers/provider_default_test.go
index 4d8a8306d2..658918c4fd 100644
--- a/providers/provider_default_test.go
+++ b/providers/provider_default_test.go
@@ -11,8 +11,10 @@ import (
func TestRefresh(t *testing.T) {
p := &ProviderData{}
+
+ expires := time.Now().Add(time.Duration(-11) * time.Minute)
refreshed, err := p.RefreshSessionIfNeeded(context.Background(), &sessions.SessionState{
- ExpiresOn: time.Now().Add(time.Duration(-11) * time.Minute),
+ ExpiresOn: &expires,
})
assert.Equal(t, false, refreshed)
assert.Equal(t, nil, err)
From d995817d9fdd0c4300ffea351220c668ccd3b2ca Mon Sep 17 00:00:00 2001
From: Joel Speed
Date: Fri, 8 May 2020 18:11:36 +0100
Subject: [PATCH 2/3] Use EncrpytInto/DecryptInto to reduce sessionstate
---
pkg/apis/sessions/session_state.go | 70 ++++++++----------------------
pkg/encryption/cipher.go | 27 ++++++++++++
pkg/encryption/cipher_test.go | 22 ++++++++++
3 files changed, 66 insertions(+), 53 deletions(-)
diff --git a/pkg/apis/sessions/session_state.go b/pkg/apis/sessions/session_state.go
index d665303c85..1bb8ff035e 100644
--- a/pkg/apis/sessions/session_state.go
+++ b/pkg/apis/sessions/session_state.go
@@ -67,39 +67,15 @@ func (s *SessionState) EncodeSessionState(c *encryption.Cipher) (string, error)
ss.PreferredUsername = s.PreferredUsername
} else {
ss = *s
- var err error
- if ss.Email != "" {
- ss.Email, err = c.Encrypt(ss.Email)
- if err != nil {
- return "", err
- }
- }
- if ss.User != "" {
- ss.User, err = c.Encrypt(ss.User)
- if err != nil {
- return "", err
- }
- }
- if ss.PreferredUsername != "" {
- ss.PreferredUsername, err = c.Encrypt(ss.PreferredUsername)
- if err != nil {
- return "", err
- }
- }
- if ss.AccessToken != "" {
- ss.AccessToken, err = c.Encrypt(ss.AccessToken)
- if err != nil {
- return "", err
- }
- }
- if ss.IDToken != "" {
- ss.IDToken, err = c.Encrypt(ss.IDToken)
- if err != nil {
- return "", err
- }
- }
- if ss.RefreshToken != "" {
- ss.RefreshToken, err = c.Encrypt(ss.RefreshToken)
+ for _, s := range []*string{
+ &ss.Email,
+ &ss.User,
+ &ss.PreferredUsername,
+ &ss.AccessToken,
+ &ss.IDToken,
+ &ss.RefreshToken,
+ } {
+ err := c.EncryptInto(s)
if err != nil {
return "", err
}
@@ -140,26 +116,14 @@ func DecodeSessionState(v string, c *encryption.Cipher) (*SessionState, error) {
ss.User = decryptedUser
}
}
- if ss.PreferredUsername != "" {
- ss.PreferredUsername, err = c.Decrypt(ss.PreferredUsername)
- if err != nil {
- return nil, err
- }
- }
- if ss.AccessToken != "" {
- ss.AccessToken, err = c.Decrypt(ss.AccessToken)
- if err != nil {
- return nil, err
- }
- }
- if ss.IDToken != "" {
- ss.IDToken, err = c.Decrypt(ss.IDToken)
- if err != nil {
- return nil, err
- }
- }
- if ss.RefreshToken != "" {
- ss.RefreshToken, err = c.Decrypt(ss.RefreshToken)
+
+ for _, s := range []*string{
+ &ss.PreferredUsername,
+ &ss.AccessToken,
+ &ss.IDToken,
+ &ss.RefreshToken,
+ } {
+ err := c.DecryptInto(s)
if err != nil {
return nil, err
}
diff --git a/pkg/encryption/cipher.go b/pkg/encryption/cipher.go
index 2dcbee6bf2..4eb42b0333 100644
--- a/pkg/encryption/cipher.go
+++ b/pkg/encryption/cipher.go
@@ -156,3 +156,30 @@ func (c *Cipher) Decrypt(s string) (string, error) {
return string(encrypted), nil
}
+
+// EncryptInto encrypts the value and stores it back in the string pointer
+func (c *Cipher) EncryptInto(s *string) error {
+ return into(c.Encrypt, s)
+}
+
+// DecryptInto decrypts the value and stores it back in the string pointer
+func (c *Cipher) DecryptInto(s *string) error {
+ return into(c.Decrypt, s)
+}
+
+// codecFunc is a function that takes a string and encodes/decodes it
+type codecFunc func(string) (string, error)
+
+func into(f codecFunc, s *string) error {
+ // Do not encrypt/decrypt nil or empty strings
+ if s == nil || *s == "" {
+ return nil
+ }
+
+ d, err := f(*s)
+ if err != nil {
+ return err
+ }
+ *s = d
+ return nil
+}
diff --git a/pkg/encryption/cipher_test.go b/pkg/encryption/cipher_test.go
index 76bfc1bc31..aed529f354 100644
--- a/pkg/encryption/cipher_test.go
+++ b/pkg/encryption/cipher_test.go
@@ -133,3 +133,25 @@ func TestEncodeAndDecodeAccessTokenB64(t *testing.T) {
assert.NotEqual(t, token, encoded)
assert.Equal(t, token, decoded)
}
+
+func TestEncodeIntoAndDecodeIntoAccessToken(t *testing.T) {
+ const secret = "0123456789abcdefghijklmnopqrstuv"
+ c, err := NewCipher([]byte(secret))
+ assert.Equal(t, nil, err)
+
+ token := "my access token"
+ originalToken := token
+
+ assert.Equal(t, nil, c.EncryptInto(&token))
+ assert.NotEqual(t, originalToken, token)
+
+ assert.Equal(t, nil, c.DecryptInto(&token))
+ assert.Equal(t, originalToken, token)
+
+ // Check no errors with empty or nil strings
+ empty := ""
+ assert.Equal(t, nil, c.EncryptInto(&empty))
+ assert.Equal(t, nil, c.DecryptInto(&empty))
+ assert.Equal(t, nil, c.EncryptInto(nil))
+ assert.Equal(t, nil, c.DecryptInto(nil))
+}
From 1a0945ea11fc8bd911231d4ba98aed83aa0fc52c Mon Sep 17 00:00:00 2001
From: Joel Speed
Date: Sun, 10 May 2020 10:37:23 +0100
Subject: [PATCH 3/3] Add Session State Improvements Changelog entry
---
CHANGELOG.md | 1 +
1 file changed, 1 insertion(+)
diff --git a/CHANGELOG.md b/CHANGELOG.md
index 9fea049529..4905a81b39 100644
--- a/CHANGELOG.md
+++ b/CHANGELOG.md
@@ -51,6 +51,7 @@
## Changes since v5.1.1
+- [#536](https://github.com/oauth2-proxy/oauth2-proxy/pull/536) Improvements to Session State code (@JoelSpeed)
- [#574](https://github.com/oauth2-proxy/oauth2-proxy/pull/574) render error page on 502 proxy status (@amnay-mo)
- [#559](https://github.com/oauth2-proxy/oauth2-proxy/pull/559) Rename cookie-domain config to cookie-domains (@JoelSpeed)
- [#569](https://github.com/oauth2-proxy/oauth2-proxy/pull/569) Updated autocompletion for `--` long options. (@Izzette)