Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

refactor(client): Reduce SQL boilerplate code #1758

Merged
merged 21 commits into from
Mar 14, 2020
Prev Previous commit
Next Next commit
u
  • Loading branch information
aeneasr committed Mar 14, 2020
commit 4d5f80dedddb7da373f514fcbabe4924d2671be5
2 changes: 1 addition & 1 deletion client/client.go
Original file line number Diff line number Diff line change
Expand Up @@ -192,7 +192,7 @@ type Client struct {
BackChannelLogoutSessionRequired bool `json:"backchannel_logout_session_required,omitempty" db:"backchannel_logout_session_required"`

// Metadata is arbitrary data.
Metadata x.JSONRawMessage `json:"metadata,omitempty" db:"metadata"`
Metadata sqlxx.JSONRawMessage `json:"metadata,omitempty" db:"metadata"`
}

func (c *Client) GetID() string {
Expand Down
6 changes: 3 additions & 3 deletions consent/handler.go
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@ import (
"net/url"
"time"

"github.com/ory/x/sqlxx"
"github.com/ory/x/stringsx"

"github.com/julienschmidt/httprouter"
Expand Down Expand Up @@ -161,7 +162,6 @@ func (h *Handler) GetConsentSessions(w http.ResponseWriter, r *http.Request, ps
}

var a []PreviousConsentSession

for _, session := range s {
session.ConsentRequest.Client = sanitizeClient(session.ConsentRequest.Client)
a = append(a, PreviousConsentSession(session))
Expand Down Expand Up @@ -337,9 +337,9 @@ func (h *Handler) AcceptLoginRequest(w http.ResponseWriter, r *http.Request, ps

if ar.Skip {
p.Remember = true // If skip is true remember is also true to allow consecutive calls as the same user!
p.AuthenticatedAt = time.Time(ar.AuthenticatedAt)
p.AuthenticatedAt = ar.AuthenticatedAt
} else {
p.AuthenticatedAt = time.Now().UTC()
p.AuthenticatedAt = sqlxx.NullTime(time.Now().UTC())
}
p.RequestedAt = ar.RequestedAt

Expand Down
44 changes: 16 additions & 28 deletions consent/manager_sql.go
Original file line number Diff line number Diff line change
Expand Up @@ -286,28 +286,25 @@ func (m *SQLManager) GetLoginRequest(ctx context.Context, challenge string) (*Lo
}

func (m *SQLManager) HandleConsentRequest(ctx context.Context, challenge string, r *HandledConsentRequest) (*ConsentRequest, error) {
d, err := newSQLHandledConsentRequest(r)
if err != nil {
return nil, err
}
r.prepareSQL()

/* #nosec G201 - sqlParamsConsentRequestHandled is a "constant" array */
if _, err := m.DB.NamedExecContext(ctx, fmt.Sprintf(
"INSERT INTO hydra_oauth2_consent_request_handled (%s) VALUES (%s)",
strings.Join(sqlParamsConsentRequestHandled, ", "),
":"+strings.Join(sqlParamsConsentRequestHandled, ", :"),
), d); err != nil {
), r); err != nil {
err = sqlcon.HandleError(err)
if errors.Cause(err) == sqlcon.ErrUniqueViolation {
return m.replaceUnusedConsentRequest(ctx, challenge, d)
return m.replaceUnusedConsentRequest(ctx, challenge, r)
}
return nil, err
}

return m.GetConsentRequest(ctx, challenge)
}

func (m *SQLManager) replaceUnusedConsentRequest(ctx context.Context, challenge string, d *sqlHandledConsentRequest) (*ConsentRequest, error) {
func (m *SQLManager) replaceUnusedConsentRequest(ctx context.Context, challenge string, d *HandledConsentRequest) (*ConsentRequest, error) {
/* #nosec G201 - sqlParamsConsentRequestHandledUpdate is a "constant" array */
if _, err := m.DB.NamedExecContext(ctx, fmt.Sprintf(
"UPDATE hydra_oauth2_consent_request_handled SET %s WHERE challenge=:challenge AND was_used=false",
Expand All @@ -320,7 +317,7 @@ func (m *SQLManager) replaceUnusedConsentRequest(ctx context.Context, challenge
}

func (m *SQLManager) VerifyAndInvalidateConsentRequest(ctx context.Context, verifier string) (*HandledConsentRequest, error) {
var d sqlHandledConsentRequest
var d HandledConsentRequest
var challenge string

// This can be solved more elegantly with a join statement, but it works for now
Expand All @@ -346,29 +343,25 @@ func (m *SQLManager) VerifyAndInvalidateConsentRequest(ctx context.Context, veri
return nil, sqlcon.HandleError(err)
}

return d.toHandledConsentRequest(r)
return d.postSQL(r), nil
}

func (m *SQLManager) HandleLoginRequest(ctx context.Context, challenge string, r *HandledLoginRequest) (*LoginRequest, error) {
d, err := newSQLHandledLoginRequest(r)
if err != nil {
return nil, err
}

/* #nosec G201 - sqlParamsAuthenticationRequestHandled is a "constant" array */
if _, err := m.DB.NamedExecContext(ctx, fmt.Sprintf(
"INSERT INTO hydra_oauth2_authentication_request_handled (%s) VALUES (%s)",
strings.Join(sqlParamsAuthenticationRequestHandled, ", "),
":"+strings.Join(sqlParamsAuthenticationRequestHandled, ", :"),
), d); err != nil {
), r.prepareSQL()); err != nil {
return nil, sqlcon.HandleError(err)
}

return m.GetLoginRequest(ctx, challenge)
}

func (m *SQLManager) VerifyAndInvalidateLoginRequest(ctx context.Context, verifier string) (*HandledLoginRequest, error) {
var d sqlHandledLoginRequest
var d HandledLoginRequest
var challenge string

// This can be solved more elegantly with a join statement, but it works for now
Expand All @@ -394,7 +387,7 @@ func (m *SQLManager) VerifyAndInvalidateLoginRequest(ctx context.Context, verifi
return nil, err
}

return d.toHandledLoginRequest(r)
return d.postSQL(r), nil
}

func (m *SQLManager) GetRememberedLoginSession(ctx context.Context, id string) (*LoginSession, error) {
Expand Down Expand Up @@ -438,7 +431,7 @@ func (m *SQLManager) DeleteLoginSession(ctx context.Context, id string) error {
}

func (m *SQLManager) FindGrantedAndRememberedConsentRequests(ctx context.Context, client, subject string) ([]HandledConsentRequest, error) {
var a []sqlHandledConsentRequest
var a []HandledConsentRequest

if err := m.DB.SelectContext(ctx, &a, m.DB.Rebind(`SELECT h.* FROM
hydra_oauth2_consent_request_handled as h
Expand All @@ -460,7 +453,7 @@ LIMIT 1`), subject, client); err != nil {
}

func (m *SQLManager) FindSubjectsGrantedConsentRequests(ctx context.Context, subject string, limit, offset int) ([]HandledConsentRequest, error) {
var a []sqlHandledConsentRequest
var a []HandledConsentRequest

if err := m.DB.SelectContext(ctx, &a, m.DB.Rebind(`SELECT h.* FROM
hydra_oauth2_consent_request_handled as h
Expand Down Expand Up @@ -497,8 +490,8 @@ WHERE
return n, nil
}

func (m *SQLManager) resolveHandledConsentRequests(ctx context.Context, requests []sqlHandledConsentRequest) ([]HandledConsentRequest, error) {
var aa []HandledConsentRequest
func (m *SQLManager) resolveHandledConsentRequests(ctx context.Context, requests []HandledConsentRequest) ([]HandledConsentRequest, error) {
var result []HandledConsentRequest
for _, v := range requests {
r, err := m.GetConsentRequest(ctx, v.Challenge)
if err != nil {
Expand All @@ -511,19 +504,14 @@ func (m *SQLManager) resolveHandledConsentRequests(ctx context.Context, requests
continue
}

va, err := v.toHandledConsentRequest(r)
if err != nil {
return nil, err
}

aa = append(aa, *va)
result = append(result, *v.postSQL(r))
}

if len(aa) == 0 {
if len(requests) == 0 {
return nil, errors.WithStack(ErrNoPreviousConsentFound)
}

return aa, nil
return result, nil
}

func (m *SQLManager) ListUserAuthenticatedClientsWithFrontChannelLogout(ctx context.Context, subject, sid string) ([]client.Client, error) {
Expand Down
6 changes: 3 additions & 3 deletions consent/manager_test_helpers.go
Original file line number Diff line number Diff line change
Expand Up @@ -61,7 +61,7 @@ func MockConsentRequest(key string, remember bool, rememberFor int, hasError boo
ACR: "1",
AuthenticatedAt: sqlxx.NullTime(time.Now().UTC().Add(-time.Hour)),
RequestedAt: time.Now().UTC().Add(-time.Hour),
Context: x.JSONRawMessage(`{"foo": "bar` + key + `"}`),
Context: sqlxx.JSONRawMessage(`{"foo": "bar` + key + `"}`),
}

var err *RequestDeniedError
Expand Down Expand Up @@ -156,7 +156,7 @@ func MockAuthRequest(key string, authAt bool) (c *LoginRequest, h *HandledLoginR
Remember: true,
Challenge: "challenge" + key,
RequestedAt: time.Now().UTC().Add(-time.Minute),
AuthenticatedAt: authenticatedAt,
AuthenticatedAt: sqlxx.NullTime(authenticatedAt),
Error: err,
Subject: c.Subject,
ACR: "acr",
Expand Down Expand Up @@ -219,7 +219,7 @@ func SaneMockConsentRequest(t *testing.T, m Manager, ar *LoginRequest, skip bool
ACR: "1",
AuthenticatedAt: sqlxx.NullTime(time.Now().UTC().Add(-time.Hour)),
RequestedAt: time.Now().UTC().Add(-time.Hour),
Context: x.JSONRawMessage(`{"foo": "bar"}`),
Context: sqlxx.JSONRawMessage(`{"foo": "bar"}`),

Challenge: uuid.New().String(),
Verifier: uuid.New().String(),
Expand Down
Loading