diff --git a/integration/aggregate_compat_test.go b/integration/aggregate_documents_compat_test.go similarity index 100% rename from integration/aggregate_compat_test.go rename to integration/aggregate_documents_compat_test.go diff --git a/integration/aggregate_stats_compat_test.go b/integration/aggregate_stats_compat_test.go new file mode 100644 index 000000000000..98000eb901fa --- /dev/null +++ b/integration/aggregate_stats_compat_test.go @@ -0,0 +1,133 @@ +// Copyright 2021 FerretDB Inc. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package integration + +import ( + "testing" + + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" + "go.mongodb.org/mongo-driver/bson" + + "github.com/FerretDB/FerretDB/integration/setup" + "github.com/FerretDB/FerretDB/integration/shareddata" +) + +func TestAggregateCompatCollStats(t *testing.T) { + t.Parallel() + + for name, tc := range map[string]struct { + skip string // skip test for all handlers, must have issue number mentioned + collStats bson.D // required + resultType compatTestCaseResultType // defaults to nonEmptyResult + }{ + "NilCollStats": { + collStats: nil, + resultType: emptyResult, + }, + "EmptyCollStats": { + collStats: bson.D{}, + }, + "Count": { + collStats: bson.D{{"count", bson.D{}}}, + }, + "StorageStats": { + collStats: bson.D{{"storageStats", bson.D{}}}, + }, + "StorageStatsWithScale": { + collStats: bson.D{{"storageStats", bson.D{{"scale", 1000}}}}, + }, + "CountAndStorageStats": { + collStats: bson.D{{"count", bson.D{}}, {"storageStats", bson.D{}}}, + }, + } { + name, tc := name, tc + t.Run(name, func(t *testing.T) { + if tc.skip != "" { + t.Skip(tc.skip) + } + + t.Helper() + t.Parallel() + + // It's enough to use a couple of providers: one for some collection and one for a non-existent collection. + s := setup.SetupCompatWithOpts(t, &setup.SetupCompatOpts{ + Providers: []shareddata.Provider{shareddata.ArrayDocuments}, + AddNonExistentCollection: true, + }) + ctx, targetCollections, compatCollections := s.Ctx, s.TargetCollections, s.CompatCollections + + var nonEmptyResults bool + for i := range targetCollections { + targetCollection := targetCollections[i] + compatCollection := compatCollections[i] + + t.Run(targetCollection.Name(), func(t *testing.T) { + t.Helper() + + command := bson.A{bson.D{{"$collStats", tc.collStats}}} + + targetCursor, targetErr := targetCollection.Aggregate(ctx, command) + compatCursor, compatErr := compatCollection.Aggregate(ctx, command) + + if targetCursor != nil { + defer targetCursor.Close(ctx) + } + if compatCursor != nil { + defer compatCursor.Close(ctx) + } + + if targetErr != nil { + t.Logf("Target error: %v", targetErr) + t.Logf("Compat error: %v", compatErr) + AssertMatchesCommandError(t, compatErr, targetErr) + + return + } + require.NoError(t, compatErr, "compat error; target returned no error") + + var targetRes, compatRes []bson.D + require.NoError(t, targetCursor.All(ctx, &targetRes)) + require.NoError(t, compatCursor.All(ctx, &compatRes)) + + // $collStats returns one document per shard. + require.Equal(t, 1, len(compatRes)) + require.Equal(t, 1, len(targetRes)) + + // Check the keys are the same + targetKeys := CollectKeys(t, targetRes[0]) + compatKeys := CollectKeys(t, compatRes[0]) + + require.Equal(t, compatKeys, targetKeys) + + if len(targetRes) > 0 || len(compatRes) > 0 { + nonEmptyResults = true + } + + // TODO Check the returned values when possible: https://github.com/FerretDB/FerretDB/issues/2349 + }) + } + + switch tc.resultType { + case nonEmptyResult: + assert.True(t, nonEmptyResults, "expected non-empty results (some documents should be modified)") + case emptyResult: + assert.False(t, nonEmptyResults, "expected empty results (no documents should be modified)") + default: + t.Fatalf("unknown result type %v", tc.resultType) + } + }) + } +} diff --git a/internal/handlers/common/aggregations/collstats.go b/internal/handlers/common/aggregations/collstats.go new file mode 100644 index 000000000000..ca4fb21371dc --- /dev/null +++ b/internal/handlers/common/aggregations/collstats.go @@ -0,0 +1,116 @@ +// Copyright 2021 FerretDB Inc. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package aggregations + +import ( + "context" + "fmt" + + "github.com/FerretDB/FerretDB/internal/handlers/common" + "github.com/FerretDB/FerretDB/internal/handlers/commonerrors" + "github.com/FerretDB/FerretDB/internal/types" + "github.com/FerretDB/FerretDB/internal/util/must" +) + +// collStatsStage represents $collStats stage. +type collStatsStage struct { + storageStats *storageStats + count bool + latencyStats bool + queryExecStats bool +} + +// storageStats represents $collStats.storageStats field. +type storageStats struct { + scale int32 +} + +// newCollStats creates a new $collStats stage. +func newCollStats(stage *types.Document) (Stage, error) { + fields, err := common.GetRequiredParam[*types.Document](stage, "$collStats") + if err != nil { + return nil, commonerrors.NewCommandErrorMsgWithArgument( + commonerrors.ErrStageCollStatsInvalidArg, + fmt.Sprintf("$collStats must take a nested object but found: %s", types.FormatAnyValue(stage)), + "$collStats (stage)", + ) + } + + var cs collStatsStage + + // TODO Return error on invalid type of count: https://github.com/FerretDB/FerretDB/issues/2336 + cs.count = fields.Has("count") + + // TODO Implement latencyStats: https://github.com/FerretDB/FerretDB/issues/2341 + cs.latencyStats = fields.Has("latencyStats") + + // TODO Implement queryExecStats: https://github.com/FerretDB/FerretDB/issues/2341 + cs.queryExecStats = fields.Has("queryExecStats") + + if fields.Has("storageStats") { + cs.storageStats = new(storageStats) + + // TODO Add proper support for scale: https://github.com/FerretDB/FerretDB/issues/1346 + cs.storageStats.scale, err = common.GetOptionalPositiveNumber( + must.NotFail(fields.Get("storageStats")).(*types.Document), + "scale", + ) + if err != nil || cs.storageStats.scale == 0 { + cs.storageStats.scale = 1 + } + } + + return &cs, nil +} + +// Process implements Stage interface. +// +// Processing consists of modification of the input document, so it contains all the necessary fields +// and the data is modified according to the given request. +func (c *collStatsStage) Process(ctx context.Context, in []*types.Document) ([]*types.Document, error) { + // For non-shared collections, the input must be an array with a single document. + if len(in) != 1 { + panic(fmt.Sprintf("collStatsStage: Process: expected 1 document, got %d", len(in))) + } + + res := in[0] + + if c.storageStats != nil { + scale := c.storageStats.scale + + if c.storageStats.scale > 1 { + scalable := []string{"size", "avgObjSize", "storageSize", "freeStorageSize", "totalIndexSize"} + for _, key := range scalable { + path := types.NewStaticPath("storageStats", key) + val := must.NotFail(res.GetByPath(path)) + must.NoError(res.SetByPath(path, val.(int32)/scale)) + } + } + + must.NoError(res.SetByPath(types.NewStaticPath("storageStats", "scaleFactor"), scale)) + } + + return []*types.Document{res}, nil +} + +// Type implements Stage interface. +func (c *collStatsStage) Type() StageType { + return StageTypeStats +} + +// check interfaces +var ( + _ Stage = (*collStatsStage)(nil) +) diff --git a/internal/handlers/common/aggregations/count.go b/internal/handlers/common/aggregations/count.go index 231562825075..bcf4f905e88f 100644 --- a/internal/handlers/common/aggregations/count.go +++ b/internal/handlers/common/aggregations/count.go @@ -24,8 +24,8 @@ import ( "github.com/FerretDB/FerretDB/internal/util/must" ) -// count represents $count stage. -type count struct { +// countStage represents $count stage. +type countStage struct { field string } @@ -72,13 +72,13 @@ func newCount(stage *types.Document) (Stage, error) { ) } - return &count{ + return &countStage{ field: field, }, nil } // Process implements Stage interface. -func (c *count) Process(ctx context.Context, in []*types.Document) ([]*types.Document, error) { +func (c *countStage) Process(ctx context.Context, in []*types.Document) ([]*types.Document, error) { if len(in) == 0 { return nil, nil } @@ -88,7 +88,12 @@ func (c *count) Process(ctx context.Context, in []*types.Document) ([]*types.Doc return []*types.Document{res}, nil } +// Type implements Stage interface. +func (c *countStage) Type() StageType { + return StageTypeDocuments +} + // check interfaces var ( - _ Stage = (*count)(nil) + _ Stage = (*countStage)(nil) ) diff --git a/internal/handlers/common/aggregations/group.go b/internal/handlers/common/aggregations/group.go index 353d8e9aeb01..feb3080ac487 100644 --- a/internal/handlers/common/aggregations/group.go +++ b/internal/handlers/common/aggregations/group.go @@ -281,6 +281,11 @@ func (m *groupMap) addOrAppend(groupKey any, docs ...*types.Document) { }) } +// Type implements Stage interface. +func (g *groupStage) Type() StageType { + return StageTypeDocuments +} + // check interfaces var ( _ Stage = (*groupStage)(nil) diff --git a/internal/handlers/common/aggregations/limit.go b/internal/handlers/common/aggregations/limit.go index 5489472c3ff7..a9922c48f032 100644 --- a/internal/handlers/common/aggregations/limit.go +++ b/internal/handlers/common/aggregations/limit.go @@ -53,3 +53,13 @@ func (l *limit) Process(ctx context.Context, in []*types.Document) ([]*types.Doc return doc, nil } + +// Type implements Stage interface. +func (l *limit) Type() StageType { + return StageTypeDocuments +} + +// check interfaces +var ( + _ Stage = (*limit)(nil) +) diff --git a/internal/handlers/common/aggregations/match.go b/internal/handlers/common/aggregations/match.go index 21685a19824c..3e084fe8d103 100644 --- a/internal/handlers/common/aggregations/match.go +++ b/internal/handlers/common/aggregations/match.go @@ -61,6 +61,11 @@ func (m *match) Process(ctx context.Context, in []*types.Document) ([]*types.Doc return res, nil } +// Type implements Stage interface. +func (m *match) Type() StageType { + return StageTypeDocuments +} + // check interfaces var ( _ Stage = (*match)(nil) diff --git a/internal/handlers/common/aggregations/skip.go b/internal/handlers/common/aggregations/skip.go index c947994d35ea..5334134de6a5 100644 --- a/internal/handlers/common/aggregations/skip.go +++ b/internal/handlers/common/aggregations/skip.go @@ -45,8 +45,13 @@ func newSkip(stage *types.Document) (Stage, error) { } // Process implements Stage interface. -func (m *skip) Process(_ context.Context, in []*types.Document) ([]*types.Document, error) { - return common.SkipDocuments(in, m.value) +func (s *skip) Process(_ context.Context, in []*types.Document) ([]*types.Document, error) { + return common.SkipDocuments(in, s.value) +} + +// Type implements Stage interface. +func (s *skip) Type() StageType { + return StageTypeDocuments } // check interfaces diff --git a/internal/handlers/common/aggregations/sort.go b/internal/handlers/common/aggregations/sort.go index 27c81c6bfcd9..16d7979d95bc 100644 --- a/internal/handlers/common/aggregations/sort.go +++ b/internal/handlers/common/aggregations/sort.go @@ -58,8 +58,8 @@ func newSort(stage *types.Document) (Stage, error) { // Process implements Stage interface. // // If sort path is invalid, it returns a possibly wrapped types.DocumentPathError. -func (m *sort) Process(ctx context.Context, in []*types.Document) ([]*types.Document, error) { - if err := common.SortDocuments(in, m.fields); err != nil { +func (s *sort) Process(ctx context.Context, in []*types.Document) ([]*types.Document, error) { + if err := common.SortDocuments(in, s.fields); err != nil { var pathErr *types.DocumentPathError if errors.As(err, &pathErr) && pathErr.Code() == types.ErrDocumentPathEmptyKey { return nil, commonerrors.NewCommandErrorMsgWithArgument( @@ -75,6 +75,11 @@ func (m *sort) Process(ctx context.Context, in []*types.Document) ([]*types.Docu return in, nil } +// Type implements Stage interface. +func (s *sort) Type() StageType { + return StageTypeDocuments +} + // check interfaces var ( _ Stage = (*sort)(nil) diff --git a/internal/handlers/common/aggregations/stage.go b/internal/handlers/common/aggregations/stage.go index a8e7ae7b076a..f8fbd36bdf9b 100644 --- a/internal/handlers/common/aggregations/stage.go +++ b/internal/handlers/common/aggregations/stage.go @@ -25,24 +25,39 @@ import ( // newStageFunc is a type for a function that creates a new aggregation stage. type newStageFunc func(stage *types.Document) (Stage, error) +// StageType is a type for aggregation stage types. +type StageType int + +const ( + // StageTypeDocuments is a type for stages that process documents. + StageTypeDocuments StageType = iota + + // StageTypeStats is a type for stages that process statistics and doesn't need documents. + StageTypeStats +) + // Stage is a common interface for all aggregation stages. // TODO use iterators instead of slices of documents // https://github.com/FerretDB/FerretDB/issues/1889. type Stage interface { // Process applies an aggregate stage on `in` document, it could modify `in` in-place. Process(ctx context.Context, in []*types.Document) ([]*types.Document, error) + + // Type returns the type of the stage. + Type() StageType } // stages maps all supported aggregation stages. var stages = map[string]newStageFunc{ // sorted alphabetically - "$count": newCount, - "$group": newGroup, - "$limit": newLimit, - "$match": newMatch, - "$skip": newSkip, - "$sort": newSort, - "$unwind": newUnwind, + "$collStats": newCollStats, + "$count": newCount, + "$group": newGroup, + "$limit": newLimit, + "$match": newMatch, + "$skip": newSkip, + "$sort": newSort, + "$unwind": newUnwind, // please keep sorted alphabetically } @@ -52,7 +67,6 @@ var unsupportedStages = map[string]struct{}{ "$bucket": {}, "$bucketAuto": {}, "$changeStream": {}, - "$collStats": {}, "$currentOp": {}, "$densify": {}, "$documents": {}, diff --git a/internal/handlers/common/aggregations/statistics.go b/internal/handlers/common/aggregations/statistics.go new file mode 100644 index 000000000000..01720bb436d0 --- /dev/null +++ b/internal/handlers/common/aggregations/statistics.go @@ -0,0 +1,55 @@ +// Copyright 2021 FerretDB Inc. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package aggregations + +// Statistic represents a statistic that can be fetched from the DB. +type Statistic int32 + +// List of statistics that can be fetched from the DB. +const ( + StatisticCount Statistic = iota + StatisticLatency + StatisticQueryExec + StatisticStorage +) + +// GetStatistics has the same idea as GetPushdownQuery: it returns a list of statistics that need +// to be fetched from the DB, because they are needed for one or more stages. +func GetStatistics(stages []Stage) map[Statistic]struct{} { + stats := make(map[Statistic]struct{}, len(stages)) + + for _, stage := range stages { + switch st := stage.(type) { + case *collStatsStage: + if st.count { + stats[StatisticCount] = struct{}{} + } + + if st.latencyStats { + stats[StatisticLatency] = struct{}{} + } + + if st.queryExecStats { + stats[StatisticQueryExec] = struct{}{} + } + + if st.storageStats != nil { + stats[StatisticStorage] = struct{}{} + } + } + } + + return stats +} diff --git a/internal/handlers/common/aggregations/unwind.go b/internal/handlers/common/aggregations/unwind.go index 6ae9951f0ec9..89d0b688fb90 100644 --- a/internal/handlers/common/aggregations/unwind.go +++ b/internal/handlers/common/aggregations/unwind.go @@ -104,17 +104,17 @@ func newUnwind(stage *types.Document) (Stage, error) { } // Process implements Stage interface. -func (m *unwind) Process(ctx context.Context, in []*types.Document) ([]*types.Document, error) { +func (u *unwind) Process(ctx context.Context, in []*types.Document) ([]*types.Document, error) { var out []*types.Document - if m.field == nil { + if u.field == nil { return nil, nil } - key := m.field.GetExpressionSuffix() + key := u.field.GetExpressionSuffix() for _, doc := range in { - d := m.field.Evaluate(doc) + d := u.field.Evaluate(doc) switch d := d.(type) { case *types.Array: iter := d.Iterator() @@ -144,3 +144,13 @@ func (m *unwind) Process(ctx context.Context, in []*types.Document) ([]*types.Do return out, nil } + +// Type implements Stage interface. +func (u *unwind) Type() StageType { + return StageTypeDocuments +} + +// check interfaces +var ( + _ Stage = (*unwind)(nil) +) diff --git a/internal/handlers/commonerrors/error.go b/internal/handlers/commonerrors/error.go index e728ff8cee81..aea4022f1b18 100644 --- a/internal/handlers/commonerrors/error.go +++ b/internal/handlers/commonerrors/error.go @@ -196,6 +196,9 @@ const ( // ErrFailedToParseInput indicates invalid input (absent or malformed fields). ErrFailedToParseInput = ErrorCode(40415) // Location40415 + // ErrCollStatsIsNotFirstStage indicates that $collStats must be the first stage in the pipeline. + ErrCollStatsIsNotFirstStage = ErrorCode(40415) // Location40602 + // ErrFreeMonitoringDisabled indicates that free monitoring is disabled // by command-line or config file. ErrFreeMonitoringDisabled = ErrorCode(50840) // Location50840 @@ -220,6 +223,9 @@ const ( // ErrStageLimitInvalidArg indicates invalid argument for the aggregation $limit stage. ErrStageLimitInvalidArg = ErrorCode(5107201) // Location5107201 + + // ErrStageCollStatsInvalidArg indicates invalid argument for the aggregation $collStats stage. + ErrStageCollStatsInvalidArg = ErrorCode(5447000) // Location5447000 ) // ErrInfo represents additional optional error information. diff --git a/internal/handlers/commonerrors/errorcode_string.go b/internal/handlers/commonerrors/errorcode_string.go index 3b2cfd47eb67..fcbae36492d3 100644 --- a/internal/handlers/commonerrors/errorcode_string.go +++ b/internal/handlers/commonerrors/errorcode_string.go @@ -63,6 +63,7 @@ func _() { _ = x[ErrEmptyFieldPath-40352] _ = x[ErrMissingField-40414] _ = x[ErrFailedToParseInput-40415] + _ = x[ErrCollStatsIsNotFirstStage-40415] _ = x[ErrFreeMonitoringDisabled-50840] _ = x[ErrValueNegative-51024] _ = x[ErrRegexOptions-51075] @@ -71,9 +72,10 @@ func _() { _ = x[ErrDuplicateField-4822819] _ = x[ErrStageSkipBadValue-5107200] _ = x[ErrStageLimitInvalidArg-5107201] + _ = x[ErrStageCollStatsInvalidArg-5447000] } -const _ErrorCode_name = "UnsetInternalErrorBadValueFailedToParseTypeMismatchNamespaceNotFoundIndexNotFoundUnsuitableValueTypeConflictingUpdateOperatorsCursorNotFoundNamespaceExistsInvalidIDEmptyNameCommandNotFoundCannotCreateIndexInvalidOptionsInvalidNamespaceIndexOptionsConflictIndexKeySpecsConflictOperationFailedDocumentValidationFailureNotImplementedMechanismUnavailableLocation11000Location15947Location15948Location15955Location15958Location15959Location15973Location15974Location15975Location15976Location15981Location15998Location16410Location16872Location17276Location28667Location28724Location28812Location28818Location31253Location31254Location40156Location40157Location40158Location40160Location40234Location40237Location40238Location40323Location40352Location40414Location40415Location50840Location51024Location51075Location51091Location51108Location4822819Location5107200Location5107201" +const _ErrorCode_name = "UnsetInternalErrorBadValueFailedToParseTypeMismatchNamespaceNotFoundIndexNotFoundUnsuitableValueTypeConflictingUpdateOperatorsCursorNotFoundNamespaceExistsInvalidIDEmptyNameCommandNotFoundCannotCreateIndexInvalidOptionsInvalidNamespaceIndexOptionsConflictIndexKeySpecsConflictOperationFailedDocumentValidationFailureNotImplementedMechanismUnavailableLocation11000Location15947Location15948Location15955Location15958Location15959Location15973Location15974Location15975Location15976Location15981Location15998Location16410Location16872Location17276Location28667Location28724Location28812Location28818Location31253Location31254Location40156Location40157Location40158Location40160Location40234Location40237Location40238Location40323Location40352Location40414Location40415Location50840Location51024Location51075Location51091Location51108Location4822819Location5107200Location5107201Location5447000" var _ErrorCode_map = map[ErrorCode]string{ 0: _ErrorCode_name[0:5], @@ -139,6 +141,7 @@ var _ErrorCode_map = map[ErrorCode]string{ 4822819: _ErrorCode_name[831:846], 5107200: _ErrorCode_name[846:861], 5107201: _ErrorCode_name[861:876], + 5447000: _ErrorCode_name[876:891], } func (i ErrorCode) String() string { diff --git a/internal/handlers/pg/msg_aggregate.go b/internal/handlers/pg/msg_aggregate.go index b9fb54a699ac..9a3692ec9ecf 100644 --- a/internal/handlers/pg/msg_aggregate.go +++ b/internal/handlers/pg/msg_aggregate.go @@ -16,6 +16,10 @@ package pg import ( "context" + "errors" + "fmt" + "os" + "time" "github.com/jackc/pgx/v4" @@ -54,13 +58,13 @@ func (h *Handler) MsgAggregate(ctx context.Context, msg *wire.OpMsg) (*wire.OpMs "allowDiskUse", "maxTimeMS", "bypassDocumentValidation", "readConcern", "hint", "comment", "writeConcern", ) - var qp pgdb.QueryParams + var db string - if qp.DB, err = common.GetRequiredParam[string](document, "$db"); err != nil { + if db, err = common.GetRequiredParam[string](document, "$db"); err != nil { return nil, err } - collection, err := document.Get(document.Command()) + collectionParam, err := document.Get(document.Command()) if err != nil { return nil, err } @@ -68,7 +72,9 @@ func (h *Handler) MsgAggregate(ctx context.Context, msg *wire.OpMsg) (*wire.OpMs // TODO handle collection-agnostic pipelines ({aggregate: 1}) // https://github.com/FerretDB/FerretDB/issues/1890 var ok bool - if qp.Collection, ok = collection.(string); !ok { + var collection string + + if collection, ok = collectionParam.(string); !ok { return nil, commonerrors.NewCommandErrorMsgWithArgument( commonerrors.ErrFailedToParse, "Invalid command format: the 'aggregate' field must specify a collection name or 1", @@ -85,10 +91,11 @@ func (h *Handler) MsgAggregate(ctx context.Context, msg *wire.OpMsg) (*wire.OpMs ) } - stagesDocs := must.NotFail(iterator.ConsumeValues(pipeline.Iterator())) - stages := make([]aggregations.Stage, len(stagesDocs)) + stages := must.NotFail(iterator.ConsumeValues(pipeline.Iterator())) + stagesDocuments := make([]aggregations.Stage, 0, len(stages)) + stagesStats := make([]aggregations.Stage, 0, len(stages)) - for i, d := range stagesDocs { + for i, d := range stages { d, ok := d.(*types.Document) if !ok { return nil, commonerrors.NewCommandErrorMsgWithArgument( @@ -99,39 +106,60 @@ func (h *Handler) MsgAggregate(ctx context.Context, msg *wire.OpMsg) (*wire.OpMs } var s aggregations.Stage + if s, err = aggregations.NewStage(d); err != nil { return nil, err } - stages[i] = s + switch s.Type() { + case aggregations.StageTypeDocuments: + stagesDocuments = append(stagesDocuments, s) + stagesStats = append(stagesStats, s) // It's possible to apply "documents" stages to statistics + case aggregations.StageTypeStats: + if i > 0 { + // TODO Add a test to cover this error: https://github.com/FerretDB/FerretDB/issues/2349 + return nil, commonerrors.NewCommandErrorMsgWithArgument( + commonerrors.ErrCollStatsIsNotFirstStage, + "$collStats is only valid as the first stage in a pipeline", + document.Command(), + ) + } + stagesStats = append(stagesStats, s) + default: + panic(fmt.Sprintf("unknown stage type: %v", s.Type())) + } } - qp.Filter = aggregations.GetPushdownQuery(stagesDocs) + var resDocs []*types.Document - var docs []*types.Document - err = dbPool.InTransaction(ctx, func(tx pgx.Tx) error { - iter, getErr := pgdb.QueryDocuments(ctx, tx, &qp) - if getErr != nil { - return getErr + // At this point we have a list of stages to apply to the documents or stats. + // If stagesStats contains the same stages as stagesDocuments, we apply aggregation to documents fetched from the DB. + // If stagesStats contains more stages than stagesDocuments, we apply aggregation to statistics fetched from the DB. + if len(stagesStats) == len(stagesDocuments) { + // only documents stages or no stages - fetch documents from the DB and apply stages to them + qp := pgdb.QueryParams{ + DB: db, + Collection: collection, + Filter: aggregations.GetPushdownQuery(stages), } - docs, err = iterator.ConsumeValues(iterator.Interface[struct{}, *types.Document](iter)) - return err - }) + resDocs, err = processStagesDocuments(ctx, &stagesDocumentsParams{dbPool, &qp, stagesDocuments}) + } else { + // stats stages are provided - fetch stats from the DB and apply stages to them + statistics := aggregations.GetStatistics(stagesStats) - if err != nil { - return nil, err + resDocs, err = processStagesStats(ctx, &stagesStatsParams{ + dbPool, db, collection, statistics, stagesStats, + }) } - for _, s := range stages { - if docs, err = s.Process(ctx, docs); err != nil { - return nil, err - } + if err != nil { + return nil, err } // TODO https://github.com/FerretDB/FerretDB/issues/1892 - firstBatch := types.MakeArray(len(docs)) - for _, doc := range docs { + firstBatch := types.MakeArray(len(resDocs)) + for _, doc := range resDocs { firstBatch.Append(doc) } @@ -141,7 +169,7 @@ func (h *Handler) MsgAggregate(ctx context.Context, msg *wire.OpMsg) (*wire.OpMs "cursor", must.NotFail(types.NewDocument( "firstBatch", firstBatch, "id", int64(0), - "ns", qp.DB+"."+qp.Collection, + "ns", db+"."+collection, )), "ok", float64(1), ))}, @@ -149,3 +177,128 @@ func (h *Handler) MsgAggregate(ctx context.Context, msg *wire.OpMsg) (*wire.OpMs return &reply, nil } + +// stagesDocumentsParams contains the parameters for processStagesDocuments. +type stagesDocumentsParams struct { + dbPool *pgdb.Pool + qp *pgdb.QueryParams + stages []aggregations.Stage +} + +// processStagesDocuments retrieves the documents from the database and then processes them through the stages. +func processStagesDocuments(ctx context.Context, p *stagesDocumentsParams) ([]*types.Document, error) { //nolint:lll // for readability + var docs []*types.Document + + if err := p.dbPool.InTransaction(ctx, func(tx pgx.Tx) error { + iter, getErr := pgdb.QueryDocuments(ctx, tx, p.qp) + if getErr != nil { + return getErr + } + + var err error + docs, err = iterator.ConsumeValues(iterator.Interface[struct{}, *types.Document](iter)) + return err + }); err != nil { + return nil, err + } + + for _, s := range p.stages { + var err error + if docs, err = s.Process(ctx, docs); err != nil { + return nil, err + } + } + + return docs, nil +} + +// stagesStatsParams contains the parameters for processStagesStats. +type stagesStatsParams struct { + dbPool *pgdb.Pool + db string + collection string + statistics map[aggregations.Statistic]struct{} + stages []aggregations.Stage +} + +// processStagesStats retrieves the statistics from the database and then processes them through the stages. +func processStagesStats(ctx context.Context, p *stagesStatsParams) ([]*types.Document, error) { + // Clarify what needs to be retrieved from the database and retrieve it. + _, hasCount := p.statistics[aggregations.StatisticCount] + _, hasStorage := p.statistics[aggregations.StatisticStorage] + + var host string + var err error + + host, err = os.Hostname() + if err != nil { + return nil, lazyerrors.Error(err) + } + + doc := must.NotFail(types.NewDocument( + "ns", p.db+"."+p.collection, + "host", host, + "localTime", time.Now().UTC().Format(time.RFC3339), + )) + + var dbStats *pgdb.DBStats + + if hasCount || hasStorage { + dbStats, err = p.dbPool.Stats(ctx, p.db, p.collection) + + switch { + case err == nil: + // do nothing + case errors.Is(err, pgdb.ErrTableNotExist): + return nil, commonerrors.NewCommandErrorMsgWithArgument( + commonerrors.ErrNamespaceNotFound, + fmt.Sprintf("ns not found: %s.%s", p.db, p.collection), + "aggregate", + ) + default: + return nil, err + } + } + + if hasStorage { + var avgObjSize int32 + if dbStats.CountRows > 0 { + avgObjSize = int32(dbStats.SizeRelation) / dbStats.CountRows + } + + doc.Set( + "storageStats", must.NotFail(types.NewDocument( + "size", int32(dbStats.SizeTotal), + "count", dbStats.CountRows, + "avgObjSize", avgObjSize, + "storageSize", int32(dbStats.SizeRelation), + "freeStorageSize", int32(0), // TODO https://github.com/FerretDB/FerretDB/issues/2342 + "capped", false, // TODO https://github.com/FerretDB/FerretDB/issues/2342 + "wiredTiger", must.NotFail(types.NewDocument()), // TODO https://github.com/FerretDB/FerretDB/issues/2342 + "nindexes", dbStats.CountIndexes, + "indexDetails", must.NotFail(types.NewDocument()), // TODO https://github.com/FerretDB/FerretDB/issues/2342 + "indexBuilds", must.NotFail(types.NewDocument()), // TODO https://github.com/FerretDB/FerretDB/issues/2342 + "totalIndexSize", int32(dbStats.SizeIndexes), + "totalSize", int32(dbStats.SizeTotal), + "indexSizes", must.NotFail(types.NewDocument()), // TODO https://github.com/FerretDB/FerretDB/issues/2342 + )), + ) + } + + if hasCount { + doc.Set( + "count", dbStats.CountRows, + ) + } + + // Process the retrieved statistics through the stages. + var res []*types.Document + + for _, s := range p.stages { + if res, err = s.Process(ctx, []*types.Document{doc}); err != nil { + return nil, err + } + } + + return res, nil +} diff --git a/internal/handlers/pg/msg_collstats.go b/internal/handlers/pg/msg_collstats.go index fae4e70e358f..6d2d462e802c 100644 --- a/internal/handlers/pg/msg_collstats.go +++ b/internal/handlers/pg/msg_collstats.go @@ -16,8 +16,10 @@ package pg import ( "context" + "errors" "github.com/FerretDB/FerretDB/internal/handlers/common" + "github.com/FerretDB/FerretDB/internal/handlers/pg/pgdb" "github.com/FerretDB/FerretDB/internal/types" "github.com/FerretDB/FerretDB/internal/util/lazyerrors" "github.com/FerretDB/FerretDB/internal/util/must" @@ -57,7 +59,14 @@ func (h *Handler) MsgCollStats(ctx context.Context, msg *wire.OpMsg) (*wire.OpMs } stats, err := dbPool.Stats(ctx, db, collection) - if err != nil { + + switch { + case err == nil: + // do nothing + case errors.Is(err, pgdb.ErrTableNotExist): + // Return empty stats for non-existent collections. + stats = new(pgdb.DBStats) + default: return nil, lazyerrors.Error(err) } @@ -66,10 +75,10 @@ func (h *Handler) MsgCollStats(ctx context.Context, msg *wire.OpMsg) (*wire.OpMs Documents: []*types.Document{must.NotFail(types.NewDocument( "ns", db+"."+collection, "count", stats.CountRows, - "size", float64(stats.SizeTotal/int64(scale)), - "storageSize", float64(stats.SizeRelation/int64(scale)), - "totalIndexSize", float64(stats.SizeIndexes/int64(scale)), - "totalSize", float64(stats.SizeTotal/int64(scale)), + "size", int32(stats.SizeTotal)/scale, + "storageSize", int32(stats.SizeRelation)/scale, + "totalIndexSize", int32(stats.SizeIndexes)/scale, + "totalSize", int32(stats.SizeTotal)/scale, "scaleFactor", scale, "ok", float64(1), ))}, diff --git a/internal/handlers/pg/msg_datasize.go b/internal/handlers/pg/msg_datasize.go index ee5ee51de833..c7872e8744c5 100644 --- a/internal/handlers/pg/msg_datasize.go +++ b/internal/handlers/pg/msg_datasize.go @@ -64,14 +64,16 @@ func (h *Handler) MsgDataSize(ctx context.Context, msg *wire.OpMsg) (*wire.OpMsg elapses := time.Since(started) addEstimate := true - if err != nil { - if !errors.Is(err, pgx.ErrNoRows) { - return nil, lazyerrors.Error(err) - } + switch { + case err == nil, errors.Is(err, pgx.ErrNoRows): + // do nothing + case errors.Is(err, pgdb.ErrTableNotExist): // return zeroes for non-existent collection stats = new(pgdb.DBStats) addEstimate = false + default: + return nil, lazyerrors.Error(err) } var pairs []any diff --git a/internal/handlers/pg/pgdb/pool.go b/internal/handlers/pg/pgdb/pool.go index 434d7e80dd17..136c6fe54d64 100644 --- a/internal/handlers/pg/pgdb/pool.go +++ b/internal/handlers/pg/pgdb/pool.go @@ -16,6 +16,7 @@ package pgdb import ( "context" + "errors" "fmt" "net/url" "strings" @@ -202,6 +203,8 @@ func (pgPool *Pool) checkConnection(ctx context.Context) error { // Stats returns a set of statistics for FerretDB server, database, collection // - or, in terms of PostgreSQL, database, schema, table. +// +// It returns ErrTableNotExist is the given collection does not exist, and ignores other errors. func (pgPool *Pool) Stats(ctx context.Context, db, collection string) (*DBStats, error) { res := &DBStats{ Name: db, @@ -242,7 +245,14 @@ func (pgPool *Pool) Stats(ctx context.Context, db, collection string) (*DBStats, return row.Scan(&res.CountTables, &res.CountRows, &res.SizeTotal, &res.SizeIndexes, &res.SizeRelation, &res.CountIndexes) }) - if err != nil { + + switch { + case err == nil: + // do nothing + case errors.Is(err, ErrTableNotExist): + // return this error as is because it can be handled by the caller + return nil, err + default: // just log it for now // TODO https://github.com/FerretDB/FerretDB/issues/1346 pgPool.p.Config().ConnConfig.Logger.Log( diff --git a/internal/handlers/tigris/msg_aggregate.go b/internal/handlers/tigris/msg_aggregate.go index 62b51f5a4e47..997f4b058a0a 100644 --- a/internal/handlers/tigris/msg_aggregate.go +++ b/internal/handlers/tigris/msg_aggregate.go @@ -16,6 +16,9 @@ package tigris import ( "context" + "fmt" + "os" + "time" "github.com/FerretDB/FerretDB/internal/handlers/common" "github.com/FerretDB/FerretDB/internal/handlers/common/aggregations" @@ -52,13 +55,13 @@ func (h *Handler) MsgAggregate(ctx context.Context, msg *wire.OpMsg) (*wire.OpMs "allowDiskUse", "maxTimeMS", "bypassDocumentValidation", "readConcern", "hint", "comment", "writeConcern", ) - var qp tigrisdb.QueryParams + var db string - if qp.DB, err = common.GetRequiredParam[string](document, "$db"); err != nil { + if db, err = common.GetRequiredParam[string](document, "$db"); err != nil { return nil, err } - collection, err := document.Get(document.Command()) + collectionParam, err := document.Get(document.Command()) if err != nil { return nil, err } @@ -66,7 +69,9 @@ func (h *Handler) MsgAggregate(ctx context.Context, msg *wire.OpMsg) (*wire.OpMs // TODO handle collection-agnostic pipelines ({aggregate: 1}) // https://github.com/FerretDB/FerretDB/issues/1890 var ok bool - if qp.Collection, ok = collection.(string); !ok { + var collection string + + if collection, ok = collectionParam.(string); !ok { return nil, commonerrors.NewCommandErrorMsgWithArgument( commonerrors.ErrFailedToParse, "Invalid command format: the 'aggregate' field must specify a collection name or 1", @@ -83,10 +88,11 @@ func (h *Handler) MsgAggregate(ctx context.Context, msg *wire.OpMsg) (*wire.OpMs ) } - stagesDocs := must.NotFail(iterator.ConsumeValues(pipeline.Iterator())) - stages := make([]aggregations.Stage, len(stagesDocs)) + stages := must.NotFail(iterator.ConsumeValues(pipeline.Iterator())) + stagesDocuments := make([]aggregations.Stage, 0, len(stages)) + stagesStats := make([]aggregations.Stage, 0, len(stages)) - for i, d := range stagesDocs { + for i, d := range stages { d, ok := d.(*types.Document) if !ok { return nil, commonerrors.NewCommandErrorMsgWithArgument( @@ -97,38 +103,62 @@ func (h *Handler) MsgAggregate(ctx context.Context, msg *wire.OpMsg) (*wire.OpMs } var s aggregations.Stage + if s, err = aggregations.NewStage(d); err != nil { return nil, err } - stages[i] = s + switch s.Type() { + case aggregations.StageTypeDocuments: + stagesDocuments = append(stagesDocuments, s) + stagesStats = append(stagesStats, s) // It's possible to apply "documents" stages to statistics + case aggregations.StageTypeStats: + if i > 0 { + // TODO Add a test to cover this error: https://github.com/FerretDB/FerretDB/issues/2349 + return nil, commonerrors.NewCommandErrorMsgWithArgument( + commonerrors.ErrCollStatsIsNotFirstStage, + "$collStats is only valid as the first stage in a pipeline", + document.Command(), + ) + } + stagesStats = append(stagesStats, s) + default: + panic(fmt.Sprintf("unknown stage type: %v", s.Type())) + } } - qp.Filter = aggregations.GetPushdownQuery(stagesDocs) - - var docs []*types.Document + var resDocs []*types.Document - iter, err := dbPool.QueryDocuments(ctx, &qp) - if err != nil { - return nil, err - } + // At this point we have a list of stages to apply to the documents or stats. + // If stagesStats contains the same stages as stagesDocuments, we apply aggregation to documents fetched from the DB. + // If stagesStats contains more stages than stagesDocuments, we apply aggregation to statistics fetched from the DB. + if len(stagesStats) == len(stagesDocuments) { + qp := tigrisdb.QueryParams{ + DB: db, + Collection: collection, + Filter: aggregations.GetPushdownQuery(stages), + } - defer iter.Close() + qp.Filter = aggregations.GetPushdownQuery(stages) - docs, err = iterator.ConsumeValues(iterator.Interface[struct{}, *types.Document](iter)) - if err != nil { - return nil, err - } + if resDocs, err = processStagesDocuments(ctx, &stagesDocumentsParams{ + dbPool, &qp, stagesDocuments, + }); err != nil { + return nil, err + } + } else { + statistics := aggregations.GetStatistics(stagesStats) - for _, s := range stages { - if docs, err = s.Process(ctx, docs); err != nil { + if resDocs, err = processStagesStats(ctx, &stagesStatsParams{ + dbPool, db, collection, statistics, stagesStats, + }); err != nil { return nil, err } } // TODO https://github.com/FerretDB/FerretDB/issues/1892 - firstBatch := types.MakeArray(len(docs)) - for _, doc := range docs { + firstBatch := types.MakeArray(len(resDocs)) + for _, doc := range resDocs { firstBatch.Append(doc) } @@ -138,7 +168,7 @@ func (h *Handler) MsgAggregate(ctx context.Context, msg *wire.OpMsg) (*wire.OpMs "cursor", must.NotFail(types.NewDocument( "firstBatch", firstBatch, "id", int64(0), - "ns", qp.DB+"."+qp.Collection, + "ns", db+"."+collection, )), "ok", float64(1), ))}, @@ -146,3 +176,132 @@ func (h *Handler) MsgAggregate(ctx context.Context, msg *wire.OpMsg) (*wire.OpMs return &reply, nil } + +// stagesDocumentsParams contains the parameters for processStagesDocuments. +type stagesDocumentsParams struct { + dbPool *tigrisdb.TigrisDB + qp *tigrisdb.QueryParams + stages []aggregations.Stage +} + +// processStagesDocuments retrieves the documents from the database and then processes them through the stages. +func processStagesDocuments(ctx context.Context, p *stagesDocumentsParams) ([]*types.Document, error) { //nolint:lll // for readability + var docs []*types.Document + + iter, err := p.dbPool.QueryDocuments(ctx, p.qp) + if err != nil { + return nil, err + } + + defer iter.Close() + + docs, err = iterator.ConsumeValues(iterator.Interface[struct{}, *types.Document](iter)) + if err != nil { + return nil, err + } + + for _, s := range p.stages { + if docs, err = s.Process(ctx, docs); err != nil { + return nil, err + } + } + + return docs, nil +} + +// stagesStatsParams contains the parameters for processStagesStats. +type stagesStatsParams struct { + dbPool *tigrisdb.TigrisDB + db string + collection string + statistics map[aggregations.Statistic]struct{} + stages []aggregations.Stage +} + +// processStagesStats retrieves the statistics from the database and then processes them through the stages. +func processStagesStats(ctx context.Context, p *stagesStatsParams) ([]*types.Document, error) { + // Clarify what needs to be retrieved from the database and retrieve it. + _, hasCount := p.statistics[aggregations.StatisticCount] + _, hasStorage := p.statistics[aggregations.StatisticStorage] + + var host string + var err error + + host, err = os.Hostname() + if err != nil { + return nil, lazyerrors.Error(err) + } + + doc := must.NotFail(types.NewDocument( + "ns", p.db+"."+p.collection, + "host", host, + "localTime", time.Now().UTC().Format(time.RFC3339), + )) + + var dbStats *tigrisdb.CollectionStats + + if hasCount || hasStorage { + var exists bool + + if exists, err = p.dbPool.CollectionExists(ctx, p.db, p.collection); err != nil { + return nil, lazyerrors.Error(err) + } + + if !exists { + return nil, commonerrors.NewCommandErrorMsgWithArgument( + commonerrors.ErrNamespaceNotFound, + fmt.Sprintf("ns not found: %s.%s", p.db, p.collection), + "aggregate", + ) + } + + querier := p.dbPool.Driver.UseDatabase(p.db) + dbStats, err = tigrisdb.FetchStats(ctx, querier, p.collection) + + if err != nil { + return nil, lazyerrors.Error(err) + } + } + + if hasStorage { + var avgObjSize int32 + if dbStats.NumObjects > 0 { + avgObjSize = int32(dbStats.Size) / dbStats.NumObjects + } + + doc.Set( + "storageStats", must.NotFail(types.NewDocument( + "size", int32(dbStats.Size), + "count", dbStats.NumObjects, + "avgObjSize", avgObjSize, + "storageSize", int32(dbStats.Size), + "freeStorageSize", int32(0), // TODO https://github.com/FerretDB/FerretDB/issues/2342 + "capped", false, // TODO https://github.com/FerretDB/FerretDB/issues/2342 + "wiredTiger", must.NotFail(types.NewDocument()), // TODO https://github.com/FerretDB/FerretDB/issues/2342 + "nindexes", int32(0), // Not supported for Tigris + "indexDetails", must.NotFail(types.NewDocument()), // Not supported for Tigris + "indexBuilds", must.NotFail(types.NewDocument()), // Not supported for Tigris + "totalIndexSize", int32(0), // Not supported for Tigris + "totalSize", int32(dbStats.Size), + "indexSizes", must.NotFail(types.NewDocument()), // Not supported for Tigris + )), + ) + } + + if hasCount { + doc.Set( + "count", dbStats.NumObjects, + ) + } + + // Process the retrieved statistics through the stages. + var res []*types.Document + + for _, s := range p.stages { + if res, err = s.Process(ctx, []*types.Document{doc}); err != nil { + return nil, err + } + } + + return res, nil +} diff --git a/internal/handlers/tigris/msg_collstats.go b/internal/handlers/tigris/msg_collstats.go index f4ab68145fb3..1b2cd441bed1 100644 --- a/internal/handlers/tigris/msg_collstats.go +++ b/internal/handlers/tigris/msg_collstats.go @@ -49,6 +49,14 @@ func (h *Handler) MsgCollStats(ctx context.Context, msg *wire.OpMsg) (*wire.OpMs return nil, err } + // TODO Add proper support for scale: https://github.com/FerretDB/FerretDB/issues/1346 + var scale int32 + + scale, err = common.GetOptionalPositiveNumber(document, "scale") + if err != nil || scale == 0 { + scale = 1 + } + querier := dbPool.Driver.UseDatabase(db) stats, err := tigrisdb.FetchStats(ctx, querier, collection) @@ -61,11 +69,11 @@ func (h *Handler) MsgCollStats(ctx context.Context, msg *wire.OpMsg) (*wire.OpMs Documents: []*types.Document{must.NotFail(types.NewDocument( "ns", db+"."+collection, "count", stats.NumObjects, - "size", stats.Size, - "storageSize", stats.Size, - "totalIndexSize", int64(0), - "totalSize", stats.Size, - "scaleFactor", int32(1), + "size", int32(stats.Size)/scale, + "storageSize", int32(stats.Size)/scale, + "totalIndexSize", int32(0), + "totalSize", int32(stats.Size)/scale, + "scaleFactor", scale, "ok", float64(1), ))}, }))