diff --git a/integration/aggregate_documents_test.go b/integration/aggregate_documents_test.go index 6649d6fe8cb4..1acfd3332d56 100644 --- a/integration/aggregate_documents_test.go +++ b/integration/aggregate_documents_test.go @@ -872,6 +872,46 @@ func TestAggregateUnsetErrors(t *testing.T) { } } +func TestAggregateSortErrors(t *testing.T) { + t.Parallel() + + for name, tc := range map[string]struct { //nolint:vet // used for test only + pipeline bson.A // required, aggregation pipeline stages + + err *mongo.CommandError // required + altMessage string // optional, alternative error message + skip string // optional, skip test with a specified reason + }{ + "DotNotationMissingField": { + pipeline: bson.A{bson.D{{"$sort", bson.D{ + {"v..foo", 1}, + }}}}, + err: &mongo.CommandError{ + Code: 15998, + Name: "Location15998", + Message: "FieldPath field names may not be empty strings.", + }, + }, + } { + name, tc := name, tc + t.Run(name, func(t *testing.T) { + if tc.skip != "" { + t.Skip(tc.skip) + } + + t.Parallel() + + require.NotNil(t, tc.pipeline, "pipeline must not be nil") + require.NotNil(t, tc.err, "err must not be nil") + + ctx, collection := setup.Setup(t) + + _, err := collection.Aggregate(ctx, tc.pipeline) + AssertEqualAltCommandError(t, *tc.err, tc.altMessage, err) + }) + } +} + func TestAggregateCommandMaxTimeMSErrors(t *testing.T) { t.Parallel() ctx, collection := setup.Setup(t) diff --git a/integration/compattestcaseresulttype_string.go b/integration/compattestcaseresulttype_string.go index d6f3edcbcd49..ae036f9724a1 100644 --- a/integration/compattestcaseresulttype_string.go +++ b/integration/compattestcaseresulttype_string.go @@ -1,4 +1,4 @@ -// Code generated by "stringer -type compatTestCaseResultType"; DO NOT EDIT. +// Code generated by "stringer -linecomment -type compatTestCaseResultType"; DO NOT EDIT. package integration diff --git a/integration/helpers.go b/integration/helpers.go index 590c8ca410e4..e49442878afa 100644 --- a/integration/helpers.go +++ b/integration/helpers.go @@ -33,7 +33,7 @@ import ( "github.com/FerretDB/FerretDB/internal/util/testutil/testtb" ) -//go:generate ../bin/stringer -type compatTestCaseResultType +//go:generate ../bin/stringer -linecomment -type compatTestCaseResultType // compatTestCaseResultType represents compatibility test case result type. // diff --git a/internal/backends/error.go b/internal/backends/error.go index 727ffa99d963..3b0bd2c13206 100644 --- a/internal/backends/error.go +++ b/internal/backends/error.go @@ -23,7 +23,7 @@ import ( "github.com/FerretDB/FerretDB/internal/util/debugbuild" ) -//go:generate ../../bin/stringer -type ErrorCode +//go:generate ../../bin/stringer -linecomment -type ErrorCode // ErrorCode represent a backend error code. type ErrorCode int diff --git a/internal/backends/errorcode_string.go b/internal/backends/errorcode_string.go index 6ca7d3b1f679..40c7bccaac99 100644 --- a/internal/backends/errorcode_string.go +++ b/internal/backends/errorcode_string.go @@ -1,4 +1,4 @@ -// Code generated by "stringer -type ErrorCode"; DO NOT EDIT. +// Code generated by "stringer -linecomment -type ErrorCode"; DO NOT EDIT. package backends diff --git a/internal/handlers/common/aggregations/expression.go b/internal/handlers/common/aggregations/expression.go index 69ba5e8d7ed7..b9a274e1c6b7 100644 --- a/internal/handlers/common/aggregations/expression.go +++ b/internal/handlers/common/aggregations/expression.go @@ -18,6 +18,7 @@ import ( "fmt" "strings" + "github.com/FerretDB/FerretDB/internal/handlers/commonpath" "github.com/FerretDB/FerretDB/internal/types" "github.com/FerretDB/FerretDB/internal/util/lazyerrors" "github.com/FerretDB/FerretDB/internal/util/must" @@ -67,30 +68,34 @@ func (e *ExpressionError) Code() ExpressionErrorCode { return e.code } -// Expression is an expression constructed from field value. +// Expression represents a value that needs evaluation. +// +// Expression for access field in document should be prefixed with a dollar sign $ followed by field key. +// For accessing embedded document or array, a dollar sign $ should be followed by dot notation. +// Options can be provided to specify how to access fields in embedded array. type Expression struct { - *ExpressionOpts + opts commonpath.FindValuesOpts path types.Path } -// ExpressionOpts represents options used to modify behavior of Expression functions. -type ExpressionOpts struct { - // TODO https://github.com/FerretDB/FerretDB/issues/2348 - - // IgnoreArrays disables checking arrays for provided key. - // So expression {"$v.foo"} won't match {"v":[{"foo":42}]} - IgnoreArrays bool // defaults to false -} +// NewExpression returns Expression from dollar sign $ prefixed string. +// It can take additional options to specify how to access fields in embedded array. +// +// It returns error if invalid Expression is provided. +func NewExpression(expression string, opts *commonpath.FindValuesOpts) (*Expression, error) { + // for aggregation expression, it does not return value by index of array + if opts == nil { + opts = &commonpath.FindValuesOpts{ + FindArrayIndex: false, + FindArrayDocuments: true, + } + } -// NewExpressionWithOpts creates a new instance by checking expression string. -// It can take additional opts that specify how expressions should be evaluated. -func NewExpressionWithOpts(expression string, opts *ExpressionOpts) (*Expression, error) { - // TODO https://github.com/FerretDB/FerretDB/issues/2348 var val string switch { case strings.HasPrefix(expression, "$$"): - // `$$` indicates field is a variable. + // double dollar sign $$ prefixed string indicates Expression is a variable name v := strings.TrimPrefix(expression, "$$") if v == "" { return nil, newExpressionError(ErrEmptyVariable) @@ -103,7 +108,7 @@ func NewExpressionWithOpts(expression string, opts *ExpressionOpts) (*Expression // TODO https://github.com/FerretDB/FerretDB/issues/2275 return nil, newExpressionError(ErrUndefinedVariable) case strings.HasPrefix(expression, "$"): - // `$` indicates field is a path. + // dollar sign $ prefixed string indicates Expression accesses field or embedded fields val = strings.TrimPrefix(expression, "$") if val == "" { @@ -121,19 +126,17 @@ func NewExpressionWithOpts(expression string, opts *ExpressionOpts) (*Expression } return &Expression{ - path: path, - ExpressionOpts: opts, + path: path, + opts: *opts, }, nil } -// NewExpression creates a new instance by checking expression string. -func NewExpression(expression string) (*Expression, error) { - // TODO https://github.com/FerretDB/FerretDB/issues/2348 - return NewExpressionWithOpts(expression, new(ExpressionOpts)) -} - -// Evaluate gets the value at the path. -// It returns error if the path does not exists. +// Evaluate uses Expression to find a field value or an embedded field value of the document and +// returns found value. If values were found from embedded array, it returns *types.Array +// containing values. +// +// It returns error if field value was not found. With embedded array field being exception, +// that case it returns empty array instead of error. func (e *Expression) Evaluate(doc *types.Document) (any, error) { path := e.path @@ -146,32 +149,35 @@ func (e *Expression) Evaluate(doc *types.Document) (any, error) { return val, nil } - var isPrefixArray bool + var isArrayField bool prefix := path.Prefix() if v, err := doc.Get(prefix); err == nil { if _, isArray := v.(*types.Array); isArray { - isPrefixArray = true + isArrayField = true } } - vals := e.getPathValue(doc, path) + vals, err := commonpath.FindValues(doc, path, &e.opts) + if err != nil { + return nil, lazyerrors.Error(err) + } if len(vals) == 0 { - if isPrefixArray { - // when the prefix is array, return empty array. + if isArrayField { + // embedded array field returns empty array return must.NotFail(types.NewArray()), nil } return nil, fmt.Errorf("no document found under %s path", path) } - if len(vals) == 1 && !isPrefixArray { - // when the prefix is not array, return the value + if len(vals) == 1 && !isArrayField { + // when it is not an embedded array field, return the value return vals[0], nil } - // when the prefix is array, return an array of value. + // embedded array field returns an array of found values arr := types.MakeArray(len(vals)) for _, v := range vals { arr.Append(v) @@ -180,68 +186,7 @@ func (e *Expression) Evaluate(doc *types.Document) (any, error) { return arr, nil } -// GetExpressionSuffix returns suffix of pathExpression. +// GetExpressionSuffix returns field key of Expression, or for dot notation it returns suffix. func (e *Expression) GetExpressionSuffix() string { return e.path.Suffix() } - -// getPathValue go through each key of the path iteratively to -// find values that exist at suffix. -// An array may return multiple values. -// At each key of the path, it checks: -// - if the document has the key. -// - if the array contains documents which have the key. (This check can -// be disabled by setting ExpressionOpts.IgnoreArrays field). -// -// It is different from `common.getDocumentsAtSuffix`, it does not find array item by -// array dot notation `foo.0.bar`. It returns empty array [] because using index -// such as `0` does not match using expression path. -func (e *Expression) getPathValue(doc *types.Document, path types.Path) []any { - // TODO https://github.com/FerretDB/FerretDB/issues/2348 - keys := path.Slice() - vals := []any{doc} - - for _, key := range keys { - // embeddedVals are the values found at current key. - var embeddedVals []any - - for _, valAtKey := range vals { - switch val := valAtKey.(type) { - case *types.Document: - embeddedVal, err := val.Get(key) - if err != nil { - continue - } - - embeddedVals = append(embeddedVals, embeddedVal) - case *types.Array: - if e.IgnoreArrays { - continue - } - // iterate elements to get documents that contain the key. - for j := 0; j < val.Len(); j++ { - elem := must.NotFail(val.Get(j)) - - docElem, isDoc := elem.(*types.Document) - if !isDoc { - continue - } - - embeddedVal, err := docElem.Get(key) - if err != nil { - continue - } - - embeddedVals = append(embeddedVals, embeddedVal) - } - - default: - // not a document or array, do nothing - } - } - - vals = embeddedVals - } - - return vals -} diff --git a/internal/handlers/common/aggregations/operators/accumulators/sum.go b/internal/handlers/common/aggregations/operators/accumulators/sum.go index 9093882858db..14a279b030e9 100644 --- a/internal/handlers/common/aggregations/operators/accumulators/sum.go +++ b/internal/handlers/common/aggregations/operators/accumulators/sum.go @@ -71,7 +71,7 @@ func newSum(accumulation *types.Document) (Accumulator, error) { accumulator.number = expr case string: var err error - if accumulator.expression, err = aggregations.NewExpression(expr); err != nil { + if accumulator.expression, err = aggregations.NewExpression(expr, nil); err != nil { // $sum returns 0 on non-existent field. accumulator.number = int32(0) } diff --git a/internal/handlers/common/aggregations/operators/sum.go b/internal/handlers/common/aggregations/operators/sum.go index ee67a4cd29d7..932b4f13d247 100644 --- a/internal/handlers/common/aggregations/operators/sum.go +++ b/internal/handlers/common/aggregations/operators/sum.go @@ -74,7 +74,7 @@ func newSum(doc *types.Document) (Operator, error) { case float64: operator.numbers = append(operator.numbers, elemExpr) case string: - ex, err := aggregations.NewExpression(elemExpr) + ex, err := aggregations.NewExpression(elemExpr, nil) var exErr *aggregations.ExpressionError if errors.As(err, &exErr) && exErr.Code() == aggregations.ErrNotExpression { @@ -93,7 +93,7 @@ func newSum(doc *types.Document) (Operator, error) { case float64: operator.numbers = []any{expr} case string: - ex, err := aggregations.NewExpression(expr) + ex, err := aggregations.NewExpression(expr, nil) var exErr *aggregations.ExpressionError if errors.As(err, &exErr) && exErr.Code() == aggregations.ErrNotExpression { diff --git a/internal/handlers/common/aggregations/operators/type.go b/internal/handlers/common/aggregations/operators/type.go index d432fef9319e..1c59fea17406 100644 --- a/internal/handlers/common/aggregations/operators/type.go +++ b/internal/handlers/common/aggregations/operators/type.go @@ -107,7 +107,7 @@ func (t *typeOp) Process(doc *types.Document) (any, error) { case string: if strings.HasPrefix(param, "$") { - expression, err := aggregations.NewExpression(param) + expression, err := aggregations.NewExpression(param, nil) if err != nil { return nil, err } diff --git a/internal/handlers/common/aggregations/stages/group.go b/internal/handlers/common/aggregations/stages/group.go index 122288e482d2..0667fa363fe1 100644 --- a/internal/handlers/common/aggregations/stages/group.go +++ b/internal/handlers/common/aggregations/stages/group.go @@ -216,7 +216,7 @@ func validateGroupKey(groupKey any) error { case *types.Document: return validateGroupKey(v) case string: - _, err := aggregations.NewExpression(v) + _, err := aggregations.NewExpression(v, nil) var exprErr *aggregations.ExpressionError if errors.As(err, &exprErr) && exprErr.Code() == aggregations.ErrNotExpression { @@ -251,7 +251,7 @@ func (g *group) groupDocuments(in []*types.Document) ([]groupedDocuments, error) types.Regex, int32, types.Timestamp, int64: m.addOrAppend(groupKey, doc) case string: - expression, err := aggregations.NewExpression(groupKey) + expression, err := aggregations.NewExpression(groupKey, nil) if err != nil { var exprErr *aggregations.ExpressionError if errors.As(err, &exprErr) { @@ -323,7 +323,7 @@ func evaluateDocument(expr, doc *types.Document, nestedField bool) (any, error) evaluatedDocument.Set(k, v) case string: - expression, err := aggregations.NewExpression(exprVal) + expression, err := aggregations.NewExpression(exprVal, nil) var exprErr *aggregations.ExpressionError if errors.As(err, &exprErr) && exprErr.Code() == aggregations.ErrNotExpression { diff --git a/internal/handlers/common/aggregations/stages/projection/projection.go b/internal/handlers/common/aggregations/stages/projection/projection.go index bd29b66252d1..fbb09cdd6b43 100644 --- a/internal/handlers/common/aggregations/stages/projection/projection.go +++ b/internal/handlers/common/aggregations/stages/projection/projection.go @@ -77,6 +77,7 @@ func ValidateProjection(projection *types.Document) (*types.Document, bool, erro ) } + // TODO https://github.com/FerretDB/FerretDB/issues/3127 path, err := types.NewPathFromString(key) if err != nil { if strings.HasSuffix(key, "$") { @@ -328,6 +329,7 @@ func projectDocumentWithoutID(doc *types.Document, projection *types.Document, i return nil, lazyerrors.Error(err) } + // TODO https://github.com/FerretDB/FerretDB/issues/3127 path, err := types.NewPathFromString(key) if err != nil { return nil, lazyerrors.Error(err) diff --git a/internal/handlers/common/aggregations/stages/sort.go b/internal/handlers/common/aggregations/stages/sort.go index 00db7238184b..1bfc04f80d92 100644 --- a/internal/handlers/common/aggregations/stages/sort.go +++ b/internal/handlers/common/aggregations/stages/sort.go @@ -59,12 +59,13 @@ func newSort(stage *types.Document) (aggregations.Stage, error) { // Process implements Stage interface. // -// If sort path is invalid, it returns a possibly wrapped types.DocumentPathError. +// If sort path is invalid, it returns a possibly wrapped types.PathError. func (s *sort) Process(ctx context.Context, iter types.DocumentsIterator, closer *iterator.MultiCloser) (types.DocumentsIterator, error) { //nolint:lll // for readability - var err error - if iter, err = common.SortIterator(iter, closer, s.fields); err != nil { - var pathErr *types.DocumentPathError - if errors.As(err, &pathErr) && pathErr.Code() == types.ErrDocumentPathEmptyKey { + iter, err := common.SortIterator(iter, closer, s.fields) + if err != nil { + // TODO https://github.com/FerretDB/FerretDB/issues/3125 + var pathErr *types.PathError + if errors.As(err, &pathErr) && pathErr.Code() == types.ErrPathElementEmpty { return nil, commonerrors.NewCommandErrorMsgWithArgument( commonerrors.ErrPathContainsEmptyElement, "FieldPath field names may not be empty strings.", diff --git a/internal/handlers/common/aggregations/stages/unset.go b/internal/handlers/common/aggregations/stages/unset.go index 2bfb996f3249..38a0b7db824c 100644 --- a/internal/handlers/common/aggregations/stages/unset.go +++ b/internal/handlers/common/aggregations/stages/unset.go @@ -85,10 +85,10 @@ func newUnset(stage *types.Document) (aggregations.Stage, error) { } err = types.IsConflictPath(visitedPaths, *path) - var pathErr *types.DocumentPathError + var pathErr *types.PathError if errors.As(err, &pathErr) { - if pathErr.Code() == types.ErrDocumentPathConflictOverwrite { + if pathErr.Code() == types.ErrPathConflictOverwrite { // the path overwrites one of visitedPaths. return nil, commonerrors.NewCommandErrorMsgWithArgument( commonerrors.ErrUnsetPathOverwrite, @@ -97,7 +97,7 @@ func newUnset(stage *types.Document) (aggregations.Stage, error) { ) } - if pathErr.Code() == types.ErrDocumentPathConflictCollision { + if pathErr.Code() == types.ErrPathConflictCollision { // the path creates collision at one of visitedPaths. return nil, commonerrors.NewCommandErrorMsgWithArgument( commonerrors.ErrUnsetPathCollision, diff --git a/internal/handlers/common/aggregations/stages/unwind.go b/internal/handlers/common/aggregations/stages/unwind.go index 8549d32b1d45..5e98987b0494 100644 --- a/internal/handlers/common/aggregations/stages/unwind.go +++ b/internal/handlers/common/aggregations/stages/unwind.go @@ -22,6 +22,7 @@ import ( "github.com/FerretDB/FerretDB/internal/handlers/common" "github.com/FerretDB/FerretDB/internal/handlers/common/aggregations" "github.com/FerretDB/FerretDB/internal/handlers/commonerrors" + "github.com/FerretDB/FerretDB/internal/handlers/commonpath" "github.com/FerretDB/FerretDB/internal/types" "github.com/FerretDB/FerretDB/internal/util/iterator" "github.com/FerretDB/FerretDB/internal/util/lazyerrors" @@ -54,10 +55,13 @@ func newUnwind(stage *types.Document) (aggregations.Stage, error) { ) } - opts := aggregations.ExpressionOpts{ - IgnoreArrays: true, - } - expr, err = aggregations.NewExpressionWithOpts(field, &opts) + // For $unwind to deconstruct an array from dot notation, array must be at the suffix. + // It returns empty result if array is found at other parts of dot notation, + // so it does not return value by index of array nor values for given key in array's document. + expr, err = aggregations.NewExpression(field, &commonpath.FindValuesOpts{ + FindArrayIndex: false, + FindArrayDocuments: false, + }) if err != nil { var exprErr *aggregations.ExpressionError diff --git a/internal/handlers/common/distinct.go b/internal/handlers/common/distinct.go index fdc7c5aeb149..eb9889e33d0c 100644 --- a/internal/handlers/common/distinct.go +++ b/internal/handlers/common/distinct.go @@ -17,12 +17,12 @@ package common import ( "errors" "fmt" - "strings" "go.uber.org/zap" "github.com/FerretDB/FerretDB/internal/handlers/commonerrors" "github.com/FerretDB/FerretDB/internal/handlers/commonparams" + "github.com/FerretDB/FerretDB/internal/handlers/commonpath" "github.com/FerretDB/FerretDB/internal/types" "github.com/FerretDB/FerretDB/internal/util/iterator" "github.com/FerretDB/FerretDB/internal/util/lazyerrors" @@ -102,26 +102,22 @@ func FilterDistinctValues(iter types.DocumentsIterator, key string) (*types.Arra return nil, lazyerrors.Error(err) } - // docsAtSuffix contains all documents exist at the suffix key. - docsAtSuffix := []*types.Document{doc} - suffix := key - - if strings.ContainsRune(key, '.') { - path, err := types.NewPathFromString(key) - if err != nil { - return nil, lazyerrors.Error(err) - } - - // Multiple documents may be found at suffix by array dot notation. - suffix, docsAtSuffix = getDocumentsAtSuffix(doc, path) + path, err := types.NewPathFromString(key) + if err != nil { + return nil, lazyerrors.Error(err) } - for _, doc := range docsAtSuffix { - val, err := doc.Get(suffix) - if err != nil { - continue - } + // distinct using dot notation returns the value by valid array index + // or values for the given key in array's document + vals, err := commonpath.FindValues(doc, path, &commonpath.FindValuesOpts{ + FindArrayIndex: true, + FindArrayDocuments: true, + }) + if err != nil { + return nil, lazyerrors.Error(err) + } + for _, val := range vals { switch v := val.(type) { case *types.Array: for i := 0; i < v.Len(); i++ { diff --git a/internal/handlers/common/filter.go b/internal/handlers/common/filter.go index 4c8c24322302..9d399b9c80ab 100644 --- a/internal/handlers/common/filter.go +++ b/internal/handlers/common/filter.go @@ -18,7 +18,6 @@ import ( "errors" "fmt" "math" - "strconv" "strings" "time" @@ -26,6 +25,7 @@ import ( "github.com/FerretDB/FerretDB/internal/handlers/commonerrors" "github.com/FerretDB/FerretDB/internal/handlers/commonparams" + "github.com/FerretDB/FerretDB/internal/handlers/commonpath" "github.com/FerretDB/FerretDB/internal/types" "github.com/FerretDB/FerretDB/internal/util/iterator" "github.com/FerretDB/FerretDB/internal/util/lazyerrors" @@ -97,7 +97,7 @@ func HasQueryOperator(filter *types.Document) (bool, error) { // filterDocumentPair handles a single filter element key/value pair {filterKey: filterValue}. func filterDocumentPair(doc *types.Document, filterKey string, filterValue any) (bool, error) { - docs := []*types.Document{doc} + var vals []any filterSuffix := filterKey if strings.ContainsRune(filterKey, '.') { @@ -106,10 +106,19 @@ func filterDocumentPair(doc *types.Document, filterKey string, filterValue any) return false, lazyerrors.Error(err) } - if filterSuffix, docs = getDocumentsAtSuffix(doc, path); len(docs) == 0 { - // When no document is found at suffix, use an empty one. - // So operators such as $nin is applied to the empty document. - docs = append(docs, types.MakeDocument(0)) + filterSuffix = path.Suffix() + + // filter using dot notation returns the value by valid array index + // or values for the given key in array's document + if vals, err = commonpath.FindValues(doc, path, &commonpath.FindValuesOpts{ + FindArrayIndex: true, + FindArrayDocuments: true, + }); err != nil { + return false, lazyerrors.Error(err) + } + } else { + if val, _ := doc.Get(filterKey); val != nil { + vals = []any{val} } } @@ -118,9 +127,19 @@ func filterDocumentPair(doc *types.Document, filterKey string, filterValue any) return filterOperator(doc, filterKey, filterValue) } - for _, doc := range docs { - switch filterValue := filterValue.(type) { - case *types.Document: + switch filterValue := filterValue.(type) { + case *types.Document: + var docs []*types.Document + for _, val := range vals { + docs = append(docs, must.NotFail(types.NewDocument(filterSuffix, val))) + } + + if len(docs) == 0 { + // operators like $nin uses empty document to filter non-existent field + docs = append(docs, types.MakeDocument(0)) + } + + for _, doc := range docs { // {field: {expr}} or {field: {document}} ok, err := filterFieldExpr(doc, filterKey, filterSuffix, filterValue) if err != nil { @@ -130,28 +149,21 @@ func filterDocumentPair(doc *types.Document, filterKey string, filterValue any) if ok { return true, nil } + } + case types.NullType: + if len(vals) == 0 { + // comparing non-existent field with null returns true + return true, nil + } - // doc did not match filter, continue next iteration. - case *types.Array: - // {field: [array]} - docValue, err := doc.Get(filterSuffix) - if err != nil { - continue // no error - the field is just not present - } - - if result := types.Compare(docValue, filterValue); result == types.Equal { + for _, val := range vals { + if result := types.Compare(val, filterValue); result == types.Equal { return true, nil } - - // doc did not match filter, continue next iteration. - case types.Regex: - // {field: /regex/} - docValue, err := doc.Get(filterSuffix) - if err != nil { - continue // no error - the field is just not present - } - - ok, err := filterFieldRegex(docValue, filterValue) + } + case types.Regex: + for _, val := range vals { + ok, err := filterFieldRegex(val, filterValue) if err != nil { return false, err } @@ -159,21 +171,10 @@ func filterDocumentPair(doc *types.Document, filterKey string, filterValue any) if ok { return true, nil } - - // doc did not match filter, continue next iteration. - default: - // {field: value} - docValue, err := doc.Get(filterSuffix) - if err != nil { - // comparing not existent field with null should return true - if _, ok := filterValue.(types.NullType); ok { - return true, nil - } - - continue // no error - the field is just not present - } - - if result := types.Compare(docValue, filterValue); result == types.Equal { + } + default: + for _, val := range vals { + if result := types.Compare(val, filterValue); result == types.Equal { return true, nil } } @@ -183,141 +184,6 @@ func filterDocumentPair(doc *types.Document, filterKey string, filterValue any) return false, nil } -// getDocumentsAtSuffix go through each key of the path iteratively to -// find all values that exist at suffix. -// An array dot notation may return multiple documents. -// At each key of the path, it checks: -// -// if the document has the key, -// if the array contains an index that is equal to the key, and -// if the array contains documents which have the key. -// -// It returns: -// -// the suffix key of path; -// a slice of documents at suffix with suffix value document pairs. -// -// Document path example: -// -// docs: {foo: {bar: 1}} -// path: `foo.bar` -// -// returns -// -// suffix: `bar` -// docsAtSuffix: [{bar: 1}] -// -// Array index path example: -// -// docs: {foo: [{bar: 1}]} -// path: `foo.0.bar` -// -// returns -// -// suffix: `bar` -// docsAtSuffix: [{bar: 1}] -// -// Array document example: -// -// docs: {foo: [{bar: 1}, {bar: 2}]} -// path: `foo.bar` -// -// returns -// -// suffix: `bar` -// docsAtSuffix: [{bar: 1}, {bar: 2}] -func getDocumentsAtSuffix(doc *types.Document, path types.Path) (suffix string, docsAtSuffix []*types.Document) { - // TODO https://github.com/FerretDB/FerretDB/issues/2348 - suffix = path.Suffix() - - // docsAtSuffix are the document found at the suffix. - docsAtSuffix = []*types.Document{} - - // keys are each part of the path. - keys := path.Slice() - - // vals are the field values found at each key of the path. - vals := []any{doc} - - for i, key := range keys { - // embeddedVals are the values found at current key. - var embeddedVals []any - - for _, valAtKey := range vals { - switch val := valAtKey.(type) { - case *types.Document: - embeddedVal, err := val.Get(key) - if err != nil { - // document does not contain key, so no embedded value was found. - continue - } - - if i == len(keys)-1 { - // a value was found at suffix. - docsAtSuffix = append(docsAtSuffix, val) - continue - } - - // key exists in the document, add embedded value to next iteration. - embeddedVals = append(embeddedVals, embeddedVal) - case *types.Array: - if index, err := strconv.Atoi(key); err == nil { - // key is an integer, check if that integer is an index of the array. - embeddedVal, err := val.Get(index) - if err != nil { - // index does not exist. - continue - } - - if i == len(keys)-1 { - // a value was found at suffix. - docsAtSuffix = append(docsAtSuffix, must.NotFail(types.NewDocument(suffix, embeddedVal))) - continue - } - - // key is the index of the array, add embedded value to the next iteration. - embeddedVals = append(embeddedVals, embeddedVal) - - continue - } - - // key was not an index, iterate array to get all documents that contain the key. - for j := 0; j < val.Len(); j++ { - valAtIndex := must.NotFail(val.Get(j)) - - embeddedDoc, isDoc := valAtIndex.(*types.Document) - if !isDoc { - // the value is not a document, so it cannot contain the key. - continue - } - - embeddedVal, err := embeddedDoc.Get(key) - if err != nil { - // the document does not contain key, so no embedded value was found. - continue - } - - if i == len(keys)-1 { - // a value was found at suffix. - docsAtSuffix = append(docsAtSuffix, must.NotFail(types.NewDocument(suffix, embeddedVal))) - continue - } - - // key exists in the document, add embedded value to next iteration. - embeddedVals = append(embeddedVals, embeddedVal) - } - - default: - // not a document or array, do nothing - } - } - - vals = embeddedVals - } - - return suffix, docsAtSuffix -} - // filterOperator handles a top-level operator filter {$operator: filterValue}. func filterOperator(doc *types.Document, operator string, filterValue any) (bool, error) { switch operator { diff --git a/internal/handlers/common/projection.go b/internal/handlers/common/projection.go index 5794cfb559f4..a75c98856439 100644 --- a/internal/handlers/common/projection.go +++ b/internal/handlers/common/projection.go @@ -81,6 +81,7 @@ func ValidateProjection(projection *types.Document) (*types.Document, bool, erro positionalProjection := strings.HasSuffix(key, "$") + // TODO https://github.com/FerretDB/FerretDB/issues/3127 path, err := types.NewPathFromString(key) if err != nil { if positionalProjection { diff --git a/internal/handlers/common/sort.go b/internal/handlers/common/sort.go index 80c7f511c06c..e750f93aaa44 100644 --- a/internal/handlers/common/sort.go +++ b/internal/handlers/common/sort.go @@ -29,7 +29,7 @@ import ( // SortDocuments sorts given documents in place according to the given sorting conditions. // -// If sort path is invalid, it returns a possibly wrapped types.DocumentPathError. +// If sort path is invalid, it returns a possibly wrapped types.PathError. func SortDocuments(docs []*types.Document, sortDoc *types.Document) error { if sortDoc.Len() == 0 { return nil diff --git a/internal/handlers/common/update.go b/internal/handlers/common/update.go index b10e87090bd6..cc70f813ef6a 100644 --- a/internal/handlers/common/update.go +++ b/internal/handlers/common/update.go @@ -260,8 +260,8 @@ func processRenameFieldExpression(command string, doc *types.Document, update *t sourcePath, err := types.NewPathFromString(key) if err != nil { - var pathErr *types.DocumentPathError - if errors.As(err, &pathErr) && pathErr.Code() == types.ErrDocumentPathEmptyKey { + var pathErr *types.PathError + if errors.As(err, &pathErr) && pathErr.Code() == types.ErrPathElementEmpty { return false, newUpdateError( commonerrors.ErrEmptyName, fmt.Sprintf( @@ -281,16 +281,16 @@ func processRenameFieldExpression(command string, doc *types.Document, update *t // Get value to move val, err := doc.GetByPath(sourcePath) if err != nil { - var dpe *types.DocumentPathError + var dpe *types.PathError if !errors.As(err, &dpe) { panic("getByPath returned error with invalid type") } - if dpe.Code() == types.ErrDocumentPathKeyNotFound || dpe.Code() == types.ErrDocumentPathIndexOutOfBound { + if dpe.Code() == types.ErrPathKeyNotFound || dpe.Code() == types.ErrPathIndexOutOfBound { continue } - if dpe.Code() == types.ErrDocumentPathArrayInvalidIndex { + if dpe.Code() == types.ErrPathIndexInvalid { return false, newUpdateError( commonerrors.ErrUnsuitableValueType, fmt.Sprintf("cannot use path '%s' to traverse the document", sourcePath), @@ -918,11 +918,11 @@ func validateOperatorKeys(command string, docs ...*types.Document) error { } err = types.IsConflictPath(visitedPaths, nextPath) - var pathErr *types.DocumentPathError + var pathErr *types.PathError if errors.As(err, &pathErr) { - if pathErr.Code() == types.ErrDocumentPathConflictOverwrite || - pathErr.Code() == types.ErrDocumentPathConflictCollision { + if pathErr.Code() == types.ErrPathConflictOverwrite || + pathErr.Code() == types.ErrPathConflictCollision { return newUpdateError( commonerrors.ErrConflictingUpdateOperators, fmt.Sprintf( diff --git a/internal/handlers/commonpath/commonpath.go b/internal/handlers/commonpath/commonpath.go new file mode 100644 index 000000000000..6970588ade9a --- /dev/null +++ b/internal/handlers/commonpath/commonpath.go @@ -0,0 +1,150 @@ +// Copyright 2021 FerretDB Inc. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +// Package commonpath contains functions used for path. +package commonpath + +import ( + "errors" + "strconv" + + "github.com/FerretDB/FerretDB/internal/types" + "github.com/FerretDB/FerretDB/internal/util/iterator" + "github.com/FerretDB/FerretDB/internal/util/lazyerrors" +) + +// FindValuesOpts sets options for FindValues. +type FindValuesOpts struct { + // If FindArrayDocuments is true, it iterates the array to find documents that have path. + // If FindArrayDocuments is false, it does not find documents from the array. + // Using path `v.foo` and `v` is an array: + // - with FindArrayDocuments true, it finds values of `foo` of found documents; + // - with FindArrayDocuments false, it returns an empty array. + // If `v` is not an array, FindArrayDocuments has no impact. + FindArrayDocuments bool + // If FindArrayIndex is true, it finds value at index of an array. + // If FindArrayIndex is false, it does not find value at index of an array. + // Using path `v.0` and `v` is an array: + // - with FindArrayIndex true, it finds 0-th index value of the array; + // - with FindArrayIndex false, it returns empty array. + // If `v` is not an array, FindArrayIndex has no impact. + FindArrayIndex bool +} + +// FindValues returns values by path, looking up into arrays. +// +// It iterates path elements, at each path element it adds to next values to iterate: +// - if it is a document and has path, it adds the document field value to next values; +// - if it is an array, FindArrayIndex is true and finds value at index, it adds value to next values; +// - if it is an array, FindArrayDocuments is true and documents in the array have path, +// it adds field value of all documents that have path to next values. +// +// It returns next values after iterating path elements. +func FindValues(doc *types.Document, path types.Path, opts *FindValuesOpts) ([]any, error) { + if opts == nil { + opts = new(FindValuesOpts) + } + + nextValues := []any{doc} + + for _, e := range path.Slice() { + values := []any{} + + for _, next := range nextValues { + switch next := next.(type) { + case *types.Document: + v, _ := next.Get(e) + if v == nil { + continue + } + + values = append(values, v) + + case *types.Array: + if opts.FindArrayIndex { + res, err := findArrayIndex(next, e) + if err == nil { + values = append(values, res) + continue + } + } + + if opts.FindArrayDocuments { + res, err := lookupArrayDocuments(next, e) + if err != nil { + return nil, lazyerrors.Error(err) + } + + values = append(values, res...) + } + + default: + // path does not exist in scalar values, nothing to do + } + } + + nextValues = values + } + + return nextValues, nil +} + +// findArrayIndex returns the value by valid array index. +// +// Error is returned if index is not a number or index does not exist in array. +func findArrayIndex(array *types.Array, index string) (any, error) { + i, err := strconv.Atoi(index) + if err != nil { + return nil, lazyerrors.Error(err) + } + + v, err := array.Get(i) + if err != nil { + return nil, lazyerrors.Error(err) + } + + return v, nil +} + +// lookupArrayDocuments returns values for the given key in array's document. +// +// Non-document array values, documents without that key, etc. are skipped. +func lookupArrayDocuments(array *types.Array, documentKey string) ([]any, error) { + iter := array.Iterator() + defer iter.Close() + + res := []any{} + + for { + _, v, err := iter.Next() + if errors.Is(err, iterator.ErrIteratorDone) { + break + } + + if err != nil { + return nil, lazyerrors.Error(err) + } + + doc, ok := v.(*types.Document) + if !ok { + continue + } + + if v, _ = doc.Get(documentKey); v != nil { + res = append(res, v) + } + } + + return res, nil +} diff --git a/internal/handlers/commonpath/commonpath_test.go b/internal/handlers/commonpath/commonpath_test.go new file mode 100644 index 000000000000..465453f1c1ec --- /dev/null +++ b/internal/handlers/commonpath/commonpath_test.go @@ -0,0 +1,225 @@ +// Copyright 2021 FerretDB Inc. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package commonpath + +import ( + "testing" + + "github.com/stretchr/testify/require" + + "github.com/FerretDB/FerretDB/internal/types" + "github.com/FerretDB/FerretDB/internal/util/must" +) + +func TestFindValues(t *testing.T) { + t.Parallel() + + t.Run("Array", func(t *testing.T) { + array := must.NotFail(types.NewDocument("foo", must.NotFail(types.NewArray( + must.NotFail(types.NewDocument("bar", 0)), + must.NotFail(types.NewDocument("bar", 1)), + )))) + + for name, tc := range map[string]struct { + doc *types.Document + path types.Path + opts *FindValuesOpts + res []any + }{ + "DistinctCommandDotNotation": { + doc: array, + path: types.NewStaticPath("foo", "bar"), + opts: &FindValuesOpts{ + FindArrayIndex: true, + FindArrayDocuments: true, + }, + res: []any{0, 1}, + }, + "DistinctCommandIndexDotNotation": { + doc: array, + path: types.NewStaticPath("foo", "1"), + opts: &FindValuesOpts{ + FindArrayIndex: true, + FindArrayDocuments: true, + }, + res: []any{must.NotFail(types.NewDocument("bar", 1))}, + }, + "DistinctCommandNestedIndexDotNotation": { + doc: array, + path: types.NewStaticPath("foo", "1", "bar"), + opts: &FindValuesOpts{ + FindArrayIndex: true, + FindArrayDocuments: true, + }, + res: []any{1}, + }, + + "AggregationOperatorDotNotation": { + doc: array, + path: types.NewStaticPath("foo", "bar"), + opts: &FindValuesOpts{ + FindArrayIndex: false, + FindArrayDocuments: true, + }, + res: []any{0, 1}, + }, + "AggregationOperatorIndexDotNotation": { + doc: array, + path: types.NewStaticPath("foo", "1"), + opts: &FindValuesOpts{ + FindArrayIndex: false, + FindArrayDocuments: true, + }, + res: []any{}, + }, + "AggregationOperatorNestedIndexDotNotation": { + doc: array, + path: types.NewStaticPath("foo", "1", "bar"), + opts: &FindValuesOpts{ + FindArrayIndex: false, + FindArrayDocuments: true, + }, + res: []any{}, + }, + + "UnwindDotNotation": { + doc: array, + path: types.NewStaticPath("foo", "bar"), + opts: &FindValuesOpts{ + FindArrayIndex: false, + FindArrayDocuments: false, + }, + res: []any{}, + }, + "UnwindIndexDotNotation": { + doc: array, + path: types.NewStaticPath("foo", "1"), + opts: &FindValuesOpts{ + FindArrayIndex: false, + FindArrayDocuments: false, + }, + res: []any{}, + }, + "UnwindNestedIndexDotNotation": { + doc: array, + path: types.NewStaticPath("foo", "1", "bar"), + opts: &FindValuesOpts{ + FindArrayIndex: false, + FindArrayDocuments: false, + }, + res: []any{}, + }, + + "GetByPathDotNotation": { + doc: array, + path: types.NewStaticPath("foo", "bar"), + opts: &FindValuesOpts{ + FindArrayIndex: true, + FindArrayDocuments: false, + }, + res: []any{}, + }, + "GetByPathIndexDotNotation": { + doc: array, + path: types.NewStaticPath("foo", "1"), + opts: &FindValuesOpts{ + FindArrayIndex: true, + FindArrayDocuments: false, + }, + res: []any{must.NotFail(types.NewDocument("bar", 1))}, + }, + "GetByPathNestedIndexDotNotation": { + doc: array, + path: types.NewStaticPath("foo", "1", "bar"), + opts: &FindValuesOpts{ + FindArrayIndex: true, + FindArrayDocuments: false, + }, + res: []any{1}, + }, + } { + name, tc := name, tc + t.Run(name, func(t *testing.T) { + t.Parallel() + + res, err := FindValues(tc.doc, tc.path, tc.opts) + require.NoError(t, err) + require.Equal(t, tc.res, res) + }) + } + }) + + t.Run("Document", func(t *testing.T) { + doc := must.NotFail(types.NewDocument("foo", must.NotFail(types.NewDocument("bar", 0)))) + + for name, tc := range map[string]struct { + doc *types.Document + path types.Path + opts *FindValuesOpts + res []any + }{ + "Empty": { + doc: new(types.Document), + path: types.NewStaticPath("foo", "bar"), + res: []any{}, + }, + "DistinctCommandDotNotation": { + doc: doc, + path: types.NewStaticPath("foo", "bar"), + opts: &FindValuesOpts{ + FindArrayIndex: true, + FindArrayDocuments: true, + }, + res: []any{0}, + }, + "AggregationOperatorDotNotation": { + doc: doc, + path: types.NewStaticPath("foo", "bar"), + opts: &FindValuesOpts{ + FindArrayIndex: false, + FindArrayDocuments: true, + }, + res: []any{0}, + }, + "UnwindDotNotation": { + doc: doc, + path: types.NewStaticPath("foo", "bar"), + opts: &FindValuesOpts{ + FindArrayIndex: false, + FindArrayDocuments: false, + }, + res: []any{0}, + }, + "GetByPathDotNotation": { + doc: doc, + path: types.NewStaticPath("foo", "bar"), + opts: &FindValuesOpts{ + FindArrayIndex: true, + FindArrayDocuments: false, + }, + res: []any{0}, + }, + } { + name, tc := name, tc + t.Run(name, func(t *testing.T) { + t.Parallel() + + res, err := FindValues(tc.doc, tc.path, tc.opts) + require.NoError(t, err) + require.Equal(t, tc.res, res) + }) + } + }) +} diff --git a/internal/handlers/pg/msg_find.go b/internal/handlers/pg/msg_find.go index 4094a1e4bdb5..7b967500ceca 100644 --- a/internal/handlers/pg/msg_find.go +++ b/internal/handlers/pg/msg_find.go @@ -110,8 +110,8 @@ func (h *Handler) MsgFind(ctx context.Context, msg *wire.OpMsg) (*wire.OpMsg, er if !queryRes.SortPushdown { iter, err = common.SortIterator(iter, closer, params.Sort) if err != nil { - var pathErr *types.DocumentPathError - if errors.As(err, &pathErr) && pathErr.Code() == types.ErrDocumentPathEmptyKey { + var pathErr *types.PathError + if errors.As(err, &pathErr) && pathErr.Code() == types.ErrPathElementEmpty { return commonerrors.NewCommandErrorMsgWithArgument( commonerrors.ErrPathContainsEmptyElement, "Empty field names in path are not allowed", diff --git a/internal/handlers/pg/msg_findandmodify.go b/internal/handlers/pg/msg_findandmodify.go index 359a4800973e..f7dc4dc9bf6d 100644 --- a/internal/handlers/pg/msg_findandmodify.go +++ b/internal/handlers/pg/msg_findandmodify.go @@ -79,8 +79,8 @@ func (h *Handler) MsgFindAndModify(ctx context.Context, msg *wire.OpMsg) (*wire. } if err = common.SortDocuments(resDocs, params.Sort); err != nil { - var pathErr *types.DocumentPathError - if errors.As(err, &pathErr) && pathErr.Code() == types.ErrDocumentPathEmptyKey { + var pathErr *types.PathError + if errors.As(err, &pathErr) && pathErr.Code() == types.ErrPathElementEmpty { return commonerrors.NewCommandErrorMsgWithArgument( commonerrors.ErrPathContainsEmptyElement, "FieldPath field names may not be empty strings.", diff --git a/internal/handlers/pg/pgdb/query.go b/internal/handlers/pg/pgdb/query.go index d411edb3b88b..31ed61127e32 100644 --- a/internal/handlers/pg/pgdb/query.go +++ b/internal/handlers/pg/pgdb/query.go @@ -254,7 +254,7 @@ func prepareWhereClause(p *Placeholder, sqlFilters *types.Document) (string, []a path, err := types.NewPathFromString(rootKey) - var pe *types.DocumentPathError + var pe *types.PathError switch { case err == nil: @@ -264,11 +264,11 @@ func prepareWhereClause(p *Placeholder, sqlFilters *types.Document) (string, []a } case errors.As(err, &pe): // ignore empty key error, otherwise return error - if pe.Code() != types.ErrDocumentPathEmptyKey { + if pe.Code() != types.ErrPathElementEmpty { return "", nil, lazyerrors.Error(err) } default: - panic("Invalid error type: DocumentPathError expected ") + panic("Invalid error type: PathError expected") } switch v := rootVal.(type) { diff --git a/internal/handlers/sqlite/msg_find.go b/internal/handlers/sqlite/msg_find.go index 8ab545fd6016..1d445a72192c 100644 --- a/internal/handlers/sqlite/msg_find.go +++ b/internal/handlers/sqlite/msg_find.go @@ -91,8 +91,8 @@ func (h *Handler) MsgFind(ctx context.Context, msg *wire.OpMsg) (*wire.OpMsg, er if err != nil { closer.Close() - var pathErr *types.DocumentPathError - if errors.As(err, &pathErr) && pathErr.Code() == types.ErrDocumentPathEmptyKey { + var pathErr *types.PathError + if errors.As(err, &pathErr) && pathErr.Code() == types.ErrPathElementEmpty { return nil, commonerrors.NewCommandErrorMsgWithArgument( commonerrors.ErrPathContainsEmptyElement, "Empty field names in path are not allowed", diff --git a/internal/types/array.go b/internal/types/array.go index 1b23cc102273..59329c92e57b 100644 --- a/internal/types/array.go +++ b/internal/types/array.go @@ -21,7 +21,7 @@ import ( "github.com/FerretDB/FerretDB/internal/util/must" ) -// Array represents BSON array. +// Array represents BSON array: an ordered collection of BSON values, accessed by 0-based indexes. // // Zero value is a valid empty array. type Array struct { @@ -52,7 +52,7 @@ func (a *Array) DeepCopy() *Array { return deepCopy(a).(*Array) } -// Len returns the number of elements in the array. +// Len returns the number of values in the array. // // It returns 0 for nil Array. func (a *Array) Len() int { @@ -76,7 +76,7 @@ func (a *Array) Get(index int) (any, error) { return a.s[index], nil } -// GetByPath returns a value by path - a sequence of indexes and keys. +// GetByPath returns a value by path. func (a *Array) GetByPath(path Path) (any, error) { return getByPath(a, path) } @@ -105,7 +105,7 @@ func (a *Array) Append(values ...any) { a.s = append(a.s, values...) } -// RemoveByPath removes document by path, doing nothing if the key does not exist. +// RemoveByPath removes (cuts) value by path, doing nothing if path points to nothing. func (a *Array) RemoveByPath(path Path) { removeByPath(a, path) } @@ -145,7 +145,7 @@ func (a *Array) Max() any { } // FilterArrayByType returns a new array which contains -// only elements of the same BSON type as ref. +// only values of the same BSON type as ref. // All numbers are treated as the same type. func (a *Array) FilterArrayByType(ref any) *Array { refType := detectDataType(ref) diff --git a/internal/types/document.go b/internal/types/document.go index d34bdb649980..5a31ca20a1f0 100644 --- a/internal/types/document.go +++ b/internal/types/document.go @@ -33,7 +33,8 @@ type document interface { Values() []any } -// Document represents BSON document. +// Document represents BSON document: an ordered collection of fields +// (key/value pairs where key is a string and value is any BSON value). type Document struct { fields []field } @@ -126,7 +127,7 @@ func (d *Document) DeepCopy() *Document { return deepCopy(d).(*Document) } -// Len returns the number of elements in the document. +// Len returns the number of fields in the document. // // It returns 0 for nil Document. func (d *Document) Len() int { @@ -308,7 +309,7 @@ func (d *Document) HasByPath(path Path) bool { return err == nil } -// GetByPath returns a value by path - a sequence of indexes and keys. +// GetByPath returns a value by path. // If the Path has only one element, it returns the value for the given key. func (d *Document) GetByPath(path Path) (any, error) { return getByPath(d, path) diff --git a/internal/types/documentpatherrorcode_string.go b/internal/types/documentpatherrorcode_string.go deleted file mode 100644 index 0ae1f29c7586..000000000000 --- a/internal/types/documentpatherrorcode_string.go +++ /dev/null @@ -1,31 +0,0 @@ -// Code generated by "stringer -linecomment -type DocumentPathErrorCode"; DO NOT EDIT. - -package types - -import "strconv" - -func _() { - // An "invalid array index" compiler error signifies that the constant values have changed. - // Re-run the stringer command to generate them again. - var x [1]struct{} - _ = x[ErrDocumentPathKeyNotFound-1] - _ = x[ErrDocumentPathCannotAccess-2] - _ = x[ErrDocumentPathArrayInvalidIndex-3] - _ = x[ErrDocumentPathIndexOutOfBound-4] - _ = x[ErrDocumentPathCannotCreateField-5] - _ = x[ErrDocumentPathEmptyKey-6] - _ = x[ErrDocumentPathConflictOverwrite-7] - _ = x[ErrDocumentPathConflictCollision-8] -} - -const _DocumentPathErrorCode_name = "ErrDocumentPathKeyNotFoundErrDocumentPathCannotAccessErrDocumentPathArrayInvalidIndexErrDocumentPathIndexOutOfBoundErrDocumentPathCannotCreateFieldErrDocumentPathEmptyKeyErrDocumentPathConflictOverwriteErrDocumentPathConflictCollision" - -var _DocumentPathErrorCode_index = [...]uint8{0, 26, 53, 85, 115, 147, 170, 202, 234} - -func (i DocumentPathErrorCode) String() string { - i -= 1 - if i < 0 || i >= DocumentPathErrorCode(len(_DocumentPathErrorCode_index)-1) { - return "DocumentPathErrorCode(" + strconv.FormatInt(int64(i+1), 10) + ")" - } - return _DocumentPathErrorCode_name[_DocumentPathErrorCode_index[i]:_DocumentPathErrorCode_index[i+1]] -} diff --git a/internal/types/format.go b/internal/types/format.go new file mode 100644 index 000000000000..100146a89a43 --- /dev/null +++ b/internal/types/format.go @@ -0,0 +1,109 @@ +// Copyright 2021 FerretDB Inc. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package types + +import ( + "fmt" + "math" + "strings" + "time" +) + +// FormatAnyValue formats value for error message output. +func FormatAnyValue(v any) string { + switch v := v.(type) { + case *Document: + return formatDocument(v) + case *Array: + return formatArray(v) + case float64: + switch { + case math.IsNaN(v): + return "nan.0" + + case math.IsInf(v, -1): + return "-inf.0" + case math.IsInf(v, +1): + return "inf.0" + case v == 0 && math.Signbit(v): + return "-0.0" + case v == 0.0: + return "0.0" + case v > 1000 || v < -1000 || v == math.SmallestNonzeroFloat64: + return fmt.Sprintf("%.15e", v) + case math.Trunc(v) == v: + return fmt.Sprintf("%d.0", int64(v)) + default: + res := fmt.Sprintf("%.2f", v) + + return strings.TrimSuffix(res, "0") + } + + case string: + return fmt.Sprintf(`"%v"`, v) + case Binary: + return fmt.Sprintf("BinData(%d, %X)", v.Subtype, v.B) + case ObjectID: + return fmt.Sprintf("ObjectId('%x')", v) + case bool: + return fmt.Sprintf("%v", v) + case time.Time: + return fmt.Sprintf("new Date(%d)", v.UnixMilli()) + case NullType: + return "null" + case Regex: + return fmt.Sprintf("/%s/%s", v.Pattern, v.Options) + case int32: + return fmt.Sprintf("%d", v) + case Timestamp: + return fmt.Sprintf("Timestamp(%v, %v)", int64(v)>>32, int32(v)) + case int64: + return fmt.Sprintf("%d", v) + default: + panic(fmt.Sprintf("unknown type %T", v)) + } +} + +// formatDocument formats Document for error output. +func formatDocument(doc *Document) string { + result := "{ " + + for i, f := range doc.fields { + if i > 0 { + result += ", " + } + + result += fmt.Sprintf("%s: %s", f.key, FormatAnyValue(f.value)) + } + + return result + " }" +} + +// formatArray formats Array for error output. +func formatArray(array *Array) string { + if len(array.s) == 0 { + return "[]" + } + + result := "[ " + + for _, elem := range array.s { + result += fmt.Sprintf("%s, ", FormatAnyValue(elem)) + } + + result = strings.TrimSuffix(result, ", ") + + return result + " ]" +} diff --git a/internal/types/path.go b/internal/types/path.go index 8901ead8b6b0..68cf36ba2f69 100644 --- a/internal/types/path.go +++ b/internal/types/path.go @@ -17,128 +17,145 @@ package types import ( "errors" "fmt" - "math" "strconv" "strings" - "time" "golang.org/x/exp/slices" "github.com/FerretDB/FerretDB/internal/util/must" ) -//go:generate ../../bin/stringer -linecomment -type DocumentPathErrorCode +//go:generate ../../bin/stringer -linecomment -type PathErrorCode -// DocumentPathErrorCode represents DocumentPathError error code. -type DocumentPathErrorCode int +// PathErrorCode represents PathError code. +type PathErrorCode int const ( - _ DocumentPathErrorCode = iota + _ PathErrorCode = iota - // ErrDocumentPathKeyNotFound indicates that key was not found in document. - ErrDocumentPathKeyNotFound + // ErrPathElementEmpty indicates that provided path contains an empty element. + ErrPathElementEmpty - // ErrDocumentPathCannotAccess indicates that path couldn't be accessed. - ErrDocumentPathCannotAccess + // ErrPathElementInvalid indicates that provided path contains an invalid element (other than empty). + ErrPathElementInvalid - // ErrDocumentPathArrayInvalidIndex indicates that provided array index is invalid. - ErrDocumentPathArrayInvalidIndex + // ErrPathKeyNotFound indicates that key was not found in document. + ErrPathKeyNotFound - // ErrDocumentPathIndexOutOfBound indicates that provided array index is out of bound. - ErrDocumentPathIndexOutOfBound + // ErrPathIndexInvalid indicates that provided array index is invalid. + ErrPathIndexInvalid - // ErrDocumentPathCannotCreateField indicates that it's impossible to create a specific field. - ErrDocumentPathCannotCreateField + // ErrPathIndexOutOfBound indicates that provided array index is out of bound. + ErrPathIndexOutOfBound - // ErrDocumentPathEmptyKey indicates that provided path contains empty key. - ErrDocumentPathEmptyKey + // ErrPathCannotAccess indicates that path couldn't be accessed. + ErrPathCannotAccess - // ErrDocumentPathConflictOverwrite indicates a path overwrites another path. - ErrDocumentPathConflictOverwrite + // ErrPathCannotCreateField indicates that it's impossible to create a specific field. + ErrPathCannotCreateField - // ErrDocumentPathConflictCollision indicates a path creates collision at another path. - ErrDocumentPathConflictCollision + // ErrPathConflictOverwrite indicates a path overwrites another path. + ErrPathConflictOverwrite + + // ErrPathConflictCollision indicates a path creates collision at another path. + ErrPathConflictCollision ) -// DocumentPathError describes an error that could occur on document path related operations. -type DocumentPathError struct { - reason error - code DocumentPathErrorCode +// PathError describes an error that could occur on path related operations. +type PathError struct { + err error + code PathErrorCode } // Error implements the error interface. -func (e *DocumentPathError) Error() string { - return e.reason.Error() +func (e *PathError) Error() string { + return e.err.Error() } -// Code returns the DocumentPathError code. -func (e *DocumentPathError) Code() DocumentPathErrorCode { +// Code returns the PathError code. +func (e *PathError) Code() PathErrorCode { return e.code } -// newDocumentPathError creates a new DocumentPathError. -func newDocumentPathError(code DocumentPathErrorCode, reason error) error { - return &DocumentPathError{reason: reason, code: code} -} - -// Path represents the field path type. It should be used wherever we work with paths or dot notation. -// Path should be stored and passed as a value. Its methods return new values, not modifying the receiver's state. -type Path struct { - s []string +// newPathError creates a new PathError. +func newPathError(code PathErrorCode, reason error) error { + return &PathError{err: reason, code: code} } -// NewStaticPath returns Path from a strings slice. +// Path represents a parsed dot notation - a sequence of elements (document keys and array indexes) separated by dots. // -// It panics on invalid paths. For that reason, it should not be used with user-provided paths. -func NewStaticPath(path ...string) Path { - return must.NotFail(NewPathFromString(strings.Join(path, "."))) +// Path's elements can't be empty, include dots or spaces. +// +// Path should be stored and passed as a value. +// Its methods return new values, not modifying the receiver's state. +type Path struct { + e []string } -// NewPathFromString returns Path from path string and error. -// It returns an error if the path is empty or contains empty elements. -// Path string should contain fields separated with '.'. -func NewPathFromString(s string) (Path, error) { +// newPath returns Path from a strings slice. +func newPath(path ...string) (Path, error) { var res Path - path := strings.Split(s, ".") - - for _, s := range path { - if s == "" { - return res, newDocumentPathError(ErrDocumentPathEmptyKey, errors.New("path element must not be empty")) + for _, e := range path { + switch { + case e == "": + return res, newPathError(ErrPathElementEmpty, errors.New("path element must not be empty")) + case strings.TrimSpace(e) != e: + return res, newPathError(ErrPathElementInvalid, errors.New("path element must not contain spaces")) + case strings.Contains(e, "."): + return res, newPathError(ErrPathElementInvalid, errors.New("path element must contain '.'")) + // TODO https://github.com/FerretDB/FerretDB/issues/3127 + // enable validation of `$` prefix and update Path struct comment + // case strings.HasPrefix(e, "$"): + // return res, newPathError(ErrPathElementInvalid, errors.New("path element must not start with '$'")) } } - res = Path{s: make([]string, len(path))} - copy(res.s, path) + res = Path{e: make([]string, len(path))} + copy(res.e, path) return res, nil } -// String returns dot-separated path value. +// NewStaticPath returns Path from a strings slice. +// +// It panics on invalid paths. For that reason, it should not be used with user-provided paths. +func NewStaticPath(path ...string) Path { + return must.NotFail(newPath(path...)) +} + +// NewPathFromString returns Path from a given dot notation. +// +// It returns an error if the path is invalid. +func NewPathFromString(s string) (Path, error) { + return newPath(strings.Split(s, ".")...) +} + +// String returns a dot notation for that path. func (p Path) String() string { - return strings.Join(p.s, ".") + return strings.Join(p.e, ".") } // Len returns path length. func (p Path) Len() int { - return len(p.s) + return len(p.e) } -// Slice returns path values array. +// Slice returns path elements array. func (p Path) Slice() []string { path := make([]string, p.Len()) - copy(path, p.s) + copy(path, p.e) return path } // Suffix returns the last path element. func (p Path) Suffix() string { - return p.s[p.Len()-1] + return p.e[p.Len()-1] } // Prefix returns the first path element. func (p Path) Prefix() string { - return p.s[0] + return p.e[0] } // TrimSuffix returns a path without the last element. @@ -147,7 +164,7 @@ func (p Path) TrimSuffix() Path { panic("path should have more than 1 element") } - return NewStaticPath(p.s[:p.Len()-1]...) + return NewStaticPath(p.e[:p.Len()-1]...) } // TrimPrefix returns a copy of path without the first element. @@ -156,7 +173,7 @@ func (p Path) TrimPrefix() Path { panic("path should have more than 1 element") } - return NewStaticPath(p.s[1:]...) + return NewStaticPath(p.e[1:]...) } // Append returns new Path constructed from the current path and given element. @@ -173,11 +190,11 @@ func RemoveByPath[T CompositeTypeInterface](comp T, path Path) { removeByPath(comp, path) } -// IsConflictPath returns DocumentPathError error if adding a path creates conflict at any of paths. -// Returned DocumentPathError error codes: +// IsConflictPath returns PathError error if adding a path creates conflict at any of paths. +// Returned PathError error codes: // -// - ErrDocumentPathConflictOverwrite when path overwrites any paths: paths = []{{"a","b"}} path = {"a"}; -// - ErrDocumentPathConflictCollision when path creates collision: paths = []{{"a"}} path = {"a","b"}; +// - ErrPathConflictOverwrite when path overwrites any paths: paths = []{{"a","b"}} path = {"a"}; +// - ErrPathConflictCollision when path creates collision: paths = []{{"a"}} path = {"a","b"}; func IsConflictPath(paths []Path, path Path) error { for _, p := range paths { target, prefix := p.Slice(), path.Slice() @@ -204,14 +221,14 @@ func IsConflictPath(paths []Path, path Path) error { } if p.Len() >= path.Len() { - return newDocumentPathError(ErrDocumentPathConflictOverwrite, errors.New("path overwrites previous path")) + return newPathError(ErrPathConflictOverwrite, errors.New("path overwrites previous path")) } // collisionPart is part of the path which creates collision, used in command error message. // If visitedPath is `a.b` and path is `a.b.c`, collisionPart is `b.c`. collisionPart := strings.Join(target[len(prefix):], ".") - return newDocumentPathError(ErrDocumentPathConflictCollision, errors.New(collisionPart)) + return newPathError(ErrPathConflictCollision, errors.New(collisionPart)) } return nil @@ -226,30 +243,30 @@ func getByPath[T CompositeTypeInterface](comp T, path Path) (any, error) { var err error next, err = s.Get(p) if err != nil { - return nil, newDocumentPathError(ErrDocumentPathKeyNotFound, fmt.Errorf("types.getByPath: %w", err)) + return nil, newPathError(ErrPathKeyNotFound, fmt.Errorf("types.getByPath: %w", err)) } case *Array: index, err := strconv.Atoi(p) if err != nil { - return nil, newDocumentPathError(ErrDocumentPathArrayInvalidIndex, fmt.Errorf("types.getByPath: %w", err)) + return nil, newPathError(ErrPathIndexInvalid, fmt.Errorf("types.getByPath: %w", err)) } if index < 0 { - return nil, newDocumentPathError( - ErrDocumentPathArrayInvalidIndex, + return nil, newPathError( + ErrPathIndexInvalid, fmt.Errorf("types.getByPath: array index below zero: %d", index), ) } next, err = s.Get(index) if err != nil { - return nil, newDocumentPathError(ErrDocumentPathIndexOutOfBound, fmt.Errorf("types.getByPath: %w", err)) + return nil, newPathError(ErrPathIndexOutOfBound, fmt.Errorf("types.getByPath: %w", err)) } default: - return nil, newDocumentPathError( - ErrDocumentPathCannotAccess, + return nil, newPathError( + ErrPathCannotAccess, fmt.Errorf("types.getByPath: can't access %T by path %q", next, p), ) } @@ -325,8 +342,8 @@ func insertByPath(doc *Document, path Path) error { case *Array: ind, err := strconv.Atoi(insertedPath.Slice()[suffix]) if err != nil { - return newDocumentPathError( - ErrDocumentPathCannotCreateField, + return newPathError( + ErrPathCannotCreateField, fmt.Errorf( "Cannot create field '%s' in element {%s: %s}", pathElem, @@ -337,8 +354,8 @@ func insertByPath(doc *Document, path Path) error { } if ind < 0 { - return newDocumentPathError( - ErrDocumentPathIndexOutOfBound, + return newPathError( + ErrPathIndexOutOfBound, fmt.Errorf( "Index out of bound: %d", ind, @@ -354,8 +371,8 @@ func insertByPath(doc *Document, path Path) error { v.Append(must.NotFail(NewDocument())) default: - return newDocumentPathError( - ErrDocumentPathCannotCreateField, + return newPathError( + ErrPathCannotCreateField, fmt.Errorf( "Cannot create field '%s' in element {%s: %s}", pathElem, @@ -375,90 +392,3 @@ func insertByPath(doc *Document, path Path) error { return nil } - -// FormatAnyValue formats value for error message output. -func FormatAnyValue(v any) string { - switch v := v.(type) { - case *Document: - return formatDocument(v) - case *Array: - return formatArray(v) - case float64: - switch { - case math.IsNaN(v): - return "nan.0" - - case math.IsInf(v, -1): - return "-inf.0" - case math.IsInf(v, +1): - return "inf.0" - case v == 0 && math.Signbit(v): - return "-0.0" - case v == 0.0: - return "0.0" - case v > 1000 || v < -1000 || v == math.SmallestNonzeroFloat64: - return fmt.Sprintf("%.15e", v) - case math.Trunc(v) == v: - return fmt.Sprintf("%d.0", int64(v)) - default: - res := fmt.Sprintf("%.2f", v) - - return strings.TrimSuffix(res, "0") - } - - case string: - return fmt.Sprintf(`"%v"`, v) - case Binary: - return fmt.Sprintf("BinData(%d, %X)", v.Subtype, v.B) - case ObjectID: - return fmt.Sprintf("ObjectId('%x')", v) - case bool: - return fmt.Sprintf("%v", v) - case time.Time: - return fmt.Sprintf("new Date(%d)", v.UnixMilli()) - case NullType: - return "null" - case Regex: - return fmt.Sprintf("/%s/%s", v.Pattern, v.Options) - case int32: - return fmt.Sprintf("%d", v) - case Timestamp: - return fmt.Sprintf("Timestamp(%v, %v)", int64(v)>>32, int32(v)) - case int64: - return fmt.Sprintf("%d", v) - default: - panic(fmt.Sprintf("unknown type %T", v)) - } -} - -// formatDocument formats Document for error output. -func formatDocument(doc *Document) string { - result := "{ " - - for i, f := range doc.fields { - if i > 0 { - result += ", " - } - - result += fmt.Sprintf("%s: %s", f.key, FormatAnyValue(f.value)) - } - - return result + " }" -} - -// formatArray formats Array for error output. -func formatArray(array *Array) string { - if len(array.s) == 0 { - return "[]" - } - - result := "[ " - - for _, elem := range array.s { - result += fmt.Sprintf("%s, ", FormatAnyValue(elem)) - } - - result = strings.TrimSuffix(result, ", ") - - return result + " ]" -} diff --git a/internal/types/path_test.go b/internal/types/path_test.go index dbb3648f7881..dab697b76015 100644 --- a/internal/types/path_test.go +++ b/internal/types/path_test.go @@ -24,6 +24,44 @@ import ( "github.com/FerretDB/FerretDB/internal/util/must" ) +func TestNewPathFromString(t *testing.T) { + for _, tc := range []struct { //nolint:vet // for readability + s string + p Path + err error + skip string + }{{ + s: "", + err: newPathError(ErrPathElementEmpty, fmt.Errorf("path element must not be empty")), + }, { + s: " ", + err: newPathError(ErrPathElementInvalid, fmt.Errorf("path element must not contain spaces")), + }, { + s: "$var", + err: newPathError(ErrPathElementInvalid, fmt.Errorf("path element must not start with '$'")), + skip: "https://github.com/FerretDB/FerretDB/issues/3127", + }} { + tc := tc + t.Run(tc.s, func(t *testing.T) { + if tc.skip != "" { + t.Skip(tc.skip) + } + + t.Parallel() + + res, err := NewPathFromString(tc.s) + if tc.err == nil { + require.NoError(t, err) + assert.Equal(t, tc.p, res) + return + } + + assert.Empty(t, res) + assert.Equal(t, tc.err, err) + }) + } +} + func TestRemoveByPath(t *testing.T) { t.Parallel() @@ -233,13 +271,11 @@ func TestGetByPath(t *testing.T) { "loadBalanced", false, )) - type testCase struct { + for _, tc := range []struct { //nolint:vet // for readability path Path res any err string - } - - for _, tc := range []testCase{{ //nolint:paralleltest // false positive + }{{ path: NewStaticPath("compression", "0"), res: "none", }, { @@ -275,10 +311,11 @@ func TestGetByPath(t *testing.T) { if tc.err == "" { require.NoError(t, err) assert.Equal(t, tc.res, res) - } else { - assert.Empty(t, res) - assert.EqualError(t, err, tc.err) + return } + + assert.Empty(t, res) + assert.EqualError(t, err, tc.err) }) } } @@ -287,14 +324,12 @@ func TestPathTrimSuffixPrefix(t *testing.T) { t.Parallel() pathOneElement := NewStaticPath("1") - pathZeroElement := Path{s: []string{}} + pathZeroElement := Path{e: []string{}} - type testCase struct { + for _, tc := range []struct { //nolint:vet // for readability name string f func() Path - } - - for _, tc := range []testCase{{ + }{{ name: "prefixOne", f: pathOneElement.TrimPrefix, }, { @@ -321,15 +356,13 @@ func TestPathTrimSuffixPrefix(t *testing.T) { func TestPathSuffixPrefix(t *testing.T) { t.Parallel() - pathZeroElement := Path{s: []string{}} + pathZeroElement := Path{e: []string{}} - type testCase struct { + // Obtaining prefix and suffix of single value path is harmless. + for _, tc := range []struct { //nolint:vet // for readability name string f func() string - } - - // Obtaining prefix and suffix of single value path is harmless. - for _, tc := range []testCase{{ + }{{ name: "prefixZero", f: pathZeroElement.Prefix, }, { diff --git a/internal/types/patherrorcode_string.go b/internal/types/patherrorcode_string.go new file mode 100644 index 000000000000..b4b0bf6080d6 --- /dev/null +++ b/internal/types/patherrorcode_string.go @@ -0,0 +1,32 @@ +// Code generated by "stringer -linecomment -type PathErrorCode"; DO NOT EDIT. + +package types + +import "strconv" + +func _() { + // An "invalid array index" compiler error signifies that the constant values have changed. + // Re-run the stringer command to generate them again. + var x [1]struct{} + _ = x[ErrPathElementEmpty-1] + _ = x[ErrPathElementInvalid-2] + _ = x[ErrPathKeyNotFound-3] + _ = x[ErrPathIndexInvalid-4] + _ = x[ErrPathIndexOutOfBound-5] + _ = x[ErrPathCannotAccess-6] + _ = x[ErrPathCannotCreateField-7] + _ = x[ErrPathConflictOverwrite-8] + _ = x[ErrPathConflictCollision-9] +} + +const _PathErrorCode_name = "ErrPathElementEmptyErrPathElementInvalidErrPathKeyNotFoundErrPathIndexInvalidErrPathIndexOutOfBoundErrPathCannotAccessErrPathCannotCreateFieldErrPathConflictOverwriteErrPathConflictCollision" + +var _PathErrorCode_index = [...]uint8{0, 19, 40, 58, 77, 99, 118, 142, 166, 190} + +func (i PathErrorCode) String() string { + i -= 1 + if i < 0 || i >= PathErrorCode(len(_PathErrorCode_index)-1) { + return "PathErrorCode(" + strconv.FormatInt(int64(i+1), 10) + ")" + } + return _PathErrorCode_name[_PathErrorCode_index[i]:_PathErrorCode_index[i+1]] +}