Skip to content

Commit

Permalink
Merge branch 'main' into new-pg-stub
Browse files Browse the repository at this point in the history
  • Loading branch information
AlekSi authored Aug 25, 2023
2 parents 1d6468e + 13240bf commit 824ef55
Show file tree
Hide file tree
Showing 3 changed files with 77 additions and 44 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -22,14 +22,18 @@
package accumulators

import (
"errors"
"fmt"

"github.com/FerretDB/FerretDB/internal/handlers/commonerrors"
"github.com/FerretDB/FerretDB/internal/types"
"github.com/FerretDB/FerretDB/internal/util/iterator"
"github.com/FerretDB/FerretDB/internal/util/must"
)

// newAccumulatorFunc is a type for a function that creates an accumulation operator.
type newAccumulatorFunc func(expression *types.Document) (Accumulator, error)
// It takes the arguments extracted from the accumulator document.
type newAccumulatorFunc func(args ...any) (Accumulator, error)

// Accumulator is a common interface for aggregation accumulation operators.
type Accumulator interface {
Expand Down Expand Up @@ -61,6 +65,30 @@ func NewAccumulator(stage, key string, value any) (Accumulator, error) {

operator := accumulation.Command()

expr := must.NotFail(accumulation.Get(operator))

var args []any

switch expr := expr.(type) {
case *types.Document:
args = append(args, expr)
case *types.Array:
iter := expr.Iterator()
defer iter.Close()

for {
_, v, err := iter.Next()

if errors.Is(err, iterator.ErrIteratorDone) {
break
}

args = append(args, v)
}
default:
args = append(args, expr)
}

newAccumulator, ok := Accumulators[operator]
if !ok {
return nil, commonerrors.NewCommandErrorMsgWithArgument(
Expand All @@ -70,7 +98,7 @@ func NewAccumulator(stage, key string, value any) (Accumulator, error) {
)
}

return newAccumulator(accumulation)
return newAccumulator(args...)
}

// Accumulators maps all aggregation accumulators.
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,6 @@ package accumulators
import (
"errors"

"github.com/FerretDB/FerretDB/internal/handlers/common"
"github.com/FerretDB/FerretDB/internal/handlers/commonerrors"
"github.com/FerretDB/FerretDB/internal/types"
"github.com/FerretDB/FerretDB/internal/util/iterator"
Expand All @@ -27,9 +26,18 @@ import (
type count struct{}

// newCount creates a new $count aggregation operator.
func newCount(expr *types.Document) (Accumulator, error) {
expression, err := common.GetRequiredParam[*types.Document](expr, "$count")
if err != nil || expression.Len() != 0 {
func newCount(args ...any) (Accumulator, error) {
if len(args) != 1 {
return nil, commonerrors.NewCommandErrorMsgWithArgument(
commonerrors.ErrStageGroupUnaryOperator,
"The $count accumulator is a unary operator",
"$count (accumulator)",
)
}

doc, ok := args[0].(*types.Document)

if !ok || doc.Len() > 0 {
return nil, commonerrors.NewCommandErrorMsgWithArgument(
commonerrors.ErrTypeMismatch,
"$count takes no arguments, i.e. $count:{}",
Expand Down
73 changes: 35 additions & 38 deletions internal/handlers/common/aggregations/operators/accumulators/sum.go
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,6 @@ import (
"github.com/FerretDB/FerretDB/internal/types"
"github.com/FerretDB/FerretDB/internal/util/iterator"
"github.com/FerretDB/FerretDB/internal/util/lazyerrors"
"github.com/FerretDB/FerretDB/internal/util/must"
)

// sum represents $sum aggregation operator.
Expand All @@ -34,52 +33,50 @@ type sum struct {
}

// newSum creates a new $sum aggregation operator.
func newSum(accumulation *types.Document) (Accumulator, error) {
expression := must.NotFail(accumulation.Get("$sum"))
func newSum(args ...any) (Accumulator, error) {
accumulator := new(sum)

switch expr := expression.(type) {
case *types.Document:
if !operators.IsOperator(expr) {
accumulator.number = int32(0)
break
}

op, err := operators.NewOperator(expr)
if err == nil {
// TODO https://github.com/FerretDB/FerretDB/issues/3129
_, err = op.Process(nil)
}

if err != nil {
var opErr operators.OperatorError
if !errors.As(err, &opErr) {
return nil, lazyerrors.Error(err)
}

return nil, opErr
}

accumulator.operator = op
case *types.Array:
if len(args) != 1 {
return nil, commonerrors.NewCommandErrorMsgWithArgument(
commonerrors.ErrStageGroupUnaryOperator,
"The $sum accumulator is a unary operator",
"$sum (accumulator)",
)
case float64:
accumulator.number = expr
case string:
var err error
if accumulator.expression, err = aggregations.NewExpression(expr, nil); err != nil {
// $sum returns 0 on non-existent field.
}

for _, arg := range args {
switch arg := arg.(type) {
case *types.Document:
if !operators.IsOperator(arg) {
accumulator.number = int32(0)
break
}

op, err := operators.NewOperator(arg)
if err != nil {
var opErr operators.OperatorError
if !errors.As(err, &opErr) {
return nil, lazyerrors.Error(err)
}

return nil, opErr
}

accumulator.operator = op
case float64:
accumulator.number = arg
case string:
var err error
if accumulator.expression, err = aggregations.NewExpression(arg, nil); err != nil {
// $sum returns 0 on non-existent field.
accumulator.number = int32(0)
}
case int32, int64:
accumulator.number = arg
default:
accumulator.number = int32(0)
// $sum returns 0 on non-numeric field
}
case int32, int64:
accumulator.number = expr
default:
accumulator.number = int32(0)
// $sum returns 0 on non-numeric field
}

return accumulator, nil
Expand Down

0 comments on commit 824ef55

Please sign in to comment.