Skip to content

Commit

Permalink
Explicitly disallow duplicate keys in data documents (#1293)
Browse files Browse the repository at this point in the history
Closes #364.
  • Loading branch information
Elena Grahovac authored Oct 25, 2022
1 parent 38b30b5 commit 552683e
Show file tree
Hide file tree
Showing 27 changed files with 244 additions and 192 deletions.
5 changes: 4 additions & 1 deletion integration/update_field_compat_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -184,14 +184,17 @@ func TestUpdateFieldCompatUnset(t *testing.T) {
testUpdateCompat(t, testCases)
}

// TestUpdateFieldCompatNull checks that update works correctly for the null values.
func TestUpdateFieldCompatSet(t *testing.T) {
t.Parallel()

testCases := map[string]updateCompatTestCase{
"SetNullInExisingField": {
update: bson.D{{"$set", bson.D{{"v", nil}}}},
},
"DuplicateKeys": {
update: bson.D{{"$set", bson.D{{"v", 42}, {"v", "hello"}}}},
skip: "https://github.com/FerretDB/FerretDB/issues/1263",
},
}

testUpdateCompat(t, testCases)
Expand Down
18 changes: 13 additions & 5 deletions integration/update_field_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -1408,6 +1408,17 @@ func TestUpdateDocumentFieldsOrderSimplified(t *testing.T) {
_, err := collection.InsertOne(ctx, bson.D{{"_id", "document"}, {"foo", int32(42)}, {"bar", "baz"}})
require.NoError(t, err)

var inserted bson.D
err = collection.FindOne(ctx, bson.D{{"_id", "document"}}).Decode(&inserted)
require.NoError(t, err)

expected := bson.D{
{"_id", "document"},
{"foo", int32(42)},
{"bar", "baz"},
}
AssertEqualDocuments(t, expected, inserted)

_, err = collection.UpdateOne(
ctx,
bson.D{{"_id", "document"}},
Expand All @@ -1416,13 +1427,13 @@ func TestUpdateDocumentFieldsOrderSimplified(t *testing.T) {
require.NoError(t, err)

var updated bson.D

err = collection.FindOne(ctx, bson.D{{"_id", "document"}}).Decode(&updated)
require.NoError(t, err)

expected := bson.D{
expected = bson.D{
{"_id", "document"},
}
AssertEqualDocuments(t, expected, updated)

_, err = collection.UpdateOne(
ctx,
Expand All @@ -1431,8 +1442,6 @@ func TestUpdateDocumentFieldsOrderSimplified(t *testing.T) {
)
require.NoError(t, err)

AssertEqualDocuments(t, expected, updated)

err = collection.FindOne(ctx, bson.D{{"_id", "document"}}).Decode(&updated)
require.NoError(t, err)

Expand All @@ -1441,6 +1450,5 @@ func TestUpdateDocumentFieldsOrderSimplified(t *testing.T) {
{"bar", "baz"},
{"foo", int32(42)},
}

AssertEqualDocuments(t, expected, updated)
}
10 changes: 5 additions & 5 deletions internal/handlers/common/error.go
Original file line number Diff line number Diff line change
Expand Up @@ -192,8 +192,8 @@ func (e *CommandError) Document() *types.Document {
"errmsg", e.err.Error(),
))
if e.code != errUnset {
must.NoError(d.Set("code", int32(e.code)))
must.NoError(d.Set("codeName", e.code.String()))
d.Set("code", int32(e.code))
d.Set("codeName", e.code.String())
}
return d
}
Expand Down Expand Up @@ -243,13 +243,13 @@ func (we *WriteErrors) Document() *types.Document {
doc := must.NotFail(types.NewDocument())

if e.index != nil {
must.NoError(doc.Set("index", *e.index))
doc.Set("index", *e.index)
}

// Fields "code" and "errmsg" must always be filled in so that clients can parse the error message.
// Otherwise, the mongo client would parse it as a CommandError.
must.NoError(doc.Set("code", int32(e.code)))
must.NoError(doc.Set("errmsg", e.err))
doc.Set("code", int32(e.code))
doc.Set("errmsg", e.err)

must.NoError(errs.Append(doc))
}
Expand Down
6 changes: 4 additions & 2 deletions internal/handlers/common/projection.go
Original file line number Diff line number Diff line change
Expand Up @@ -202,11 +202,13 @@ func applyComplexProjection(k1 string, doc, projectionVal *types.Document) (err
if err != nil {
return err
}

if res == nil {
must.NoError(doc.Set(k1, types.Null))
doc.Set(k1, types.Null)
return nil
}
must.NoError(doc.Set(k1, res))

doc.Set(k1, res)
default:
return NewError(ErrCommandNotFound,
lazyerrors.Errorf("applyComplexProjection: unknown projection operator: %q", projectionType),
Expand Down
2 changes: 1 addition & 1 deletion internal/handlers/common/serverstatus.go
Original file line number Diff line number Diff line change
Expand Up @@ -57,7 +57,7 @@ func ServerStatus(startTime time.Time, cm *connmetrics.ConnMetrics) (*types.Docu
cmdDoc = must.NotFail(types.NewDocument("total", cmdMetrics.Total, "failed", cmdMetrics.Failed))
}

must.NoError(metricsDoc.Set(cmd, cmdDoc))
metricsDoc.Set(cmd, cmdDoc)
}

res := must.NotFail(types.NewDocument(
Expand Down
42 changes: 16 additions & 26 deletions internal/handlers/common/update.go
Original file line number Diff line number Diff line change
Expand Up @@ -105,19 +105,18 @@ func UpdateDocument(doc, update *types.Document) (bool, error) {
// Treats the update as a Replacement object.
setDoc := update

sort.Strings(setDoc.Keys())

for _, setKey := range doc.Keys() {
if !setDoc.Has(setKey) && setKey != "_id" {
doc.Remove(setKey)
}
}

for _, setKey := range setDoc.Keys() {
setDocKeys := setDoc.Keys()
sort.Strings(setDocKeys)

for _, setKey := range setDocKeys {
setValue := must.NotFail(setDoc.Get(setKey))
if err := doc.Set(setKey, setValue); err != nil {
return false, err
}
doc.Set(setKey, setValue)
}

changed = true
Expand All @@ -132,9 +131,10 @@ func UpdateDocument(doc, update *types.Document) (bool, error) {
func processSetFieldExpression(doc, setDoc *types.Document, setOnInsert bool) (bool, error) {
var changed bool

sort.Strings(setDoc.Keys())
setDocKeys := setDoc.Keys()
sort.Strings(setDocKeys)

for _, setKey := range setDoc.Keys() {
for _, setKey := range setDocKeys {
setValue := must.NotFail(setDoc.Get(setKey))

path := types.NewPathFromString(setKey)
Expand Down Expand Up @@ -365,9 +365,7 @@ func processMaxFieldExpression(doc *types.Document, updateV any) (bool, error) {
}
}

if err := doc.Set(field, maxVal); err != nil {
return changed, err
}
doc.Set(field, maxVal)
changed = true
}

Expand All @@ -378,45 +376,37 @@ func processMaxFieldExpression(doc *types.Document, updateV any) (bool, error) {
// If the document was changed it returns true.
func processCurrentDateFieldExpression(doc *types.Document, currentDateVal any) (bool, error) {
var changed bool
var err error
currentDateExpression := currentDateVal.(*types.Document)

now := time.Now().UTC()
sort.Strings(currentDateExpression.Keys())
keys := currentDateExpression.Keys()
sort.Strings(keys)

for _, field := range currentDateExpression.Keys() {
for _, field := range keys {
currentDateField := must.NotFail(currentDateExpression.Get(field))

switch currentDateField := currentDateField.(type) {
case *types.Document:
currentDateType, err := currentDateField.Get("$type")
if err != nil { // default is date
if err := doc.Set(field, now); err != nil {
return false, err
}
doc.Set(field, now)
changed = true
continue
}

currentDateType = currentDateType.(string)
switch currentDateType {
case "timestamp":
if err := doc.Set(field, types.NextTimestamp(now)); err != nil {
return false, err
}
doc.Set(field, types.NextTimestamp(now))
changed = true

case "date":
if err := doc.Set(field, now); err != nil {
return false, err
}
doc.Set(field, now)
changed = true
}

case bool:
if err = doc.Set(field, now); err != nil {
return false, err
}
doc.Set(field, now)
changed = true
}
}
Expand Down
2 changes: 1 addition & 1 deletion internal/handlers/pg/msg_delete.go
Original file line number Diff line number Diff line change
Expand Up @@ -119,7 +119,7 @@ func (h *Handler) MsgDelete(ctx context.Context, msg *wire.OpMsg) (*wire.OpMsg,
replyDoc = delErrors.Document()
}

must.NoError(replyDoc.Set("n", deleted))
replyDoc.Set("n", deleted)

var reply wire.OpMsg

Expand Down
2 changes: 1 addition & 1 deletion internal/handlers/pg/msg_explain.go
Original file line number Diff line number Diff line change
Expand Up @@ -96,7 +96,7 @@ func (h *Handler) MsgExplain(ctx context.Context, msg *wire.OpMsg) (*wire.OpMsg,
))

cmd := command.DeepCopy()
must.NoError(cmd.Set("$db", sp.DB))
cmd.Set("$db", sp.DB)

var reply wire.OpMsg
err = reply.SetSections(wire.OpMsgSection{
Expand Down
10 changes: 5 additions & 5 deletions internal/handlers/pg/msg_findandmodify.go
Original file line number Diff line number Diff line change
Expand Up @@ -164,7 +164,7 @@ func (h *Handler) MsgFindAndModify(ctx context.Context, msg *wire.OpMsg) (*wire.
upsert = params.Update

if !upsert.Has("_id") {
must.NoError(upsert.Set("_id", must.NotFail(resDocs[0].Get("_id"))))
upsert.Set("_id", must.NotFail(resDocs[0].Get("_id")))
}

_, err = h.update(ctx, tx, &sqlParam, upsert)
Expand All @@ -188,7 +188,7 @@ func (h *Handler) MsgFindAndModify(ctx context.Context, msg *wire.OpMsg) (*wire.
))

if upserted {
must.NoError(lastErrorObject.Set("upserted", must.NotFail(resultDoc.Get("_id"))))
lastErrorObject.Set("upserted", must.NotFail(resultDoc.Get("_id")))
}

must.NoError(reply.SetSections(wire.OpMsgSection{
Expand Down Expand Up @@ -265,9 +265,9 @@ func (h *Handler) upsert(ctx context.Context, tx pgx.Tx, docs []*types.Document,

if !upsert.Has("_id") {
if params.query.Has("_id") {
must.NoError(upsert.Set("_id", must.NotFail(params.query.Get("_id"))))
upsert.Set("_id", must.NotFail(params.query.Get("_id")))
} else {
must.NoError(upsert.Set("_id", types.NewObjectID()))
upsert.Set("_id", types.NewObjectID())
}
}

Expand All @@ -288,7 +288,7 @@ func (h *Handler) upsert(ctx context.Context, tx pgx.Tx, docs []*types.Document,
}
} else {
for _, k := range params.update.Keys() {
must.NoError(upsert.Set(k, must.NotFail(params.update.Get(k))))
upsert.Set(k, must.NotFail(params.update.Get(k)))
}
}

Expand Down
16 changes: 4 additions & 12 deletions internal/handlers/pg/msg_getparameter.go
Original file line number Diff line number Diff line change
Expand Up @@ -115,24 +115,16 @@ func selectUnit(document, resDB *types.Document, showDetails, allParameters bool
item = val
}
}
err = doc.Set(k, item)
if err != nil {
return nil, err
}

doc.Set(k, item)
}

if doc.Len() < 1 {
err := doc.Set("ok", float64(0))
if err != nil {
return nil, err
}
doc.Set("ok", float64(0))
return doc, nil
}

err = doc.Set("ok", float64(1))
if err != nil {
return nil, err
}
doc.Set("ok", float64(1))
return doc, nil
}

Expand Down
10 changes: 6 additions & 4 deletions internal/handlers/pg/msg_update.go
Original file line number Diff line number Diff line change
Expand Up @@ -183,7 +183,7 @@ func (h *Handler) MsgUpdate(ctx context.Context, msg *wire.OpMsg) (*wire.OpMsg,
return err
}
if !doc.Has("_id") {
must.NoError(doc.Set("_id", types.NewObjectID()))
doc.Set("_id", types.NewObjectID())
}

must.NoError(upserted.Append(must.NotFail(types.NewDocument(
Expand Down Expand Up @@ -233,11 +233,13 @@ func (h *Handler) MsgUpdate(ctx context.Context, msg *wire.OpMsg) (*wire.OpMsg,
res := must.NotFail(types.NewDocument(
"n", matched,
))

if upserted.Len() != 0 {
must.NoError(res.Set("upserted", &upserted))
res.Set("upserted", &upserted)
}
must.NoError(res.Set("nModified", modified))
must.NoError(res.Set("ok", float64(1)))

res.Set("nModified", modified)
res.Set("ok", float64(1))

var reply wire.OpMsg
err = reply.SetSections(wire.OpMsgSection{
Expand Down
4 changes: 2 additions & 2 deletions internal/handlers/pg/pgdb/collections.go
Original file line number Diff line number Diff line change
Expand Up @@ -126,8 +126,8 @@ func CreateCollection(ctx context.Context, tx pgx.Tx, db, collection string) err
}

// TODO keep "collections" sorted after each update
must.NoError(collections.Set(collection, table))
must.NoError(settings.Set("collections", collections))
collections.Set(collection, table)
settings.Set("collections", collections)

err = updateSettingsTable(ctx, tx, db, settings)
if err != nil {
Expand Down
2 changes: 1 addition & 1 deletion internal/handlers/pg/pgdb/query.go
Original file line number Diff line number Diff line change
Expand Up @@ -291,7 +291,7 @@ func convertJSON(value any) any {
keys := maps.Keys(value)
for _, k := range keys {
v := value[k]
must.NoError(d.Set(k, convertJSON(v)))
d.Set(k, convertJSON(v))
}
return d

Expand Down
6 changes: 3 additions & 3 deletions internal/handlers/pg/pgdb/settings.go
Original file line number Diff line number Diff line change
Expand Up @@ -128,8 +128,8 @@ func getTableName(ctx context.Context, tx pgx.Tx, db, collection string) (string
}

tableName := formatCollectionName(collection)
must.NoError(collections.Set(collection, tableName))
must.NoError(settings.Set("collections", collections))
collections.Set(collection, tableName)
settings.Set("collections", collections)

err = updateSettingsTable(ctx, tx, db, settings)
if err != nil {
Expand Down Expand Up @@ -195,7 +195,7 @@ func removeTableFromSettings(ctx context.Context, tx pgx.Tx, db, collection stri

collections.Remove(collection)

must.NoError(settings.Set("collections", collections))
settings.Set("collections", collections)

if err := updateSettingsTable(ctx, tx, db, settings); err != nil {
return lazyerrors.Error(err)
Expand Down
4 changes: 1 addition & 3 deletions internal/handlers/pg/pjson/document.go
Original file line number Diff line number Diff line change
Expand Up @@ -75,9 +75,7 @@ func (doc *documentType) UnmarshalJSON(data []byte) error {
return lazyerrors.Error(err)
}

if err = td.Set(key, v); err != nil {
return lazyerrors.Error(err)
}
td.Set(key, v)
}

*doc = documentType(*td)
Expand Down
2 changes: 1 addition & 1 deletion internal/handlers/tigris/msg_delete.go
Original file line number Diff line number Diff line change
Expand Up @@ -111,7 +111,7 @@ func (h *Handler) MsgDelete(ctx context.Context, msg *wire.OpMsg) (*wire.OpMsg,
replyDoc = delErrors.Document()
}

must.NoError(replyDoc.Set("n", deleted))
replyDoc.Set("n", deleted)

var reply wire.OpMsg

Expand Down
Loading

1 comment on commit 552683e

@vercel
Copy link

@vercel vercel bot commented on 552683e Oct 25, 2022

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Successfully deployed to the following URLs:

ferret-db – ./

ferret-db-ferretdb.vercel.app
ferret-db.vercel.app
ferret-db-git-main-ferretdb.vercel.app

Please sign in to comment.