diff --git a/integration/query_array_compat_test.go b/integration/query_array_compat_test.go index 92eb4cd767c4..2202a61ced88 100644 --- a/integration/query_array_compat_test.go +++ b/integration/query_array_compat_test.go @@ -142,7 +142,7 @@ func TestQueryArrayCompatElemMatch(t *testing.T) { {"v", bson.D{{"$elemMatch", bson.D{{"$gt", int32(0)}}}}}, }, resultType: emptyResult, - resultPushdown: pgPushdown, + resultPushdown: allPushdown, }, "GtZero": { filter: bson.D{{"v", bson.D{{"$elemMatch", bson.D{{"$gt", int32(0)}}}}}}, diff --git a/integration/query_comparison_compat_test.go b/integration/query_comparison_compat_test.go index 5d4f6e039267..5d1e76473e8f 100644 --- a/integration/query_comparison_compat_test.go +++ b/integration/query_comparison_compat_test.go @@ -67,147 +67,147 @@ func TestQueryComparisonCompatImplicit(t *testing.T) { }, "Int32": { filter: bson.D{{"v", int32(42)}}, - resultPushdown: pgPushdown, + resultPushdown: allPushdown, }, "Int64": { filter: bson.D{{"v", int64(42)}}, - resultPushdown: pgPushdown, + resultPushdown: allPushdown, }, "Double": { filter: bson.D{{"v", 42.13}}, - resultPushdown: pgPushdown, + resultPushdown: allPushdown, }, "DoubleMax": { filter: bson.D{{"v", math.MaxFloat64}}, - resultPushdown: pgPushdown, + resultPushdown: allPushdown, }, "DoubleSmallest": { filter: bson.D{{"v", math.SmallestNonzeroFloat64}}, - resultPushdown: pgPushdown, + resultPushdown: allPushdown, }, "DoubleBig": { filter: bson.D{{"v", float64(1 << 61)}}, - resultPushdown: pgPushdown, + resultPushdown: allPushdown, }, "DoubleBigPlus": { filter: bson.D{{"v", float64((1 << 61) + 1)}}, - resultPushdown: pgPushdown, + resultPushdown: allPushdown, }, "DoubleBigMinus": { filter: bson.D{{"v", float64((1 << 61) - 1)}}, - resultPushdown: pgPushdown, + resultPushdown: allPushdown, }, "DoubleNegBig": { filter: bson.D{{"v", -float64(1 << 61)}}, - resultPushdown: pgPushdown, + resultPushdown: allPushdown, }, "DoubleNegBigPlus": { filter: bson.D{{"v", -float64(1<<61) + 1}}, - resultPushdown: pgPushdown, + resultPushdown: allPushdown, }, "DoubleNegBigMinus": { filter: bson.D{{"v", -float64(1<<61) - 1}}, - resultPushdown: pgPushdown, + resultPushdown: allPushdown, }, "Int64Max": { filter: bson.D{{"v", int64(math.MaxInt64)}}, - resultPushdown: pgPushdown, + resultPushdown: allPushdown, }, "Int64Min": { filter: bson.D{{"v", int64(math.MinInt64)}}, - resultPushdown: pgPushdown, + resultPushdown: allPushdown, }, "Float64PrecMax": { filter: bson.D{{"v", float64(1 << 53)}}, - resultPushdown: pgPushdown, + resultPushdown: allPushdown, }, "Float64PrecMaxPlusOne": { filter: bson.D{{"v", float64(1<<53 + 1)}}, - resultPushdown: pgPushdown, + resultPushdown: allPushdown, }, "Float64PrecMaxMinusOne": { filter: bson.D{{"v", float64(1<<53 - 1)}}, - resultPushdown: pgPushdown, + resultPushdown: allPushdown, }, "Float64PrecMin": { filter: bson.D{{"v", -float64(1<<53 - 1)}}, - resultPushdown: pgPushdown, + resultPushdown: allPushdown, }, "Float64PrecMinPlus": { filter: bson.D{{"v", -float64(1<<53-1) + 1}}, - resultPushdown: pgPushdown, + resultPushdown: allPushdown, }, "Float64PrecMinMinus": { filter: bson.D{{"v", -float64(1<<53-1) - 1}}, - resultPushdown: pgPushdown, + resultPushdown: allPushdown, }, "Int64PrecMax": { filter: bson.D{{"v", int64(1 << 53)}}, - resultPushdown: pgPushdown, + resultPushdown: allPushdown, }, "Int64PrecMaxPlusOne": { filter: bson.D{{"v", int64(1<<53 + 1)}}, - resultPushdown: pgPushdown, + resultPushdown: allPushdown, }, "Int64PrecMaxMinusOne": { filter: bson.D{{"v", int64(1<<53 - 1)}}, - resultPushdown: pgPushdown, + resultPushdown: allPushdown, }, "Int64PrecMin": { filter: bson.D{{"v", -int64(1<<53 - 1)}}, - resultPushdown: pgPushdown, + resultPushdown: allPushdown, }, "Int64PrecMinPlus": { filter: bson.D{{"v", -int64(1<<53-1) + 1}}, - resultPushdown: pgPushdown, + resultPushdown: allPushdown, }, "Int64PrecMinMinus": { filter: bson.D{{"v", -int64(1<<53-1) - 1}}, - resultPushdown: pgPushdown, + resultPushdown: allPushdown, }, "Int64Big": { filter: bson.D{{"v", int64(1 << 61)}}, - resultPushdown: pgPushdown, + resultPushdown: allPushdown, }, "Int64BigPlus": { filter: bson.D{{"v", int64(1<<61) + 1}}, - resultPushdown: pgPushdown, + resultPushdown: allPushdown, }, "Int64BigMinus": { filter: bson.D{{"v", int64(1<<61) - 1}}, - resultPushdown: pgPushdown, + resultPushdown: allPushdown, }, "Int64NegBig": { filter: bson.D{{"v", -int64(1 << 61)}}, - resultPushdown: pgPushdown, + resultPushdown: allPushdown, }, "Int64NegBigPlus": { filter: bson.D{{"v", -int64(1<<61) + 1}}, - resultPushdown: pgPushdown, + resultPushdown: allPushdown, }, "Int64NegBigMinus": { filter: bson.D{{"v", -int64(1<<61) - 1}}, - resultPushdown: pgPushdown, + resultPushdown: allPushdown, }, "String": { filter: bson.D{{"v", "foo"}}, - resultPushdown: pgPushdown, + resultPushdown: allPushdown, }, "StringInt": { filter: bson.D{{"v", "42"}}, - resultPushdown: pgPushdown, + resultPushdown: allPushdown, }, "StringDouble": { filter: bson.D{{"v", "42.13"}}, - resultPushdown: pgPushdown, + resultPushdown: allPushdown, }, "StringEmpty": { filter: bson.D{{"v", ""}}, - resultPushdown: pgPushdown, + resultPushdown: allPushdown, }, "Binary": { filter: bson.D{{"v", primitive.Binary{Subtype: 0x80, Data: []byte{42, 0, 13}}}}, @@ -217,27 +217,27 @@ func TestQueryComparisonCompatImplicit(t *testing.T) { }, "BoolFalse": { filter: bson.D{{"v", false}}, - resultPushdown: pgPushdown, + resultPushdown: allPushdown, }, "BoolTrue": { filter: bson.D{{"v", true}}, - resultPushdown: pgPushdown, + resultPushdown: allPushdown, }, "Datetime": { filter: bson.D{{"v", primitive.NewDateTimeFromTime(time.Date(2021, 11, 1, 10, 18, 42, 123000000, time.UTC))}}, - resultPushdown: pgPushdown, + resultPushdown: allPushdown, }, "DatetimeEpoch": { filter: bson.D{{"v", primitive.NewDateTimeFromTime(time.Unix(0, 0))}}, - resultPushdown: pgPushdown, + resultPushdown: allPushdown, }, "DatetimeYearMin": { filter: bson.D{{"v", primitive.NewDateTimeFromTime(time.Date(0, 1, 1, 0, 0, 0, 0, time.UTC))}}, - resultPushdown: pgPushdown, + resultPushdown: allPushdown, }, "DatetimeYearMax": { filter: bson.D{{"v", primitive.NewDateTimeFromTime(time.Date(9999, 12, 31, 23, 59, 59, 999000000, time.UTC))}}, - resultPushdown: pgPushdown, + resultPushdown: allPushdown, }, "IDNull": { filter: bson.D{{"_id", nil}}, @@ -246,21 +246,21 @@ func TestQueryComparisonCompatImplicit(t *testing.T) { "IDInt32": { filter: bson.D{{"_id", int32(1)}}, resultType: emptyResult, - resultPushdown: pgPushdown, + resultPushdown: allPushdown, }, "IDInt64": { filter: bson.D{{"_id", int64(1)}}, resultType: emptyResult, - resultPushdown: pgPushdown, + resultPushdown: allPushdown, }, "IDDouble": { filter: bson.D{{"_id", 4.2}}, resultType: emptyResult, - resultPushdown: pgPushdown, + resultPushdown: allPushdown, }, "IDString": { filter: bson.D{{"_id", "string"}}, - resultPushdown: pgPushdown, + resultPushdown: allPushdown, }, "IDObjectID": { filter: bson.D{{"_id", primitive.NilObjectID}}, @@ -274,7 +274,7 @@ func TestQueryComparisonCompatImplicit(t *testing.T) { }, "ValueNumber": { filter: bson.D{{"v", 42}}, - resultPushdown: pgPushdown, + resultPushdown: allPushdown, }, "ValueRegex": { filter: bson.D{{"v", primitive.Regex{Pattern: "^fo"}}}, @@ -283,7 +283,7 @@ func TestQueryComparisonCompatImplicit(t *testing.T) { "EmptyKey": { filter: bson.D{{"", "foo"}}, resultType: emptyResult, - resultPushdown: pgPushdown, + resultPushdown: allPushdown, }, } diff --git a/integration/query_compat_test.go b/integration/query_compat_test.go index 2384cb6a63d0..2542b8cfc16a 100644 --- a/integration/query_compat_test.go +++ b/integration/query_compat_test.go @@ -201,15 +201,15 @@ func TestQueryCompatFilter(t *testing.T) { }, "String": { filter: bson.D{{"v", "foo"}}, - resultPushdown: pgPushdown, + resultPushdown: allPushdown, }, "Int32": { filter: bson.D{{"v", int32(42)}}, - resultPushdown: pgPushdown, + resultPushdown: allPushdown, }, "IDString": { filter: bson.D{{"_id", "string"}}, - resultPushdown: pgPushdown, + resultPushdown: allPushdown, }, "IDNilObjectID": { filter: bson.D{{"_id", primitive.NilObjectID}}, @@ -221,7 +221,7 @@ func TestQueryCompatFilter(t *testing.T) { }, "ObjectID": { filter: bson.D{{"v", primitive.NilObjectID}}, - resultPushdown: pgPushdown, + resultPushdown: allPushdown, }, "UnknownFilterOperator": { filter: bson.D{{"v", bson.D{{"$someUnknownOperator", 42}}}}, diff --git a/integration/query_projection_compat_test.go b/integration/query_projection_compat_test.go index 9beb92a59f9a..02517bad4547 100644 --- a/integration/query_projection_compat_test.go +++ b/integration/query_projection_compat_test.go @@ -235,17 +235,17 @@ func TestQueryProjectionPositionalOperatorCompat(t *testing.T) { // e.g. missing {v: } in the filter. filter: bson.D{{"_id", "array"}}, projection: bson.D{{"v.$", true}}, - resultPushdown: pgPushdown, + resultPushdown: allPushdown, }, "Implicit": { filter: bson.D{{"v", float64(42)}}, projection: bson.D{{"v.$", true}}, - resultPushdown: pgPushdown, + resultPushdown: allPushdown, }, "ImplicitNoMatch": { filter: bson.D{{"v", "non-existent"}}, projection: bson.D{{"v.$", true}}, - resultPushdown: pgPushdown, + resultPushdown: allPushdown, resultType: emptyResult, }, "Eq": { @@ -273,12 +273,12 @@ func TestQueryProjectionPositionalOperatorCompat(t *testing.T) { "ImplicitDotNotation": { filter: bson.D{{"v", float64(42)}}, projection: bson.D{{"v.foo.$", true}}, - resultPushdown: pgPushdown, + resultPushdown: allPushdown, }, "ImplicitDotNoMatch": { filter: bson.D{{"v", "non-existent"}}, projection: bson.D{{"v.foo.$", true}}, - resultPushdown: pgPushdown, + resultPushdown: allPushdown, resultType: emptyResult, }, "GtDotNotation": { @@ -299,7 +299,7 @@ func TestQueryProjectionPositionalOperatorCompat(t *testing.T) { {"v", bson.D{{"$gt", 41}}}, }, projection: bson.D{{"v.$", true}}, - resultPushdown: pgPushdown, + resultPushdown: allPushdown, }, "TwoFilter": { filter: bson.D{ diff --git a/integration/query_test.go b/integration/query_test.go index 3a152a8c54b6..261de16ed2a1 100644 --- a/integration/query_test.go +++ b/integration/query_test.go @@ -880,7 +880,7 @@ func TestQueryCommandLimitPushDown(t *testing.T) { filter: bson.D{{"_id", "array"}}, limit: 3, len: 1, - queryPushdown: pgPushdown, + queryPushdown: allPushdown, limitPushdown: false, }, "ValueFilter": { @@ -888,7 +888,7 @@ func TestQueryCommandLimitPushDown(t *testing.T) { sort: bson.D{{"_id", 1}}, limit: 3, len: 3, - queryPushdown: pgPushdown, + queryPushdown: allPushdown, limitPushdown: false, }, "DotNotationFilter": { @@ -917,7 +917,7 @@ func TestQueryCommandLimitPushDown(t *testing.T) { sort: bson.D{{"_id", 1}}, limit: 3, len: 1, - queryPushdown: pgPushdown, + queryPushdown: allPushdown, limitPushdown: false, }, "ValueFilterSort": { @@ -925,7 +925,7 @@ func TestQueryCommandLimitPushDown(t *testing.T) { sort: bson.D{{"_id", 1}}, limit: 3, len: 3, - queryPushdown: pgPushdown, + queryPushdown: allPushdown, limitPushdown: false, }, "DotNotationFilterSort": { diff --git a/internal/backends/sqlite/collection.go b/internal/backends/sqlite/collection.go index b8f808443ef5..dc128cca98c6 100644 --- a/internal/backends/sqlite/collection.go +++ b/internal/backends/sqlite/collection.go @@ -16,9 +16,11 @@ package sqlite import ( "context" + "encoding/hex" "errors" "fmt" "strings" + "time" sqlite3 "modernc.org/sqlite" sqlite3lib "modernc.org/sqlite/lib" @@ -66,16 +68,12 @@ func (c *collection) Query(ctx context.Context, params *backends.QueryParams) (* var whereClause string var args []any + var err error - // that logic should exist in one place - // TODO https://github.com/FerretDB/FerretDB/issues/3235 - if params != nil && params.Filter.Len() == 1 { - v, _ := params.Filter.Get("_id") - if v != nil { - if id, ok := v.(types.ObjectID); ok { - whereClause = fmt.Sprintf(` WHERE %s = ?`, metadata.IDColumn) - args = []any{string(must.NotFail(sjson.MarshalSingleValue(id)))} - } + if params != nil { + whereClause, args, err = prepareWhereClause(params.Filter) + if err != nil { + return nil, lazyerrors.Error(err) } } @@ -91,6 +89,21 @@ func (c *collection) Query(ctx context.Context, params *backends.QueryParams) (* }, nil } +// parseValue parses the provided value to be used in SQLite query. +func parseValue(v any) any { + switch v := v.(type) { + case *types.Document, *types.Array, float64, types.Binary, string, + bool, types.NullType, types.Regex, int32, types.Timestamp, int64: + return v + case types.ObjectID: + return hex.EncodeToString(v[:]) + case time.Time: + return v.UnixMilli() + default: + panic(fmt.Sprintf("Unexpected type of value: %v", v)) + } +} + // InsertAll implements backends.Collection interface. func (c *collection) InsertAll(ctx context.Context, params *backends.InsertAllParams) (*backends.InsertAllResult, error) { if _, err := c.r.CollectionCreate(ctx, c.dbName, c.name); err != nil { @@ -234,23 +247,19 @@ func (c *collection) Explain(ctx context.Context, params *backends.ExplainParams }, nil } - var queryPushdown bool var whereClause string var args []any + var err error - // that logic should exist in one place - // TODO https://github.com/FerretDB/FerretDB/issues/3235 - if params != nil && params.Filter.Len() == 1 { - v, _ := params.Filter.Get("_id") - if v != nil { - if id, ok := v.(types.ObjectID); ok { - queryPushdown = true - whereClause = fmt.Sprintf(` WHERE %s = ?`, metadata.IDColumn) - args = []any{string(must.NotFail(sjson.MarshalSingleValue(id)))} - } + if params != nil { + whereClause, args, err = prepareWhereClause(params.Filter) + if err != nil { + return nil, lazyerrors.Error(err) } } + queryPushdown := whereClause != "" + q := fmt.Sprintf(`EXPLAIN QUERY PLAN SELECT %s FROM %q`+whereClause, metadata.DefaultColumn, meta.TableName) rows, err := db.QueryContext(ctx, q, args...) diff --git a/internal/backends/sqlite/pushdown.go b/internal/backends/sqlite/pushdown.go new file mode 100644 index 000000000000..74419e948c35 --- /dev/null +++ b/internal/backends/sqlite/pushdown.go @@ -0,0 +1,163 @@ +// Copyright 2021 FerretDB Inc. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package sqlite + +import ( + "errors" + "fmt" + "strings" + "time" + + "github.com/FerretDB/FerretDB/internal/backends/sqlite/metadata" + "github.com/FerretDB/FerretDB/internal/types" + "github.com/FerretDB/FerretDB/internal/util/iterator" + "github.com/FerretDB/FerretDB/internal/util/lazyerrors" +) + +// prepareWhereClause adds WHERE clause with filters given in the document. +// It returns the WHERE clause and the SQLite arguments. +func prepareWhereClause(filterDoc *types.Document) (string, []any, error) { + if filterDoc == nil { + return "", []any{}, nil + } + + iter := filterDoc.Iterator() + defer iter.Close() + + var filters []string + var args []any + + for { + k, v, err := iter.Next() + if errors.Is(err, iterator.ErrIteratorDone) { + break + } + + if err != nil { + return "", nil, lazyerrors.Error(err) + } + + // queryPath stores the path that is used in SQLite to access specific key + // if the key is _id we use our predifined path, as the handling of _id may + // change in the future + queryPath := metadata.IDColumn + + // keyArgs store the optional parameters used to query the key + var keyArgs []any + + if k != "_id" { + // To use parameters inside of SQLite json path the parameter token ("?") + // needs to be concatenated to path with || operator + queryPath = fmt.Sprintf(`%s->('$."' || ? || '"' )`, metadata.DefaultColumn) + keyArgs = append(keyArgs, k) + } + + // don't pushdown $comment + if strings.HasPrefix(k, "$") { + continue + } + + path, err := types.NewPathFromString(k) + + var pe *types.PathError + + switch { + case err == nil: + // TODO https://github.com/FerretDB/FerretDB/issues/2069 + if path.Len() > 1 { + continue + } + case errors.As(err, &pe): + // ignore empty key error, otherwise return error + if pe.Code() != types.ErrPathElementEmpty { + return "", nil, lazyerrors.Error(err) + } + default: + panic("Invalid error type: PathError expected") + } + + switch v := v.(type) { + case *types.Document, *types.Array, types.Binary, types.NullType, types.Regex, types.Timestamp: + // type not supported for pushdown + continue + + case float64: + comparison := ` = ?` + + switch { + case v > types.MaxSafeDouble: + comparison = ` > ?` + v = types.MaxSafeDouble + + case v < -types.MaxSafeDouble: + comparison = ` < ?` + v = -types.MaxSafeDouble + default: + // don't change the default eq query + } + + subquery := fmt.Sprintf(`EXISTS (SELECT value FROM json_each(%s) WHERE value %s)`, queryPath, comparison) + filters = append(filters, subquery) + + // TODO https://github.com/FerretDB/FerretDB/issues/3386 + args = append(args, keyArgs...) + args = append(args, parseValue(v)) + + case types.ObjectID, time.Time, string, bool, int32: + subquery := fmt.Sprintf(`EXISTS (SELECT value FROM json_each(%s) WHERE value = ?)`, queryPath) + filters = append(filters, subquery) + + // TODO https://github.com/FerretDB/FerretDB/issues/3386 + args = append(args, keyArgs...) + args = append(args, parseValue(v)) + + case int64: + comparison := ` = ?` + maxSafeDouble := int64(types.MaxSafeDouble) + + // If value cannot be safe double, fetch all numbers out of the safe range + switch { + case v > maxSafeDouble: + comparison = ` > ?` + v = maxSafeDouble + + case v < -maxSafeDouble: + comparison = `< ?` + v = -maxSafeDouble + default: + // don't change the default eq query + } + + // json_each returns top level json values, and the contents of arrays if any + // https://www.sqlite.org/json1.html#jeach + subquery := fmt.Sprintf(`EXISTS (SELECT value FROM json_each(%s) WHERE value %s)`, queryPath, comparison) + filters = append(filters, subquery) + + // TODO https://github.com/FerretDB/FerretDB/issues/3386 + args = append(args, keyArgs...) + args = append(args, parseValue(v)) + + default: + panic(fmt.Sprintf("Unexpected type of value: %v", v)) + } + } + + var whereClause string + if len(filters) > 0 { + whereClause = ` WHERE ` + strings.Join(filters, " AND ") + } + + return whereClause, args, nil +}