From 6a0f709e5bb6073d275cd8ab34e7e470fec6c09f Mon Sep 17 00:00:00 2001 From: Chi Fujii Date: Tue, 28 Mar 2023 10:30:23 +0900 Subject: [PATCH 01/11] init --- integration/aggregate_compat_test.go | 116 ++++++++++++++++ integration/shareddata/scalars.go | 17 +++ integration/shareddata/shareddata.go | 1 + .../handlers/common/aggregations/group.go | 1 + .../handlers/common/aggregations/group_sum.go | 117 ++++++++++++++++ internal/handlers/commonerrors/error.go | 3 + .../handlers/commonerrors/errorcode_string.go | 26 ++-- internal/types/number.go | 129 ++++++++++++++++++ internal/types/numbererrorcode_string.go | 25 ++++ 9 files changed, 423 insertions(+), 12 deletions(-) create mode 100644 internal/handlers/common/aggregations/group_sum.go create mode 100644 internal/types/number.go create mode 100644 internal/types/numbererrorcode_string.go diff --git a/integration/aggregate_compat_test.go b/integration/aggregate_compat_test.go index 55ea7a0aab87..757c1df5500c 100644 --- a/integration/aggregate_compat_test.go +++ b/integration/aggregate_compat_test.go @@ -348,6 +348,7 @@ func TestAggregateCompatGroupDeterministicCollections(t *testing.T) { shareddata.Doubles, shareddata.BigDoubles, + shareddata.SmallDoubles, shareddata.Strings, shareddata.Binaries, shareddata.ObjectIDs, @@ -551,6 +552,7 @@ func TestAggregateCompatGroupDotNotation(t *testing.T) { shareddata.Doubles, shareddata.BigDoubles, + shareddata.SmallDoubles, shareddata.Strings, shareddata.Binaries, shareddata.ObjectIDs, @@ -624,6 +626,7 @@ func TestAggregateCompatGroupDocDotNotation(t *testing.T) { shareddata.Doubles, shareddata.BigDoubles, + shareddata.SmallDoubles, shareddata.Strings, shareddata.Binaries, shareddata.ObjectIDs, @@ -712,6 +715,119 @@ func TestAggregateCompatGroupCount(t *testing.T) { testAggregateStagesCompat(t, testCases) } +func TestAggregateCompatGroupSum(t *testing.T) { + // Scalars and BigDoubles are skipped as they produce `Infinity`. + providers := []shareddata.Provider{ + // shareddata.Scalars, + + // TODO: handle doubles close to max precision in doubles. + // shareddata.Doubles, + // shareddata.BigDoubles, + shareddata.SmallDoubles, + shareddata.Strings, + shareddata.Binaries, + shareddata.ObjectIDs, + shareddata.Bools, + shareddata.DateTimes, + shareddata.Nulls, + shareddata.Regexes, + shareddata.Int32s, + shareddata.Timestamps, + shareddata.Int64s, + shareddata.Unsets, + shareddata.ObjectIDKeys, + + shareddata.Composites, + shareddata.PostgresEdgeCases, + + shareddata.DocumentsDoubles, + shareddata.DocumentsStrings, + shareddata.DocumentsDocuments, + + shareddata.ArrayStrings, + shareddata.ArrayDoubles, + shareddata.ArrayInt32s, + shareddata.ArrayRegexes, + shareddata.ArrayDocuments, + + shareddata.Mixed, + shareddata.ArrayAndDocuments, + } + + testCases := map[string]aggregateStagesCompatTestCase{ + "Value": { + pipeline: bson.A{ + bson.D{{"$sort", bson.D{{"_id", 1}}}}, + bson.D{{"$group", bson.D{ + {"_id", nil}, + {"sum", bson.D{{"$sum", "$v"}}}, + }}}, + }, + }, + "EmptyString": { + pipeline: bson.A{bson.D{{"$group", bson.D{ + {"_id", nil}, + {"count", bson.D{{"$sum", ""}}}, + }}}}, + }, + "NonExpression": { + pipeline: bson.A{bson.D{{"$group", bson.D{ + {"_id", nil}, + {"sum", bson.D{{"$sum", "v"}}}, + }}}}, + }, + "NonExistent": { + pipeline: bson.A{bson.D{{"$group", bson.D{ + {"_id", nil}, + {"sum", bson.D{{"$sum", "$non-existent"}}}, + }}}}, + }, + "Document": { + pipeline: bson.A{bson.D{{"$group", bson.D{ + {"_id", nil}, + {"sum", bson.D{{"$sum", bson.D{}}}}, + }}}}, + }, + "ArraySum": { + pipeline: bson.A{bson.D{{"$group", bson.D{ + {"_id", nil}, + {"sum", bson.D{{"$sum", bson.A{"$v", "$c"}}}}, + }}}}, + resultType: emptyResult, + }, + "Int32": { + pipeline: bson.A{bson.D{{"$group", bson.D{ + {"_id", nil}, + {"sum", bson.D{{"$sum", int32(1)}}}}}}}, + }, + "Int64": { + pipeline: bson.A{bson.D{{"$group", bson.D{ + {"_id", nil}, + {"sum", bson.D{{"$sum", int64(20)}}}}}}}, + }, + "Double": { + pipeline: bson.A{bson.D{{"$group", bson.D{ + {"_id", nil}, + {"sum", bson.D{{"$sum", 43.7}}}}}}}, + }, + "Bool": { + pipeline: bson.A{bson.D{{"$group", bson.D{ + {"_id", nil}, + {"sum", bson.D{{"$sum", true}}}}}}}, + }, + "Duplicate": { + pipeline: bson.A{bson.D{{"$group", bson.D{ + {"_id", "$v"}, + {"sum", bson.D{{"$sum", "$v"}}}, + {"sum", bson.D{{"$sum", "$s"}}}, + }}}}, + resultType: emptyResult, + }, + } + + testAggregateStagesCompatWithProviders(t, providers, testCases) +} + func TestAggregateCompatMatch(t *testing.T) { testCases := map[string]aggregateStagesCompatTestCase{ "ID": { diff --git a/integration/shareddata/scalars.go b/integration/shareddata/scalars.go index 16c4ba67c7ed..c983fbd31ac7 100644 --- a/integration/shareddata/scalars.go +++ b/integration/shareddata/scalars.go @@ -182,6 +182,23 @@ var BigDoubles = &Values[string]{ }, } +// SmallDoubles contains double values that does not go close to +// maximum precision for tests. +var SmallDoubles = &Values[string]{ + name: "SmallDoubles", + backends: []string{"ferretdb-pg", "ferretdb-tigris", "mongodb"}, + validators: map[string]map[string]any{ + "ferretdb-tigris": { + "$tigrisSchemaString": tigrisSchema(`"type": "number"`), + }, + }, + data: map[string]any{ + "double": 42.13, + "double-whole": 42.0, + "double-smallest": math.SmallestNonzeroFloat64, + }, +} + // Strings contains string values for tests. // Tigris JSON schema validator contains extra properties to make it suitable for more tests. var Strings = &Values[string]{ diff --git a/integration/shareddata/shareddata.go b/integration/shareddata/shareddata.go index d3177809b34b..5f997e41e7bd 100644 --- a/integration/shareddata/shareddata.go +++ b/integration/shareddata/shareddata.go @@ -51,6 +51,7 @@ func AllProviders() Providers { Doubles, BigDoubles, + SmallDoubles, Strings, Binaries, ObjectIDs, diff --git a/internal/handlers/common/aggregations/group.go b/internal/handlers/common/aggregations/group.go index 5f3e4e42ef2a..2f104d137cc6 100644 --- a/internal/handlers/common/aggregations/group.go +++ b/internal/handlers/common/aggregations/group.go @@ -40,6 +40,7 @@ type Accumulator interface { var accumulators = map[string]newAccumulatorFunc{ // sorted alphabetically "$count": newCountAccumulator, + "$sum": newSumAccumulator, } // groupStage represents $group stage. diff --git a/internal/handlers/common/aggregations/group_sum.go b/internal/handlers/common/aggregations/group_sum.go new file mode 100644 index 000000000000..732a781cc3d8 --- /dev/null +++ b/internal/handlers/common/aggregations/group_sum.go @@ -0,0 +1,117 @@ +// Copyright 2021 FerretDB Inc. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package aggregations + +import ( + "context" + "errors" + "github.com/FerretDB/FerretDB/internal/handlers/commonerrors" + "github.com/FerretDB/FerretDB/internal/types" + "github.com/FerretDB/FerretDB/internal/util/lazyerrors" + "github.com/FerretDB/FerretDB/internal/util/must" + "time" +) + +// sumAccumulator represents $sum accumulator for $group. +type sumAccumulator struct { + expression types.Expression + n any +} + +// newSumAccumulator creates a new $sum accumulator for $group. +func newSumAccumulator(accumulation *types.Document) (Accumulator, error) { + expr := must.NotFail(accumulation.Get("$sum")) + + accumulator := new(sumAccumulator) + + switch expr := expr.(type) { + case *types.Document: + case *types.Array: + return nil, commonerrors.NewCommandErrorMsgWithArgument( + commonerrors.ErrStageGroupUnaryOperator, + "The $sum accumulator is a unary operator", + "$sum (accumulator)", + ) + case float64: + accumulator.n = expr + case string: + // get field expression + var err error + accumulator.expression, err = types.NewExpression(expr) + + var fieldPathErr *types.FieldPathError + if errors.As(err, &fieldPathErr) && fieldPathErr.Code() == types.ErrNotFieldPath { + // when field is not a path, ignore this error. + } else { + if err != nil { + return nil, commonerrors.NewCommandErrorMsgWithArgument( + commonerrors.ErrTypeMismatch, + "$sum takes no arguments, i.e. $sum:{}", + "$sum (accumulator)", + ) + } + } + case types.Binary: + case types.ObjectID: + case bool: + case time.Time: + case types.NullType: + case types.Regex: + case int32: + accumulator.n = expr + case types.Timestamp: + case int64: + accumulator.n = expr + default: + // $sum ignores non-existent field + } + + return accumulator, nil +} + +// Accumulate implements Accumulator interface. +func (s *sumAccumulator) Accumulate(ctx context.Context, groupID any, grouped []*types.Document) (any, error) { + if s.expression != nil { + var values []any + for _, doc := range grouped { + v := s.expression.Evaluate(doc) + values = append(values, v) + } + + res, err := types.AddNumbers(values...) + if err != nil { + // handle INF + return nil, lazyerrors.Error(err) + } + return res, err + } + + switch n := s.n.(type) { + case float64: + return float64(len(grouped)) * n, nil + case int32: + return int32(len(grouped)) * n, nil + case int64: + return int64(len(grouped)) * n, nil + } + + // $sum returns 0 on non-existent and non-numeric field. + return int32(0), nil +} + +// check interfaces +var ( + _ Accumulator = (*sumAccumulator)(nil) +) diff --git a/internal/handlers/commonerrors/error.go b/internal/handlers/commonerrors/error.go index 2a02e1170e29..526ec9263aa6 100644 --- a/internal/handlers/commonerrors/error.go +++ b/internal/handlers/commonerrors/error.go @@ -145,6 +145,9 @@ const ( // ErrStageCountBadValue indicates that $count stage contains invalid value. ErrStageCountBadValue = ErrorCode(40160) // Location40160 + // ErrStageGroupUnaryOperator indicates that $sum is a unary operator. + ErrStageGroupUnaryOperator = ErrorCode(40237) // Location40237 + // ErrStageGroupMultipleAccumulator indicates that group field must specify one accumulator. ErrStageGroupMultipleAccumulator = ErrorCode(40238) // Location40238 diff --git a/internal/handlers/commonerrors/errorcode_string.go b/internal/handlers/commonerrors/errorcode_string.go index 336dab84b5e0..fa155581141b 100644 --- a/internal/handlers/commonerrors/errorcode_string.go +++ b/internal/handlers/commonerrors/errorcode_string.go @@ -46,6 +46,7 @@ func _() { _ = x[ErrStageCountNonEmptyString-40157] _ = x[ErrStageCountBadPrefix-40158] _ = x[ErrStageCountBadValue-40160] + _ = x[ErrStageGroupUnaryOperator-40237] _ = x[ErrStageGroupMultipleAccumulator-40238] _ = x[ErrStageInvalid-40323] _ = x[ErrStageGroupInvalidAccumulator-40234] @@ -60,7 +61,7 @@ func _() { _ = x[ErrDuplicateField-4822819] } -const _ErrorCode_name = "UnsetInternalErrorBadValueFailedToParseTypeMismatchNamespaceNotFoundUnsuitableValueTypeConflictingUpdateOperatorsCursorNotFoundNamespaceExistsInvalidIDEmptyNameCommandNotFoundInvalidNamespaceOperationFailedDocumentValidationFailureNotImplementedMechanismUnavailableLocation11000Location15947Location15948Location15955Location15959Location15973Location15974Location15975Location15976Location15998Location16872Location17276Location28667Location28724Location31253Location31254Location40156Location40157Location40158Location40160Location40234Location40238Location40323Location40352Location40414Location40415Location50840Location51024Location51075Location51091Location51108Location4822819" +const _ErrorCode_name = "UnsetInternalErrorBadValueFailedToParseTypeMismatchNamespaceNotFoundUnsuitableValueTypeConflictingUpdateOperatorsCursorNotFoundNamespaceExistsInvalidIDEmptyNameCommandNotFoundInvalidNamespaceOperationFailedDocumentValidationFailureNotImplementedMechanismUnavailableLocation11000Location15947Location15948Location15955Location15959Location15973Location15974Location15975Location15976Location15998Location16872Location17276Location28667Location28724Location31253Location31254Location40156Location40157Location40158Location40160Location40234Location40237Location40238Location40323Location40352Location40414Location40415Location50840Location51024Location51075Location51091Location51108Location4822819" var _ErrorCode_map = map[ErrorCode]string{ 0: _ErrorCode_name[0:5], @@ -102,17 +103,18 @@ var _ErrorCode_map = map[ErrorCode]string{ 40158: _ErrorCode_name[499:512], 40160: _ErrorCode_name[512:525], 40234: _ErrorCode_name[525:538], - 40238: _ErrorCode_name[538:551], - 40323: _ErrorCode_name[551:564], - 40352: _ErrorCode_name[564:577], - 40414: _ErrorCode_name[577:590], - 40415: _ErrorCode_name[590:603], - 50840: _ErrorCode_name[603:616], - 51024: _ErrorCode_name[616:629], - 51075: _ErrorCode_name[629:642], - 51091: _ErrorCode_name[642:655], - 51108: _ErrorCode_name[655:668], - 4822819: _ErrorCode_name[668:683], + 40237: _ErrorCode_name[538:551], + 40238: _ErrorCode_name[551:564], + 40323: _ErrorCode_name[564:577], + 40352: _ErrorCode_name[577:590], + 40414: _ErrorCode_name[590:603], + 40415: _ErrorCode_name[603:616], + 50840: _ErrorCode_name[616:629], + 51024: _ErrorCode_name[629:642], + 51075: _ErrorCode_name[642:655], + 51091: _ErrorCode_name[655:668], + 51108: _ErrorCode_name[668:681], + 4822819: _ErrorCode_name[681:696], } func (i ErrorCode) String() string { diff --git a/internal/types/number.go b/internal/types/number.go new file mode 100644 index 000000000000..619bf06ef626 --- /dev/null +++ b/internal/types/number.go @@ -0,0 +1,129 @@ +// Copyright 2021 FerretDB Inc. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package types + +import ( + "math" + "math/big" +) + +//go:generate ../../bin/stringer -linecomment -type NumberErrorCode + +// NumberErrorCode represents error code from numerical operation. +type NumberErrorCode int + +const ( + _ NumberErrorCode = iota + + // ErrLongExceeded indicates that long exceeded in its size. + ErrLongExceeded + + // ErrNotExactResult indicates that float addition dropped precision. + ErrNotExactResult +) + +// NumberError describes an error that occurs applying number operation. +type NumberError struct { + code NumberErrorCode +} + +// newNumberError creates a new NumberError. +func newNumberError(code NumberErrorCode) error { + return &NumberError{code: code} +} + +// Error implements the error interface. +func (e *NumberError) Error() string { + return e.code.String() +} + +// Code returns the FieldPathError code. +func (e *NumberError) Code() NumberErrorCode { + return e.code +} + +// AddNumbers returns the result of addition and error if addition failed. +func AddNumbers(vs ...any) (any, error) { + sum := big.NewInt(0) + sumFloat := big.NewFloat(0) + + var hasFloat64, hasInt64 bool + + for _, v := range vs { + switch v := v.(type) { + case float64: + hasFloat64 = true + if v > MaxSafeDouble { + // todo handle lost precision + smallPart := v - MaxSafeDouble + sum.Add(sum, big.NewInt(int64(MaxSafeDouble))) + sumFloat.Add(sumFloat, big.NewFloat(smallPart)) + continue + } + + if v < -MaxSafeDouble { + // todo handle lost precision + smallPart := v + MaxSafeDouble + + sum.Add(sum, big.NewInt(int64(-MaxSafeDouble))) + sumFloat.Add(sumFloat, big.NewFloat(smallPart)) + continue + } + + // todo check overflow + sumFloat.Add(sumFloat, big.NewFloat(v)) + case int32: + sum.Add(sum, big.NewInt(int64(v))) + case int64: + hasInt64 = true + sum.Add(sum, big.NewInt(v)) + default: + // ignore non-number + } + } + + if !sum.IsInt64() { + return nil, newNumberError(ErrLongExceeded) + } + + sumBig := sum.Int64() + + res := sumBig + + if hasFloat64 { + f, accuracy := sumFloat.Float64() + if accuracy != big.Exact { + return nil, newNumberError(ErrNotExactResult) + } + + // todo check overflow + if sumBig > int64(MaxSafeDouble) || sumBig < -int64(MaxSafeDouble) { + // not accurate result + return float64(sumBig) + f, nil + } + + return float64(sumBig) + f, nil + } + + if hasInt64 { + return res, nil + } + + if res < math.MaxInt32 && res > math.MinInt32 { + return int32(res), nil + } + + return res, nil +} diff --git a/internal/types/numbererrorcode_string.go b/internal/types/numbererrorcode_string.go new file mode 100644 index 000000000000..ed7263202253 --- /dev/null +++ b/internal/types/numbererrorcode_string.go @@ -0,0 +1,25 @@ +// Code generated by "stringer -linecomment -type NumberErrorCode"; DO NOT EDIT. + +package types + +import "strconv" + +func _() { + // An "invalid array index" compiler error signifies that the constant values have changed. + // Re-run the stringer command to generate them again. + var x [1]struct{} + _ = x[ErrLongExceeded-1] + _ = x[ErrNotExactResult-2] +} + +const _NumberErrorCode_name = "ErrLongExceededErrNotExactResult" + +var _NumberErrorCode_index = [...]uint8{0, 15, 32} + +func (i NumberErrorCode) String() string { + i -= 1 + if i < 0 || i >= NumberErrorCode(len(_NumberErrorCode_index)-1) { + return "NumberErrorCode(" + strconv.FormatInt(int64(i+1), 10) + ")" + } + return _NumberErrorCode_name[_NumberErrorCode_index[i]:_NumberErrorCode_index[i+1]] +} From 36cc33e034b16757520ccf73d56c359dc94fd309 Mon Sep 17 00:00:00 2001 From: Chi Fujii Date: Tue, 28 Mar 2023 13:17:14 +0900 Subject: [PATCH 02/11] more impl --- integration/aggregate_compat_test.go | 166 ++++++++++++------ integration/shareddata/scalars.go | 18 +- integration/shareddata/shareddata.go | 2 +- integration/update_field_compat_test.go | 4 +- .../handlers/common/aggregations/group_sum.go | 61 ++----- .../handlers/common/aggregations/number.go | 71 ++++++++ internal/types/number.go | 129 -------------- internal/types/numbererrorcode_string.go | 25 --- 8 files changed, 218 insertions(+), 258 deletions(-) create mode 100644 internal/handlers/common/aggregations/number.go delete mode 100644 internal/types/number.go delete mode 100644 internal/types/numbererrorcode_string.go diff --git a/integration/aggregate_compat_test.go b/integration/aggregate_compat_test.go index 757c1df5500c..8f8e1c690ed6 100644 --- a/integration/aggregate_compat_test.go +++ b/integration/aggregate_compat_test.go @@ -347,7 +347,7 @@ func TestAggregateCompatGroupDeterministicCollections(t *testing.T) { // shareddata.Scalars, shareddata.Doubles, - shareddata.BigDoubles, + shareddata.OverflowVergeDoubles, shareddata.SmallDoubles, shareddata.Strings, shareddata.Binaries, @@ -551,7 +551,7 @@ func TestAggregateCompatGroupDotNotation(t *testing.T) { shareddata.Scalars, shareddata.Doubles, - shareddata.BigDoubles, + shareddata.OverflowVergeDoubles, shareddata.SmallDoubles, shareddata.Strings, shareddata.Binaries, @@ -625,7 +625,7 @@ func TestAggregateCompatGroupDocDotNotation(t *testing.T) { shareddata.Scalars, shareddata.Doubles, - shareddata.BigDoubles, + shareddata.OverflowVergeDoubles, shareddata.SmallDoubles, shareddata.Strings, shareddata.Binaries, @@ -716,13 +716,15 @@ func TestAggregateCompatGroupCount(t *testing.T) { } func TestAggregateCompatGroupSum(t *testing.T) { - // Scalars and BigDoubles are skipped as they produce `Infinity`. + // Doubles is skipped as they produce wrong result due to inaccurate precision. + // Composites, ArrayStrings, ArrayInt32s, Mixed and ArrayAndDocuments are skipped due to + // https://github.com/FerretDB/FerretDB/issues/2185. providers := []shareddata.Provider{ - // shareddata.Scalars, + shareddata.Scalars, // TODO: handle doubles close to max precision in doubles. // shareddata.Doubles, - // shareddata.BigDoubles, + shareddata.OverflowVergeDoubles, shareddata.SmallDoubles, shareddata.Strings, shareddata.Binaries, @@ -737,90 +739,154 @@ func TestAggregateCompatGroupSum(t *testing.T) { shareddata.Unsets, shareddata.ObjectIDKeys, - shareddata.Composites, + // shareddata.Composites, shareddata.PostgresEdgeCases, shareddata.DocumentsDoubles, shareddata.DocumentsStrings, shareddata.DocumentsDocuments, - shareddata.ArrayStrings, + // shareddata.ArrayStrings, shareddata.ArrayDoubles, - shareddata.ArrayInt32s, + // shareddata.ArrayInt32s, shareddata.ArrayRegexes, shareddata.ArrayDocuments, - shareddata.Mixed, - shareddata.ArrayAndDocuments, + // shareddata.Mixed, + // shareddata.ArrayAndDocuments, } testCases := map[string]aggregateStagesCompatTestCase{ - "Value": { + "GroupNullID": { pipeline: bson.A{ + // Without $sort sum of large values results in wrong result. bson.D{{"$sort", bson.D{{"_id", 1}}}}, bson.D{{"$group", bson.D{ {"_id", nil}, {"sum", bson.D{{"$sum", "$v"}}}, }}}, + // Without $sort documents are ordered not the same. + // Descending sort is used because it is more unique than + // ascending sort for shared data. + bson.D{{"$sort", bson.D{{"_id", -1}}}}, + }, + }, + "GroupByID": { + pipeline: bson.A{ + bson.D{{"$sort", bson.D{{"_id", 1}}}}, + bson.D{{"$group", bson.D{ + {"_id", "$_id"}, + {"sum", bson.D{{"$sum", "$v"}}}, + }}}, + bson.D{{"$sort", bson.D{{"_id", -1}}}}, + }, + }, + "GroupByValue": { + pipeline: bson.A{ + bson.D{{"$sort", bson.D{{"_id", 1}}}}, + bson.D{{"$group", bson.D{ + {"_id", "$v"}, + {"sum", bson.D{{"$sum", "$v"}}}, + }}}, + bson.D{{"$sort", bson.D{{"_id", -1}}}}, }, }, "EmptyString": { - pipeline: bson.A{bson.D{{"$group", bson.D{ - {"_id", nil}, - {"count", bson.D{{"$sum", ""}}}, - }}}}, + pipeline: bson.A{ + bson.D{{"$sort", bson.D{{"_id", 1}}}}, + bson.D{{"$group", bson.D{ + {"_id", "$v"}, + {"sum", bson.D{{"$sum", ""}}}, + }}}, + bson.D{{"$sort", bson.D{{"_id", -1}}}}, + }, }, "NonExpression": { - pipeline: bson.A{bson.D{{"$group", bson.D{ - {"_id", nil}, - {"sum", bson.D{{"$sum", "v"}}}, - }}}}, + pipeline: bson.A{ + bson.D{{"$sort", bson.D{{"_id", 1}}}}, + bson.D{{"$group", bson.D{ + {"_id", nil}, + {"sum", bson.D{{"$sum", "v"}}}, + }}}}, }, "NonExistent": { - pipeline: bson.A{bson.D{{"$group", bson.D{ - {"_id", nil}, - {"sum", bson.D{{"$sum", "$non-existent"}}}, - }}}}, + pipeline: bson.A{ + bson.D{{"$sort", bson.D{{"_id", 1}}}}, + bson.D{{"$group", bson.D{ + {"_id", "$v"}, + {"sum", bson.D{{"$sum", "$non-existent"}}}, + }}}, + bson.D{{"$sort", bson.D{{"_id", -1}}}}, + }, }, "Document": { - pipeline: bson.A{bson.D{{"$group", bson.D{ - {"_id", nil}, - {"sum", bson.D{{"$sum", bson.D{}}}}, - }}}}, + pipeline: bson.A{ + bson.D{{"$sort", bson.D{{"_id", 1}}}}, + + bson.D{{"$group", bson.D{ + {"_id", "$v"}, + {"sum", bson.D{{"$sum", bson.D{}}}}, + }}}, + bson.D{{"$sort", bson.D{{"_id", -1}}}}, + }, }, - "ArraySum": { - pipeline: bson.A{bson.D{{"$group", bson.D{ - {"_id", nil}, - {"sum", bson.D{{"$sum", bson.A{"$v", "$c"}}}}, - }}}}, + "Array": { + pipeline: bson.A{ + bson.D{{"$sort", bson.D{{"_id", 1}}}}, + bson.D{{"$group", bson.D{ + {"_id", "$v"}, + {"sum", bson.D{{"$sum", bson.A{"$v", "$c"}}}}, + }}}}, resultType: emptyResult, }, "Int32": { - pipeline: bson.A{bson.D{{"$group", bson.D{ - {"_id", nil}, - {"sum", bson.D{{"$sum", int32(1)}}}}}}}, + pipeline: bson.A{ + bson.D{{"$sort", bson.D{{"_id", 1}}}}, + bson.D{{"$group", bson.D{ + {"_id", "$v"}, + {"sum", bson.D{{"$sum", int32(1)}}}, + }}}, + bson.D{{"$sort", bson.D{{"_id", -1}}}}, + }, }, "Int64": { - pipeline: bson.A{bson.D{{"$group", bson.D{ - {"_id", nil}, - {"sum", bson.D{{"$sum", int64(20)}}}}}}}, + pipeline: bson.A{ + bson.D{{"$sort", bson.D{{"_id", 1}}}}, + bson.D{{"$group", bson.D{ + {"_id", "$v"}, + {"sum", bson.D{{"$sum", int64(20)}}}, + }}}, + bson.D{{"$sort", bson.D{{"_id", -1}}}}, + }, }, "Double": { - pipeline: bson.A{bson.D{{"$group", bson.D{ - {"_id", nil}, - {"sum", bson.D{{"$sum", 43.7}}}}}}}, + pipeline: bson.A{ + bson.D{{"$sort", bson.D{{"_id", 1}}}}, + bson.D{{"$group", bson.D{ + {"_id", "$v"}, + {"sum", bson.D{{"$sum", 43.7}}}, + }}}, + bson.D{{"$sort", bson.D{{"_id", -1}}}}, + }, }, "Bool": { - pipeline: bson.A{bson.D{{"$group", bson.D{ - {"_id", nil}, - {"sum", bson.D{{"$sum", true}}}}}}}, + pipeline: bson.A{ + bson.D{{"$sort", bson.D{{"_id", 1}}}}, + bson.D{{"$group", bson.D{ + {"_id", "$v"}, + {"sum", bson.D{{"$sum", true}}}, + }}}, + bson.D{{"$sort", bson.D{{"_id", -1}}}}, + }, }, "Duplicate": { - pipeline: bson.A{bson.D{{"$group", bson.D{ - {"_id", "$v"}, - {"sum", bson.D{{"$sum", "$v"}}}, - {"sum", bson.D{{"$sum", "$s"}}}, - }}}}, + pipeline: bson.A{ + bson.D{{"$sort", bson.D{{"_id", 1}}}}, + bson.D{{"$group", bson.D{ + {"_id", "$v"}, + {"sum", bson.D{{"$sum", "$v"}}}, + {"sum", bson.D{{"$sum", "$s"}}}, + }}}}, resultType: emptyResult, }, } diff --git a/integration/shareddata/scalars.go b/integration/shareddata/scalars.go index c983fbd31ac7..ec8696c236eb 100644 --- a/integration/shareddata/scalars.go +++ b/integration/shareddata/scalars.go @@ -163,13 +163,13 @@ var Doubles = &Values[string]{ }, } -// BigDoubles contains double values which would overflow on +// OverflowVergeDoubles contains double values which would overflow on // numeric update operation such as $mul. Upon such, // target returns error and compat returns +INF or -INF. -// BigDoubles may be excluded on such update tests and tested +// OverflowVergeDoubles may be excluded on such update tests and tested // in diff tests https://github.com/FerretDB/dance. -var BigDoubles = &Values[string]{ - name: "BigDoubles", +var OverflowVergeDoubles = &Values[string]{ + name: "OverflowVergeDoubles", backends: []string{"ferretdb-pg", "ferretdb-tigris", "mongodb"}, validators: map[string]map[string]any{ "ferretdb-tigris": { @@ -183,7 +183,7 @@ var BigDoubles = &Values[string]{ } // SmallDoubles contains double values that does not go close to -// maximum precision for tests. +// the maximum safe precision for tests. var SmallDoubles = &Values[string]{ name: "SmallDoubles", backends: []string{"ferretdb-pg", "ferretdb-tigris", "mongodb"}, @@ -193,9 +193,11 @@ var SmallDoubles = &Values[string]{ }, }, data: map[string]any{ - "double": 42.13, - "double-whole": 42.0, - "double-smallest": math.SmallestNonzeroFloat64, + "double": 42.13, + "double-whole": 42.0, + "double-1": 4080.1234, + "double-2": 1048560.0099, + "double-3": 268435440.2, }, } diff --git a/integration/shareddata/shareddata.go b/integration/shareddata/shareddata.go index 5f997e41e7bd..1256a6108c7f 100644 --- a/integration/shareddata/shareddata.go +++ b/integration/shareddata/shareddata.go @@ -50,7 +50,7 @@ func AllProviders() Providers { Scalars, Doubles, - BigDoubles, + OverflowVergeDoubles, SmallDoubles, Strings, Binaries, diff --git a/integration/update_field_compat_test.go b/integration/update_field_compat_test.go index 26f45eb6ef45..b6e2ffcad6c2 100644 --- a/integration/update_field_compat_test.go +++ b/integration/update_field_compat_test.go @@ -1023,9 +1023,9 @@ func TestUpdateFieldCompatMul(t *testing.T) { t.Parallel() providers := shareddata.AllProviders(). - // BigDoubles and Scalars contain numbers that produces +INF on compat, + // OverflowVergeDoubles and Scalars contain numbers that produces +INF on compat, // validation error on target upon $mul operation. - Remove("BigDoubles", "Scalars") + Remove("OverflowVergeDoubles", "Scalars") testCases := map[string]updateCompatTestCase{ "Int32": { diff --git a/internal/handlers/common/aggregations/group_sum.go b/internal/handlers/common/aggregations/group_sum.go index 732a781cc3d8..426d763f3982 100644 --- a/internal/handlers/common/aggregations/group_sum.go +++ b/internal/handlers/common/aggregations/group_sum.go @@ -16,28 +16,24 @@ package aggregations import ( "context" - "errors" + "github.com/FerretDB/FerretDB/internal/handlers/commonerrors" "github.com/FerretDB/FerretDB/internal/types" - "github.com/FerretDB/FerretDB/internal/util/lazyerrors" "github.com/FerretDB/FerretDB/internal/util/must" - "time" ) // sumAccumulator represents $sum accumulator for $group. type sumAccumulator struct { expression types.Expression - n any + count any } // newSumAccumulator creates a new $sum accumulator for $group. func newSumAccumulator(accumulation *types.Document) (Accumulator, error) { expr := must.NotFail(accumulation.Get("$sum")) - accumulator := new(sumAccumulator) switch expr := expr.(type) { - case *types.Document: case *types.Array: return nil, commonerrors.NewCommandErrorMsgWithArgument( commonerrors.ErrStageGroupUnaryOperator, @@ -45,37 +41,18 @@ func newSumAccumulator(accumulation *types.Document) (Accumulator, error) { "$sum (accumulator)", ) case float64: - accumulator.n = expr + accumulator.count = expr case string: - // get field expression var err error - accumulator.expression, err = types.NewExpression(expr) - - var fieldPathErr *types.FieldPathError - if errors.As(err, &fieldPathErr) && fieldPathErr.Code() == types.ErrNotFieldPath { - // when field is not a path, ignore this error. - } else { - if err != nil { - return nil, commonerrors.NewCommandErrorMsgWithArgument( - commonerrors.ErrTypeMismatch, - "$sum takes no arguments, i.e. $sum:{}", - "$sum (accumulator)", - ) - } + if accumulator.expression, err = types.NewExpression(expr); err != nil { + // $sum returns 0 on non-existent field. + accumulator.count = int32(0) } - case types.Binary: - case types.ObjectID: - case bool: - case time.Time: - case types.NullType: - case types.Regex: - case int32: - accumulator.n = expr - case types.Timestamp: - case int64: - accumulator.n = expr + case int32, int64: + accumulator.count = expr default: - // $sum ignores non-existent field + accumulator.count = int32(0) + // $sum returns 0 on non-numeric field } return accumulator, nil @@ -85,26 +62,24 @@ func newSumAccumulator(accumulation *types.Document) (Accumulator, error) { func (s *sumAccumulator) Accumulate(ctx context.Context, groupID any, grouped []*types.Document) (any, error) { if s.expression != nil { var values []any + for _, doc := range grouped { v := s.expression.Evaluate(doc) values = append(values, v) } - res, err := types.AddNumbers(values...) - if err != nil { - // handle INF - return nil, lazyerrors.Error(err) - } - return res, err + res := sumNumbers(values...) + + return res, nil } - switch n := s.n.(type) { + switch count := s.count.(type) { case float64: - return float64(len(grouped)) * n, nil + return float64(len(grouped)) * count, nil case int32: - return int32(len(grouped)) * n, nil + return int32(len(grouped)) * count, nil case int64: - return int64(len(grouped)) * n, nil + return int64(len(grouped)) * count, nil } // $sum returns 0 on non-existent and non-numeric field. diff --git a/internal/handlers/common/aggregations/number.go b/internal/handlers/common/aggregations/number.go new file mode 100644 index 000000000000..f21eaac2a568 --- /dev/null +++ b/internal/handlers/common/aggregations/number.go @@ -0,0 +1,71 @@ +// Copyright 2021 FerretDB Inc. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package aggregations + +import ( + "math" + "math/big" +) + +// sumNumbers accumulate numbers and returns the result of summation. +// It ignores non-number values. +// This should only be used for aggregation, aggregation does not return +// error on overflow. +func sumNumbers(vs ...any) any { + // use big.Int to accumulate values larger than math.MaxInt64. + sumInt := big.NewInt(0) + + var sumFloat float64 + + var hasFloat64, hasInt64 bool + + for _, v := range vs { + switch v := v.(type) { + case float64: + hasFloat64 = true + + sumFloat = sumFloat + v + case int32: + sumInt.Add(sumInt, big.NewInt(int64(v))) + case int64: + hasInt64 = true + + sumInt.Add(sumInt, big.NewInt(v)) + default: + // ignore non-number + } + } + + if !sumInt.IsInt64() { + // TODO: handle overflow + return sumInt.Int64() + } + + if hasFloat64 { + // return float64 + // TODO: handle infinity + return float64(sumInt.Int64()) + sumFloat + } + + res := sumInt.Int64() + + if !hasInt64 && res <= math.MaxInt32 && res >= math.MinInt32 { + // convert to int32 when input is int32 only and can be represented in int32. + return int32(res) + } + + // return int64 + return res +} diff --git a/internal/types/number.go b/internal/types/number.go deleted file mode 100644 index 619bf06ef626..000000000000 --- a/internal/types/number.go +++ /dev/null @@ -1,129 +0,0 @@ -// Copyright 2021 FerretDB Inc. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -package types - -import ( - "math" - "math/big" -) - -//go:generate ../../bin/stringer -linecomment -type NumberErrorCode - -// NumberErrorCode represents error code from numerical operation. -type NumberErrorCode int - -const ( - _ NumberErrorCode = iota - - // ErrLongExceeded indicates that long exceeded in its size. - ErrLongExceeded - - // ErrNotExactResult indicates that float addition dropped precision. - ErrNotExactResult -) - -// NumberError describes an error that occurs applying number operation. -type NumberError struct { - code NumberErrorCode -} - -// newNumberError creates a new NumberError. -func newNumberError(code NumberErrorCode) error { - return &NumberError{code: code} -} - -// Error implements the error interface. -func (e *NumberError) Error() string { - return e.code.String() -} - -// Code returns the FieldPathError code. -func (e *NumberError) Code() NumberErrorCode { - return e.code -} - -// AddNumbers returns the result of addition and error if addition failed. -func AddNumbers(vs ...any) (any, error) { - sum := big.NewInt(0) - sumFloat := big.NewFloat(0) - - var hasFloat64, hasInt64 bool - - for _, v := range vs { - switch v := v.(type) { - case float64: - hasFloat64 = true - if v > MaxSafeDouble { - // todo handle lost precision - smallPart := v - MaxSafeDouble - sum.Add(sum, big.NewInt(int64(MaxSafeDouble))) - sumFloat.Add(sumFloat, big.NewFloat(smallPart)) - continue - } - - if v < -MaxSafeDouble { - // todo handle lost precision - smallPart := v + MaxSafeDouble - - sum.Add(sum, big.NewInt(int64(-MaxSafeDouble))) - sumFloat.Add(sumFloat, big.NewFloat(smallPart)) - continue - } - - // todo check overflow - sumFloat.Add(sumFloat, big.NewFloat(v)) - case int32: - sum.Add(sum, big.NewInt(int64(v))) - case int64: - hasInt64 = true - sum.Add(sum, big.NewInt(v)) - default: - // ignore non-number - } - } - - if !sum.IsInt64() { - return nil, newNumberError(ErrLongExceeded) - } - - sumBig := sum.Int64() - - res := sumBig - - if hasFloat64 { - f, accuracy := sumFloat.Float64() - if accuracy != big.Exact { - return nil, newNumberError(ErrNotExactResult) - } - - // todo check overflow - if sumBig > int64(MaxSafeDouble) || sumBig < -int64(MaxSafeDouble) { - // not accurate result - return float64(sumBig) + f, nil - } - - return float64(sumBig) + f, nil - } - - if hasInt64 { - return res, nil - } - - if res < math.MaxInt32 && res > math.MinInt32 { - return int32(res), nil - } - - return res, nil -} diff --git a/internal/types/numbererrorcode_string.go b/internal/types/numbererrorcode_string.go deleted file mode 100644 index ed7263202253..000000000000 --- a/internal/types/numbererrorcode_string.go +++ /dev/null @@ -1,25 +0,0 @@ -// Code generated by "stringer -linecomment -type NumberErrorCode"; DO NOT EDIT. - -package types - -import "strconv" - -func _() { - // An "invalid array index" compiler error signifies that the constant values have changed. - // Re-run the stringer command to generate them again. - var x [1]struct{} - _ = x[ErrLongExceeded-1] - _ = x[ErrNotExactResult-2] -} - -const _NumberErrorCode_name = "ErrLongExceededErrNotExactResult" - -var _NumberErrorCode_index = [...]uint8{0, 15, 32} - -func (i NumberErrorCode) String() string { - i -= 1 - if i < 0 || i >= NumberErrorCode(len(_NumberErrorCode_index)-1) { - return "NumberErrorCode(" + strconv.FormatInt(int64(i+1), 10) + ")" - } - return _NumberErrorCode_name[_NumberErrorCode_index[i]:_NumberErrorCode_index[i+1]] -} From 54fe0968710e1a9f97fd5c235a88282978767eb0 Mon Sep 17 00:00:00 2001 From: Chi Fujii Date: Tue, 28 Mar 2023 14:03:59 +0900 Subject: [PATCH 03/11] comment and refactor --- integration/aggregate_compat_test.go | 54 +++++++++++++++++-- .../handlers/common/aggregations/group_sum.go | 28 +++++----- .../handlers/common/aggregations/number.go | 18 ++++--- 3 files changed, 78 insertions(+), 22 deletions(-) diff --git a/integration/aggregate_compat_test.go b/integration/aggregate_compat_test.go index 8f8e1c690ed6..b73f2cbf8c0e 100644 --- a/integration/aggregate_compat_test.go +++ b/integration/aggregate_compat_test.go @@ -15,6 +15,7 @@ package integration import ( + "math" "testing" "github.com/stretchr/testify/assert" @@ -716,13 +717,14 @@ func TestAggregateCompatGroupCount(t *testing.T) { } func TestAggregateCompatGroupSum(t *testing.T) { - // Doubles is skipped as they produce wrong result due to inaccurate precision. + // Doubles is skipped as they produce wrong result due to lost precision. // Composites, ArrayStrings, ArrayInt32s, Mixed and ArrayAndDocuments are skipped due to // https://github.com/FerretDB/FerretDB/issues/2185. providers := []shareddata.Provider{ shareddata.Scalars, - // TODO: handle doubles close to max precision in doubles. + // TODO: handle accumulation of doubles close to max precision. + // https://github.com/FerretDB/FerretDB/issues/2300 // shareddata.Doubles, shareddata.OverflowVergeDoubles, shareddata.SmallDoubles, @@ -807,7 +809,9 @@ func TestAggregateCompatGroupSum(t *testing.T) { bson.D{{"$group", bson.D{ {"_id", nil}, {"sum", bson.D{{"$sum", "v"}}}, - }}}}, + }}}, + bson.D{{"$sort", bson.D{{"_id", -1}}}}, + }, }, "NonExistent": { pipeline: bson.A{ @@ -836,7 +840,9 @@ func TestAggregateCompatGroupSum(t *testing.T) { bson.D{{"$group", bson.D{ {"_id", "$v"}, {"sum", bson.D{{"$sum", bson.A{"$v", "$c"}}}}, - }}}}, + }}}, + bson.D{{"$sort", bson.D{{"_id", -1}}}}, + }, resultType: emptyResult, }, "Int32": { @@ -849,6 +855,26 @@ func TestAggregateCompatGroupSum(t *testing.T) { bson.D{{"$sort", bson.D{{"_id", -1}}}}, }, }, + "MaxInt32": { + pipeline: bson.A{ + bson.D{{"$sort", bson.D{{"_id", 1}}}}, + bson.D{{"$group", bson.D{ + {"_id", "$v"}, + {"sum", bson.D{{"$sum", math.MaxInt32}}}, + }}}, + bson.D{{"$sort", bson.D{{"_id", -1}}}}, + }, + }, + "NegativeInt32": { + pipeline: bson.A{ + bson.D{{"$sort", bson.D{{"_id", 1}}}}, + bson.D{{"$group", bson.D{ + {"_id", "$v"}, + {"sum", bson.D{{"$sum", int32(-1)}}}, + }}}, + bson.D{{"$sort", bson.D{{"_id", -1}}}}, + }, + }, "Int64": { pipeline: bson.A{ bson.D{{"$sort", bson.D{{"_id", 1}}}}, @@ -859,6 +885,16 @@ func TestAggregateCompatGroupSum(t *testing.T) { bson.D{{"$sort", bson.D{{"_id", -1}}}}, }, }, + "MaxInt64": { + pipeline: bson.A{ + bson.D{{"$sort", bson.D{{"_id", 1}}}}, + bson.D{{"$group", bson.D{ + {"_id", "$v"}, + {"sum", bson.D{{"$sum", math.MaxInt64}}}, + }}}, + bson.D{{"$sort", bson.D{{"_id", -1}}}}, + }, + }, "Double": { pipeline: bson.A{ bson.D{{"$sort", bson.D{{"_id", 1}}}}, @@ -869,6 +905,16 @@ func TestAggregateCompatGroupSum(t *testing.T) { bson.D{{"$sort", bson.D{{"_id", -1}}}}, }, }, + "MaxDouble": { + pipeline: bson.A{ + bson.D{{"$sort", bson.D{{"_id", 1}}}}, + bson.D{{"$group", bson.D{ + {"_id", "$v"}, + {"sum", bson.D{{"$sum", math.MaxFloat64}}}, + }}}, + bson.D{{"$sort", bson.D{{"_id", -1}}}}, + }, + }, "Bool": { pipeline: bson.A{ bson.D{{"$sort", bson.D{{"_id", 1}}}}, diff --git a/internal/handlers/common/aggregations/group_sum.go b/internal/handlers/common/aggregations/group_sum.go index 426d763f3982..641e6bf7c37e 100644 --- a/internal/handlers/common/aggregations/group_sum.go +++ b/internal/handlers/common/aggregations/group_sum.go @@ -25,7 +25,7 @@ import ( // sumAccumulator represents $sum accumulator for $group. type sumAccumulator struct { expression types.Expression - count any + number any } // newSumAccumulator creates a new $sum accumulator for $group. @@ -41,17 +41,17 @@ func newSumAccumulator(accumulation *types.Document) (Accumulator, error) { "$sum (accumulator)", ) case float64: - accumulator.count = expr + accumulator.number = expr case string: var err error if accumulator.expression, err = types.NewExpression(expr); err != nil { // $sum returns 0 on non-existent field. - accumulator.count = int32(0) + accumulator.number = int32(0) } case int32, int64: - accumulator.count = expr + accumulator.number = expr default: - accumulator.count = int32(0) + accumulator.number = int32(0) // $sum returns 0 on non-numeric field } @@ -73,13 +73,17 @@ func (s *sumAccumulator) Accumulate(ctx context.Context, groupID any, grouped [] return res, nil } - switch count := s.count.(type) { - case float64: - return float64(len(grouped)) * count, nil - case int32: - return int32(len(grouped)) * count, nil - case int64: - return int64(len(grouped)) * count, nil + switch number := s.number.(type) { + case float64, int32, int64: + // Below is equivalent of len(grouped)*number, + // with handling conversion on int32/int64 overflows. + // For example, { $sum: 1 } is equivalent of $count. + numbers := make([]any, len(grouped)) + for i := 0; i < len(grouped); i++ { + numbers[i] = number + } + + return sumNumbers(numbers...), nil } // $sum returns 0 on non-existent and non-numeric field. diff --git a/internal/handlers/common/aggregations/number.go b/internal/handlers/common/aggregations/number.go index f21eaac2a568..2df8207e183c 100644 --- a/internal/handlers/common/aggregations/number.go +++ b/internal/handlers/common/aggregations/number.go @@ -27,6 +27,8 @@ func sumNumbers(vs ...any) any { // use big.Int to accumulate values larger than math.MaxInt64. sumInt := big.NewInt(0) + // TODO: handle accumulation of doubles close to max precision. + // https://github.com/FerretDB/FerretDB/issues/2300 var sumFloat float64 var hasFloat64, hasInt64 bool @@ -48,17 +50,21 @@ func sumNumbers(vs ...any) any { } } - if !sumInt.IsInt64() { - // TODO: handle overflow - return sumInt.Int64() - } - if hasFloat64 { // return float64 - // TODO: handle infinity return float64(sumInt.Int64()) + sumFloat } + if !sumInt.IsInt64() { + // int64 is bigger than maximum of int64, convert to float. + bigFloat := new(big.Float).SetInt(sumInt) + + // ignore accuracy because there is no rounding from int64. + res, _ := bigFloat.Float64() + + return res + } + res := sumInt.Int64() if !hasInt64 && res <= math.MaxInt32 && res >= math.MinInt32 { From 776fa6a562a4edd018311775ea88150e9df37860 Mon Sep 17 00:00:00 2001 From: Chi Fujii Date: Tue, 28 Mar 2023 14:13:53 +0900 Subject: [PATCH 04/11] comment --- internal/handlers/common/aggregations/number.go | 7 +++++-- 1 file changed, 5 insertions(+), 2 deletions(-) diff --git a/internal/handlers/common/aggregations/number.go b/internal/handlers/common/aggregations/number.go index 2df8207e183c..981cbc7695be 100644 --- a/internal/handlers/common/aggregations/number.go +++ b/internal/handlers/common/aggregations/number.go @@ -20,6 +20,9 @@ import ( ) // sumNumbers accumulate numbers and returns the result of summation. +// The result has the same type as the input, except when the result +// cannot be presented accurately. Then int32 is converted to int64, +// and int64 is converted to float64. // It ignores non-number values. // This should only be used for aggregation, aggregation does not return // error on overflow. @@ -56,7 +59,7 @@ func sumNumbers(vs ...any) any { } if !sumInt.IsInt64() { - // int64 is bigger than maximum of int64, convert to float. + // int64 is bigger than maximum of int64, convert to float64. bigFloat := new(big.Float).SetInt(sumInt) // ignore accuracy because there is no rounding from int64. @@ -68,7 +71,7 @@ func sumNumbers(vs ...any) any { res := sumInt.Int64() if !hasInt64 && res <= math.MaxInt32 && res >= math.MinInt32 { - // convert to int32 when input is int32 only and can be represented in int32. + // convert to int32 if input has no int64 and can be represented in int32. return int32(res) } From 93e7b339153122bdf673da4d031271f9900450d0 Mon Sep 17 00:00:00 2001 From: Chi Fujii Date: Tue, 28 Mar 2023 14:29:52 +0900 Subject: [PATCH 05/11] lint --- integration/aggregate_compat_test.go | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/integration/aggregate_compat_test.go b/integration/aggregate_compat_test.go index b73f2cbf8c0e..74ef1b5b9b0a 100644 --- a/integration/aggregate_compat_test.go +++ b/integration/aggregate_compat_test.go @@ -932,7 +932,9 @@ func TestAggregateCompatGroupSum(t *testing.T) { {"_id", "$v"}, {"sum", bson.D{{"$sum", "$v"}}}, {"sum", bson.D{{"$sum", "$s"}}}, - }}}}, + }}}, + bson.D{{"$sort", bson.D{{"_id", -1}}}}, + }, resultType: emptyResult, }, } From 1970142037c3b93426a6932d202b3bf7cff8085e Mon Sep 17 00:00:00 2001 From: Chi Fujii Date: Tue, 28 Mar 2023 16:37:08 +0900 Subject: [PATCH 06/11] use remove instead of list providers --- integration/aggregate_compat_test.go | 48 ++++++---------------------- 1 file changed, 9 insertions(+), 39 deletions(-) diff --git a/integration/aggregate_compat_test.go b/integration/aggregate_compat_test.go index 74ef1b5b9b0a..e36d9b463b87 100644 --- a/integration/aggregate_compat_test.go +++ b/integration/aggregate_compat_test.go @@ -717,46 +717,16 @@ func TestAggregateCompatGroupCount(t *testing.T) { } func TestAggregateCompatGroupSum(t *testing.T) { - // Doubles is skipped as they produce wrong result due to lost precision. - // Composites, ArrayStrings, ArrayInt32s, Mixed and ArrayAndDocuments are skipped due to - // https://github.com/FerretDB/FerretDB/issues/2185. - providers := []shareddata.Provider{ - shareddata.Scalars, - - // TODO: handle accumulation of doubles close to max precision. + providers := shareddata.AllProviders(). + // skipped due to https://github.com/FerretDB/FerretDB/issues/2185. + Remove("Composites"). + Remove("ArrayStrings"). + Remove("ArrayInt32s"). + Remove("Mixed"). + Remove("ArrayAndDocuments"). + // TODO: handle $sum of doubles near max precision. // https://github.com/FerretDB/FerretDB/issues/2300 - // shareddata.Doubles, - shareddata.OverflowVergeDoubles, - shareddata.SmallDoubles, - shareddata.Strings, - shareddata.Binaries, - shareddata.ObjectIDs, - shareddata.Bools, - shareddata.DateTimes, - shareddata.Nulls, - shareddata.Regexes, - shareddata.Int32s, - shareddata.Timestamps, - shareddata.Int64s, - shareddata.Unsets, - shareddata.ObjectIDKeys, - - // shareddata.Composites, - shareddata.PostgresEdgeCases, - - shareddata.DocumentsDoubles, - shareddata.DocumentsStrings, - shareddata.DocumentsDocuments, - - // shareddata.ArrayStrings, - shareddata.ArrayDoubles, - // shareddata.ArrayInt32s, - shareddata.ArrayRegexes, - shareddata.ArrayDocuments, - - // shareddata.Mixed, - // shareddata.ArrayAndDocuments, - } + Remove("Doubles") testCases := map[string]aggregateStagesCompatTestCase{ "GroupNullID": { From 09faa1a45b4175cd5660f094d526b1512a162613 Mon Sep 17 00:00:00 2001 From: Chi Fujii Date: Tue, 28 Mar 2023 17:00:48 +0900 Subject: [PATCH 07/11] update --- .../handlers/common/aggregations/number.go | 27 ++++++++----------- 1 file changed, 11 insertions(+), 16 deletions(-) diff --git a/internal/handlers/common/aggregations/number.go b/internal/handlers/common/aggregations/number.go index 981cbc7695be..f3b67dc32908 100644 --- a/internal/handlers/common/aggregations/number.go +++ b/internal/handlers/common/aggregations/number.go @@ -28,11 +28,11 @@ import ( // error on overflow. func sumNumbers(vs ...any) any { // use big.Int to accumulate values larger than math.MaxInt64. - sumInt := big.NewInt(0) + intSum := big.NewInt(0) // TODO: handle accumulation of doubles close to max precision. // https://github.com/FerretDB/FerretDB/issues/2300 - var sumFloat float64 + var floatSum float64 var hasFloat64, hasInt64 bool @@ -41,34 +41,29 @@ func sumNumbers(vs ...any) any { case float64: hasFloat64 = true - sumFloat = sumFloat + v + floatSum = floatSum + v case int32: - sumInt.Add(sumInt, big.NewInt(int64(v))) + intSum.Add(intSum, big.NewInt(int64(v))) case int64: hasInt64 = true - sumInt.Add(sumInt, big.NewInt(v)) + intSum.Add(intSum, big.NewInt(v)) default: // ignore non-number } } - if hasFloat64 { - // return float64 - return float64(sumInt.Int64()) + sumFloat - } - - if !sumInt.IsInt64() { - // int64 is bigger than maximum of int64, convert to float64. - bigFloat := new(big.Float).SetInt(sumInt) + if hasFloat64 || !intSum.IsInt64() { + // intSum may be bigger than maximum of int64, convert to float64. + intAsBigFloat := new(big.Float).SetInt(intSum) // ignore accuracy because there is no rounding from int64. - res, _ := bigFloat.Float64() + intAsFloat, _ := intAsBigFloat.Float64() - return res + return intAsFloat + floatSum } - res := sumInt.Int64() + res := intSum.Int64() if !hasInt64 && res <= math.MaxInt32 && res >= math.MinInt32 { // convert to int32 if input has no int64 and can be represented in int32. From dd74f8a92da7ca6c26ea9c5335a7dfff1c34a288 Mon Sep 17 00:00:00 2001 From: Chi Fujii Date: Tue, 28 Mar 2023 17:20:14 +0900 Subject: [PATCH 08/11] refactor --- internal/handlers/common/aggregations/group_sum.go | 14 ++++++-------- internal/handlers/common/aggregations/number.go | 3 +-- 2 files changed, 7 insertions(+), 10 deletions(-) diff --git a/internal/handlers/common/aggregations/group_sum.go b/internal/handlers/common/aggregations/group_sum.go index 641e6bf7c37e..bfe6fd7d7aec 100644 --- a/internal/handlers/common/aggregations/group_sum.go +++ b/internal/handlers/common/aggregations/group_sum.go @@ -59,24 +59,22 @@ func newSumAccumulator(accumulation *types.Document) (Accumulator, error) { } // Accumulate implements Accumulator interface. -func (s *sumAccumulator) Accumulate(ctx context.Context, groupID any, grouped []*types.Document) (any, error) { - if s.expression != nil { +func (a *sumAccumulator) Accumulate(ctx context.Context, groupID any, grouped []*types.Document) (any, error) { + if a.expression != nil { var values []any for _, doc := range grouped { - v := s.expression.Evaluate(doc) + v := a.expression.Evaluate(doc) values = append(values, v) } - res := sumNumbers(values...) - - return res, nil + return sumNumbers(values...), nil } - switch number := s.number.(type) { + switch number := a.number.(type) { case float64, int32, int64: // Below is equivalent of len(grouped)*number, - // with handling conversion on int32/int64 overflows. + // with conversion handling upon overflow of int32 and int64. // For example, { $sum: 1 } is equivalent of $count. numbers := make([]any, len(grouped)) for i := 0; i < len(grouped); i++ { diff --git a/internal/handlers/common/aggregations/number.go b/internal/handlers/common/aggregations/number.go index f3b67dc32908..d3e455959968 100644 --- a/internal/handlers/common/aggregations/number.go +++ b/internal/handlers/common/aggregations/number.go @@ -22,8 +22,7 @@ import ( // sumNumbers accumulate numbers and returns the result of summation. // The result has the same type as the input, except when the result // cannot be presented accurately. Then int32 is converted to int64, -// and int64 is converted to float64. -// It ignores non-number values. +// and int64 is converted to float64. It ignores non-number values. // This should only be used for aggregation, aggregation does not return // error on overflow. func sumNumbers(vs ...any) any { From 45e8620c6a275d6da2538f0078d6015fc16b8385 Mon Sep 17 00:00:00 2001 From: Chi Fujii Date: Tue, 28 Mar 2023 17:25:21 +0900 Subject: [PATCH 09/11] update comment --- integration/aggregate_compat_test.go | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/integration/aggregate_compat_test.go b/integration/aggregate_compat_test.go index e36d9b463b87..2fcf57083cd0 100644 --- a/integration/aggregate_compat_test.go +++ b/integration/aggregate_compat_test.go @@ -731,15 +731,15 @@ func TestAggregateCompatGroupSum(t *testing.T) { testCases := map[string]aggregateStagesCompatTestCase{ "GroupNullID": { pipeline: bson.A{ - // Without $sort sum of large values results in wrong result. + // Without $sort, the sum of large values results different in compat and target. bson.D{{"$sort", bson.D{{"_id", 1}}}}, bson.D{{"$group", bson.D{ {"_id", nil}, {"sum", bson.D{{"$sum", "$v"}}}, }}}, - // Without $sort documents are ordered not the same. + // Without $sort, documents are ordered not the same. // Descending sort is used because it is more unique than - // ascending sort for shared data. + // ascending sort for shareddata collections. bson.D{{"$sort", bson.D{{"_id", -1}}}}, }, }, From 68267bc201c3df0b8d51dbceada3e3c60cae5f91 Mon Sep 17 00:00:00 2001 From: Chi Fujii Date: Tue, 28 Mar 2023 17:35:46 +0900 Subject: [PATCH 10/11] tests are run parallel --- integration/aggregate_compat_test.go | 24 ++++++++++++++++++++++++ 1 file changed, 24 insertions(+) diff --git a/integration/aggregate_compat_test.go b/integration/aggregate_compat_test.go index 2fcf57083cd0..3bac47daa6d7 100644 --- a/integration/aggregate_compat_test.go +++ b/integration/aggregate_compat_test.go @@ -225,6 +225,8 @@ func testAggregateCommandCompat(t *testing.T, testCases map[string]aggregateComm } func TestAggregateCommandCompat(t *testing.T) { + t.Parallel() + testCases := map[string]aggregateCommandCompatTestCase{ "CollectionAgnostic": { command: bson.D{ @@ -258,6 +260,8 @@ func TestAggregateCommandCompat(t *testing.T) { } func TestAggregateCompatStages(t *testing.T) { + t.Parallel() + testCases := map[string]aggregateStagesCompatTestCase{ "MatchAndCount": { pipeline: bson.A{ @@ -282,6 +286,8 @@ func TestAggregateCompatStages(t *testing.T) { } func TestAggregateCompatEmptyPipeline(t *testing.T) { + t.Parallel() + providers := []shareddata.Provider{ // for testing empty pipeline use a collection with single document, // because sorting will not matter. @@ -298,6 +304,8 @@ func TestAggregateCompatEmptyPipeline(t *testing.T) { } func TestAggregateCompatCount(t *testing.T) { + t.Parallel() + testCases := map[string]aggregateStagesCompatTestCase{ "Value": { pipeline: bson.A{bson.D{{"$count", "v"}}}, @@ -331,6 +339,8 @@ func TestAggregateCompatCount(t *testing.T) { } func TestAggregateCompatGroupDeterministicCollections(t *testing.T) { + t.Parallel() + // Scalars collection is not included because aggregation groups // numbers of different types for $group, and this causes output // _id to be different number type between compat and target. @@ -412,6 +422,8 @@ func TestAggregateCompatGroupDeterministicCollections(t *testing.T) { } func TestAggregateCompatGroup(t *testing.T) { + t.Parallel() + testCases := map[string]aggregateStagesCompatTestCase{ "NullID": { pipeline: bson.A{bson.D{{"$group", bson.D{ @@ -541,6 +553,8 @@ func TestAggregateCompatGroup(t *testing.T) { } func TestAggregateCompatGroupDotNotation(t *testing.T) { + t.Parallel() + // Providers Composites, ArrayAndDocuments and Mixed // cannot be used due to sorting difference. // FerretDB always sorts empty array is less than null. @@ -616,6 +630,8 @@ func TestAggregateCompatGroupDotNotation(t *testing.T) { } func TestAggregateCompatGroupDocDotNotation(t *testing.T) { + t.Parallel() + // Providers Composites and Mixed cannot be used due to sorting difference. // FerretDB always sorts empty array is less than null. // In compat, for `.sort()` an empty array is less than null. @@ -670,6 +686,8 @@ func TestAggregateCompatGroupDocDotNotation(t *testing.T) { } func TestAggregateCompatGroupCount(t *testing.T) { + t.Parallel() + testCases := map[string]aggregateStagesCompatTestCase{ "CountNull": { pipeline: bson.A{bson.D{{"$group", bson.D{ @@ -717,6 +735,8 @@ func TestAggregateCompatGroupCount(t *testing.T) { } func TestAggregateCompatGroupSum(t *testing.T) { + t.Parallel() + providers := shareddata.AllProviders(). // skipped due to https://github.com/FerretDB/FerretDB/issues/2185. Remove("Composites"). @@ -913,6 +933,8 @@ func TestAggregateCompatGroupSum(t *testing.T) { } func TestAggregateCompatMatch(t *testing.T) { + t.Parallel() + testCases := map[string]aggregateStagesCompatTestCase{ "ID": { pipeline: bson.A{bson.D{{"$match", bson.D{{"_id", "string"}}}}}, @@ -964,6 +986,8 @@ func TestAggregateCompatMatch(t *testing.T) { } func TestAggregateCompatSort(t *testing.T) { + t.Parallel() + testCases := map[string]aggregateStagesCompatTestCase{ "AscendingID": { pipeline: bson.A{bson.D{{"$sort", bson.D{{"_id", 1}}}}}, From 45524e8441380068268ff429e65659045a561d02 Mon Sep 17 00:00:00 2001 From: Chi Fujii Date: Tue, 28 Mar 2023 17:46:50 +0900 Subject: [PATCH 11/11] rename --- internal/handlers/common/aggregations/number.go | 15 +++++++-------- 1 file changed, 7 insertions(+), 8 deletions(-) diff --git a/internal/handlers/common/aggregations/number.go b/internal/handlers/common/aggregations/number.go index d3e455959968..a041bd1a37da 100644 --- a/internal/handlers/common/aggregations/number.go +++ b/internal/handlers/common/aggregations/number.go @@ -52,23 +52,22 @@ func sumNumbers(vs ...any) any { } } + // handle float64 or intSum bigger than the maximum of int64. if hasFloat64 || !intSum.IsInt64() { - // intSum may be bigger than maximum of int64, convert to float64. - intAsBigFloat := new(big.Float).SetInt(intSum) - // ignore accuracy because there is no rounding from int64. - intAsFloat, _ := intAsBigFloat.Float64() + intAsFloat, _ := new(big.Float).SetInt(intSum).Float64() return intAsFloat + floatSum } - res := intSum.Int64() + integer := intSum.Int64() - if !hasInt64 && res <= math.MaxInt32 && res >= math.MinInt32 { + // handle int32 + if !hasInt64 && integer <= math.MaxInt32 && integer >= math.MinInt32 { // convert to int32 if input has no int64 and can be represented in int32. - return int32(res) + return int32(integer) } // return int64 - return res + return integer }