diff --git a/integration/query_compat_test.go b/integration/query_compat_test.go index 24189fcecb8e..fb92ab522f89 100644 --- a/integration/query_compat_test.go +++ b/integration/query_compat_test.go @@ -41,17 +41,23 @@ type queryCompatTestCase struct { resultType compatTestCaseResultType // defaults to nonEmptyResult resultPushdown bool // defaults to false + skipIDCheck bool // skip check collected IDs, use it when no ids returned from query skip string // skip test for all handlers, must have issue number mentioned skipForTigris string // skip test for Tigris } -// testQueryCompat tests query compatibility test cases. -func testQueryCompat(t *testing.T, testCases map[string]queryCompatTestCase) { +func testQueryCompatWithProviders(t *testing.T, providers shareddata.Providers, testCases map[string]queryCompatTestCase) { t.Helper() + require.NotEmpty(t, providers) + // Use shared setup because find queries can't modify data. // TODO Use read-only user. https://github.com/FerretDB/FerretDB/issues/1025 - ctx, targetCollections, compatCollections := setup.SetupCompat(t) + s := setup.SetupCompatWithOpts(t, &setup.SetupCompatOpts{ + Providers: providers, + }) + + ctx, targetCollections, compatCollections := s.Ctx, s.TargetCollections, s.CompatCollections for name, tc := range testCases { name, tc := name, tc @@ -151,8 +157,11 @@ func testQueryCompat(t *testing.T, testCases map[string]queryCompatTestCase) { require.NoError(t, targetCursor.All(ctx, &targetRes)) require.NoError(t, compatCursor.All(ctx, &compatRes)) - t.Logf("Compat (expected) IDs: %v", CollectIDs(t, compatRes)) - t.Logf("Target (actual) IDs: %v", CollectIDs(t, targetRes)) + if !tc.skipIDCheck { + t.Logf("Compat (expected) IDs: %v", CollectIDs(t, compatRes)) + t.Logf("Target (actual) IDs: %v", CollectIDs(t, targetRes)) + } + AssertEqualDocumentsSlice(t, compatRes, targetRes) if len(targetRes) > 0 || len(compatRes) > 0 { @@ -173,6 +182,13 @@ func testQueryCompat(t *testing.T, testCases map[string]queryCompatTestCase) { } } +// testQueryCompat tests query compatibility test cases. +func testQueryCompat(t *testing.T, testCases map[string]queryCompatTestCase) { + t.Helper() + + testQueryCompatWithProviders(t, shareddata.AllProviders(), testCases) +} + func TestQueryCompatFilter(t *testing.T) { t.Parallel() diff --git a/integration/query_projection_compat_test.go b/integration/query_projection_compat_test.go index fa6bbf92a1e0..e707e1220d0a 100644 --- a/integration/query_projection_compat_test.go +++ b/integration/query_projection_compat_test.go @@ -18,25 +18,120 @@ import ( "testing" "go.mongodb.org/mongo-driver/bson" + + "github.com/FerretDB/FerretDB/integration/shareddata" ) func TestQueryProjectionCompat(t *testing.T) { t.Parallel() + // topLevelFieldsIntegers contains documents with several top level fields with integer values. + topLevelFieldsIntegers := shareddata.NewTopLevelFieldsProvider[string]( + "TopLevelFieldsIntegers", + []string{"ferretdb-pg", "ferretdb-tigris", "mongodb"}, + map[string]map[string]any{ + "ferretdb-tigris": { + "$tigrisSchemaString": `{ + "title": "%%collection%%", + "primary_key": ["_id"], + "properties": { + "foo": {"type": "integer", "format": "int32"}, + "bar": {"type": "integer", "format": "int32"}, + "_id": {"type": "string"} + } + }`, + }, + }, + map[string]shareddata.Fields{ + "int32-two": { + {Key: "foo", Value: int32(1)}, + {Key: "bar", Value: int32(2)}, + }, + }, + ) + + providers := append(shareddata.AllProviders(), topLevelFieldsIntegers) + testCases := map[string]queryCompatTestCase{ - "FindProjectionInclusions": { - filter: bson.D{{"_id", "document-composite"}}, - projection: bson.D{{"foo", int32(1)}, {"42", true}}, - skipForTigris: "Tigris does not support field names started from numbers (`42`)", - resultPushdown: true, - }, - "FindProjectionExclusions": { - filter: bson.D{{"_id", "document-composite"}}, - projection: bson.D{{"foo", int32(0)}, {"array", false}}, - skipForTigris: "Tigris does not support language keyword 'array' as field name", - resultPushdown: true, + "Include1Field": { + filter: bson.D{}, + projection: bson.D{{"v", int32(1)}}, + }, + "Exclude1Field": { + filter: bson.D{}, + projection: bson.D{{"v", int32(0)}}, + }, + "Include2Fields": { + filter: bson.D{}, + projection: bson.D{{"foo", 1.24}, {"bar", true}}, + }, + "Exclude2Fields": { + filter: bson.D{}, + projection: bson.D{{"foo", int32(0)}, {"bar", false}}, + }, + "Include1FieldExclude1Field": { + filter: bson.D{}, + projection: bson.D{{"foo", int32(1)}, {"bar", true}}, + }, + "IncludeID": { + filter: bson.D{}, + projection: bson.D{{"_id", int64(-1)}}, + }, + "ExcludeID": { + filter: bson.D{}, + projection: bson.D{{"_id", false}}, + skipIDCheck: true, + }, + "IncludeFieldExcludeID": { + filter: bson.D{}, + projection: bson.D{{"_id", false}, {"v", true}}, + skipIDCheck: true, + }, + "ExcludeFieldIncludeID": { + filter: bson.D{}, + projection: bson.D{{"_id", true}, {"v", false}}, + }, + "ExcludeFieldExcludeID": { + filter: bson.D{}, + projection: bson.D{{"_id", false}, {"v", false}}, + skipIDCheck: true, + }, + "IncludeFieldIncludeID": { + filter: bson.D{}, + projection: bson.D{{"_id", true}, {"v", true}}, + }, + "DotNotationInclude": { + filter: bson.D{}, + projection: bson.D{{"v.foo", true}}, + skip: "https://github.com/FerretDB/FerretDB/issues/2430", + }, + "DotNotationIncludeTwo": { + filter: bson.D{}, + projection: bson.D{{"v.foo", true}, {"v.array", true}}, + skip: "https://github.com/FerretDB/FerretDB/issues/2430", + }, + "DotNotationExclude": { + filter: bson.D{}, + projection: bson.D{{"v.foo", false}}, + skip: "https://github.com/FerretDB/FerretDB/issues/2430", + }, + "DotNotationExcludeTwo": { + filter: bson.D{}, + projection: bson.D{{"v.foo", false}, {"v.array", false}}, + skip: "https://github.com/FerretDB/FerretDB/issues/2430", + }, + "DotNotationExcludeSecondLevel": { + filter: bson.D{}, + projection: bson.D{{"v.array.42", false}}, + skip: "https://github.com/FerretDB/FerretDB/issues/2430", + }, + "DotNotationIncludeExclude": { + filter: bson.D{}, + projection: bson.D{{"v.foo", true}, {"v.array", false}}, + resultType: emptyResult, + skip: "https://github.com/FerretDB/FerretDB/issues/2430", }, } - testQueryCompat(t, testCases) + testQueryCompatWithProviders(t, providers, testCases) } diff --git a/integration/query_projection_test.go b/integration/query_projection_test.go deleted file mode 100644 index aa3460cf455f..000000000000 --- a/integration/query_projection_test.go +++ /dev/null @@ -1,60 +0,0 @@ -// Copyright 2021 FerretDB Inc. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -package integration - -import ( - "testing" - - "github.com/stretchr/testify/require" - "go.mongodb.org/mongo-driver/bson" - "go.mongodb.org/mongo-driver/mongo/options" - - "github.com/FerretDB/FerretDB/integration/setup" - "github.com/FerretDB/FerretDB/integration/shareddata" -) - -func TestQueryProjection(t *testing.T) { - setup.SkipForTigris(t) - - t.Parallel() - ctx, collection := setup.Setup(t, shareddata.Composites) - - for name, tc := range map[string]struct { - projection any - filter any - expected bson.D - }{ - "FindProjectionIDExclusion": { - filter: bson.D{{"_id", "document-composite"}}, - // TODO: https://github.com/FerretDB/FerretDB/issues/537 - projection: bson.D{{"_id", false}, {"array", int32(1)}}, - expected: bson.D{}, - }, - } { - name, tc := name, tc - t.Run(name, func(t *testing.T) { - t.Parallel() - - cursor, err := collection.Find(ctx, tc.filter, options.Find().SetProjection(tc.projection)) - require.NoError(t, err) - - var actual []bson.D - err = cursor.All(ctx, &actual) - require.NoError(t, err) - require.Len(t, actual, 1) - AssertEqualDocuments(t, tc.expected, actual[0]) - }) - } -} diff --git a/integration/shareddata/shareddata.go b/integration/shareddata/shareddata.go index 5bb6bdab60ce..6583660c33fa 100644 --- a/integration/shareddata/shareddata.go +++ b/integration/shareddata/shareddata.go @@ -242,8 +242,85 @@ func (b *BenchmarkValues) Docs() iterator.Interface[struct{}, bson.D] { return b.iter } +// field represents a field in a document. +type field struct { + Value any + Key string +} + +// Fields is a slice of ordered field name value pair. To avoid fields being inserted in different order between compat and target, use a slice instead of a map. +type Fields []field + +// NewTopLevelFieldsProvider creates a new TopLevelValues provider. +func NewTopLevelFieldsProvider[id comparable](name string, backends []string, validators map[string]map[string]any, data map[id]Fields) *TopLevelValues[id] { + return &TopLevelValues[id]{ + name: name, + backends: backends, + validators: validators, + data: data, + } +} + +// TopLevelValues stores shared data documents as {"_id": key, "field1": value1, "field2": value2, ...} documents. +// +//nolint:vet // for readability +type TopLevelValues[id comparable] struct { + name string + backends []string + validators map[string]map[string]any // backend -> validator name -> validator + data map[id]Fields +} + +// Name implements Provider interface. +func (t *TopLevelValues[id]) Name() string { + return t.name +} + +// Validators implements Provider interface. +func (t *TopLevelValues[id]) Validators(backend, collection string) map[string]any { + switch backend { + case "ferretdb-tigris": + validators := make(map[string]any, len(t.validators[backend])) + + for key, value := range t.validators[backend] { + validators[key] = strings.ReplaceAll(value.(string), "%%collection%%", collection) + } + + return validators + default: + return t.validators[backend] + } +} + +// Docs implements Provider interface. +func (t *TopLevelValues[id]) Docs() []bson.D { + ids := maps.Keys(t.data) + + res := make([]bson.D, 0, len(t.data)) + + for _, id := range ids { + doc := bson.D{{"_id", id}} + + fields := t.data[id] + + for _, field := range fields { + doc = append(doc, bson.E{Key: field.Key, Value: field.Value}) + } + + res = append(res, doc) + } + + return res +} + +// IsCompatible implements Provider interface. +func (t *TopLevelValues[id]) IsCompatible(backend string) bool { + return slices.Contains(t.backends, backend) +} + // check interfaces var ( _ Provider = (*Values[string])(nil) _ BenchmarkProvider = (*BenchmarkValues)(nil) + _ Provider = (*TopLevelValues[string])(nil) ) diff --git a/internal/handlers/common/projection.go b/internal/handlers/common/projection.go index 130bbb42b5ae..c9cd1aac3549 100644 --- a/internal/handlers/common/projection.go +++ b/internal/handlers/common/projection.go @@ -15,160 +15,206 @@ package common import ( + "errors" "fmt" "github.com/FerretDB/FerretDB/internal/handlers/commonerrors" "github.com/FerretDB/FerretDB/internal/types" + "github.com/FerretDB/FerretDB/internal/util/iterator" "github.com/FerretDB/FerretDB/internal/util/lazyerrors" "github.com/FerretDB/FerretDB/internal/util/must" ) -// ProjectDocuments modifies given documents in places according to the given projection. -func ProjectDocuments(docs []*types.Document, projection *types.Document) error { - if projection.Len() == 0 { - return nil - } +var errProjectionEmpty = errors.New("projection is empty") - inclusion, err := isProjectionInclusion(projection) - if err != nil { - return err - } +// validateProjection check projection document. +// Document fields could be either included or excluded but not both. +// Exception is for the _id field that could be included or excluded. +func validateProjection(projection *types.Document) (*types.Document, bool, error) { + validated := types.MakeDocument(0) - for i := 0; i < len(docs); i++ { - err = projectDocument(inclusion, docs[i], projection) - if err != nil { - return err - } + if projection == nil { + return nil, false, errProjectionEmpty } - return nil -} + var projectionVal *bool -// isProjectionInclusion: projection can be only inclusion or exclusion. Validate and return true if inclusion. -// Exception for the _id field. -func isProjectionInclusion(projection *types.Document) (inclusion bool, err error) { - var exclusion bool - for _, k := range projection.Keys() { - if k == "_id" { // _id is a special case and can be both - continue + iter := projection.Iterator() + defer iter.Close() + + for { + key, value, err := iter.Next() + if errors.Is(err, iterator.ErrIteratorDone) { + break } - v := must.NotFail(projection.Get(k)) - switch v := v.(type) { - case *types.Document: - for _, projectionType := range v.Keys() { - err = commonerrors.NewCommandError( - commonerrors.ErrNotImplemented, - fmt.Errorf("projection of %s is not supported", projectionType), - ) - return - } + if err != nil { + return nil, false, lazyerrors.Error(err) + } + + var result bool + switch value := value.(type) { + case *types.Document, *types.Array, string: + return nil, false, commonerrors.NewCommandErrorMsg( + commonerrors.ErrNotImplemented, + fmt.Sprintf("projection expression %s is not supported", types.FormatAnyValue(value)), + ) case float64, int32, int64: - result := types.Compare(v, int32(0)) - if result == types.Equal { - if inclusion { - err = commonerrors.NewCommandErrorMsgWithArgument(commonerrors.ErrProjectionExIn, - fmt.Sprintf("Cannot do exclusion on field %s in inclusion projection", k), - "projection", - ) - return - } - exclusion = true - } else { - if exclusion { - err = commonerrors.NewCommandErrorMsgWithArgument(commonerrors.ErrProjectionInEx, - fmt.Sprintf("Cannot do inclusion on field %s in exclusion projection", k), - "projection", - ) - return - } - inclusion = true - } + // projection treats 0 as false and any other value as true + comparison := types.Compare(value, int32(0)) - case bool: - if v { - if exclusion { - err = commonerrors.NewCommandErrorMsgWithArgument(commonerrors.ErrProjectionInEx, - fmt.Sprintf("Cannot do inclusion on field %s in exclusion projection", k), - "projection", - ) - return - } - inclusion = true - } else { - if inclusion { - err = commonerrors.NewCommandErrorMsgWithArgument(commonerrors.ErrProjectionExIn, - fmt.Sprintf("Cannot do exclusion on field %s in inclusion projection", k), - "projection", - ) - return - } - exclusion = true + if comparison != types.Equal { + result = true } - + case bool: + result = value default: - err = lazyerrors.Errorf("unsupported operation %s %v (%T)", k, v, v) - return + return nil, false, lazyerrors.Errorf("unsupported operation %s %value (%T)", key, value, value) } - } - return -} -func projectDocument(inclusion bool, doc *types.Document, projection *types.Document) error { - projectionMap := projection.Map() + // set the value with boolean result to omit type assertion when we will apply projection + validated.Set(key, result) - for k1 := range doc.Map() { - projectionVal, ok := projectionMap[k1] - if !ok { - if k1 == "_id" { // if _id is not in projection map, do not do anything with it + if projection.Len() == 1 && key == "_id" { + return validated, result, nil + } + + // if projectionVal is nil we are processing the first field + if projectionVal == nil { + if key == "_id" { continue } - if inclusion { // k1 from doc is absent in projection, remove from doc only if projection type inclusion - doc.Remove(k1) - } + + projectionVal = &result + continue } - switch projectionVal := projectionVal.(type) { // found in the projection - case *types.Document: // field: { $elemMatch: { field2: value }} - if err := applyComplexProjection(projectionVal); err != nil { - return err + if *projectionVal != result { + if *projectionVal { + return nil, false, commonerrors.NewCommandErrorMsgWithArgument( + commonerrors.ErrProjectionExIn, + fmt.Sprintf("Cannot do exclusion on field %s in inclusion projection", key), + "projection", + ) } - case float64, int32, int64: // field: number - result := types.Compare(projectionVal, int32(0)) - if result == types.Equal { - doc.Remove(k1) - } + return nil, false, commonerrors.NewCommandErrorMsgWithArgument( + commonerrors.ErrProjectionInEx, + fmt.Sprintf("Cannot do inclusion on field %s in exclusion projection", key), + "projection", + ) + } + } - case bool: // field: bool - if !projectionVal { - doc.Remove(k1) - } + return validated, *projectionVal, nil +} +// projectDocument applies projection to the copy of the document. +func projectDocument(doc, projection *types.Document, inclusion bool) (*types.Document, error) { + projected, err := types.NewDocument("_id", must.NotFail(doc.Get("_id"))) + if err != nil { + return nil, err + } + + if projection.Has("_id") { + idValue := must.NotFail(projection.Get("_id")) + + var set bool + + switch idValue := idValue.(type) { + case *types.Document: // field: { $elemMatch: { field2: value }} + return nil, commonerrors.NewCommandErrorMsg( + commonerrors.ErrCommandNotFound, + fmt.Sprintf("projection %s is not supported", + types.FormatAnyValue(idValue), + ), + ) + case bool: + set = idValue default: - return lazyerrors.Errorf("unsupported operation %s %v (%T)", k1, projectionVal, projectionVal) + return nil, lazyerrors.Errorf("unsupported operation %s %v (%T)", "_id", idValue, idValue) + } + + if !set { + projected.Remove("_id") } } - return nil + + projectedWithoutID, err := projectDocumentWithoutID(doc, projection, inclusion) + if err != nil { + return nil, err + } + + for _, key := range projectedWithoutID.Keys() { + projected.Set(key, must.NotFail(projectedWithoutID.Get(key))) + } + + return projected, nil } -func applyComplexProjection(projectionVal *types.Document) error { - for _, projectionType := range projectionVal.Keys() { - switch projectionType { - case "$elemMatch", "$slice": - return commonerrors.NewCommandError( - commonerrors.ErrNotImplemented, - fmt.Errorf("projection of %s is not supported", projectionType), - ) - default: - return commonerrors.NewCommandError( +// projectDocumentWithoutID applies projection to the copy of the document and returns projected document. +// It ignores _id field in the projection. +func projectDocumentWithoutID(doc *types.Document, projection *types.Document, inclusion bool) (*types.Document, error) { + projectionWithoutID := projection.DeepCopy() + projectionWithoutID.Remove("_id") + + docWithoutID := doc.DeepCopy() + docWithoutID.Remove("_id") + + projected := types.MakeDocument(0) + + if !inclusion { + projected = docWithoutID.DeepCopy() + } + + iter := projectionWithoutID.Iterator() + defer iter.Close() + + for { + key, value, err := iter.Next() + if errors.Is(err, iterator.ErrIteratorDone) { + break + } + + if err != nil { + return nil, lazyerrors.Error(err) + } + + path, err := types.NewPathFromString(key) + if err != nil { + return nil, lazyerrors.Error(err) + } + + switch value := value.(type) { // found in the projection + case *types.Document: // field: { $elemMatch: { field2: value }} + return nil, commonerrors.NewCommandErrorMsg( commonerrors.ErrCommandNotFound, - fmt.Errorf("projection of %s is not supported", projectionType), + fmt.Sprintf("projection %s is not supported", + types.FormatAnyValue(value), + ), ) + + case bool: // field: bool + // process top level fields + if path.Len() == 1 { + if inclusion { + if docWithoutID.Has(key) { + projected.Set(key, must.NotFail(docWithoutID.Get(key))) + } + + continue + } + + projected.Remove(key) + } + + // TODO: process dot notation here https://github.com/FerretDB/FerretDB/issues/2430 + default: + return nil, lazyerrors.Errorf("unsupported operation %s %v (%T)", key, value, value) } } - return nil + return projected, nil } diff --git a/internal/handlers/common/projection_iterator.go b/internal/handlers/common/projection_iterator.go index 2bbb83de22c6..c19a4b98d256 100644 --- a/internal/handlers/common/projection_iterator.go +++ b/internal/handlers/common/projection_iterator.go @@ -15,6 +15,8 @@ package common import ( + "errors" + "github.com/FerretDB/FerretDB/internal/types" "github.com/FerretDB/FerretDB/internal/util/iterator" "github.com/FerretDB/FerretDB/internal/util/lazyerrors" @@ -27,14 +29,18 @@ import ( // // Close method closes the underlying iterator. func ProjectionIterator(iter types.DocumentsIterator, closer *iterator.MultiCloser, projection *types.Document) (types.DocumentsIterator, error) { //nolint:lll // for readability - inclusion, err := isProjectionInclusion(projection) + projectionValidated, inclusion, err := validateProjection(projection) + if errors.Is(err, errProjectionEmpty) { + return iter, nil + } + if err != nil { return nil, lazyerrors.Error(err) } res := &projectionIterator{ iter: iter, - projection: projection, + projection: projectionValidated, inclusion: inclusion, } closer.Add(res) @@ -58,12 +64,12 @@ func (iter *projectionIterator) Next() (struct{}, *types.Document, error) { return unused, nil, lazyerrors.Error(err) } - err = projectDocument(iter.inclusion, doc, iter.projection) + projected, err := projectDocument(doc, iter.projection, iter.inclusion) if err != nil { return unused, nil, lazyerrors.Error(err) } - return unused, doc, nil + return unused, projected, nil } // Close implements iterator.Interface. See ProjectionIterator for details.