Skip to content

Commit

Permalink
fetch has been updated
Browse files Browse the repository at this point in the history
  • Loading branch information
ribaraka committed May 10, 2022
1 parent 2352785 commit 48ea3e4
Show file tree
Hide file tree
Showing 6 changed files with 93 additions and 91 deletions.
57 changes: 17 additions & 40 deletions internal/handlers/pg/fetch.go
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,6 @@ package pg

import (
"context"
"fmt"
"io"
"strings"

Expand All @@ -27,11 +26,26 @@ import (
"github.com/FerretDB/FerretDB/internal/util/lazyerrors"
)

// sqlParam represents options/parameters used for sql query.
type sqlParam struct {
db string
collection string
comment string
}

// fetch fetches all documents from the given database and collection.
//
// TODO https://github.com/FerretDB/FerretDB/issues/372
func (h *Handler) fetch(ctx context.Context, db, collection string) ([]*types.Document, error) {
sql := fmt.Sprintf(`SELECT _jsonb FROM %s`, pgx.Identifier{db, collection}.Sanitize())
func (h *Handler) fetch(ctx context.Context, param sqlParam) ([]*types.Document, error) {
sql := `SELECT `
if param.comment != "" {
param.comment = strings.ReplaceAll(param.comment, "/*", "/ *")
param.comment = strings.ReplaceAll(param.comment, "*/", "* /")

sql += `/* ` + param.comment + ` */ `
}
sql += `_jsonb FROM ` + pgx.Identifier{param.db, param.collection}.Sanitize()

rows, err := h.pgPool.Query(ctx, sql)
if err != nil {
return nil, lazyerrors.Error(err)
Expand Down Expand Up @@ -74,40 +88,3 @@ func nextRow(rows pgx.Rows) (*types.Document, error) {

return doc.(*types.Document), nil
}

func (h *Handler) protoFetch(ctx context.Context, param fetchParam) ([]*types.Document, error) {
var sql string
if param.comment != "" {
param.comment = strings.ReplaceAll(param.comment, "/*", "/ *")
param.comment = strings.ReplaceAll(param.comment, "*/", "* /")
param.comment = fmt.Sprintf("/* %s */", param.comment)

sql = fmt.Sprintf(`SELECT %s _jsonb FROM %s`, param.comment, pgx.Identifier{param.db, param.collection}.Sanitize())
} else {
sql = fmt.Sprintf(`SELECT _jsonb FROM %s`, pgx.Identifier{param.db, param.collection}.Sanitize())
}

rows, err := h.pgPool.Query(ctx, sql)
if err != nil {
return nil, lazyerrors.Error(err)
}
defer rows.Close()

var res []*types.Document
for {
doc, err := nextRow(rows)
if err == io.EOF {
break
}
if err != nil {
return nil, lazyerrors.Error(err)
}
res = append(res, doc)
}

return res, nil
}

type fetchParam struct {
db, collection, comment string
}
29 changes: 18 additions & 11 deletions internal/handlers/pg/msg_count.go
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@ package pg

import (
"context"
"fmt"

"github.com/FerretDB/FerretDB/internal/handlers/common"
"github.com/FerretDB/FerretDB/internal/types"
Expand Down Expand Up @@ -44,16 +45,6 @@ func (h *Handler) MsgCount(ctx context.Context, msg *wire.OpMsg) (*wire.OpMsg, e
}
common.Ignored(document, h.l, ignoredFields...)

command := document.Command()

var db, collection string
if db, err = common.GetRequiredParam[string](document, "$db"); err != nil {
return nil, err
}
if collection, err = common.GetRequiredParam[string](document, command); err != nil {
return nil, err
}

var filter *types.Document
if filter, err = common.GetOptionalParam(document, "query", filter); err != nil {
return nil, err
Expand All @@ -66,7 +57,23 @@ func (h *Handler) MsgCount(ctx context.Context, msg *wire.OpMsg) (*wire.OpMsg, e
}
}

fetchedDocs, err := h.fetch(ctx, db, collection)
var sp sqlParam
if sp.db, err = common.GetRequiredParam[string](document, "$db"); err != nil {
return nil, err
}
collectionParam, err := document.Get(document.Command())
if err != nil {
return nil, err
}
var ok bool
if sp.collection, ok = collectionParam.(string); !ok {
return nil, common.NewErrorMsg(
common.ErrBadValue,
fmt.Sprintf("collection name has invalid type %s", common.AliasFromType(collectionParam)),
)
}

fetchedDocs, err := h.fetch(ctx, sp)
if err != nil {
return nil, err
}
Expand Down
30 changes: 18 additions & 12 deletions internal/handlers/pg/msg_delete.go
Original file line number Diff line number Diff line change
Expand Up @@ -42,16 +42,6 @@ func (h *Handler) MsgDelete(ctx context.Context, msg *wire.OpMsg) (*wire.OpMsg,
}
common.Ignored(document, h.l, "ordered", "writeConcern")

command := document.Command()

var db, collection string
if db, err = common.GetRequiredParam[string](document, "$db"); err != nil {
return nil, err
}
if collection, err = common.GetRequiredParam[string](document, command); err != nil {
return nil, err
}

var deletes *types.Array
if deletes, err = common.GetOptionalParam(document, "deletes", deletes); err != nil {
return nil, err
Expand Down Expand Up @@ -80,7 +70,23 @@ func (h *Handler) MsgDelete(ctx context.Context, msg *wire.OpMsg) (*wire.OpMsg,
}
}

fetchedDocs, err := h.fetch(ctx, db, collection)
var sp sqlParam
if sp.db, err = common.GetRequiredParam[string](document, "$db"); err != nil {
return nil, err
}
collectionParam, err := document.Get(document.Command())
if err != nil {
return nil, err
}
var ok bool
if sp.collection, ok = collectionParam.(string); !ok {
return nil, common.NewErrorMsg(
common.ErrBadValue,
fmt.Sprintf("collection name has invalid type %s", common.AliasFromType(collectionParam)),
)
}

fetchedDocs, err := h.fetch(ctx, sp)
if err != nil {
return nil, err
}
Expand Down Expand Up @@ -118,7 +124,7 @@ func (h *Handler) MsgDelete(ctx context.Context, msg *wire.OpMsg) (*wire.OpMsg,

sql := fmt.Sprintf(
"DELETE FROM %s WHERE _jsonb->'_id' IN (%s)",
pgx.Identifier{db, collection}.Sanitize(), strings.Join(placeholders, ", "),
pgx.Identifier{sp.db, sp.collection}.Sanitize(), strings.Join(placeholders, ", "),
)
tag, err := h.pgPool.Exec(ctx, sql, ids...)
if err != nil {
Expand Down
14 changes: 7 additions & 7 deletions internal/handlers/pg/msg_find.go
Original file line number Diff line number Diff line change
Expand Up @@ -76,31 +76,31 @@ func (h *Handler) MsgFind(ctx context.Context, msg *wire.OpMsg) (*wire.OpMsg, er
}
}

var fp fetchParam
if fp.db, err = common.GetRequiredParam[string](document, "$db"); err != nil {
var sp sqlParam
if sp.db, err = common.GetRequiredParam[string](document, "$db"); err != nil {
return nil, err
}
collectionParam, err := document.Get(document.Command())
if err != nil {
return nil, err
}
var ok bool
if fp.collection, ok = collectionParam.(string); !ok {
if sp.collection, ok = collectionParam.(string); !ok {
return nil, common.NewErrorMsg(
common.ErrBadValue,
fmt.Sprintf("collection name has invalid type %s", common.AliasFromType(collectionParam)),
)
}
// comment set through options.FindOne().SetComment() method
if fp.comment, err = common.GetOptionalParam(document, "comment", fp.comment); err != nil {
if sp.comment, err = common.GetOptionalParam(document, "comment", sp.comment); err != nil {
return nil, err
}
// comment in query, e.g. db.collection.find({$comment: "test"})
if fp.comment, err = common.GetOptionalParam(filter, "$comment", fp.comment); err != nil {
if sp.comment, err = common.GetOptionalParam(filter, "$comment", sp.comment); err != nil {
return nil, err
}

fetchedDocs, err := h.protoFetch(ctx, fp)
fetchedDocs, err := h.fetch(ctx, sp)
if err != nil {
return nil, err
}
Expand Down Expand Up @@ -142,7 +142,7 @@ func (h *Handler) MsgFind(ctx context.Context, msg *wire.OpMsg) (*wire.OpMsg, er
"cursor", types.MustNewDocument(
"firstBatch", firstBatch,
"id", int64(0), // TODO
"ns", fp.db+"."+fp.collection,
"ns", sp.db+"."+sp.collection,
),
"ok", float64(1),
)},
Expand Down
28 changes: 17 additions & 11 deletions internal/handlers/pg/msg_findandmodify.go
Original file line number Diff line number Diff line change
Expand Up @@ -67,26 +67,32 @@ func (h *Handler) MsgFindAndModify(ctx context.Context, msg *wire.OpMsg) (*wire.
}
common.Ignored(document, h.l, ignoredFields...)

command := document.Command()

var db, collection string
if db, err = common.GetRequiredParam[string](document, "$db"); err != nil {
var query *types.Document
var remove bool
if query, err = common.GetOptionalParam(document, "query", query); err != nil {
return nil, err
}
if collection, err = common.GetRequiredParam[string](document, command); err != nil {
if remove, err = common.GetOptionalParam(document, "remove", remove); err != nil {
return nil, err
}

var query *types.Document
var remove bool
if query, err = common.GetOptionalParam(document, "query", query); err != nil {
var sp sqlParam
if sp.db, err = common.GetRequiredParam[string](document, "$db"); err != nil {
return nil, err
}
if remove, err = common.GetOptionalParam(document, "remove", remove); err != nil {
collectionParam, err := document.Get(document.Command())
if err != nil {
return nil, err
}
var ok bool
if sp.collection, ok = collectionParam.(string); !ok {
return nil, common.NewErrorMsg(
common.ErrBadValue,
fmt.Sprintf("collection name has invalid type %s", common.AliasFromType(collectionParam)),
)
}

fetchedDocs, err := h.fetch(ctx, db, collection)
fetchedDocs, err := h.fetch(ctx, sp)
if err != nil {
return nil, err
}
Expand All @@ -112,7 +118,7 @@ func (h *Handler) MsgFindAndModify(ctx context.Context, msg *wire.OpMsg) (*wire.

if len(resDocs) == 1 && remove {
id := must.NotFail(fjson.Marshal(must.NotFail(resDocs[0].Get("_id"))))
sql := fmt.Sprintf("DELETE FROM %s WHERE _jsonb->'_id' IN ($1)", pgx.Identifier{db, collection}.Sanitize())
sql := fmt.Sprintf("DELETE FROM %s WHERE _jsonb->'_id' IN ($1)", pgx.Identifier{sp.db, sp.collection}.Sanitize())
if _, err := h.pgPool.Exec(ctx, sql, id); err != nil {
return nil, lazyerrors.Error(err)
}
Expand Down
26 changes: 16 additions & 10 deletions internal/handlers/pg/msg_update.go
Original file line number Diff line number Diff line change
Expand Up @@ -43,27 +43,33 @@ func (h *Handler) MsgUpdate(ctx context.Context, msg *wire.OpMsg) (*wire.OpMsg,
}
common.Ignored(document, h.l, "ordered", "writeConcern", "bypassDocumentValidation", "comment")

command := document.Command()

var db, collection string
if db, err = common.GetRequiredParam[string](document, "$db"); err != nil {
var sp sqlParam
if sp.db, err = common.GetRequiredParam[string](document, "$db"); err != nil {
return nil, err
}
if collection, err = common.GetRequiredParam[string](document, command); err != nil {
collectionParam, err := document.Get(document.Command())
if err != nil {
return nil, err
}
var ok bool
if sp.collection, ok = collectionParam.(string); !ok {
return nil, common.NewErrorMsg(
common.ErrBadValue,
fmt.Sprintf("collection name has invalid type %s", common.AliasFromType(collectionParam)),
)
}

var updates *types.Array
if updates, err = common.GetOptionalParam(document, "updates", updates); err != nil {
return nil, err
}

created, err := h.pgPool.EnsureTableExist(ctx, db, collection)
created, err := h.pgPool.EnsureTableExist(ctx, sp.db, sp.collection)
if err != nil {
return nil, err
}
if created {
h.l.Info("Created table.", zap.String("schema", db), zap.String("table", collection))
h.l.Info("Created table.", zap.String("schema", sp.db), zap.String("table", sp.collection))
}

var matched, modified int32
Expand Down Expand Up @@ -97,7 +103,7 @@ func (h *Handler) MsgUpdate(ctx context.Context, msg *wire.OpMsg) (*wire.OpMsg,
return nil, err
}

fetchedDocs, err := h.fetch(ctx, db, collection)
fetchedDocs, err := h.fetch(ctx, sp)
if err != nil {
return nil, err
}
Expand Down Expand Up @@ -135,7 +141,7 @@ func (h *Handler) MsgUpdate(ctx context.Context, msg *wire.OpMsg) (*wire.OpMsg,
"_id", must.NotFail(doc.Get("_id")),
))))

sql := fmt.Sprintf("INSERT INTO %s (_jsonb) VALUES ($1)", pgx.Identifier{db, collection}.Sanitize())
sql := fmt.Sprintf("INSERT INTO %s (_jsonb) VALUES ($1)", pgx.Identifier{sp.db, sp.collection}.Sanitize())
b, err := fjson.Marshal(doc)
if err != nil {
return nil, err
Expand All @@ -157,7 +163,7 @@ func (h *Handler) MsgUpdate(ctx context.Context, msg *wire.OpMsg) (*wire.OpMsg,
return nil, lazyerrors.Error(err)
}

sql := fmt.Sprintf("UPDATE %s SET _jsonb = $1 WHERE _jsonb->'_id' = $2", pgx.Identifier{db, collection}.Sanitize())
sql := fmt.Sprintf("UPDATE %s SET _jsonb = $1 WHERE _jsonb->'_id' = $2", pgx.Identifier{sp.db, sp.collection}.Sanitize())
id := must.NotFail(doc.Get("_id"))
tag, err := h.pgPool.Exec(ctx, sql, must.NotFail(fjson.Marshal(doc)), must.NotFail(fjson.Marshal(id)))
if err != nil {
Expand Down

0 comments on commit 48ea3e4

Please sign in to comment.