Skip to content

Commit

Permalink
Support recursive operator calls for $sum aggregation accumulator (F…
Browse files Browse the repository at this point in the history
  • Loading branch information
noisersup authored and yonarw committed Aug 31, 2023
1 parent 0ca5351 commit ae24839
Show file tree
Hide file tree
Showing 3 changed files with 61 additions and 6 deletions.
23 changes: 22 additions & 1 deletion integration/aggregate_documents_compat_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -1098,7 +1098,28 @@ func TestAggregateCompatGroupSum(t *testing.T) {
}}},
bson.D{{"$sort", bson.D{{"_id", -1}}}},
},
skip: "https://github.com/FerretDB/FerretDB/issues/2694",
},
"RecursiveInvalid": {
pipeline: bson.A{
bson.D{{"$group", bson.D{{"sum", bson.D{{"$sum", bson.D{{"v", "$v"}}}}}}}},
},
resultType: emptyResult,
},
"RecursiveArrayInvalid": {
pipeline: bson.A{
bson.D{{"$group", bson.D{{"sum", bson.D{{"$sum", bson.D{{"$type", bson.A{"1", "2"}}}}}}}}},
},
resultType: emptyResult,
},
"RecursiveOperatorNonExistent": {
pipeline: bson.A{
bson.D{{"$group", bson.D{
{"_id", "$_id"},
// first $sum is accumulator operator, second $sum is operator
{"sum", bson.D{{"$sum", bson.D{{"$non-existent", "$v"}}}}},
}}},
},
resultType: emptyResult,
},
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@ import (
"errors"

"github.com/FerretDB/FerretDB/internal/handlers/common/aggregations"
"github.com/FerretDB/FerretDB/internal/handlers/common/aggregations/operators"
"github.com/FerretDB/FerretDB/internal/handlers/commonerrors"
"github.com/FerretDB/FerretDB/internal/types"
"github.com/FerretDB/FerretDB/internal/util/iterator"
Expand All @@ -28,6 +29,7 @@ import (
// sum represents $sum aggregation operator.
type sum struct {
expression *aggregations.Expression
operator operators.Operator
number any
}

Expand All @@ -37,6 +39,28 @@ func newSum(accumulation *types.Document) (Accumulator, error) {
accumulator := new(sum)

switch expr := expression.(type) {
case *types.Document:
if !operators.IsOperator(expr) {
accumulator.number = int32(0)
break
}

op, err := operators.NewOperator(expr)
if err == nil {
// TODO https://github.com/FerretDB/FerretDB/issues/3129
_, err = op.Process(nil)
}

if err != nil {
var opErr operators.OperatorError
if !errors.As(err, &opErr) {
return nil, lazyerrors.Error(err)
}

return nil, opErr
}

accumulator.operator = op
case *types.Array:
return nil, commonerrors.NewCommandErrorMsgWithArgument(
commonerrors.ErrStageGroupUnaryOperator,
Expand All @@ -54,7 +78,6 @@ func newSum(accumulation *types.Document) (Accumulator, error) {
case int32, int64:
accumulator.number = expr
default:
// TODO https://github.com/FerretDB/FerretDB/issues/2694
accumulator.number = int32(0)
// $sum returns 0 on non-numeric field
}
Expand All @@ -77,7 +100,18 @@ func (s *sum) Accumulate(iter types.DocumentsIterator) (any, error) {
return nil, lazyerrors.Error(err)
}

if s.expression != nil {
switch {
case s.operator != nil:
v, err := s.operator.Process(doc)
if err != nil {
return nil, err
}

numbers = append(numbers, v)

continue

case s.expression != nil:
value, err := s.expression.Evaluate(doc)

// sum fields that exist
Expand Down
6 changes: 3 additions & 3 deletions internal/handlers/common/aggregations/stages/group.go
Original file line number Diff line number Diff line change
Expand Up @@ -102,7 +102,7 @@ func newGroup(stage *types.Document) (aggregations.Stage, error) {

accumulator, err := accumulators.NewAccumulator("$group", field, v)
if err != nil {
return nil, err
return nil, processOperatorError(err)
}

groups = append(groups, groupBy{
Expand Down Expand Up @@ -149,7 +149,7 @@ func (g *group) Process(ctx context.Context, iter types.DocumentsIterator, close
for _, accumulation := range g.groupBy {
out, err := accumulation.accumulate(groupIter)
if err != nil {
return nil, err
return nil, processOperatorError(err)
}

if doc.Has(accumulation.outputField) {
Expand Down Expand Up @@ -374,7 +374,7 @@ func processOperatorError(err error) error {
}
}

return lazyerrors.Error(err)
return err
}

// check interfaces
Expand Down

0 comments on commit ae24839

Please sign in to comment.