Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Support recursive operator calls for $sum aggregation accumulator #3116

Merged
merged 18 commits into from
Jul 31, 2023
23 changes: 22 additions & 1 deletion integration/aggregate_documents_compat_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -1098,7 +1098,28 @@ func TestAggregateCompatGroupSum(t *testing.T) {
}}},
bson.D{{"$sort", bson.D{{"_id", -1}}}},
},
skip: "https://github.com/FerretDB/FerretDB/issues/2694",
},
"RecursiveInvalid": {
pipeline: bson.A{
bson.D{{"$group", bson.D{{"sum", bson.D{{"$sum", bson.D{{"v", "$v"}}}}}}}},
},
resultType: emptyResult,
},
"RecursiveArrayInvalid": {
pipeline: bson.A{
bson.D{{"$group", bson.D{{"sum", bson.D{{"$sum", bson.D{{"$type", bson.A{"1", "2"}}}}}}}}},
},
resultType: emptyResult,
},
"RecursiveOperatorNonExistent": {
pipeline: bson.A{
bson.D{{"$group", bson.D{
{"_id", "$_id"},
// first $sum is accumulator operator, second $sum is operator
{"sum", bson.D{{"$sum", bson.D{{"$non-existent", "$v"}}}}},
}}},
},
resultType: emptyResult,
},
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@ import (
"errors"

"github.com/FerretDB/FerretDB/internal/handlers/common/aggregations"
"github.com/FerretDB/FerretDB/internal/handlers/common/aggregations/operators"
"github.com/FerretDB/FerretDB/internal/handlers/commonerrors"
"github.com/FerretDB/FerretDB/internal/types"
"github.com/FerretDB/FerretDB/internal/util/iterator"
Expand All @@ -28,6 +29,7 @@ import (
// sum represents $sum aggregation operator.
type sum struct {
expression *aggregations.Expression
operator operators.Operator
number any
}

Expand All @@ -37,6 +39,28 @@ func newSum(accumulation *types.Document) (Accumulator, error) {
accumulator := new(sum)

switch expr := expression.(type) {
case *types.Document:
if !operators.IsOperator(expr) {
accumulator.number = int32(0)
break
}

op, err := operators.NewOperator(expr)
if err == nil {
// TODO https://github.com/FerretDB/FerretDB/issues/3129
_, err = op.Process(nil)
}

if err != nil {
var opErr operators.OperatorError
if !errors.As(err, &opErr) {
return nil, lazyerrors.Error(err)
}

return nil, opErr
}

accumulator.operator = op
case *types.Array:
return nil, commonerrors.NewCommandErrorMsgWithArgument(
commonerrors.ErrStageGroupUnaryOperator,
Expand All @@ -54,7 +78,6 @@ func newSum(accumulation *types.Document) (Accumulator, error) {
case int32, int64:
accumulator.number = expr
default:
// TODO https://github.com/FerretDB/FerretDB/issues/2694
accumulator.number = int32(0)
// $sum returns 0 on non-numeric field
}
Expand All @@ -77,7 +100,18 @@ func (s *sum) Accumulate(iter types.DocumentsIterator) (any, error) {
return nil, lazyerrors.Error(err)
}

if s.expression != nil {
switch {
case s.operator != nil:
v, err := s.operator.Process(doc)
if err != nil {
return nil, err
}
chilagrow marked this conversation as resolved.
Show resolved Hide resolved

numbers = append(numbers, v)

continue

case s.expression != nil:
value, err := s.expression.Evaluate(doc)

// sum fields that exist
Expand Down
6 changes: 3 additions & 3 deletions internal/handlers/common/aggregations/stages/group.go
Original file line number Diff line number Diff line change
Expand Up @@ -102,7 +102,7 @@ func newGroup(stage *types.Document) (aggregations.Stage, error) {

accumulator, err := accumulators.NewAccumulator("$group", field, v)
if err != nil {
return nil, err
return nil, processOperatorError(err)
}

groups = append(groups, groupBy{
Expand Down Expand Up @@ -149,7 +149,7 @@ func (g *group) Process(ctx context.Context, iter types.DocumentsIterator, close
for _, accumulation := range g.groupBy {
out, err := accumulation.accumulate(groupIter)
if err != nil {
return nil, err
return nil, processOperatorError(err)
}

if doc.Has(accumulation.outputField) {
Expand Down Expand Up @@ -374,7 +374,7 @@ func processOperatorError(err error) error {
}
}

return lazyerrors.Error(err)
return err
}

// check interfaces
Expand Down