Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Support $comment query operator #563

Merged
merged 5 commits into from
May 11, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
29 changes: 29 additions & 0 deletions integration/basic_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -93,3 +93,32 @@ func TestInsertFind(t *testing.T) {
})
}
}

func TestFindCommentMethod(t *testing.T) {
t.Parallel()
ctx, collection := setup(t, shareddata.Scalars)
name := collection.Name()
databaseNames, err := collection.Database().Client().ListDatabaseNames(ctx, bson.D{})
require.NoError(t, err)
comment := "*/ 1; DROP SCHEMA " + name + " CASCADE -- "

var doc bson.D
opts := options.FindOne().SetComment(comment)
err = collection.FindOne(ctx, bson.D{{"_id", "string"}}, opts).Decode(&doc)
require.NoError(t, err)
AlekSi marked this conversation as resolved.
Show resolved Hide resolved
assert.Contains(t, databaseNames, name)
}

func TestFindCommentQuery(t *testing.T) {
t.Parallel()
ctx, collection := setup(t, shareddata.Scalars)
name := collection.Name()
databaseNames, err := collection.Database().Client().ListDatabaseNames(ctx, bson.D{})
require.NoError(t, err)
comment := "*/ 1; DROP SCHEMA " + name + " CASCADE -- "

var doc bson.D
err = collection.FindOne(ctx, bson.M{"_id": "string", "$comment": comment}).Decode(&doc)
require.NoError(t, err)
assert.Contains(t, databaseNames, name)
}
3 changes: 3 additions & 0 deletions internal/handlers/common/filter.go
Original file line number Diff line number Diff line change
Expand Up @@ -185,6 +185,9 @@ func filterOperator(doc *types.Document, operator string, filterValue any) (bool
}
return true, nil

case "$comment":
return true, nil

