Skip to content

Commit

Permalink
feat: redirect to original page after login
Browse files Browse the repository at this point in the history
  • Loading branch information
mentos1386 committed May 25, 2024
1 parent b8c37cf commit 034f530
Show file tree
Hide file tree
Showing 2 changed files with 70 additions and 17 deletions.
24 changes: 8 additions & 16 deletions internal/server/handlers/authentication.go
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@ import (
jwtInternal "github.com/mentos1386/zdravko/pkg/jwt"
)

const sessionName = "zdravko-hey"
const authenticationSessionName = "zdravko-hey"

type AuthenticatedPrincipal struct {
User *AuthenticatedUser
Expand Down Expand Up @@ -48,7 +48,7 @@ func GetUser(ctx context.Context) *AuthenticatedUser {
}

func (h *BaseHandler) AuthenticateRequestWithCookies(r *http.Request) (*AuthenticatedUser, error) {
session, err := h.store.Get(r, sessionName)
session, err := h.store.Get(r, authenticationSessionName)
if err != nil {
return nil, err
}
Expand Down Expand Up @@ -114,7 +114,7 @@ func (h *BaseHandler) AuthenticateRequestWithToken(r *http.Request) (*Authentica
}

func (h *BaseHandler) SetAuthenticatedUserForRequest(w http.ResponseWriter, r *http.Request, user *AuthenticatedUser) error {
session, err := h.store.Get(r, sessionName)
session, err := h.store.Get(r, authenticationSessionName)
if err != nil {
return err
}
Expand All @@ -124,24 +124,16 @@ func (h *BaseHandler) SetAuthenticatedUserForRequest(w http.ResponseWriter, r *h
session.Values["oauth2_refresh_token"] = user.OAuth2RefreshToken
session.Values["oauth2_token_type"] = user.OAuth2TokenType
session.Values["oauth2_expiry"] = user.OAuth2Expiry.Format(time.RFC3339)
err = h.store.Save(r, w, session)
if err != nil {
return err
}
return nil
return h.store.Save(r, w, session)
}

func (h *BaseHandler) ClearAuthenticatedUserForRequest(w http.ResponseWriter, r *http.Request) error {
session, err := h.store.Get(r, sessionName)
session, err := h.store.Get(r, authenticationSessionName)
if err != nil {
return err
}
session.Options.MaxAge = -1
err = h.store.Save(r, w, session)
if err != nil {
return err
}
return nil
return h.store.Save(r, w, session)
}

type AuthenticatedHandler func(http.ResponseWriter, *http.Request, *AuthenticatedPrincipal)
Expand All @@ -159,7 +151,7 @@ func (h *BaseHandler) Authenticated(next echo.HandlerFunc) echo.HandlerFunc {
if user.OAuth2Expiry.Before(time.Now()) {
user, err = h.RefreshToken(c.Response(), c.Request(), user)
if err != nil {
return c.Redirect(http.StatusTemporaryRedirect, "/oauth2/login")
return c.Redirect(http.StatusTemporaryRedirect, "/oauth2/login?redirect="+c.Request().URL.Path)
}
}

Expand All @@ -173,6 +165,6 @@ func (h *BaseHandler) Authenticated(next echo.HandlerFunc) echo.HandlerFunc {
return next(cc)
}

return c.Redirect(http.StatusTemporaryRedirect, "/oauth2/login")
return c.Redirect(http.StatusTemporaryRedirect, "/oauth2/login?redirect="+c.Request().URL.Path)
}
}
63 changes: 62 additions & 1 deletion internal/server/handlers/oauth2.go
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,45 @@ import (
"golang.org/x/oauth2"
)

const oauth2RedirectSessionName = "zdravko-hey-oauth2"

func (h *BaseHandler) setOAuth2Redirect(c echo.Context, redirect string) error {
w := c.Response()
r := c.Request()

session, err := h.store.Get(r, oauth2RedirectSessionName)
if err != nil {
return err
}
session.Values["redirect"] = redirect
return h.store.Save(r, w, session)
}

func (h *BaseHandler) getOAuth2Redirect(c echo.Context) (string, error) {
r := c.Request()

session, err := h.store.Get(r, oauth2RedirectSessionName)
if err != nil {
return "", err
}
if session.IsNew {
return "", nil
}
return session.Values["redirect"].(string), nil
}

func (h *BaseHandler) clearOAuth2Redirect(c echo.Context) error {
w := c.Response()
r := c.Request()

session, err := h.store.Get(r, oauth2RedirectSessionName)
if err != nil {
return err
}
session.Options.MaxAge = -1
return h.store.Save(r, w, session)
}

type UserInfo struct {
Id int `json:"id"` // FIXME: This might not always be int?
Sub string `json:"sub"`
Expand Down Expand Up @@ -97,6 +136,14 @@ func (h *BaseHandler) OAuth2LoginGET(c echo.Context) error {

url := conf.AuthCodeURL(state, oauth2.AccessTypeOffline)

redirect := c.QueryParam("redirect")
h.logger.Info("OAuth2LoginGET", "redirect", redirect)

err = h.setOAuth2Redirect(c, redirect)
if err != nil {
return err
}

return c.Redirect(http.StatusTemporaryRedirect, url)
}

Expand Down Expand Up @@ -156,7 +203,21 @@ func (h *BaseHandler) OAuth2CallbackGET(c echo.Context) error {
return err
}

return c.Redirect(http.StatusTemporaryRedirect, "/settings")
redirect, err := h.getOAuth2Redirect(c)
if err != nil {
return err
}
h.logger.Info("OAuth2CallbackGET", "redirect", redirect)
if redirect == "" {
redirect = "/settings"
}

err = h.clearOAuth2Redirect(c)
if err != nil {
return err
}

return c.Redirect(http.StatusTemporaryRedirect, redirect)
}

func (h *BaseHandler) OAuth2LogoutGET(c echo.Context) error {
Expand Down

0 comments on commit 034f530

Please sign in to comment.