Skip to content

Commit

Permalink
Fix _id restriction in aggregation $project stage (#3508)
Browse files Browse the repository at this point in the history
Closes #2826.
  • Loading branch information
chilagrow authored Oct 9, 2023
1 parent b932d0f commit be0f7b6
Show file tree
Hide file tree
Showing 4 changed files with 79 additions and 22 deletions.
74 changes: 62 additions & 12 deletions integration/aggregate_documents_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -25,11 +25,14 @@ import (
"go.mongodb.org/mongo-driver/mongo"

"github.com/FerretDB/FerretDB/integration/setup"
"github.com/FerretDB/FerretDB/integration/shareddata"
)

func TestAggregateAddFieldsErrors(t *testing.T) {
t.Parallel()

ctx, collection := setup.Setup(t)

for name, tc := range map[string]struct { //nolint:vet // used for test only
pipeline bson.A // required, aggregation pipeline stages

Expand Down Expand Up @@ -70,8 +73,6 @@ func TestAggregateAddFieldsErrors(t *testing.T) {
require.NotNil(t, tc.pipeline, "pipeline must not be nil")
require.NotNil(t, tc.err, "err must not be nil")

ctx, collection := setup.Setup(t)

_, err := collection.Aggregate(ctx, tc.pipeline)

if tc.altMessage != "" {
Expand All @@ -87,6 +88,8 @@ func TestAggregateAddFieldsErrors(t *testing.T) {
func TestAggregateGroupErrors(t *testing.T) {
t.Parallel()

ctx, collection := setup.Setup(t)

for name, tc := range map[string]struct {
pipeline bson.A // required, aggregation pipeline stages

Expand Down Expand Up @@ -281,8 +284,6 @@ func TestAggregateGroupErrors(t *testing.T) {
require.NotNil(t, tc.pipeline, "pipeline must not be nil")
require.NotNil(t, tc.err, "err must not be nil")

ctx, collection := setup.Setup(t)

res, err := collection.Aggregate(ctx, tc.pipeline)

assert.Nil(t, res)
Expand All @@ -294,6 +295,8 @@ func TestAggregateGroupErrors(t *testing.T) {
func TestAggregateProjectErrors(t *testing.T) {
t.Parallel()

ctx, collection := setup.Setup(t)

for name, tc := range map[string]struct {
pipeline bson.A // required, aggregation pipeline stages

Expand Down Expand Up @@ -602,17 +605,66 @@ func TestAggregateProjectErrors(t *testing.T) {
require.NotNil(t, tc.pipeline, "pipeline must not be nil")
require.NotNil(t, tc.err, "err must not be nil")

ctx, collection := setup.Setup(t)

_, err := collection.Aggregate(ctx, tc.pipeline)
AssertEqualAltCommandError(t, *tc.err, tc.altMessage, err)
})
}
}

func TestAggregateProject(t *testing.T) {
t.Parallel()

ctx, collection := setup.Setup(t, shareddata.Scalars)

for name, tc := range map[string]struct { //nolint:vet // used for testing only
pipeline bson.A // required, aggregation pipeline stages

res []bson.D // required, expected response
skip string // optional, skip test with a specified reason
}{
"IDFalseValueTrue": {
pipeline: bson.A{
bson.D{{"$match", bson.D{{"_id", "int32"}}}},
bson.D{{"$project", bson.D{{"_id", false}, {"v", true}}}},
},
res: []bson.D{{{"v", int32(42)}}},
},
"ValueTrueIDFalse": {
pipeline: bson.A{
bson.D{{"$match", bson.D{{"_id", "int32"}}}},
bson.D{{"$project", bson.D{{"v", true}, {"_id", false}}}},
},
res: []bson.D{{{"v", int32(42)}}},
},
} {
name, tc := name, tc
t.Run(name, func(t *testing.T) {
if tc.skip != "" {
t.Skip(tc.skip)
}

t.Parallel()

require.NotNil(t, tc.pipeline, "pipeline must not be nil")
require.NotNil(t, tc.res, "res must not be nil")

cursor, err := collection.Aggregate(ctx, tc.pipeline)
require.NoError(t, err)
defer cursor.Close(ctx)

var res []bson.D
err = cursor.All(ctx, &res)
require.NoError(t, err)
require.Equal(t, tc.res, res)
})
}
}

func TestAggregateSetErrors(t *testing.T) {
t.Parallel()

ctx, collection := setup.Setup(t)

for name, tc := range map[string]struct { //nolint:vet // used for test only
pipeline bson.A // required, aggregation pipeline stages

Expand Down Expand Up @@ -653,8 +705,6 @@ func TestAggregateSetErrors(t *testing.T) {
require.NotNil(t, tc.pipeline, "pipeline must not be nil")
require.NotNil(t, tc.err, "err must not be nil")

ctx, collection := setup.Setup(t)

_, err := collection.Aggregate(ctx, tc.pipeline)
AssertEqualAltCommandError(t, *tc.err, tc.altMessage, err)
})
Expand All @@ -664,6 +714,8 @@ func TestAggregateSetErrors(t *testing.T) {
func TestAggregateUnsetErrors(t *testing.T) {
t.Parallel()

ctx, collection := setup.Setup(t)

for name, tc := range map[string]struct { //nolint:vet // used for test only
pipeline bson.A // required, aggregation pipeline stages

Expand Down Expand Up @@ -854,8 +906,6 @@ func TestAggregateUnsetErrors(t *testing.T) {
require.NotNil(t, tc.pipeline, "pipeline must not be nil")
require.NotNil(t, tc.err, "err must not be nil")

ctx, collection := setup.Setup(t)

_, err := collection.Aggregate(ctx, tc.pipeline)
AssertEqualAltCommandError(t, *tc.err, tc.altMessage, err)
})
Expand All @@ -865,6 +915,8 @@ func TestAggregateUnsetErrors(t *testing.T) {
func TestAggregateSortErrors(t *testing.T) {
t.Parallel()

ctx, collection := setup.Setup(t)

for name, tc := range map[string]struct { //nolint:vet // used for test only
pipeline bson.A // required, aggregation pipeline stages

Expand Down Expand Up @@ -894,8 +946,6 @@ func TestAggregateSortErrors(t *testing.T) {
require.NotNil(t, tc.pipeline, "pipeline must not be nil")
require.NotNil(t, tc.err, "err must not be nil")

ctx, collection := setup.Setup(t)

_, err := collection.Aggregate(ctx, tc.pipeline)
AssertEqualAltCommandError(t, *tc.err, tc.altMessage, err)
})
Expand Down
2 changes: 2 additions & 0 deletions integration/getmore_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -1004,6 +1004,8 @@ func TestGetMoreCommandMaxTimeMSCursor(t *testing.T) {
cursor, err := collection.Aggregate(ctx, bson.A{}, opts)
require.NoError(t, err)

defer cursor.Close(ctx)

cursor.SetBatchSize(50000)

// getMore uses maxTimeMS set on aggregate
Expand Down
21 changes: 11 additions & 10 deletions integration/query_projection_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -258,16 +258,16 @@ func TestQueryProjectionSuccess(t *testing.T) {
ctx, collection := setup.Setup(t, shareddata.Scalars)

for name, tc := range map[string]struct { //nolint:vet // used for testing only
filter bson.D // required
projection any // required
expectedResponse []bson.D // required
filter bson.D // required
projection any // required
res []bson.D // required

skip string // optional, skip test with a specified reason
}{
"QueryProjectionOfFieldV": {
filter: bson.D{{"_id", "int32"}},
projection: bson.D{{"v", true}, {"_id", false}},
expectedResponse: []bson.D{
res: []bson.D{
{{"v", int32(42)}},
},
},
Expand All @@ -282,15 +282,16 @@ func TestQueryProjectionSuccess(t *testing.T) {

require.NotNil(t, tc.filter, "filter should be set")
require.NotNil(t, tc.projection, "projection should be set")
require.NotNil(t, tc.expectedResponse, "expectedResponse should be set")
require.NotNil(t, tc.res, "res should be set")

res, err := collection.Find(ctx, tc.filter, options.Find().SetProjection(tc.projection))
cursor, err := collection.Find(ctx, tc.filter, options.Find().SetProjection(tc.projection))
require.NoError(t, err)
defer res.Close(ctx)
var results []bson.D
err = res.All(ctx, &results)
defer cursor.Close(ctx)

var res []bson.D
err = cursor.All(ctx, &res)
require.NoError(t, err)
assert.Equal(t, tc.expectedResponse, results)
assert.Equal(t, tc.res, res)
})
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -214,6 +214,10 @@ func ValidateProjection(projection *types.Document) (*types.Document, bool, erro
}

if *projectionVal != result {
if key == "_id" {
continue
}

if *projectionVal {
return nil, false, commonerrors.NewCommandErrorMsgWithArgument(
commonerrors.ErrProjectionExIn,
Expand Down

0 comments on commit be0f7b6

Please sign in to comment.