default:
msg := fmt.Sprintf(
`unknown top level operator: %s. `+
Expand Down
21 changes: 18 additions & 3 deletions internal/handlers/pg/fetch.go
Original file line number Diff line number Diff line change
Expand Up @@ -16,8 +16,8 @@ package pg

import (
"context"
"fmt"
"io"
"strings"

"github.com/jackc/pgx/v4"

Expand All @@ -26,11 +26,26 @@ import (
"github.com/FerretDB/FerretDB/internal/util/lazyerrors"
)

// sqlParam represents options/parameters used for sql query.
type sqlParam struct {
db string
collection string
comment string
}

// fetch fetches all documents from the given database and collection.
//
// TODO https://github.com/FerretDB/FerretDB/issues/372
func (h *Handler) fetch(ctx context.Context, db, collection string) ([]*types.Document, error) {
sql := fmt.Sprintf(`SELECT _jsonb FROM %s`, pgx.Identifier{db, collection}.Sanitize())
func (h *Handler) fetch(ctx context.Context, param sqlParam) ([]*types.Document, error) {
sql := `SELECT `
if param.comment != "" {
param.comment = strings.ReplaceAll(param.comment, "/*", "/ *")
param.comment = strings.ReplaceAll(param.comment, "*/", "* /")

sql += `/* ` + param.comment + ` */ `
}
sql += `_jsonb FROM ` + pgx.Identifier{param.db, param.collection}.Sanitize()

rows, err := h.pgPool.Query(ctx, sql)
if err != nil {
return nil, lazyerrors.Error(err)
Expand Down
29 changes: 18 additions & 11 deletions internal/handlers/pg/msg_count.go
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@ package pg

import (
"context"
"fmt"

"github.com/FerretDB/FerretDB/internal/handlers/common"
"github.com/FerretDB/FerretDB/internal/types"
Expand Down Expand Up @@ -45,16 +46,6 @@ func (h *Handler) MsgCount(ctx context.Context, msg *wire.OpMsg) (*wire.OpMsg, e
}
common.Ignored(document, h.l, ignoredFields...)

command := document.Command()

var db, collection string
if db, err = common.GetRequiredParam[string](document, "$db"); err != nil {
return nil, err
}
if collection, err = common.GetRequiredParam[string](document, command); err != nil {
return nil, err
}

var filter *types.Document
if filter, err = common.GetOptionalParam(document, "query", filter); err != nil {
return nil, err
Expand All @@ -67,7 +58,23 @@ func (h *Handler) MsgCount(ctx context.Context, msg *wire.OpMsg) (*wire.OpMsg, e
}
}

fetchedDocs, err := h.fetch(ctx, db, collection)
var sp sqlParam
if sp.db, err = common.GetRequiredParam[string](document, "$db"); err != nil {
return nil, err
}
collectionParam, err := document.Get(document.Command())
if err != nil {
return nil, err
}
var ok bool
if sp.collection, ok = collectionParam.(string); !ok {
return nil, common.NewErrorMsg(
common.ErrBadValue,
fmt.Sprintf("collection name has invalid type %s", common.AliasFromType(collectionParam)),
)
}

fetchedDocs, err := h.fetch(ctx, sp)
if err != nil {
return nil, err
}
Expand Down
30 changes: 18 additions & 12 deletions internal/handlers/pg/msg_delete.go
Original file line number Diff line number Diff line change
Expand Up @@ -42,16 +42,6 @@ func (h *Handler) MsgDelete(ctx context.Context, msg *wire.OpMsg) (*wire.OpMsg,
}
common.Ignored(document, h.l, "ordered", "writeConcern")

command := document.Command()

var db, collection string
if db, err = common.GetRequiredParam[string](document, "$db"); err != nil {
return nil, err
}
if collection, err = common.GetRequiredParam[string](document, command); err != nil {
return nil, err
}

var deletes *types.Array
if deletes, err = common.GetOptionalParam(document, "deletes", deletes); err != nil {
return nil, err
Expand Down Expand Up @@ -80,7 +70,23 @@ func (h *Handler) MsgDelete(ctx context.Context, msg *wire.OpMsg) (*wire.OpMsg,
}
}

fetchedDocs, err := h.fetch(ctx, db, collection)
var sp sqlParam
if sp.db, err = common.GetRequiredParam[string](document, "$db"); err != nil {
return nil, err
}
collectionParam, err := document.Get(document.Command())
if err != nil {
return nil, err
}
var ok bool
if sp.collection, ok = collectionParam.(string); !ok {
return nil, common.NewErrorMsg(
common.ErrBadValue,
fmt.Sprintf("collection name has invalid type %s", common.AliasFromType(collectionParam)),
)
}

fetchedDocs, err := h.fetch(ctx, sp)
if err != nil {
return nil, err
}
Expand Down Expand Up @@ -118,7 +124,7 @@ func (h *Handler) MsgDelete(ctx context.Context, msg *wire.OpMsg) (*wire.OpMsg,

sql := fmt.Sprintf(
"DELETE FROM %s WHERE _jsonb->'_id' IN (%s)",
pgx.Identifier{db, collection}.Sanitize(), strings.Join(placeholders, ", "),
pgx.Identifier{sp.db, sp.collection}.Sanitize(), strings.Join(placeholders, ", "),
)
tag, err := h.pgPool.Exec(ctx, sql, ids...)
if err != nil {
Expand Down
47 changes: 26 additions & 21 deletions internal/handlers/pg/msg_find.go
Original file line number Diff line number Diff line change
Expand Up @@ -52,32 +52,13 @@ func (h *Handler) MsgFind(ctx context.Context, msg *wire.OpMsg) (*wire.OpMsg, er
"hint",
"batchSize",
"singleBatch",
"comment",
"maxTimeMS",
"readConcern",
"max",
"min",
}
common.Ignored(document, h.l, ignoredFields...)

command := document.Command()

var db, collection string
if db, err = common.GetRequiredParam[string](document, "$db"); err != nil {
return nil, err
}
collectionParam, err := document.Get(command)
if err != nil {
return nil, err
}
collection, ok := collectionParam.(string)
if !ok {
return nil, common.NewErrorMsg(
common.ErrBadValue,
fmt.Sprintf("collection name has invalid type %s", common.AliasFromType(collectionParam)),
)
}

var filter, sort, projection *types.Document
if filter, err = common.GetOptionalParam(document, "filter", filter); err != nil {
return nil, err
Expand All @@ -96,7 +77,31 @@ func (h *Handler) MsgFind(ctx context.Context, msg *wire.OpMsg) (*wire.OpMsg, er
}
}

fetchedDocs, err := h.fetch(ctx, db, collection)
seeforschauer marked this conversation as resolved.
Show resolved Hide resolved
var sp sqlParam
if sp.db, err = common.GetRequiredParam[string](document, "$db"); err != nil {
return nil, err
}
collectionParam, err := document.Get(document.Command())
if err != nil {
return nil, err
}
var ok bool
if sp.collection, ok = collectionParam.(string); !ok {
return nil, common.NewErrorMsg(
common.ErrBadValue,
fmt.Sprintf("collection name has invalid type %s", common.AliasFromType(collectionParam)),
)
}
// comment set through options.FindOne().SetComment() method
if sp.comment, err = common.GetOptionalParam(document, "comment", sp.comment); err != nil {
return nil, err
}
// comment in query, e.g. db.collection.find({$comment: "test"})
seeforschauer marked this conversation as resolved.
Show resolved Hide resolved
if sp.comment, err = common.GetOptionalParam(filter, "$comment", sp.comment); err != nil {
return nil, err
}

fetchedDocs, err := h.fetch(ctx, sp)
if err != nil {
return nil, err
}
Expand Down Expand Up @@ -138,7 +143,7 @@ func (h *Handler) MsgFind(ctx context.Context, msg *wire.OpMsg) (*wire.OpMsg, er
"cursor", must.NotFail(types.NewDocument(
"firstBatch", firstBatch,
"id", int64(0), // TODO
"ns", db+"."+collection,
"ns", sp.db+"."+sp.collection,
)),
"ok", float64(1),
))},
Expand Down
28 changes: 17 additions & 11 deletions internal/handlers/pg/msg_findandmodify.go
Original file line number Diff line number Diff line change
Expand Up @@ -67,26 +67,32 @@ func (h *Handler) MsgFindAndModify(ctx context.Context, msg *wire.OpMsg) (*wire.
}
common.Ignored(document, h.l, ignoredFields...)

command := document.Command()

var db, collection string
if db, err = common.GetRequiredParam[string](document, "$db"); err != nil {
var query *types.Document
var remove bool
if query, err = common.GetOptionalParam(document, "query", query); err != nil {
return nil, err
}
if collection, err = common.GetRequiredParam[string](document, command); err != nil {
if remove, err = common.GetOptionalParam(document, "remove", remove); err != nil {
return nil, err
}

var query *types.Document
var remove bool
if query, err = common.GetOptionalParam(document, "query", query); err != nil {
var sp sqlParam
if sp.db, err = common.GetRequiredParam[string](document, "$db"); err != nil {
return nil, err
}
if remove, err = common.GetOptionalParam(document, "remove", remove); err != nil {
collectionParam, err := document.Get(document.Command())
if err != nil {
return nil, err
}
var ok bool
if sp.collection, ok = collectionParam.(string); !ok {
return nil, common.NewErrorMsg(
common.ErrBadValue,
fmt.Sprintf("collection name has invalid type %s", common.AliasFromType(collectionParam)),
)
}

fetchedDocs, err := h.fetch(ctx, db, collection)
fetchedDocs, err := h.fetch(ctx, sp)
if err != nil {
return nil, err
}
Expand All @@ -112,7 +118,7 @@ func (h *Handler) MsgFindAndModify(ctx context.Context, msg *wire.OpMsg) (*wire.

if len(resDocs) == 1 && remove {
id := must.NotFail(fjson.Marshal(must.NotFail(resDocs[0].Get("_id"))))
sql := fmt.Sprintf("DELETE FROM %s WHERE _jsonb->'_id' IN ($1)", pgx.Identifier{db, collection}.Sanitize())
sql := fmt.Sprintf("DELETE FROM %s WHERE _jsonb->'_id' IN ($1)", pgx.Identifier{sp.db, sp.collection}.Sanitize())
if _, err := h.pgPool.Exec(ctx, sql, id); err != nil {
return nil, lazyerrors.Error(err)
}
Expand Down
27 changes: 17 additions & 10 deletions internal/handlers/pg/msg_update.go
Original file line number Diff line number Diff line change
Expand Up @@ -43,27 +43,33 @@ func (h *Handler) MsgUpdate(ctx context.Context, msg *wire.OpMsg) (*wire.OpMsg,
}
common.Ignored(document, h.l, "ordered", "writeConcern", "bypassDocumentValidation", "comment")

command := document.Command()

var db, collection string
if db, err = common.GetRequiredParam[string](document, "$db"); err != nil {
var sp sqlParam
if sp.db, err = common.GetRequiredParam[string](document, "$db"); err != nil {
return nil, err
}
if collection, err = common.GetRequiredParam[string](document, command); err != nil {
collectionParam, err := document.Get(document.Command())
if err != nil {
return nil, err
}
var ok bool
if sp.collection, ok = collectionParam.(string); !ok {
return nil, common.NewErrorMsg(
common.ErrBadValue,
fmt.Sprintf("collection name has invalid type %s", common.AliasFromType(collectionParam)),
)
}

var updates *types.Array
if updates, err = common.GetOptionalParam(document, "updates", updates); err != nil {
return nil, err
}

created, err := h.pgPool.EnsureTableExist(ctx, db, collection)
created, err := h.pgPool.EnsureTableExist(ctx, sp.db, sp.collection)
if err != nil {
return nil, err
}
if created {
h.l.Info("Created table.", zap.String("schema", db), zap.String("table", collection))
h.l.Info("Created table.", zap.String("schema", sp.db), zap.String("table", sp.collection))
}

var matched, modified int32
Expand Down Expand Up @@ -97,7 +103,7 @@ func (h *Handler) MsgUpdate(ctx context.Context, msg *wire.OpMsg) (*wire.OpMsg,
return nil, err
}

fetchedDocs, err := h.fetch(ctx, db, collection)
fetchedDocs, err := h.fetch(ctx, sp)
if err != nil {
return nil, err
}
Expand Down Expand Up @@ -135,7 +141,7 @@ func (h *Handler) MsgUpdate(ctx context.Context, msg *wire.OpMsg) (*wire.OpMsg,
"_id", must.NotFail(doc.Get("_id")),
))))

sql := fmt.Sprintf("INSERT INTO %s (_jsonb) VALUES ($1)", pgx.Identifier{db, collection}.Sanitize())
sql := fmt.Sprintf("INSERT INTO %s (_jsonb) VALUES ($1)", pgx.Identifier{sp.db, sp.collection}.Sanitize())
b, err := fjson.Marshal(doc)
if err != nil {
return nil, err
Expand All @@ -157,7 +163,8 @@ func (h *Handler) MsgUpdate(ctx context.Context, msg *wire.OpMsg) (*wire.OpMsg,
return nil, lazyerrors.Error(err)
}

sql := fmt.Sprintf("UPDATE %s SET _jsonb = $1 WHERE _jsonb->'_id' = $2", pgx.Identifier{db, collection}.Sanitize())
sql := "UPDATE " + pgx.Identifier{sp.db, sp.collection}.Sanitize() +
" SET _jsonb = $1 WHERE _jsonb->'_id' = $2"
id := must.NotFail(doc.Get("_id"))
tag, err := h.pgPool.Exec(ctx, sql, must.NotFail(fjson.Marshal(doc)), must.NotFail(fjson.Marshal(id)))
if err != nil {
Expand Down