diff --git a/internal/handlers/common/aggregations/aggregations.go b/internal/handlers/common/aggregations/aggregations.go index d45995781e9a..6224f2aa41a3 100644 --- a/internal/handlers/common/aggregations/aggregations.go +++ b/internal/handlers/common/aggregations/aggregations.go @@ -22,26 +22,8 @@ import ( "github.com/FerretDB/FerretDB/internal/util/iterator" ) -// StageType is a type for aggregation stage types. -type StageType int - -const ( - _ StageType = iota - - // StageTypeDocuments is a type for stages that process documents. - StageTypeDocuments - - // StageTypeStats is a type for stages that process statistics and doesn't need documents. - StageTypeStats -) - // Stage is a common interface for all aggregation stages. type Stage interface { // Process applies an aggregate stage on documents from iterator. Process(ctx context.Context, iter types.DocumentsIterator, closer *iterator.MultiCloser) (types.DocumentsIterator, error) - - // Type returns the type of the stage. - // - // TODO Remove it? https://github.com/FerretDB/FerretDB/issues/2423 - Type() StageType } diff --git a/internal/handlers/common/aggregations/stages/add_fields.go b/internal/handlers/common/aggregations/stages/add_fields.go index a973a79b31fe..ca99d0e93aa2 100644 --- a/internal/handlers/common/aggregations/stages/add_fields.go +++ b/internal/handlers/common/aggregations/stages/add_fields.go @@ -67,11 +67,6 @@ func (s *addFields) Process(_ context.Context, iter types.DocumentsIterator, clo return common.AddFieldsIterator(iter, closer, s.newField), nil } -// Type implements Stage interface. -func (s *addFields) Type() aggregations.StageType { - return aggregations.StageTypeDocuments -} - // check interfaces var ( _ aggregations.Stage = (*addFields)(nil) diff --git a/internal/handlers/common/aggregations/stages/collstats.go b/internal/handlers/common/aggregations/stages/collstats.go index ed8761ba65b8..9045b9ab066e 100644 --- a/internal/handlers/common/aggregations/stages/collstats.go +++ b/internal/handlers/common/aggregations/stages/collstats.go @@ -124,11 +124,6 @@ func (c *collStats) Process(ctx context.Context, iter types.DocumentsIterator, c return iter, nil } -// Type implements Stage interface. -func (c *collStats) Type() aggregations.StageType { - return aggregations.StageTypeStats -} - // check interfaces var ( _ aggregations.Stage = (*collStats)(nil) diff --git a/internal/handlers/common/aggregations/stages/count.go b/internal/handlers/common/aggregations/stages/count.go index 1efab303bdbd..d66c722f31c7 100644 --- a/internal/handlers/common/aggregations/stages/count.go +++ b/internal/handlers/common/aggregations/stages/count.go @@ -83,11 +83,6 @@ func (c *count) Process(ctx context.Context, iter types.DocumentsIterator, close return common.CountIterator(iter, closer, c.field), nil } -// Type implements Stage interface. -func (c *count) Type() aggregations.StageType { - return aggregations.StageTypeDocuments -} - // check interfaces var ( _ aggregations.Stage = (*count)(nil) diff --git a/internal/handlers/common/aggregations/stages/group.go b/internal/handlers/common/aggregations/stages/group.go index 118a371e92b7..5b44ef149488 100644 --- a/internal/handlers/common/aggregations/stages/group.go +++ b/internal/handlers/common/aggregations/stages/group.go @@ -265,11 +265,6 @@ func (m *groupMap) addOrAppend(groupKey any, docs ...*types.Document) { }) } -// Type implements Stage interface. -func (g *group) Type() aggregations.StageType { - return aggregations.StageTypeDocuments -} - // check interfaces var ( _ aggregations.Stage = (*group)(nil) diff --git a/internal/handlers/common/aggregations/stages/limit.go b/internal/handlers/common/aggregations/stages/limit.go index a9b66e27f702..aeab86deeb20 100644 --- a/internal/handlers/common/aggregations/stages/limit.go +++ b/internal/handlers/common/aggregations/stages/limit.go @@ -51,11 +51,6 @@ func (l *limit) Process(ctx context.Context, iter types.DocumentsIterator, close return common.LimitIterator(iter, closer, l.limit), nil } -// Type implements Stage interface. -func (l *limit) Type() aggregations.StageType { - return aggregations.StageTypeDocuments -} - // check interfaces var ( _ aggregations.Stage = (*limit)(nil) diff --git a/internal/handlers/common/aggregations/stages/match.go b/internal/handlers/common/aggregations/stages/match.go index 041e263289e3..eed9d7da76ed 100644 --- a/internal/handlers/common/aggregations/stages/match.go +++ b/internal/handlers/common/aggregations/stages/match.go @@ -50,11 +50,6 @@ func (m *match) Process(ctx context.Context, iter types.DocumentsIterator, close return common.FilterIterator(iter, closer, m.filter), nil } -// Type implements Stage interface. -func (m *match) Type() aggregations.StageType { - return aggregations.StageTypeDocuments -} - // check interfaces var ( _ aggregations.Stage = (*match)(nil) diff --git a/internal/handlers/common/aggregations/stages/project.go b/internal/handlers/common/aggregations/stages/project.go index 3fdee8b3f505..faab241b2cb4 100644 --- a/internal/handlers/common/aggregations/stages/project.go +++ b/internal/handlers/common/aggregations/stages/project.go @@ -67,11 +67,6 @@ func (p *project) Process(_ context.Context, iter types.DocumentsIterator, close return projection.ProjectionIterator(iter, closer, p.projection) } -// Type implements Stage interface. -func (p *project) Type() aggregations.StageType { - return aggregations.StageTypeDocuments -} - // check interfaces var ( _ aggregations.Stage = (*project)(nil) diff --git a/internal/handlers/common/aggregations/stages/set.go b/internal/handlers/common/aggregations/stages/set.go index 96ae8321696d..27a51a7bb7e5 100644 --- a/internal/handlers/common/aggregations/stages/set.go +++ b/internal/handlers/common/aggregations/stages/set.go @@ -69,11 +69,6 @@ func (s *set) Process(_ context.Context, iter types.DocumentsIterator, closer *i return common.AddFieldsIterator(iter, closer, s.newField), nil } -// Type implements Stage interface. -func (s *set) Type() aggregations.StageType { - return aggregations.StageTypeDocuments -} - // check interfaces var ( _ aggregations.Stage = (*set)(nil) diff --git a/internal/handlers/common/aggregations/stages/skip.go b/internal/handlers/common/aggregations/stages/skip.go index d6d792e9d1b6..671219b9fbe3 100644 --- a/internal/handlers/common/aggregations/stages/skip.go +++ b/internal/handlers/common/aggregations/stages/skip.go @@ -51,11 +51,6 @@ func (s *skip) Process(ctx context.Context, iter types.DocumentsIterator, closer return common.SkipIterator(iter, closer, s.value), nil } -// Type implements Stage interface. -func (s *skip) Type() aggregations.StageType { - return aggregations.StageTypeDocuments -} - // check interfaces var ( _ aggregations.Stage = (*skip)(nil) diff --git a/internal/handlers/common/aggregations/stages/sort.go b/internal/handlers/common/aggregations/stages/sort.go index 5c27345d97ae..00db7238184b 100644 --- a/internal/handlers/common/aggregations/stages/sort.go +++ b/internal/handlers/common/aggregations/stages/sort.go @@ -78,11 +78,6 @@ func (s *sort) Process(ctx context.Context, iter types.DocumentsIterator, closer return iter, nil } -// Type implements Stage interface. -func (s *sort) Type() aggregations.StageType { - return aggregations.StageTypeDocuments -} - // check interfaces var ( _ aggregations.Stage = (*sort)(nil) diff --git a/internal/handlers/common/aggregations/stages/unset.go b/internal/handlers/common/aggregations/stages/unset.go index 1750b1605149..2bfb996f3249 100644 --- a/internal/handlers/common/aggregations/stages/unset.go +++ b/internal/handlers/common/aggregations/stages/unset.go @@ -144,11 +144,6 @@ func (u *unset) Process(_ context.Context, iter types.DocumentsIterator, closer return projection.ProjectionIterator(iter, closer, u.exclusion) } -// Type implements Stage interface. -func (u *unset) Type() aggregations.StageType { - return aggregations.StageTypeDocuments -} - // validateUnsetField returns error on invalid field value. func validateUnsetField(field string) (*types.Path, error) { if field == "" { diff --git a/internal/handlers/common/aggregations/stages/unwind.go b/internal/handlers/common/aggregations/stages/unwind.go index 0be1f1a41999..8549d32b1d45 100644 --- a/internal/handlers/common/aggregations/stages/unwind.go +++ b/internal/handlers/common/aggregations/stages/unwind.go @@ -160,11 +160,6 @@ func (u *unwind) Process(ctx context.Context, iter types.DocumentsIterator, clos return iter, nil } -// Type implements Stage interface. -func (u *unwind) Type() aggregations.StageType { - return aggregations.StageTypeDocuments -} - // check interfaces var ( _ aggregations.Stage = (*unwind)(nil) diff --git a/internal/handlers/pg/msg_aggregate.go b/internal/handlers/pg/msg_aggregate.go index da7fb78ddeac..cbaa263149a1 100644 --- a/internal/handlers/pg/msg_aggregate.go +++ b/internal/handlers/pg/msg_aggregate.go @@ -164,7 +164,7 @@ func (h *Handler) MsgAggregate(ctx context.Context, msg *wire.OpMsg) (*wire.OpMs aggregationStages := must.NotFail(iterator.ConsumeValues(pipeline.Iterator())) stagesDocuments := make([]aggregations.Stage, 0, len(aggregationStages)) - stagesStats := make([]aggregations.Stage, 0, len(aggregationStages)) + collStatsDocuments := make([]aggregations.Stage, 0, len(aggregationStages)) for i, d := range aggregationStages { d, ok := d.(*types.Document) @@ -182,11 +182,8 @@ func (h *Handler) MsgAggregate(ctx context.Context, msg *wire.OpMsg) (*wire.OpMs return nil, err } - switch s.Type() { - case aggregations.StageTypeDocuments: - stagesDocuments = append(stagesDocuments, s) - stagesStats = append(stagesStats, s) // It's possible to apply "documents" stages to statistics - case aggregations.StageTypeStats: + switch d.Command() { + case "$collStats": if i > 0 { // TODO Add a test to cover this error: https://github.com/FerretDB/FerretDB/issues/2349 return nil, commonerrors.NewCommandErrorMsgWithArgument( @@ -195,9 +192,11 @@ func (h *Handler) MsgAggregate(ctx context.Context, msg *wire.OpMsg) (*wire.OpMs document.Command(), ) } - stagesStats = append(stagesStats, s) + + collStatsDocuments = append(collStatsDocuments, s) default: - panic(fmt.Sprintf("unknown stage type: %v", s.Type())) + stagesDocuments = append(stagesDocuments, s) + collStatsDocuments = append(collStatsDocuments, s) // It's possible to apply any stage after $collStats stage } } @@ -245,9 +244,9 @@ func (h *Handler) MsgAggregate(ctx context.Context, msg *wire.OpMsg) (*wire.OpMs var iter iterator.Interface[struct{}, *types.Document] // At this point we have a list of stages to apply to the documents or stats. - // If stagesStats contains the same stages as stagesDocuments, we apply aggregation to documents fetched from the DB. - // If stagesStats contains more stages than stagesDocuments, we apply aggregation to statistics fetched from the DB. - if len(stagesStats) == len(stagesDocuments) { + // If collStatsDocuments contains the same stages as stagesDocuments, we apply aggregation to documents fetched from the DB. + // If collStatsDocuments contains more stages than stagesDocuments, we apply aggregation to statistics fetched from the DB. + if len(collStatsDocuments) == len(stagesDocuments) { filter, sort := aggregations.GetPushdownQuery(aggregationStages) // only documents stages or no stages - fetch documents from the DB and apply stages to them @@ -267,10 +266,11 @@ func (h *Handler) MsgAggregate(ctx context.Context, msg *wire.OpMsg) (*wire.OpMs iter, err = processStagesDocuments(ctx, closer, &stagesDocumentsParams{dbPool, &qp, stagesDocuments}) } else { // stats stages are provided - fetch stats from the DB and apply stages to them - statistics := stages.GetStatistics(stagesStats) + // TODO move $collStatsDocuments specific logic to its stage https://github.com/FerretDB/FerretDB/issues/2423 + statistics := stages.GetStatistics(collStatsDocuments) iter, err = processStagesStats(ctx, closer, &stagesStatsParams{ - dbPool, db, collection, statistics, stagesStats, + dbPool, db, collection, statistics, collStatsDocuments, }) } @@ -377,6 +377,7 @@ type stagesStatsParams struct { } // processStagesStats retrieves the statistics from the database and then processes them through the stages. +// TODO move $collStats specific logic to its stage https://github.com/FerretDB/FerretDB/issues/2423 func processStagesStats(ctx context.Context, closer *iterator.MultiCloser, p *stagesStatsParams) (types.DocumentsIterator, error) { //nolint:lll // for readability // Clarify what needs to be retrieved from the database and retrieve it. _, hasCount := p.statistics[stages.StatisticCount]