Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Remove Type() interface from aggregation stage #3045

Merged
merged 5 commits into from
Jul 14, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
18 changes: 0 additions & 18 deletions internal/handlers/common/aggregations/aggregations.go
Original file line number Diff line number Diff line change
Expand Up @@ -22,26 +22,8 @@ import (
"github.com/FerretDB/FerretDB/internal/util/iterator"
)

// StageType is a type for aggregation stage types.
type StageType int

const (
_ StageType = iota

// StageTypeDocuments is a type for stages that process documents.
StageTypeDocuments

// StageTypeStats is a type for stages that process statistics and doesn't need documents.
StageTypeStats
)

// Stage is a common interface for all aggregation stages.
type Stage interface {
// Process applies an aggregate stage on documents from iterator.
Process(ctx context.Context, iter types.DocumentsIterator, closer *iterator.MultiCloser) (types.DocumentsIterator, error)

// Type returns the type of the stage.
//
// TODO Remove it? https://github.com/FerretDB/FerretDB/issues/2423
Type() StageType
}
5 changes: 0 additions & 5 deletions internal/handlers/common/aggregations/stages/add_fields.go
Original file line number Diff line number Diff line change
Expand Up @@ -67,11 +67,6 @@ func (s *addFields) Process(_ context.Context, iter types.DocumentsIterator, clo
return common.AddFieldsIterator(iter, closer, s.newField), nil
}

// Type implements Stage interface.
func (s *addFields) Type() aggregations.StageType {
return aggregations.StageTypeDocuments
}

