From eb2943e67f5af9377c5f68a5f7bd9738520d0ece Mon Sep 17 00:00:00 2001 From: Chi Fujii Date: Mon, 10 Jul 2023 18:58:44 +0900 Subject: [PATCH] Add MongoDB integration tests for `maxTimeMS` in `find`, `aggregate` and `getMore` (#2953) Closes #1808. --- .../aggregate_documents_compat_test.go | 12 +- integration/aggregate_documents_test.go | 193 +++++++++ integration/getmore_test.go | 403 ++++++++++++++++++ internal/handlers/common/getmore.go | 55 ++- internal/handlers/pg/msg_aggregate.go | 62 ++- internal/handlers/pg/msg_find.go | 1 + 6 files changed, 714 insertions(+), 12 deletions(-) diff --git a/integration/aggregate_documents_compat_test.go b/integration/aggregate_documents_compat_test.go index 65594db2fa04..6aba23ed3e4d 100644 --- a/integration/aggregate_documents_compat_test.go +++ b/integration/aggregate_documents_compat_test.go @@ -204,7 +204,7 @@ func testAggregateCommandCompat(t *testing.T, testCases map[string]aggregateComm t.Run(targetCollection.Name(), func(t *testing.T) { t.Helper() - var targetRes, compatRes []bson.D + var targetRes, compatRes bson.D targetErr := targetCollection.Database().RunCommand(ctx, command).Decode(&targetRes) compatErr := compatCollection.Database().RunCommand(ctx, command).Decode(&compatRes) @@ -223,8 +223,7 @@ func testAggregateCommandCompat(t *testing.T, testCases map[string]aggregateComm return } require.NoError(t, compatErr, "compat error; target returned no error") - - AssertEqualDocumentsSlice(t, compatRes, targetRes) + AssertEqualDocuments(t, compatRes, targetRes) if len(targetRes) > 0 || len(compatRes) > 0 { nonEmptyResults = true @@ -280,16 +279,13 @@ func TestAggregateCommandCompat(t *testing.T) { }, resultType: emptyResult, }, - "MaxTimeMSNegative": { + "MaxTimeMSDoubleWholeNumber": { command: bson.D{ {"aggregate", "collection-name"}, {"pipeline", bson.A{}}, - {"maxTimeMS", int64(-1)}, {"cursor", bson.D{}}, + {"maxTimeMS", float64(1000)}, }, - resultType: emptyResult, - // compat and target return an error from the driver - // > cannot decode document into []primitive.D }, } diff --git a/integration/aggregate_documents_test.go b/integration/aggregate_documents_test.go index 0caade03adc2..9ddaebe9934f 100644 --- a/integration/aggregate_documents_test.go +++ b/integration/aggregate_documents_test.go @@ -15,6 +15,7 @@ package integration import ( + "math" "testing" "github.com/stretchr/testify/assert" @@ -661,6 +662,198 @@ func TestAggregateUnsetErrors(t *testing.T) { } } +func TestAggregateCommandMaxTimeMSErrors(t *testing.T) { + t.Parallel() + ctx, collection := setup.Setup(t) + + for name, tc := range map[string]struct { //nolint:vet // used for testing only + command bson.D // required, command to run + + err *mongo.CommandError // required, expected error from MongoDB + altMessage string // optional, alternative error message for FerretDB, ignored if empty + skip string // optional, skip test with a specified reason + }{ + "NegativeLong": { + command: bson.D{ + {"aggregate", collection.Name()}, + {"pipeline", bson.A{}}, + {"cursor", bson.D{}}, + {"maxTimeMS", int64(-1)}, + }, + err: &mongo.CommandError{ + Code: 51024, + Name: "Location51024", + Message: "BSON field 'maxTimeMS' value must be >= 0, actual value '-1'", + }, + }, + "MaxLong": { + command: bson.D{ + {"aggregate", collection.Name()}, + {"pipeline", bson.A{}}, + {"cursor", bson.D{}}, + {"maxTimeMS", math.MaxInt64}, + }, + err: &mongo.CommandError{ + Code: 2, + Name: "BadValue", + Message: "9223372036854775807 value for maxTimeMS is out of range", + }, + }, + "Double": { + command: bson.D{ + {"aggregate", collection.Name()}, + {"pipeline", bson.A{}}, + {"cursor", bson.D{}}, + {"maxTimeMS", 1000.5}, + }, + err: &mongo.CommandError{ + Code: 2, + Name: "BadValue", + Message: "maxTimeMS has non-integral value", + }, + }, + "NegativeDouble": { + command: bson.D{ + {"aggregate", collection.Name()}, + {"pipeline", bson.A{}}, + {"cursor", bson.D{}}, + {"maxTimeMS", -14245345234123245.55}, + }, + err: &mongo.CommandError{ + Code: 51024, + Name: "Location51024", + Message: "BSON field 'maxTimeMS' value must be >= 0, actual value '-14245345234123246'", + }, + altMessage: "BSON field 'maxTimeMS' value must be >= 0, actual value '-1.424534523412325e+16'", + }, + "BigDouble": { + command: bson.D{ + {"aggregate", collection.Name()}, + {"pipeline", bson.A{}}, + {"cursor", bson.D{}}, + {"maxTimeMS", math.MaxFloat64}, + }, + err: &mongo.CommandError{ + Code: 2, + Name: "BadValue", + Message: "9223372036854775807 value for maxTimeMS is out of range", + }, + altMessage: "1.797693134862316e+308 value for maxTimeMS is out of range", + }, + "BigNegativeDouble": { + command: bson.D{ + {"aggregate", collection.Name()}, + {"pipeline", bson.A{}}, + {"cursor", bson.D{}}, + {"maxTimeMS", -math.MaxFloat64}, + }, + err: &mongo.CommandError{ + Code: 51024, + Name: "Location51024", + Message: "BSON field 'maxTimeMS' value must be >= 0, actual value '-9223372036854775808'", + }, + altMessage: "BSON field 'maxTimeMS' value must be >= 0, actual value '-1.797693134862316e+308'", + }, + "NegativeInt32": { + command: bson.D{ + {"aggregate", collection.Name()}, + {"pipeline", bson.A{}}, + {"cursor", bson.D{}}, + {"maxTimeMS", -1123123}, + }, + err: &mongo.CommandError{ + Code: 51024, + Name: "Location51024", + Message: "BSON field 'maxTimeMS' value must be >= 0, actual value '-1123123'", + }, + }, + "MaxIntPlus": { + command: bson.D{ + {"aggregate", collection.Name()}, + {"pipeline", bson.A{}}, + {"cursor", bson.D{}}, + {"maxTimeMS", math.MaxInt32 + 1}, + }, + err: &mongo.CommandError{ + Code: 2, + Name: "BadValue", + Message: "2147483648 value for maxTimeMS is out of range", + }, + }, + "Null": { + command: bson.D{ + {"aggregate", collection.Name()}, + {"pipeline", bson.A{}}, + {"cursor", bson.D{}}, + {"maxTimeMS", nil}, + }, + err: &mongo.CommandError{ + Code: 2, + Name: "BadValue", + Message: "maxTimeMS must be a number", + }, + }, + "String": { + command: bson.D{ + {"aggregate", collection.Name()}, + {"pipeline", bson.A{}}, + {"cursor", bson.D{}}, + {"maxTimeMS", "string"}, + }, + err: &mongo.CommandError{ + Code: 14, + Name: "TypeMismatch", + Message: "BSON field 'aggregate.maxTimeMS' is the wrong type 'string', expected types '[long, int, decimal, double']", + }, + altMessage: "BSON field 'aggregate.maxTimeMS' is the wrong type 'string', expected types '[long, int, decimal, double]'", + }, + "Array": { + command: bson.D{ + {"aggregate", collection.Name()}, + {"pipeline", bson.A{}}, + {"cursor", bson.D{}}, + {"maxTimeMS", bson.A{int32(42), "foo", nil}}, + }, + err: &mongo.CommandError{ + Code: 14, + Name: "TypeMismatch", + Message: "BSON field 'aggregate.maxTimeMS' is the wrong type 'array', expected types '[long, int, decimal, double']", + }, + altMessage: "BSON field 'aggregate.maxTimeMS' is the wrong type 'array', expected types '[long, int, decimal, double]'", + }, + "Document": { + command: bson.D{ + {"aggregate", collection.Name()}, + {"pipeline", bson.A{}}, + {"cursor", bson.D{}}, + {"maxTimeMS", bson.D{{"foo", int32(42)}}}, + }, + err: &mongo.CommandError{ + Code: 14, + Name: "TypeMismatch", + Message: "BSON field 'aggregate.maxTimeMS' is the wrong type 'object', expected types '[long, int, decimal, double']", + }, + altMessage: "BSON field 'aggregate.maxTimeMS' is the wrong type 'object', expected types '[long, int, decimal, double]'", + }, + } { + name, tc := name, tc + t.Run(name, func(t *testing.T) { + if tc.skip != "" { + t.Skip(tc.skip) + } + + t.Parallel() + + require.NotNil(t, tc.err, "err must not be nil") + + var res bson.D + err := collection.Database().RunCommand(ctx, tc.command).Decode(&res) + AssertEqualAltCommandError(t, *tc.err, tc.altMessage, err) + require.Nil(t, res) + }) + } +} + func TestAggregateCommandCursor(t *testing.T) { t.Parallel() ctx, collection := setup.Setup(t) diff --git a/integration/getmore_test.go b/integration/getmore_test.go index 66ad22d55176..1972f5516653 100644 --- a/integration/getmore_test.go +++ b/integration/getmore_test.go @@ -15,6 +15,7 @@ package integration import ( + "math" "net/url" "testing" @@ -26,6 +27,7 @@ import ( "go.mongodb.org/mongo-driver/mongo/options" "github.com/FerretDB/FerretDB/integration/setup" + "github.com/FerretDB/FerretDB/integration/shareddata" "github.com/FerretDB/FerretDB/internal/types" "github.com/FerretDB/FerretDB/internal/util/must" ) @@ -701,3 +703,404 @@ func TestGetMoreCommandConnection(t *testing.T) { ) }) } + +func TestGetMoreCommandMaxTimeMSErrors(t *testing.T) { + t.Parallel() + ctx, collection := setup.Setup(t) + + for name, tc := range map[string]struct { //nolint:vet // used for testing only + command bson.D // required, command to run + + err *mongo.CommandError // required, expected error from MongoDB + altMessage string // optional, alternative error message for FerretDB, ignored if empty + skip string // optional, skip test with a specified reason + }{ + "NegativeLong": { + command: bson.D{ + {"getMore", int64(112233)}, + {"collection", collection.Name()}, + {"maxTimeMS", int64(-1)}, + }, + err: &mongo.CommandError{ + Code: 2, + Name: "BadValue", + Message: "-1 value for maxTimeMS is out of range", + }, + }, + "MaxLong": { + command: bson.D{ + {"getMore", int64(112233)}, + {"collection", collection.Name()}, + {"maxTimeMS", math.MaxInt64}, + }, + err: &mongo.CommandError{ + Code: 2, + Name: "BadValue", + Message: "9223372036854775807 value for maxTimeMS is out of range", + }, + }, + "Double": { + command: bson.D{ + {"getMore", int64(112233)}, + {"collection", collection.Name()}, + {"maxTimeMS", 1000.5}, + }, + err: &mongo.CommandError{ + Code: 2, + Name: "BadValue", + Message: "maxTimeMS has non-integral value", + }, + altMessage: "BSON field 'getMore.maxTimeMS' is the wrong type 'double', expected types '[long, int, decimal, double]'", + }, + "NegativeDouble": { + command: bson.D{ + {"getMore", int64(112233)}, + {"collection", collection.Name()}, + {"maxTimeMS", -14245345234123245.55}, + }, + err: &mongo.CommandError{ + Code: 2, + Name: "BadValue", + Message: "-14245345234123246 value for maxTimeMS is out of range", + }, + altMessage: "-1.4245345234123246e+16 value for maxTimeMS is out of range", + }, + "BigDouble": { + command: bson.D{ + {"getMore", int64(112233)}, + {"collection", collection.Name()}, + {"maxTimeMS", math.MaxFloat64}, + }, + err: &mongo.CommandError{ + Code: 2, + Name: "BadValue", + Message: "9223372036854775807 value for maxTimeMS is out of range", + }, + altMessage: "1.797693134862316e+308 value for maxTimeMS is out of range", + }, + "BigNegativeDouble": { + command: bson.D{ + {"getMore", int64(112233)}, + {"collection", collection.Name()}, + {"maxTimeMS", -math.MaxFloat64}, + }, + err: &mongo.CommandError{ + Code: 2, + Name: "BadValue", + Message: "-9223372036854775808 value for maxTimeMS is out of range", + }, + altMessage: "-1.797693134862316e+308 value for maxTimeMS is out of range", + }, + "NegativeInt": { + command: bson.D{ + {"getMore", int64(112233)}, + {"collection", collection.Name()}, + {"maxTimeMS", -1123123}, + }, + err: &mongo.CommandError{ + Code: 2, + Name: "BadValue", + Message: "-1123123 value for maxTimeMS is out of range", + }, + }, + "MaxInt": { + command: bson.D{ + {"getMore", int64(112233)}, + {"collection", collection.Name()}, + {"maxTimeMS", math.MaxInt32 + 1}, + }, + err: &mongo.CommandError{ + Code: 2, + Name: "BadValue", + Message: "2147483648 value for maxTimeMS is out of range", + }, + }, + "Null": { + command: bson.D{ + {"getMore", int64(112233)}, + {"collection", collection.Name()}, + {"maxTimeMS", nil}, + }, + err: &mongo.CommandError{ + Code: 2, + Name: "BadValue", + Message: "maxTimeMS must be a number", + }, + }, + "String": { + command: bson.D{ + {"getMore", int64(112233)}, + {"collection", collection.Name()}, + {"maxTimeMS", "string"}, + }, + err: &mongo.CommandError{ + Code: 14, + Name: "TypeMismatch", + Message: "BSON field 'getMore.maxTimeMS' is the wrong type 'string', expected types '[long, int, decimal, double']", + }, + altMessage: "BSON field 'getMore.maxTimeMS' is the wrong type 'string', expected types '[long, int, decimal, double]'", + }, + "Array": { + command: bson.D{ + {"getMore", int64(112233)}, + {"collection", collection.Name()}, + {"maxTimeMS", bson.A{int32(42), "foo", nil}}, + }, + err: &mongo.CommandError{ + Code: 14, + Name: "TypeMismatch", + Message: "BSON field 'getMore.maxTimeMS' is the wrong type 'array', expected types '[long, int, decimal, double']", + }, + altMessage: "BSON field 'getMore.maxTimeMS' is the wrong type 'array', expected types '[long, int, decimal, double]'", + }, + "Document": { + command: bson.D{ + {"getMore", int64(112233)}, + {"collection", collection.Name()}, + {"maxTimeMS", bson.D{{"foo", int32(42)}}}, + }, + err: &mongo.CommandError{ + Code: 14, + Name: "TypeMismatch", + Message: "BSON field 'getMore.maxTimeMS' is the wrong type 'object', expected types '[long, int, decimal, double']", + }, + altMessage: "BSON field 'getMore.maxTimeMS' is the wrong type 'object', expected types '[long, int, decimal, double]'", + }, + } { + name, tc := name, tc + t.Run(name, func(t *testing.T) { + if tc.skip != "" { + t.Skip(tc.skip) + } + + t.Parallel() + + require.NotNil(t, tc.err, "err must not be nil") + + var res bson.D + err := collection.Database().RunCommand(ctx, tc.command).Decode(&res) + AssertEqualAltCommandError(t, *tc.err, tc.altMessage, err) + require.Nil(t, res) + }) + } +} + +func TestGetMoreCommandMaxTimeMSCursor(t *testing.T) { + // do not run tests in parallel to for server execution time to use maximum possible maxTimeMS + + // options are applied to create a client that uses single connection pool + s := setup.SetupWithOpts(t, &setup.SetupOpts{ + ExtraOptions: url.Values{ + "minPoolSize": []string{"1"}, + "maxPoolSize": []string{"1"}, + "maxIdleTimeMS": []string{"0"}, + }, + Providers: []shareddata.Provider{shareddata.Composites}, + }) + + ctx, collection := s.Ctx, s.Collection + + // need large amount of documents for time out to trigger + arr, _ := generateDocuments(0, 5000) + + _, err := collection.InsertMany(ctx, arr) + require.NoError(t, err) + + t.Run("FindExpire", func(tt *testing.T) { + t := setup.FailsForFerretDB(tt, "https://github.com/FerretDB/FerretDB/issues/1808") + + opts := options.Find(). + // set batchSize big enough to hit maxTimeMS + SetBatchSize(2000). + // set maxTimeMS small enough for find to expire + SetMaxTime(1). + // set sort to slow down the query more than 1ms + SetSort(bson.D{{"v", 1}}) + + _, err := collection.Find(ctx, bson.D{}, opts) + + // MongoDB returns Message or altMessage + AssertEqualAltCommandError( + t, + mongo.CommandError{ + Code: 50, + Name: "MaxTimeMSExpired", + Message: "Executor error during find command :: caused by :: operation exceeded time limit", + }, + "operation exceeded time limit", + err, + ) + }) + + t.Run("FindGetMorePropagateMaxTimeMS", func(t *testing.T) { + // this test case is not stable and frequently fails because + // `Find` unexpectedly timeout or `cursor.Next()` does not timeout expectedly + t.Skip("https://github.com/FerretDB/FerretDB/issues/2983") + + opts := options.Find(). + // setting zero on find sets nextBatch on getMore to unlimited + SetBatchSize(0). + // maxTimeMS is 1 but it won't expire because of zero BatchSize + SetMaxTime(1) + + cursor, err := collection.Find(ctx, bson.D{}, opts) + require.NoError(t, err) + + cursor.SetBatchSize(50000) + + // getMore uses maxTimeMS set on find + ok := cursor.Next(ctx) + assert.False(t, ok) + + // MongoDB returns Message or altMessage + AssertEqualAltCommandError( + t, + mongo.CommandError{ + Code: 50, + Name: "MaxTimeMSExpired", + Message: "Executor error during getMore :: caused by :: operation exceeded time limit", + }, + "operation exceeded time limit", + cursor.Err(), + ) + }) + + t.Run("FindGetMoreMaxTimeMS", func(tt *testing.T) { + t := setup.FailsForFerretDB(tt, "https://github.com/FerretDB/FerretDB/issues/1808") + + var res bson.D + err := collection.Database().RunCommand(ctx, bson.D{ + {"find", collection.Name()}, + {"batchSize", 0}, + }).Decode(&res) + require.NoError(t, err) + + doc := ConvertDocument(t, res) + + v, _ := doc.Get("cursor") + require.NotNil(t, v) + + cursor, ok := v.(*types.Document) + require.True(t, ok) + + cursorID, _ := cursor.Get("id") + require.NotZero(t, cursorID) + + err = collection.Database().RunCommand(ctx, bson.D{ + {"getMore", cursorID}, + {"collection", collection.Name()}, + {"batchSize", 2000}, + {"maxTimeMS", 1}, + }).Decode(&res) + + AssertEqualAltCommandError( + t, + mongo.CommandError{ + Code: 2, + Name: "BadValue", + Message: "cannot set maxTimeMS on getMore command for a non-awaitData cursor", + }, + "", + err, + ) + }) + + t.Run("AggregateExpire", func(tt *testing.T) { + t := setup.FailsForFerretDB(tt, "https://github.com/FerretDB/FerretDB/issues/1808") + + opts := options.Aggregate(). + // set batchSize big enough to hit maxTimeMS + SetBatchSize(2000). + // set maxTimeMS small enough for aggregate to expire + SetMaxTime(1) + + // use $sort stage to slow down the query more than 1ms + _, err := collection.Aggregate(ctx, bson.A{bson.D{{"$sort", bson.D{{"v", 1}}}}}, opts) + + // MongoDB returns Message or altMessage + AssertEqualAltCommandError( + t, + mongo.CommandError{ + Code: 50, + Name: "MaxTimeMSExpired", + Message: "PlanExecutor error during aggregation :: caused by :: operation exceeded time limit", + }, + "operation exceeded time limit", + err, + ) + }) + + t.Run("AggregateGetMorePropagateMaxTimeMS", func(t *testing.T) { + // this test case is not stable and frequently fails because + // `Aggregate` unexpectedly timeout or `cursor.Next()` does not timeout expectedly + t.Skip("https://github.com/FerretDB/FerretDB/issues/2983") + + opts := options.Aggregate(). + // setting zero on aggregate sets nextBatch on getMore to unlimited + SetBatchSize(0). + // maxTimeMS is 1 but it won't expire on aggregate because of zero BatchSize + SetMaxTime(1) + + cursor, err := collection.Aggregate(ctx, bson.A{}, opts) + require.NoError(t, err) + + cursor.SetBatchSize(50000) + + // getMore uses maxTimeMS set on aggregate + ok := cursor.Next(ctx) + assert.False(t, ok) + + // MongoDB returns Message or altMessage + AssertEqualAltCommandError( + t, + mongo.CommandError{ + Code: 50, + Name: "MaxTimeMSExpired", + Message: "Executor error during getMore :: caused by :: operation exceeded time limit", + }, + "operation exceeded time limit", + cursor.Err(), + ) + }) + + t.Run("AggregateGetMoreMaxTimeMS", func(tt *testing.T) { + t := setup.FailsForFerretDB(tt, "https://github.com/FerretDB/FerretDB/issues/1808") + + var res bson.D + err := collection.Database().RunCommand(ctx, bson.D{ + {"aggregate", collection.Name()}, + {"pipeline", bson.A{}}, + {"cursor", bson.D{{"batchSize", 0}}}, + }).Decode(&res) + require.NoError(t, err) + + doc := ConvertDocument(t, res) + + v, _ := doc.Get("cursor") + require.NotNil(t, v) + + cursor, ok := v.(*types.Document) + require.True(t, ok) + + cursorID, _ := cursor.Get("id") + require.NotZero(t, cursorID) + + err = collection.Database().RunCommand(ctx, bson.D{ + {"getMore", cursorID}, + {"collection", collection.Name()}, + {"batchSize", 2000}, + {"maxTimeMS", 1}, + }).Decode(&res) + + AssertEqualAltCommandError( + t, + mongo.CommandError{ + Code: 2, + Name: "BadValue", + Message: "cannot set maxTimeMS on getMore command for a non-awaitData cursor", + }, + "", + err, + ) + }) +} diff --git a/internal/handlers/common/getmore.go b/internal/handlers/common/getmore.go index d2e8c99925c1..54da14e643b9 100644 --- a/internal/handlers/common/getmore.go +++ b/internal/handlers/common/getmore.go @@ -16,7 +16,9 @@ package common import ( "context" + "errors" "fmt" + "math" "github.com/FerretDB/FerretDB/internal/clientconn/conninfo" "github.com/FerretDB/FerretDB/internal/clientconn/cursor" @@ -81,7 +83,58 @@ func GetMore(ctx context.Context, msg *wire.OpMsg, registry *cursor.Registry) (* } // TODO maxTimeMS https://github.com/FerretDB/FerretDB/issues/1808 - // TODO comment + v, _ = document.Get("maxTimeMS") + if v == nil { + v = int64(0) + } + + // cannot use other existing commonparams function, they return different error codes + maxTimeMS, err := commonparams.GetWholeNumberParam(v) + if err != nil { + switch { + case errors.Is(err, commonparams.ErrUnexpectedType): + if _, ok = v.(types.NullType); ok { + return nil, commonerrors.NewCommandErrorMsgWithArgument( + commonerrors.ErrBadValue, + "maxTimeMS must be a number", + document.Command(), + ) + } + + return nil, commonerrors.NewCommandErrorMsgWithArgument( + commonerrors.ErrTypeMismatch, + fmt.Sprintf( + `BSON field 'getMore.maxTimeMS' is the wrong type '%s', expected types '[long, int, decimal, double]'`, + commonparams.AliasFromType(v), + ), + document.Command(), + ) + case errors.Is(err, commonparams.ErrNotWholeNumber): + return nil, commonerrors.NewCommandErrorMsgWithArgument( + commonerrors.ErrBadValue, + "maxTimeMS has non-integral value", + document.Command(), + ) + case errors.Is(err, commonparams.ErrLongExceededPositive) || errors.Is(err, commonparams.ErrLongExceededNegative): + return nil, commonerrors.NewCommandErrorMsgWithArgument( + commonerrors.ErrBadValue, + fmt.Sprintf("%s value for maxTimeMS is out of range", types.FormatAnyValue(v)), + document.Command(), + ) + default: + return nil, lazyerrors.Error(err) + } + } + + if maxTimeMS < int64(0) || maxTimeMS > math.MaxInt32 { + return nil, commonerrors.NewCommandErrorMsgWithArgument( + commonerrors.ErrBadValue, + fmt.Sprintf("%v value for maxTimeMS is out of range", v), + document.Command(), + ) + } + + // TODO comment https://github.com/FerretDB/FerretDB/issues/2986 username, _ := conninfo.Get(ctx).Auth() diff --git a/internal/handlers/pg/msg_aggregate.go b/internal/handlers/pg/msg_aggregate.go index 04235f863198..63809ec73b00 100644 --- a/internal/handlers/pg/msg_aggregate.go +++ b/internal/handlers/pg/msg_aggregate.go @@ -16,7 +16,9 @@ package pg import ( "context" + "errors" "fmt" + "math" "os" "time" @@ -91,10 +93,64 @@ func (h *Handler) MsgAggregate(ctx context.Context, msg *wire.OpMsg) (*wire.OpMs v = int64(0) } - maxTimeMS, err := commonparams.GetValidatedNumberParamWithMinValue(document.Command(), "maxTimeMS", v, 0) + // cannot use other existing commonparams function, they return different error codes + maxTimeMS, err := commonparams.GetWholeNumberParam(v) if err != nil { - // unreachable for MongoDB GO driver, it validates maxTimeMS parameter - return nil, lazyerrors.Error(err) + switch { + case errors.Is(err, commonparams.ErrUnexpectedType): + if _, ok = v.(types.NullType); ok { + return nil, commonerrors.NewCommandErrorMsgWithArgument( + commonerrors.ErrBadValue, + "maxTimeMS must be a number", + document.Command(), + ) + } + + return nil, commonerrors.NewCommandErrorMsgWithArgument( + commonerrors.ErrTypeMismatch, + fmt.Sprintf( + `BSON field 'aggregate.maxTimeMS' is the wrong type '%s', expected types '[long, int, decimal, double]'`, + commonparams.AliasFromType(v), + ), + document.Command(), + ) + case errors.Is(err, commonparams.ErrNotWholeNumber): + return nil, commonerrors.NewCommandErrorMsgWithArgument( + commonerrors.ErrBadValue, + "maxTimeMS has non-integral value", + document.Command(), + ) + case errors.Is(err, commonparams.ErrLongExceededPositive): + return nil, commonerrors.NewCommandErrorMsgWithArgument( + commonerrors.ErrBadValue, + fmt.Sprintf("%s value for maxTimeMS is out of range", types.FormatAnyValue(v)), + document.Command(), + ) + case errors.Is(err, commonparams.ErrLongExceededNegative): + return nil, commonerrors.NewCommandErrorMsgWithArgument( + commonerrors.ErrValueNegative, + fmt.Sprintf("BSON field 'maxTimeMS' value must be >= 0, actual value '%s'", types.FormatAnyValue(v)), + document.Command(), + ) + default: + return nil, lazyerrors.Error(err) + } + } + + if maxTimeMS < int64(0) { + return nil, commonerrors.NewCommandErrorMsgWithArgument( + commonerrors.ErrValueNegative, + fmt.Sprintf("BSON field 'maxTimeMS' value must be >= 0, actual value '%s'", types.FormatAnyValue(v)), + document.Command(), + ) + } + + if maxTimeMS > math.MaxInt32 { + return nil, commonerrors.NewCommandErrorMsgWithArgument( + commonerrors.ErrBadValue, + fmt.Sprintf("%v value for maxTimeMS is out of range", v), + document.Command(), + ) } pipeline, err := common.GetRequiredParam[*types.Array](document, "pipeline") diff --git a/internal/handlers/pg/msg_find.go b/internal/handlers/pg/msg_find.go index 0cd27c30e850..7a2e86c03a9f 100644 --- a/internal/handlers/pg/msg_find.go +++ b/internal/handlers/pg/msg_find.go @@ -159,6 +159,7 @@ func (h *Handler) MsgFind(ctx context.Context, msg *wire.OpMsg) (*wire.OpMsg, er } if params.SingleBatch || firstBatch.Len() < int(params.BatchSize) { + // TODO: support tailable cursor https://github.com/FerretDB/FerretDB/issues/2963 // let the client know that there are no more results cursorID = 0