Skip to content

Commit

Permalink
create pool upon registry creation
Browse files Browse the repository at this point in the history
  • Loading branch information
chilagrow committed Feb 19, 2024
1 parent c3583d1 commit 8d72103
Show file tree
Hide file tree
Showing 2 changed files with 24 additions and 16 deletions.
30 changes: 22 additions & 8 deletions internal/backends/postgresql/metadata/registry.go
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@ import (
"context"
"fmt"
"hash/fnv"
"net/url"
"regexp"
"slices"
"sort"
Expand Down Expand Up @@ -86,6 +87,9 @@ type Registry struct {
}

// NewRegistry creates a registry for PostgreSQL databases with a given base URI.
//
// It gets a pool using the user and password from the base URI, which is later used
// by connections that by passes backend authentication.
func NewRegistry(u string, l *zap.Logger, sp *state.Provider) (*Registry, error) {
p, err := pool.New(u, l, sp)
if err != nil {
Expand All @@ -97,6 +101,23 @@ func NewRegistry(u string, l *zap.Logger, sp *state.Provider) (*Registry, error)
l: l,
}

baseURI, err := url.Parse(u)
if err != nil {
return nil, lazyerrors.Error(err)
}

username := baseURI.User.Username()
pwd, _ := baseURI.User.Password()

c := conninfo.New()
c.SetAuth(username, pwd)

ctx := conninfo.Ctx(context.Background(), c)
_, err = r.getPool(ctx)
if err != nil {
return nil, lazyerrors.Error(err)
}

return r, nil
}

Expand All @@ -122,14 +143,7 @@ func (r *Registry) getPool(ctx context.Context) (*pgxpool.Pool, error) {

if connInfo.BypassBackendAuth() {
if p = r.p.GetAny(); p == nil {
// no connection pool has been created yet and authentication
// is bypassed, attempt to use credentials to connect
username, password := connInfo.Auth()

var err error
if p, err = r.p.Get(username, password); err != nil {
return nil, lazyerrors.Error(err)
}
return nil, lazyerrors.New("no connection pool")
}
} else {
username, password := connInfo.Auth()
Expand Down
10 changes: 2 additions & 8 deletions internal/handler/msg_saslstart.go
Original file line number Diff line number Diff line change
Expand Up @@ -65,14 +65,8 @@ func (h *Handler) MsgSASLStart(ctx context.Context, msg *wire.OpMsg) (*wire.OpMs
return nil, err
}

if h.EnableNewAuth {
// If new auth is enabled and the database does not contain any user,
// backend authentication is bypassed.
conninfo.Get(ctx).SetAuth(username, password)
conninfo.Get(ctx).SetBypassBackendAuth()
} else {
conninfo.Get(ctx).SetAuth(username, password)
}
conninfo.Get(ctx).SetBypassBackendAuth()
conninfo.Get(ctx).SetAuth(username, password)

var emptyPayload types.Binary
must.NoError(reply.SetSections(wire.MakeOpMsgSection(
Expand Down

0 comments on commit 8d72103

Please sign in to comment.