Skip to content

Commit

Permalink
Use ErrNotImplemented in default refresh implementation
Browse files Browse the repository at this point in the history
  • Loading branch information
Nick Meves committed Jun 23, 2021
1 parent baf6cf3 commit ff914d7
Show file tree
Hide file tree
Showing 5 changed files with 54 additions and 29 deletions.
6 changes: 6 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -4,10 +4,16 @@

## Important Notes

- [#1086](https://github.com/oauth2-proxy/oauth2-proxy/pull/1086) The extra validation to protect invalid session
deserialization from v6.0.0 (only) has been removed to improve performance. If you are on v6.0.0, either upgrade
to a version before this first and allow legacy sessions to expire gracefully or change your `cookie-secret`
value and force all sessions to reauthenticate.

## Breaking Changes

## Changes since v7.1.3

- [#1086](https://github.com/oauth2-proxy/oauth2-proxy/pull/1086) Refresh sessions before token expiration if configured (@NickMeves)
- [#1226](https://github.com/oauth2-proxy/oauth2-proxy/pull/1226) Move app redirection logic to its own package (@JoelSpeed)
- [#1128](https://github.com/oauth2-proxy/oauth2-proxy/pull/1128) Use gorilla mux for OAuth Proxy routing (@JoelSpeed)
- [#1238](https://github.com/oauth2-proxy/oauth2-proxy/pull/1238) Added ADFS provider (@samirachoadi)
Expand Down
29 changes: 18 additions & 11 deletions pkg/middleware/stored_session.go
Original file line number Diff line number Diff line change
Expand Up @@ -11,19 +11,20 @@ import (
middlewareapi "github.com/oauth2-proxy/oauth2-proxy/v7/pkg/apis/middleware"
sessionsapi "github.com/oauth2-proxy/oauth2-proxy/v7/pkg/apis/sessions"
"github.com/oauth2-proxy/oauth2-proxy/v7/pkg/logger"
"github.com/oauth2-proxy/oauth2-proxy/v7/providers"
)

// StoredSessionLoaderOptions cotnains all of the requirements to construct
// StoredSessionLoaderOptions contains all of the requirements to construct
// a stored session loader.
// All options must be provided.
type StoredSessionLoaderOptions struct {
// Session storage basckend
// Session storage backend
SessionStore sessionsapi.SessionStore

// How often should sessions be refreshed
RefreshPeriod time.Duration

// Provider based sesssion refreshing
// Provider based session refreshing
RefreshSession func(context.Context, *sessionsapi.SessionState) (bool, error)

// Provider based session validation.
Expand Down Expand Up @@ -115,7 +116,7 @@ func (s *storedSessionLoader) refreshSessionIfNeeded(rw http.ResponseWriter, req
return nil
}

logger.Printf("Refreshing %s old session cookie for %s (refresh after %s)", session.Age(), session, s.refreshPeriod)
logger.Printf("Refreshing session - User: %s; SessionAge: %s", session.User, session.Age())
err := s.refreshSession(rw, req, session)
if err != nil {
// If a preemptive refresh fails, we still keep the session
Expand All @@ -131,21 +132,27 @@ func (s *storedSessionLoader) refreshSessionIfNeeded(rw http.ResponseWriter, req
// and will save the session if it was updated.
func (s *storedSessionLoader) refreshSession(rw http.ResponseWriter, req *http.Request, session *sessionsapi.SessionState) error {
refreshed, err := s.sessionRefresher(req.Context(), session)
if err != nil {
if err != nil && !errors.Is(err, providers.ErrNotImplemented) {
return fmt.Errorf("error refreshing tokens: %v", err)
}

// HACK:
// Providers that don't implement `RefreshSession` use the default
// implementation which returns `ErrNotImplemented`.
// Pretend it refreshed to reset the refresh timer so that `ValidateSession`
// isn't triggered every subsequent request and is only called once during
// this request.
if errors.Is(err, providers.ErrNotImplemented) {
refreshed = true
}

// Session not refreshed, nothing to persist.
if !refreshed {
return nil
}

// If we refreshed, update the `CreatedAt` time to reset the refresh timer
//
// HACK:
// Providers that don't implement `RefreshSession` use the default
// implementation. It always returns `refreshed == true`, so the
// `session.CreatedAt` is updated and doesn't trigger `ValidateSession`
// every subsequent request.
// (In case underlying provider implementations forget)
session.CreatedAtNow()

// Because the session was refreshed, make sure to save it
Expand Down
27 changes: 25 additions & 2 deletions pkg/middleware/stored_session_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -11,15 +11,17 @@ import (
middlewareapi "github.com/oauth2-proxy/oauth2-proxy/v7/pkg/apis/middleware"
sessionsapi "github.com/oauth2-proxy/oauth2-proxy/v7/pkg/apis/sessions"
"github.com/oauth2-proxy/oauth2-proxy/v7/pkg/clock"
"github.com/oauth2-proxy/oauth2-proxy/v7/providers"
. "github.com/onsi/ginkgo"
. "github.com/onsi/ginkgo/extensions/table"
. "github.com/onsi/gomega"
)

var _ = Describe("Stored Session Suite", func() {
const (
refresh = "Refresh"
noRefresh = "NoRefresh"
refresh = "Refresh"
noRefresh = "NoRefresh"
notImplemented = "NotImplemented"
)

var ctx = context.Background()
Expand Down Expand Up @@ -293,6 +295,8 @@ var _ = Describe("Stored Session Suite", func() {
return true, nil
case noRefresh:
return false, nil
case notImplemented:
return false, providers.ErrNotImplemented
default:
return false, errors.New("error refreshing session")
}
Expand Down Expand Up @@ -364,6 +368,16 @@ var _ = Describe("Stored Session Suite", func() {
expectRefreshed: true,
expectValidated: true,
}),
Entry("when the provider doesn't implement refresh but validation succeeds", refreshSessionIfNeededTableInput{
refreshPeriod: 1 * time.Minute,
session: &sessionsapi.SessionState{
RefreshToken: notImplemented,
CreatedAt: &createdPast,
},
expectedErr: nil,
expectRefreshed: true,
expectValidated: true,
}),
Entry("when the provider refresh fails but validation succeeds", refreshSessionIfNeededTableInput{
refreshPeriod: 1 * time.Minute,
session: &sessionsapi.SessionState{
Expand Down Expand Up @@ -418,6 +432,8 @@ var _ = Describe("Stored Session Suite", func() {
return true, nil
case noRefresh:
return false, nil
case notImplemented:
return false, providers.ErrNotImplemented
default:
return false, errors.New("error refreshing session")
}
Expand Down Expand Up @@ -448,6 +464,13 @@ var _ = Describe("Stored Session Suite", func() {
expectedErr: nil,
expectSaved: true,
}),
Entry("when the provider doesn't implement refresh", refreshSessionWithProviderTableInput{
session: &sessionsapi.SessionState{
RefreshToken: notImplemented,
},
expectedErr: nil,
expectSaved: true,
}),
Entry("when the provider returns an error", refreshSessionWithProviderTableInput{
session: &sessionsapi.SessionState{
RefreshToken: "RefreshError",
Expand Down
15 changes: 2 additions & 13 deletions providers/provider_default.go
Original file line number Diff line number Diff line change
Expand Up @@ -130,19 +130,8 @@ func (p *ProviderData) ValidateSession(ctx context.Context, s *sessions.SessionS
}

// RefreshSession refreshes the user's session
func (p *ProviderData) RefreshSession(_ context.Context, s *sessions.SessionState) (bool, error) {
if s == nil {
return false, nil
}

// HACK:
// Pretend `RefreshSession` occurred so `ValidateSession` isn't called
// on every request after any potential set refresh period elapses.
// See `middleware.refreshSession` for detailed logic & explanation.
//
// Intentionally doesn't use `ErrNotImplemented` since all providers will
// call this and we don't want to force them to implement this dummy logic.
return true, nil
func (p *ProviderData) RefreshSession(_ context.Context, _ *sessions.SessionState) (bool, error) {
return false, ErrNotImplemented
}

// CreateSessionFromToken converts Bearer IDTokens into sessions
Expand Down
6 changes: 3 additions & 3 deletions providers/provider_default_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -22,12 +22,12 @@ func TestRefresh(t *testing.T) {
ss.SetExpiresOn(expires)

refreshed, err := p.RefreshSession(context.Background(), ss)
assert.True(t, refreshed)
assert.NoError(t, err)
assert.False(t, refreshed)
assert.Equal(t, ErrNotImplemented, err)

refreshed, err = p.RefreshSession(context.Background(), nil)
assert.False(t, refreshed)
assert.NoError(t, err)
assert.Equal(t, ErrNotImplemented, err)
}

func TestAcrValuesNotConfigured(t *testing.T) {
Expand Down

0 comments on commit ff914d7

Please sign in to comment.