Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Improvements to Session State code #536

Merged
merged 4 commits into from
May 30, 2020
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -51,6 +51,7 @@

## Changes since v5.1.1

- [#536](https://github.com/oauth2-proxy/oauth2-proxy/pull/536) Improvements to Session State code (@JoelSpeed)
- [#573](https://github.com/oauth2-proxy/oauth2-proxy/pull/573) Properly parse redis urls for cluster and sentinel connections (@amnay-mo)
- [#574](https://github.com/oauth2-proxy/oauth2-proxy/pull/574) render error page on 502 proxy status (@amnay-mo)
- [#559](https://github.com/oauth2-proxy/oauth2-proxy/pull/559) Rename cookie-domain config to cookie-domains (@JoelSpeed)
Expand Down
38 changes: 23 additions & 15 deletions oauthproxy_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -484,8 +484,7 @@ func TestBasicAuthPassword(t *testing.T) {
})

rw := httptest.NewRecorder()
req, _ := http.NewRequest("GET", "/oauth2/callback?code=callback_code&state=nonce:",
strings.NewReader(""))
req, _ := http.NewRequest("GET", "/oauth2/callback?code=callback_code&state=nonce:", strings.NewReader(""))
req.AddCookie(proxy.MakeCSRFCookie(req, "nonce", proxy.CookieExpire, time.Now()))
proxy.ServeHTTP(rw, req)
if rw.Code >= 400 {
Expand Down Expand Up @@ -541,11 +540,12 @@ func TestBasicAuthWithEmail(t *testing.T) {
expectedEmailHeader := "Basic " + base64.StdEncoding.EncodeToString([]byte(emailAddress+":"+opts.BasicAuthPassword))
expectedUserHeader := "Basic " + base64.StdEncoding.EncodeToString([]byte(userName+":"+opts.BasicAuthPassword))

created := time.Now()
session := &sessions.SessionState{
User: userName,
Email: emailAddress,
AccessToken: "oauth_token",
CreatedAt: time.Now(),
CreatedAt: &created,
}
{
rw := httptest.NewRecorder()
Expand Down Expand Up @@ -582,11 +582,12 @@ func TestPassUserHeadersWithEmail(t *testing.T) {
const emailAddress = "john.doe@example.com"
const userName = "9fcab5c9b889a557"

created := time.Now()
session := &sessions.SessionState{
User: userName,
Email: emailAddress,
AccessToken: "oauth_token",
CreatedAt: time.Now(),
CreatedAt: &created,
}
{
rw := httptest.NewRecorder()
Expand Down Expand Up @@ -959,7 +960,8 @@ func (p *ProcessCookieTest) LoadCookiedSession() (*sessions.SessionState, error)
func TestLoadCookiedSession(t *testing.T) {
pcTest := NewProcessCookieTestWithDefaults()

startSession := &sessions.SessionState{Email: "john.doe@example.com", AccessToken: "my_access_token", CreatedAt: time.Now()}
created := time.Now()
startSession := &sessions.SessionState{Email: "john.doe@example.com", AccessToken: "my_access_token", CreatedAt: &created}
pcTest.SaveSession(startSession)

session, err := pcTest.LoadCookiedSession()
Expand All @@ -985,7 +987,7 @@ func TestProcessCookieRefreshNotSet(t *testing.T) {
})
reference := time.Now().Add(time.Duration(-2) * time.Hour)

startSession := &sessions.SessionState{Email: "michael.bland@gsa.gov", AccessToken: "my_access_token", CreatedAt: reference}
startSession := &sessions.SessionState{Email: "michael.bland@gsa.gov", AccessToken: "my_access_token", CreatedAt: &reference}
pcTest.SaveSession(startSession)

session, err := pcTest.LoadCookiedSession()
Expand All @@ -1001,7 +1003,7 @@ func TestProcessCookieFailIfCookieExpired(t *testing.T) {
opts.Cookie.Expire = time.Duration(24) * time.Hour
})
reference := time.Now().Add(time.Duration(25) * time.Hour * -1)
startSession := &sessions.SessionState{Email: "michael.bland@gsa.gov", AccessToken: "my_access_token", CreatedAt: reference}
startSession := &sessions.SessionState{Email: "michael.bland@gsa.gov", AccessToken: "my_access_token", CreatedAt: &reference}
pcTest.SaveSession(startSession)

session, err := pcTest.LoadCookiedSession()
Expand All @@ -1016,7 +1018,7 @@ func TestProcessCookieFailIfRefreshSetAndCookieExpired(t *testing.T) {
opts.Cookie.Expire = time.Duration(24) * time.Hour
})
reference := time.Now().Add(time.Duration(25) * time.Hour * -1)
startSession := &sessions.SessionState{Email: "michael.bland@gsa.gov", AccessToken: "my_access_token", CreatedAt: reference}
startSession := &sessions.SessionState{Email: "michael.bland@gsa.gov", AccessToken: "my_access_token", CreatedAt: &reference}
pcTest.SaveSession(startSession)

pcTest.proxy.CookieRefresh = time.Hour
Expand Down Expand Up @@ -1062,8 +1064,9 @@ func NewAuthOnlyEndpointTest(modifiers ...OptionsModifier) *ProcessCookieTest {

func TestAuthOnlyEndpointAccepted(t *testing.T) {
test := NewAuthOnlyEndpointTest()
created := time.Now()
startSession := &sessions.SessionState{
Email: "michael.bland@gsa.gov", AccessToken: "my_access_token", CreatedAt: time.Now()}
Email: "michael.bland@gsa.gov", AccessToken: "my_access_token", CreatedAt: &created}
test.SaveSession(startSession)

test.proxy.ServeHTTP(test.rw, test.req)
Expand All @@ -1087,7 +1090,7 @@ func TestAuthOnlyEndpointUnauthorizedOnExpiration(t *testing.T) {
})
reference := time.Now().Add(time.Duration(25) * time.Hour * -1)
startSession := &sessions.SessionState{
Email: "michael.bland@gsa.gov", AccessToken: "my_access_token", CreatedAt: reference}
Email: "michael.bland@gsa.gov", AccessToken: "my_access_token", CreatedAt: &reference}
test.SaveSession(startSession)

test.proxy.ServeHTTP(test.rw, test.req)
Expand All @@ -1098,8 +1101,9 @@ func TestAuthOnlyEndpointUnauthorizedOnExpiration(t *testing.T) {

func TestAuthOnlyEndpointUnauthorizedOnEmailValidationFailure(t *testing.T) {
test := NewAuthOnlyEndpointTest()
created := time.Now()
startSession := &sessions.SessionState{
Email: "michael.bland@gsa.gov", AccessToken: "my_access_token", CreatedAt: time.Now()}
Email: "michael.bland@gsa.gov", AccessToken: "my_access_token", CreatedAt: &created}
test.SaveSession(startSession)
test.validateUser = false

Expand Down Expand Up @@ -1129,8 +1133,9 @@ func TestAuthOnlyEndpointSetXAuthRequestHeaders(t *testing.T) {
pcTest.req, _ = http.NewRequest("GET",
pcTest.opts.ProxyPrefix+"/auth", nil)

created := time.Now()
startSession := &sessions.SessionState{
User: "oauth_user", Email: "oauth_user@example.com", AccessToken: "oauth_token", CreatedAt: time.Now()}
User: "oauth_user", Email: "oauth_user@example.com", AccessToken: "oauth_token", CreatedAt: &created}
pcTest.SaveSession(startSession)

pcTest.proxy.ServeHTTP(pcTest.rw, pcTest.req)
Expand Down Expand Up @@ -1160,8 +1165,9 @@ func TestAuthOnlyEndpointSetBasicAuthTrueRequestHeaders(t *testing.T) {
pcTest.req, _ = http.NewRequest("GET",
pcTest.opts.ProxyPrefix+"/auth", nil)

created := time.Now()
startSession := &sessions.SessionState{
User: "oauth_user", Email: "oauth_user@example.com", AccessToken: "oauth_token", CreatedAt: time.Now()}
User: "oauth_user", Email: "oauth_user@example.com", AccessToken: "oauth_token", CreatedAt: &created}
pcTest.SaveSession(startSession)

pcTest.proxy.ServeHTTP(pcTest.rw, pcTest.req)
Expand Down Expand Up @@ -1193,8 +1199,9 @@ func TestAuthOnlyEndpointSetBasicAuthFalseRequestHeaders(t *testing.T) {
pcTest.req, _ = http.NewRequest("GET",
pcTest.opts.ProxyPrefix+"/auth", nil)

created := time.Now()
startSession := &sessions.SessionState{
User: "oauth_user", Email: "oauth_user@example.com", AccessToken: "oauth_token", CreatedAt: time.Now()}
User: "oauth_user", Email: "oauth_user@example.com", AccessToken: "oauth_token", CreatedAt: &created}
pcTest.SaveSession(startSession)

pcTest.proxy.ServeHTTP(pcTest.rw, pcTest.req)
Expand Down Expand Up @@ -1569,10 +1576,11 @@ func TestGetJwtSession(t *testing.T) {
}

// Bearer
expires := time.Unix(1912151821, 0)
session, _ := test.proxy.GetJwtSession(test.req)
assert.Equal(t, session.User, "john@example.com")
assert.Equal(t, session.Email, "john@example.com")
assert.Equal(t, session.ExpiresOn, time.Unix(1912151821, 0))
assert.Equal(t, session.ExpiresOn, &expires)
assert.Equal(t, session.IDToken, goodJwt)

test.proxy.ServeHTTP(test.rw, test.req)
Expand Down
132 changes: 34 additions & 98 deletions pkg/apis/sessions/session_state.go
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,6 @@ package sessions

import (
"encoding/json"
"errors"
"fmt"
"time"

Expand All @@ -11,35 +10,28 @@ import (

// SessionState is used to store information about the currently authenticated user session
type SessionState struct {
AccessToken string `json:",omitempty"`
IDToken string `json:",omitempty"`
CreatedAt time.Time `json:"-"`
ExpiresOn time.Time `json:"-"`
RefreshToken string `json:",omitempty"`
Email string `json:",omitempty"`
User string `json:",omitempty"`
PreferredUsername string `json:",omitempty"`
}

// SessionStateJSON is used to encode SessionState into JSON without exposing time.Time zero value
type SessionStateJSON struct {
*SessionState
CreatedAt *time.Time `json:",omitempty"`
ExpiresOn *time.Time `json:",omitempty"`
AccessToken string `json:",omitempty"`
IDToken string `json:",omitempty"`
CreatedAt *time.Time `json:",omitempty"`
ExpiresOn *time.Time `json:",omitempty"`
RefreshToken string `json:",omitempty"`
Email string `json:",omitempty"`
User string `json:",omitempty"`
PreferredUsername string `json:",omitempty"`
}

// IsExpired checks whether the session has expired
func (s *SessionState) IsExpired() bool {
if !s.ExpiresOn.IsZero() && s.ExpiresOn.Before(time.Now()) {
if s.ExpiresOn != nil && !s.ExpiresOn.IsZero() && s.ExpiresOn.Before(time.Now()) {
return true
}
return false
}

// Age returns the age of a session
func (s *SessionState) Age() time.Duration {
if !s.CreatedAt.IsZero() {
return time.Now().Truncate(time.Second).Sub(s.CreatedAt)
if s.CreatedAt != nil && !s.CreatedAt.IsZero() {
return time.Now().Truncate(time.Second).Sub(*s.CreatedAt)
}
return 0
}
Expand Down Expand Up @@ -75,80 +67,36 @@ func (s *SessionState) EncodeSessionState(c *encryption.Cipher) (string, error)
ss.PreferredUsername = s.PreferredUsername
} else {
ss = *s
var err error
if ss.Email != "" {
ss.Email, err = c.Encrypt(ss.Email)
if err != nil {
return "", err
}
}
if ss.User != "" {
ss.User, err = c.Encrypt(ss.User)
if err != nil {
return "", err
}
}
if ss.PreferredUsername != "" {
ss.PreferredUsername, err = c.Encrypt(ss.PreferredUsername)
if err != nil {
return "", err
}
}
if ss.AccessToken != "" {
ss.AccessToken, err = c.Encrypt(ss.AccessToken)
if err != nil {
return "", err
}
}
if ss.IDToken != "" {
ss.IDToken, err = c.Encrypt(ss.IDToken)
if err != nil {
return "", err
}
}
if ss.RefreshToken != "" {
ss.RefreshToken, err = c.Encrypt(ss.RefreshToken)
for _, s := range []*string{
&ss.Email,
&ss.User,
&ss.PreferredUsername,
&ss.AccessToken,
&ss.IDToken,
&ss.RefreshToken,
} {
err := c.EncryptInto(s)
if err != nil {
return "", err
}
}
}
// Embed SessionState and ExpiresOn pointer into SessionStateJSON
ssj := &SessionStateJSON{SessionState: &ss}
if !ss.CreatedAt.IsZero() {
ssj.CreatedAt = &ss.CreatedAt
}
if !ss.ExpiresOn.IsZero() {
ssj.ExpiresOn = &ss.ExpiresOn
}
b, err := json.Marshal(ssj)

b, err := json.Marshal(ss)
return string(b), err
}

// DecodeSessionState decodes the session cookie string into a SessionState
func DecodeSessionState(v string, c *encryption.Cipher) (*SessionState, error) {
var ssj SessionStateJSON
var ss *SessionState
err := json.Unmarshal([]byte(v), &ssj)
var ss SessionState
err := json.Unmarshal([]byte(v), &ss)
if err != nil {
return nil, fmt.Errorf("error unmarshalling session: %w", err)
}
if ssj.SessionState == nil {
return nil, errors.New("expected session state to not be nil")
}

// Extract SessionState and CreatedAt,ExpiresOn value from SessionStateJSON
ss = ssj.SessionState
if ssj.CreatedAt != nil {
ss.CreatedAt = *ssj.CreatedAt
}
if ssj.ExpiresOn != nil {
ss.ExpiresOn = *ssj.ExpiresOn
}

if c == nil {
// Load only Email and User when cipher is unavailable
ss = &SessionState{
ss = SessionState{
Email: ss.Email,
User: ss.User,
PreferredUsername: ss.PreferredUsername,
Expand All @@ -168,30 +116,18 @@ func DecodeSessionState(v string, c *encryption.Cipher) (*SessionState, error) {
ss.User = decryptedUser
}
}
if ss.PreferredUsername != "" {
ss.PreferredUsername, err = c.Decrypt(ss.PreferredUsername)
if err != nil {
return nil, err
}
}
if ss.AccessToken != "" {
ss.AccessToken, err = c.Decrypt(ss.AccessToken)
if err != nil {
return nil, err
}
}
if ss.IDToken != "" {
ss.IDToken, err = c.Decrypt(ss.IDToken)
if err != nil {
return nil, err
}
}
if ss.RefreshToken != "" {
ss.RefreshToken, err = c.Decrypt(ss.RefreshToken)

for _, s := range []*string{
&ss.PreferredUsername,
&ss.AccessToken,
&ss.IDToken,
&ss.RefreshToken,
} {
err := c.DecryptInto(s)
if err != nil {
return nil, err
}
}
}
return ss, nil
return &ss, nil
}
Loading