Skip to content

Commit

Permalink
Refactor common saslStart code (#2181)
Browse files Browse the repository at this point in the history
  • Loading branch information
AlekSi authored Mar 10, 2023
1 parent 566611d commit 3e2a5ae
Show file tree
Hide file tree
Showing 5 changed files with 94 additions and 70 deletions.
68 changes: 68 additions & 0 deletions internal/handlers/common/saslstart.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,68 @@
// Copyright 2021 FerretDB Inc.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.

package common

import (
"bytes"
"encoding/base64"
"fmt"

"github.com/FerretDB/FerretDB/internal/handlers/commonerrors"
"github.com/FerretDB/FerretDB/internal/types"
"github.com/FerretDB/FerretDB/internal/util/lazyerrors"
)

// SASLStartPlain extracts username and password from PLAIN `saslStart` payload.
func SASLStartPlain(doc *types.Document) (string, string, error) {
var payload []byte

// some drivers send payload as a string
stringPayload, err := GetRequiredParam[string](doc, "payload")
if err == nil {
if payload, err = base64.StdEncoding.DecodeString(stringPayload); err != nil {
return "", "", lazyerrors.Error(err)
}
}

// most drivers follow spec and send payload as a binary
binaryPayload, err := GetRequiredParam[types.Binary](doc, "payload")
if err == nil {
payload = binaryPayload.B
}

if payload == nil {
// return error about expected types.Binary, not string
return "", "", err
}

parts := bytes.Split(payload, []byte{0})
if l := len(parts); l != 3 {
return "", "", NewCommandErrorMsgWithArgument(
commonerrors.ErrTypeMismatch,
fmt.Sprintf("Invalid payload (expected 3 parts, got %d)", l),
"payload",
)
}

authzid, authcid, passwd := parts[0], parts[1], parts[2]

// Some drivers (Go) send empty authorization identity (authzid),
// while others (Java) set it to the same value as authentication identity (authcid)
// (see https://www.rfc-editor.org/rfc/rfc4616.html).
// Ignore authzid for now.
_ = authzid

return string(authcid), string(passwd), nil
}
3 changes: 3 additions & 0 deletions internal/handlers/dummy/msg_saslstart.go
Original file line number Diff line number Diff line change
Expand Up @@ -24,10 +24,13 @@ import (

// MsgSASLStart implements HandlerInterface.
func (h *Handler) MsgSASLStart(ctx context.Context, msg *wire.OpMsg) (*wire.OpMsg, error) {
var emptyPayload types.Binary
var reply wire.OpMsg
must.NoError(reply.SetSections(wire.OpMsgSection{
Documents: []*types.Document{must.NotFail(types.NewDocument(
"conversationId", int32(1),
"done", true,
"payload", emptyPayload,
"ok", float64(1),
))},
}))
Expand Down
3 changes: 3 additions & 0 deletions internal/handlers/hana/msg_saslstart.go
Original file line number Diff line number Diff line change
Expand Up @@ -24,10 +24,13 @@ import (

// MsgSASLStart implements HandlerInterface.
func (h *Handler) MsgSASLStart(ctx context.Context, msg *wire.OpMsg) (*wire.OpMsg, error) {
var emptyPayload types.Binary
var reply wire.OpMsg
must.NoError(reply.SetSections(wire.OpMsgSection{
Documents: []*types.Document{must.NotFail(types.NewDocument(
"conversationId", int32(1),
"done", true,
"payload", emptyPayload,
"ok", float64(1),
))},
}))
Expand Down
55 changes: 10 additions & 45 deletions internal/handlers/pg/msg_saslstart.go
Original file line number Diff line number Diff line change
Expand Up @@ -15,38 +15,17 @@
package pg

import (
"bytes"
"context"
"encoding/base64"
"fmt"

"github.com/FerretDB/FerretDB/internal/clientconn/conninfo"
"github.com/FerretDB/FerretDB/internal/handlers/common"
"github.com/FerretDB/FerretDB/internal/handlers/commonerrors"
"github.com/FerretDB/FerretDB/internal/types"
"github.com/FerretDB/FerretDB/internal/util/lazyerrors"
"github.com/FerretDB/FerretDB/internal/util/must"
"github.com/FerretDB/FerretDB/internal/wire"
)

func getPayload(doc *types.Document) ([]byte, error) {
binaryPayload, err := common.GetRequiredParam[types.Binary](doc, "payload")
if err == nil {
return binaryPayload.B, nil
}

payload, err := common.GetRequiredParam[string](doc, "payload")
if err != nil {
return nil, lazyerrors.Error(err)
}

data, err := base64.StdEncoding.DecodeString(payload)
if err != nil {
return nil, lazyerrors.Error(err)
}

return data, nil
}

// MsgSASLStart implements HandlerInterface.
func (h *Handler) MsgSASLStart(ctx context.Context, msg *wire.OpMsg) (*wire.OpMsg, error) {
doc, err := msg.Document()
Expand All @@ -59,37 +38,23 @@ func (h *Handler) MsgSASLStart(ctx context.Context, msg *wire.OpMsg) (*wire.OpMs
return nil, lazyerrors.Error(err)
}

if mechanism != "PLAIN" {
return nil, common.NewCommandErrorMsgWithArgument(
common.ErrTypeMismatch,
var username, password string

switch mechanism {
case "PLAIN":
username, password, err = common.SASLStartPlain(doc)
default:
err = commonerrors.NewCommandErrorMsgWithArgument(
commonerrors.ErrTypeMismatch,
"Unsupported mechanism '"+mechanism+"'",
"mechanism",
)
}

payload, err := getPayload(doc)
if err != nil {
return nil, lazyerrors.Error(err)
}

parts := bytes.Split(payload, []byte{0})
if l := len(parts); l != 3 {
return nil, common.NewCommandErrorMsgWithArgument(
common.ErrTypeMismatch,
fmt.Sprintf("Invalid payload (expected 3 parts, got %d)", l),
"payload",
)
}

authzid, authcid, passwd := parts[0], parts[1], parts[2]

// Some drivers (Go) send empty authorization identity (authzid),
// while others (Java) set it to the same value as authentication identity (authcid)
// (see https://www.rfc-editor.org/rfc/rfc4616.html).
// Ignore authzid for now.
_ = authzid

conninfo.Get(ctx).SetAuth(string(authcid), string(passwd))
conninfo.Get(ctx).SetAuth(username, password)

if _, err = h.DBPool(ctx); err != nil {
return nil, lazyerrors.Error(err)
Expand Down
35 changes: 10 additions & 25 deletions internal/handlers/tigris/msg_saslstart.go
Original file line number Diff line number Diff line change
Expand Up @@ -15,12 +15,11 @@
package tigris

import (
"bytes"
"context"
"fmt"

"github.com/FerretDB/FerretDB/internal/clientconn/conninfo"
"github.com/FerretDB/FerretDB/internal/handlers/common"
"github.com/FerretDB/FerretDB/internal/handlers/commonerrors"
"github.com/FerretDB/FerretDB/internal/types"
"github.com/FerretDB/FerretDB/internal/util/lazyerrors"
"github.com/FerretDB/FerretDB/internal/util/must"
Expand All @@ -39,37 +38,23 @@ func (h *Handler) MsgSASLStart(ctx context.Context, msg *wire.OpMsg) (*wire.OpMs
return nil, lazyerrors.Error(err)
}

if mechanism != "PLAIN" {
return nil, common.NewCommandErrorMsgWithArgument(
common.ErrTypeMismatch,
var username, password string

switch mechanism {
case "PLAIN":
username, password, err = common.SASLStartPlain(doc)
default:
err = commonerrors.NewCommandErrorMsgWithArgument(
commonerrors.ErrTypeMismatch,
"Unsupported mechanism '"+mechanism+"'",
"mechanism",
)
}

payload, err := common.GetRequiredParam[types.Binary](doc, "payload")
if err != nil {
return nil, lazyerrors.Error(err)
}

parts := bytes.Split(payload.B, []byte{0})
if l := len(parts); l != 3 {
return nil, common.NewCommandErrorMsgWithArgument(
common.ErrTypeMismatch,
fmt.Sprintf("Invalid payload (expected 3 parts, got %d)", l),
"payload",
)
}

authzid, authcid, passwd := parts[0], parts[1], parts[2]

// Some drivers (Go) send empty authorization identity (authzid),
// while others (Java) set it to the same value as authentication identity (authcid)
// (see https://www.rfc-editor.org/rfc/rfc4616.html).
// Ignore authzid for now.
_ = authzid

conninfo.Get(ctx).SetAuth(string(authcid), string(passwd))
conninfo.Get(ctx).SetAuth(username, password)

if _, err = h.DBPool(ctx); err != nil {
return nil, lazyerrors.Error(err)
Expand Down

0 comments on commit 3e2a5ae

Please sign in to comment.