Skip to content

Commit

Permalink
Merge branch 'main' into test-with-proxy
Browse files Browse the repository at this point in the history
  • Loading branch information
seeforschauer authored Apr 15, 2022
2 parents 49257d9 + 449e184 commit aa41226
Show file tree
Hide file tree
Showing 2 changed files with 62 additions and 14 deletions.
57 changes: 43 additions & 14 deletions internal/types/document.go
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,10 @@ import (
"fmt"
"strconv"
"unicode/utf8"

"golang.org/x/exp/slices"

"github.com/FerretDB/FerretDB/internal/util/must"
)

// Common interface with bson.Document.
Expand Down Expand Up @@ -61,11 +65,7 @@ func ConvertDocument(d document) (*Document, error) {
//
// Deprecated: use `must.NotFail(ConvertDocument(...))` instead.
func MustConvertDocument(d document) *Document {
doc, err := ConvertDocument(d)
if err != nil {
panic(err)
}
return doc
return must.NotFail(ConvertDocument(d))
}

// NewDocument creates a document with the given key/value pairs.
Expand Down Expand Up @@ -108,11 +108,7 @@ func NewDocument(pairs ...any) (*Document, error) {
//
// Deprecated: use `must.NotFail(NewDocument(...))` instead.
func MustNewDocument(pairs ...any) *Document {
doc, err := NewDocument(pairs...)
if err != nil {
panic(err)
}
return doc
return must.NotFail(NewDocument(pairs...))
}

func (*Document) compositeType() {}
Expand Down Expand Up @@ -149,6 +145,8 @@ func (d *Document) validate() error {
return fmt.Errorf("types.Document.validate: keys and values count mismatch: %d != %d", len(d.m), len(d.keys))
}

// TODO check that _id is not regex or array

prevKeys := make(map[string]struct{}, len(d.keys))
for _, key := range d.keys {
if !isValidKey(key) {
Expand Down Expand Up @@ -213,6 +211,9 @@ func (d *Document) Command() string {
return keys[0]
}

// add adds the value for the given key, returning error if that key is already present.
//
// As a special case, _id always becomes the first key.
func (d *Document) add(key string, value any) error {
if _, ok := d.m[key]; ok {
return fmt.Errorf("types.Document.add: key already present: %q", key)
Expand All @@ -226,12 +227,27 @@ func (d *Document) add(key string, value any) error {
return fmt.Errorf("types.Document.validate: %w", err)
}

d.keys = append(d.keys, key)
// update keys slice
if key == "_id" {
// TODO check that value is not regex or array

// ensure that _id is the first field
d.keys = slices.Insert(d.keys, 0, key)
} else {
d.keys = append(d.keys, key)
}

d.m[key] = value

return nil
}

// Has returns true if the given key is present in the document.
func (d *Document) Has(key string) bool {
_, ok := d.m[key]
return ok
}

// Get returns a value at the given key.
func (d *Document) Get(key string) (any, error) {
if value, ok := d.m[key]; ok {
Expand All @@ -246,7 +262,9 @@ func (d *Document) GetByPath(path ...string) (any, error) {
return getByPath(d, path...)
}

// Set the value of the given key, replacing any existing value.
// Set sets the value for the given key, replacing any existing value.
//
// As a special case, _id always becomes the first key.
func (d *Document) Set(key string, value any) error {
if !isValidKey(key) {
return fmt.Errorf("types.Document.Set: invalid key: %q", key)
Expand All @@ -256,8 +274,19 @@ func (d *Document) Set(key string, value any) error {
return fmt.Errorf("types.Document.validate: %w", err)
}

if _, ok := d.m[key]; !ok {
d.keys = append(d.keys, key)
// update keys slice
if key == "_id" {
// TODO check that value is not regex or array

// ensure that _id is the first field
if i := slices.Index(d.keys, key); i >= 0 {
d.keys = slices.Delete(d.keys, i, i+1)
}
d.keys = slices.Insert(d.keys, 0, key)
} else {
if _, ok := d.m[key]; !ok {
d.keys = append(d.keys, key)
}
}

if d.m == nil {
Expand Down
19 changes: 19 additions & 0 deletions internal/types/document_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -89,6 +89,25 @@ func TestDocument(t *testing.T) {
assert.Equal(t, int32(42), b.m["foo"])
})

t.Run("SetID", func(t *testing.T) {
t.Parallel()

doc := must.NotFail(NewDocument(
"_id", int32(42),
"foo", "bar",
))
assert.Equal(t, []string{"_id", "foo"}, doc.keys)

doc = must.NotFail(NewDocument(
"foo", "bar",
"_id", int32(42),
))
assert.Equal(t, []string{"_id", "foo"}, doc.keys)

doc.Set("_id", "bar")
assert.Equal(t, []string{"_id", "foo"}, doc.keys)
})

t.Run("Validate", func(t *testing.T) {
t.Parallel()

Expand Down

0 comments on commit aa41226

Please sign in to comment.