// check interfaces
var (
_ aggregations.Stage = (*addFields)(nil)
Expand Down
5 changes: 0 additions & 5 deletions internal/handlers/common/aggregations/stages/collstats.go
Original file line number Diff line number Diff line change
Expand Up @@ -124,11 +124,6 @@ func (c *collStats) Process(ctx context.Context, iter types.DocumentsIterator, c
return iter, nil
}

// Type implements Stage interface.
func (c *collStats) Type() aggregations.StageType {
return aggregations.StageTypeStats
}

// check interfaces
var (
_ aggregations.Stage = (*collStats)(nil)
Expand Down
5 changes: 0 additions & 5 deletions internal/handlers/common/aggregations/stages/count.go
Original file line number Diff line number Diff line change
Expand Up @@ -83,11 +83,6 @@ func (c *count) Process(ctx context.Context, iter types.DocumentsIterator, close
return common.CountIterator(iter, closer, c.field), nil
}

// Type implements Stage interface.
func (c *count) Type() aggregations.StageType {
return aggregations.StageTypeDocuments
}

// check interfaces
var (
_ aggregations.Stage = (*count)(nil)
Expand Down
5 changes: 0 additions & 5 deletions internal/handlers/common/aggregations/stages/group.go
Original file line number Diff line number Diff line change
Expand Up @@ -265,11 +265,6 @@ func (m *groupMap) addOrAppend(groupKey any, docs ...*types.Document) {
})
}

// Type implements Stage interface.
func (g *group) Type() aggregations.StageType {
return aggregations.StageTypeDocuments
}

// check interfaces
var (
_ aggregations.Stage = (*group)(nil)
Expand Down
5 changes: 0 additions & 5 deletions internal/handlers/common/aggregations/stages/limit.go
Original file line number Diff line number Diff line change
Expand Up @@ -51,11 +51,6 @@ func (l *limit) Process(ctx context.Context, iter types.DocumentsIterator, close
return common.LimitIterator(iter, closer, l.limit), nil
}

// Type implements Stage interface.
func (l *limit) Type() aggregations.StageType {
return aggregations.StageTypeDocuments
}

// check interfaces
var (
_ aggregations.Stage = (*limit)(nil)
Expand Down
5 changes: 0 additions & 5 deletions internal/handlers/common/aggregations/stages/match.go
Original file line number Diff line number Diff line change
Expand Up @@ -50,11 +50,6 @@ func (m *match) Process(ctx context.Context, iter types.DocumentsIterator, close
return common.FilterIterator(iter, closer, m.filter), nil
}

// Type implements Stage interface.
func (m *match) Type() aggregations.StageType {
return aggregations.StageTypeDocuments
}

// check interfaces
var (
_ aggregations.Stage = (*match)(nil)
Expand Down
5 changes: 0 additions & 5 deletions internal/handlers/common/aggregations/stages/project.go
Original file line number Diff line number Diff line change
Expand Up @@ -67,11 +67,6 @@ func (p *project) Process(_ context.Context, iter types.DocumentsIterator, close
return projection.ProjectionIterator(iter, closer, p.projection)
}

// Type implements Stage interface.
func (p *project) Type() aggregations.StageType {
return aggregations.StageTypeDocuments
}

// check interfaces
var (
_ aggregations.Stage = (*project)(nil)
Expand Down
5 changes: 0 additions & 5 deletions internal/handlers/common/aggregations/stages/set.go
Original file line number Diff line number Diff line change
Expand Up @@ -69,11 +69,6 @@ func (s *set) Process(_ context.Context, iter types.DocumentsIterator, closer *i
return common.AddFieldsIterator(iter, closer, s.newField), nil
}

// Type implements Stage interface.
func (s *set) Type() aggregations.StageType {
return aggregations.StageTypeDocuments
}

// check interfaces
var (
_ aggregations.Stage = (*set)(nil)
Expand Down
5 changes: 0 additions & 5 deletions internal/handlers/common/aggregations/stages/skip.go
Original file line number Diff line number Diff line change
Expand Up @@ -51,11 +51,6 @@ func (s *skip) Process(ctx context.Context, iter types.DocumentsIterator, closer
return common.SkipIterator(iter, closer, s.value), nil
}

// Type implements Stage interface.
func (s *skip) Type() aggregations.StageType {
return aggregations.StageTypeDocuments
}

// check interfaces
var (
_ aggregations.Stage = (*skip)(nil)
Expand Down
5 changes: 0 additions & 5 deletions internal/handlers/common/aggregations/stages/sort.go
Original file line number Diff line number Diff line change
Expand Up @@ -78,11 +78,6 @@ func (s *sort) Process(ctx context.Context, iter types.DocumentsIterator, closer
return iter, nil
}

// Type implements Stage interface.
func (s *sort) Type() aggregations.StageType {
return aggregations.StageTypeDocuments
}

// check interfaces
var (
_ aggregations.Stage = (*sort)(nil)
Expand Down
5 changes: 0 additions & 5 deletions internal/handlers/common/aggregations/stages/unset.go
Original file line number Diff line number Diff line change
Expand Up @@ -144,11 +144,6 @@ func (u *unset) Process(_ context.Context, iter types.DocumentsIterator, closer
return projection.ProjectionIterator(iter, closer, u.exclusion)
}

// Type implements Stage interface.
func (u *unset) Type() aggregations.StageType {
return aggregations.StageTypeDocuments
}

// validateUnsetField returns error on invalid field value.
func validateUnsetField(field string) (*types.Path, error) {
if field == "" {
Expand Down
5 changes: 0 additions & 5 deletions internal/handlers/common/aggregations/stages/unwind.go
Original file line number Diff line number Diff line change
Expand Up @@ -160,11 +160,6 @@ func (u *unwind) Process(ctx context.Context, iter types.DocumentsIterator, clos
return iter, nil
}

// Type implements Stage interface.
func (u *unwind) Type() aggregations.StageType {
return aggregations.StageTypeDocuments
}

// check interfaces
var (
_ aggregations.Stage = (*unwind)(nil)
Expand Down
27 changes: 14 additions & 13 deletions internal/handlers/pg/msg_aggregate.go
Original file line number Diff line number Diff line change
Expand Up @@ -164,7 +164,7 @@ func (h *Handler) MsgAggregate(ctx context.Context, msg *wire.OpMsg) (*wire.OpMs

aggregationStages := must.NotFail(iterator.ConsumeValues(pipeline.Iterator()))
stagesDocuments := make([]aggregations.Stage, 0, len(aggregationStages))
stagesStats := make([]aggregations.Stage, 0, len(aggregationStages))
collStatsDocuments := make([]aggregations.Stage, 0, len(aggregationStages))

for i, d := range aggregationStages {
d, ok := d.(*types.Document)
Expand All @@ -182,11 +182,8 @@ func (h *Handler) MsgAggregate(ctx context.Context, msg *wire.OpMsg) (*wire.OpMs
return nil, err
}

switch s.Type() {
case aggregations.StageTypeDocuments:
stagesDocuments = append(stagesDocuments, s)
stagesStats = append(stagesStats, s) // It's possible to apply "documents" stages to statistics
case aggregations.StageTypeStats:
switch d.Command() {
case "$collStats":
if i > 0 {
// TODO Add a test to cover this error: https://github.com/FerretDB/FerretDB/issues/2349
return nil, commonerrors.NewCommandErrorMsgWithArgument(
Expand All @@ -195,9 +192,11 @@ func (h *Handler) MsgAggregate(ctx context.Context, msg *wire.OpMsg) (*wire.OpMs
document.Command(),
)
}
stagesStats = append(stagesStats, s)

collStatsDocuments = append(collStatsDocuments, s)
default:
panic(fmt.Sprintf("unknown stage type: %v", s.Type()))
stagesDocuments = append(stagesDocuments, s)
collStatsDocuments = append(collStatsDocuments, s) // It's possible to apply any stage after $collStats stage
}
}

Expand Down Expand Up @@ -245,9 +244,9 @@ func (h *Handler) MsgAggregate(ctx context.Context, msg *wire.OpMsg) (*wire.OpMs
var iter iterator.Interface[struct{}, *types.Document]

// At this point we have a list of stages to apply to the documents or stats.
// If stagesStats contains the same stages as stagesDocuments, we apply aggregation to documents fetched from the DB.
// If stagesStats contains more stages than stagesDocuments, we apply aggregation to statistics fetched from the DB.
if len(stagesStats) == len(stagesDocuments) {
// If collStatsDocuments contains the same stages as stagesDocuments, we apply aggregation to documents fetched from the DB.
// If collStatsDocuments contains more stages than stagesDocuments, we apply aggregation to statistics fetched from the DB.
if len(collStatsDocuments) == len(stagesDocuments) {
filter, sort := aggregations.GetPushdownQuery(aggregationStages)

// only documents stages or no stages - fetch documents from the DB and apply stages to them
Expand All @@ -267,10 +266,11 @@ func (h *Handler) MsgAggregate(ctx context.Context, msg *wire.OpMsg) (*wire.OpMs
iter, err = processStagesDocuments(ctx, closer, &stagesDocumentsParams{dbPool, &qp, stagesDocuments})
} else {
// stats stages are provided - fetch stats from the DB and apply stages to them
statistics := stages.GetStatistics(stagesStats)
// TODO move $collStatsDocuments specific logic to its stage https://github.com/FerretDB/FerretDB/issues/2423
statistics := stages.GetStatistics(collStatsDocuments)

iter, err = processStagesStats(ctx, closer, &stagesStatsParams{
dbPool, db, collection, statistics, stagesStats,
dbPool, db, collection, statistics, collStatsDocuments,
})
}

Expand Down Expand Up @@ -377,6 +377,7 @@ type stagesStatsParams struct {
}

// processStagesStats retrieves the statistics from the database and then processes them through the stages.
// TODO move $collStats specific logic to its stage https://github.com/FerretDB/FerretDB/issues/2423
func processStagesStats(ctx context.Context, closer *iterator.MultiCloser, p *stagesStatsParams) (types.DocumentsIterator, error) { //nolint:lll // for readability
// Clarify what needs to be retrieved from the database and retrieve it.
_, hasCount := p.statistics[stages.StatisticCount]
Expand Down