diff --git a/internal/backends/postgresql/collection.go b/internal/backends/postgresql/collection.go index 7d350b061e9f..b585f197c5a5 100644 --- a/internal/backends/postgresql/collection.go +++ b/internal/backends/postgresql/collection.go @@ -56,6 +56,10 @@ func (c *collection) Query(ctx context.Context, params *backends.QueryParams) (* return nil, lazyerrors.Error(err) } + if params == nil { + params = new(backends.QueryParams) + } + if p == nil { return &backends.QueryResult{ Iter: newQueryIterator(ctx, nil), @@ -73,16 +77,18 @@ func (c *collection) Query(ctx context.Context, params *backends.QueryParams) (* }, nil } - // TODO https://github.com/FerretDB/FerretDB/issues/3490 + q := prepareSelectClause(c.dbName, meta.TableName) - // TODO https://github.com/FerretDB/FerretDB/issues/3414 - q := fmt.Sprintf( - `SELECT %s FROM %s`, - metadata.DefaultColumn, - pgx.Identifier{c.dbName, meta.TableName}.Sanitize(), - ) + var placeholder metadata.Placeholder - rows, err := p.Query(ctx, q) + where, args, err := prepareWhereClause(&placeholder, params.Filter) + if err != nil { + return nil, lazyerrors.Error(err) + } + + q += where + + rows, err := p.Query(ctx, q, args...) if err != nil { return nil, lazyerrors.Error(err) } @@ -258,8 +264,10 @@ func (c *collection) Explain(ctx context.Context, params *backends.ExplainParams return nil, lazyerrors.Error(err) } + res := new(backends.ExplainResult) + if p == nil { - return new(backends.ExplainResult), nil + return res, nil } meta, err := c.r.CollectionGet(ctx, c.dbName, c.name) @@ -268,20 +276,25 @@ func (c *collection) Explain(ctx context.Context, params *backends.ExplainParams } if meta == nil { - return &backends.ExplainResult{ - QueryPlanner: must.NotFail(types.NewDocument()), - }, nil + res.QueryPlanner = must.NotFail(types.NewDocument()) + return res, nil } - // TODO https://github.com/FerretDB/FerretDB/issues/3414 - q := fmt.Sprintf( - `EXPLAIN (VERBOSE true, FORMAT JSON) SELECT %s FROM %s`, - metadata.DefaultColumn, - pgx.Identifier{c.dbName, meta.TableName}.Sanitize(), - ) + q := `EXPLAIN (VERBOSE true, FORMAT JSON) ` + prepareSelectClause(c.dbName, meta.TableName) + + var placeholder metadata.Placeholder + + where, args, err := prepareWhereClause(&placeholder, params.Filter) + if err != nil { + return nil, lazyerrors.Error(err) + } + + res.QueryPushdown = where != "" + + q += where var b []byte - err = p.QueryRow(ctx, q).Scan(&b) + err = p.QueryRow(ctx, q, args...).Scan(&b) if err != nil { return nil, lazyerrors.Error(err) @@ -292,9 +305,9 @@ func (c *collection) Explain(ctx context.Context, params *backends.ExplainParams return nil, lazyerrors.Error(err) } - return &backends.ExplainResult{ - QueryPlanner: must.NotFail(types.NewDocument("Plan", queryPlan)), - }, nil + res.QueryPlanner = queryPlan + + return res, nil } // Stats implements backends.Collection interface. diff --git a/internal/backends/postgresql/query.go b/internal/backends/postgresql/query.go new file mode 100644 index 000000000000..d8729972e1db --- /dev/null +++ b/internal/backends/postgresql/query.go @@ -0,0 +1,228 @@ +// Copyright 2021 FerretDB Inc. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package postgresql + +import ( + "errors" + "fmt" + "strings" + "time" + + "github.com/jackc/pgx/v5" + + "github.com/FerretDB/FerretDB/internal/backends/postgresql/metadata" + "github.com/FerretDB/FerretDB/internal/handlers/sjson" + "github.com/FerretDB/FerretDB/internal/types" + "github.com/FerretDB/FerretDB/internal/util/iterator" + "github.com/FerretDB/FerretDB/internal/util/lazyerrors" + "github.com/FerretDB/FerretDB/internal/util/must" +) + +// prepareSelectClause returns simple SELECT clause for provided db and table name, +// that can be used to construct the SQL query. +func prepareSelectClause(db, table string) string { + return fmt.Sprintf( + `SELECT %s FROM %s`, + metadata.DefaultColumn, + pgx.Identifier{db, table}.Sanitize(), + ) +} + +// prepareWhereClause adds WHERE clause with given filters to the query and returns the query and arguments. +func prepareWhereClause(p *metadata.Placeholder, sqlFilters *types.Document) (string, []any, error) { + var filters []string + var args []any + + iter := sqlFilters.Iterator() + defer iter.Close() + + // iterate through root document + for { + rootKey, rootVal, err := iter.Next() + if err != nil { + if errors.Is(err, iterator.ErrIteratorDone) { + break + } + + return "", nil, lazyerrors.Error(err) + } + + // don't pushdown $comment, it's attached to query in handlers + if strings.HasPrefix(rootKey, "$") { + continue + } + + path, err := types.NewPathFromString(rootKey) + + var pe *types.PathError + + switch { + case err == nil: + // Handle dot notation. + // TODO https://github.com/FerretDB/FerretDB/issues/2069 + if path.Len() > 1 { + continue + } + case errors.As(err, &pe): + // ignore empty key error, otherwise return error + if pe.Code() != types.ErrPathElementEmpty { + return "", nil, lazyerrors.Error(err) + } + default: + panic("Invalid error type: PathError expected") + } + + switch v := rootVal.(type) { + case *types.Document: + iter := v.Iterator() + defer iter.Close() + + // iterate through subdocument, as it may contain operators + for { + k, v, err := iter.Next() + if err != nil { + if errors.Is(err, iterator.ErrIteratorDone) { + break + } + + return "", nil, lazyerrors.Error(err) + } + + switch k { + case "$eq": + if f, a := filterEqual(p, rootKey, v); f != "" { + filters = append(filters, f) + args = append(args, a...) + } + + case "$ne": + sql := `NOT ( ` + + // does document contain the key, + // it is necessary, as NOT won't work correctly if the key does not exist. + `_jsonb ? %[1]s AND ` + + // does the value under the key is equal to filter value + `_jsonb->%[1]s @> %[2]s AND ` + + // does the value type is equal to the filter's one + `_jsonb->'$s'->'p'->%[1]s->'t' = '"%[3]s"' )` + + switch v := v.(type) { + case *types.Document, *types.Array, types.Binary, + types.NullType, types.Regex, types.Timestamp: + // type not supported for pushdown + + case float64, bool, int32, int64: + filters = append(filters, fmt.Sprintf(sql, p.Next(), p.Next(), sjson.GetTypeOfValue(v))) + args = append(args, rootKey, v) + + case string, types.ObjectID, time.Time: + filters = append(filters, fmt.Sprintf(sql, p.Next(), p.Next(), sjson.GetTypeOfValue(v))) + args = append(args, rootKey, string(must.NotFail(sjson.MarshalSingleValue(v)))) + + default: + panic(fmt.Sprintf("Unexpected type of value: %v", v)) + } + + default: + // $gt and $lt + // TODO https://github.com/FerretDB/FerretDB/issues/1875 + continue + } + } + + case *types.Array, types.Binary, types.NullType, types.Regex, types.Timestamp: + // type not supported for pushdown + + case float64, string, types.ObjectID, bool, time.Time, int32, int64: + if f, a := filterEqual(p, rootKey, v); f != "" { + filters = append(filters, f) + args = append(args, a...) + } + + default: + panic(fmt.Sprintf("Unexpected type of value: %v", v)) + } + } + + var filter string + if len(filters) > 0 { + filter = ` WHERE ` + strings.Join(filters, " AND ") + } + + return filter, args, nil +} + +// filterEqual returns the proper SQL filter with arguments that filters documents +// where the value under k is equal to v. +func filterEqual(p *metadata.Placeholder, k string, v any) (filter string, args []any) { + // Select if value under the key is equal to provided value. + sql := `_jsonb->%[1]s @> %[2]s` + + switch v := v.(type) { + case *types.Document, *types.Array, types.Binary, + types.NullType, types.Regex, types.Timestamp: + // type not supported for pushdown + + case float64: + // If value is not safe double, fetch all numbers out of safe range. + switch { + case v > types.MaxSafeDouble: + sql = `_jsonb->%[1]s > %[2]s` + v = types.MaxSafeDouble + + case v < -types.MaxSafeDouble: + sql = `_jsonb->%[1]s < %[2]s` + v = -types.MaxSafeDouble + default: + // don't change the default eq query + } + + filter = fmt.Sprintf(sql, p.Next(), p.Next()) + args = append(args, k, v) + + case string, types.ObjectID, time.Time: + // don't change the default eq query + filter = fmt.Sprintf(sql, p.Next(), p.Next()) + args = append(args, k, string(must.NotFail(sjson.MarshalSingleValue(v)))) + + case bool, int32: + // don't change the default eq query + filter = fmt.Sprintf(sql, p.Next(), p.Next()) + args = append(args, k, v) + + case int64: + maxSafeDouble := int64(types.MaxSafeDouble) + + // If value cannot be safe double, fetch all numbers out of the safe range. + switch { + case v > maxSafeDouble: + sql = `_jsonb->%[1]s > %[2]s` + v = maxSafeDouble + + case v < -maxSafeDouble: + sql = `_jsonb->%[1]s < %[2]s` + v = -maxSafeDouble + default: + // don't change the default eq query + } + + filter = fmt.Sprintf(sql, p.Next(), p.Next()) + args = append(args, k, v) + + default: + panic(fmt.Sprintf("Unexpected type of value: %v", v)) + } + + return +} diff --git a/internal/backends/postgresql/query_test.go b/internal/backends/postgresql/query_test.go new file mode 100644 index 000000000000..277366bb4a7d --- /dev/null +++ b/internal/backends/postgresql/query_test.go @@ -0,0 +1,255 @@ +// Copyright 2021 FerretDB Inc. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package postgresql + +import ( + "math" + "testing" + "time" + + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" + + "github.com/FerretDB/FerretDB/internal/backends/postgresql/metadata" + "github.com/FerretDB/FerretDB/internal/types" + "github.com/FerretDB/FerretDB/internal/util/must" +) + +func TestPrepareWhereClause(t *testing.T) { + t.Parallel() + objectID := types.ObjectID{0x62, 0x56, 0xc5, 0xba, 0x0b, 0xad, 0xc0, 0xff, 0xee, 0xff, 0xff, 0xff} + + // WHERE clauses occurring frequently in tests + whereContain := " WHERE _jsonb->$1 @> $2" + whereGt := " WHERE _jsonb->$1 > $2" + whereNotEq := ` WHERE NOT ( _jsonb ? $1 AND _jsonb->$1 @> $2 AND _jsonb->'$s'->'p'->$1->'t' = ` + + for name, tc := range map[string]struct { + filter *types.Document + expected string + skip string + args []any // if empty, check is disabled + }{ + "IDObjectID": { + filter: must.NotFail(types.NewDocument("_id", objectID)), + expected: whereContain, + }, + "IDString": { + filter: must.NotFail(types.NewDocument("_id", "foo")), + expected: whereContain, + }, + "IDBool": { + filter: must.NotFail(types.NewDocument("_id", "foo")), + expected: whereContain, + }, + "IDDotNotation": { + filter: must.NotFail(types.NewDocument("_id.doc", "foo")), + }, + + "DotNotation": { + filter: must.NotFail(types.NewDocument("v.doc", "foo")), + }, + "DotNotationArrayIndex": { + filter: must.NotFail(types.NewDocument("v.arr.0", "foo")), + }, + + "ImplicitString": { + filter: must.NotFail(types.NewDocument("v", "foo")), + expected: whereContain, + }, + "ImplicitEmptyString": { + filter: must.NotFail(types.NewDocument("v", "")), + expected: whereContain, + }, + "ImplicitInt32": { + filter: must.NotFail(types.NewDocument("v", int32(42))), + expected: whereContain, + }, + "ImplicitInt64": { + filter: must.NotFail(types.NewDocument("v", int64(42))), + expected: whereContain, + }, + "ImplicitFloat64": { + filter: must.NotFail(types.NewDocument("v", float64(42.13))), + expected: whereContain, + }, + "ImplicitMaxFloat64": { + filter: must.NotFail(types.NewDocument("v", math.MaxFloat64)), + expected: whereGt, + }, + "ImplicitBool": { + filter: must.NotFail(types.NewDocument("v", true)), + expected: whereContain, + }, + "ImplicitDatetime": { + filter: must.NotFail(types.NewDocument( + "v", time.Date(2021, 11, 1, 10, 18, 42, 123000000, time.UTC), + )), + expected: whereContain, + }, + "ImplicitObjectID": { + filter: must.NotFail(types.NewDocument("v", objectID)), + expected: whereContain, + }, + + "EqString": { + filter: must.NotFail(types.NewDocument( + "v", must.NotFail(types.NewDocument("$eq", "foo")), + )), + args: []any{`v`, `"foo"`}, + expected: whereContain, + }, + "EqEmptyString": { + filter: must.NotFail(types.NewDocument( + "v", must.NotFail(types.NewDocument("$eq", "")), + )), + expected: whereContain, + }, + "EqInt32": { + filter: must.NotFail(types.NewDocument( + "v", must.NotFail(types.NewDocument("$eq", int32(42))), + )), + expected: whereContain, + }, + "EqInt64": { + filter: must.NotFail(types.NewDocument( + "v", must.NotFail(types.NewDocument("$eq", int64(42))), + )), + expected: whereContain, + }, + "EqFloat64": { + filter: must.NotFail(types.NewDocument( + "v", must.NotFail(types.NewDocument("$eq", float64(42.13))), + )), + expected: whereContain, + }, + "EqMaxFloat64": { + filter: must.NotFail(types.NewDocument( + "v", must.NotFail(types.NewDocument("$eq", math.MaxFloat64)), + )), + args: []any{`v`, types.MaxSafeDouble}, + expected: whereGt, + }, + "EqDoubleBigInt64": { + filter: must.NotFail(types.NewDocument( + "v", must.NotFail(types.NewDocument("$eq", float64(2<<61))), + )), + args: []any{`v`, types.MaxSafeDouble}, + expected: whereGt, + }, + "EqBool": { + filter: must.NotFail(types.NewDocument( + "v", must.NotFail(types.NewDocument("$eq", true)), + )), + expected: whereContain, + }, + "EqDatetime": { + filter: must.NotFail(types.NewDocument( + "v", must.NotFail(types.NewDocument( + "$eq", time.Date(2021, 11, 1, 10, 18, 42, 123000000, time.UTC), + )), + )), + expected: whereContain, + }, + "EqObjectID": { + filter: must.NotFail(types.NewDocument( + "v", must.NotFail(types.NewDocument("$eq", objectID)), + )), + expected: whereContain, + }, + + "NeString": { + filter: must.NotFail(types.NewDocument( + "v", must.NotFail(types.NewDocument("$ne", "foo")), + )), + expected: whereNotEq + `'"string"' )`, + }, + "NeEmptyString": { + filter: must.NotFail(types.NewDocument( + "v", must.NotFail(types.NewDocument("$ne", "")), + )), + expected: whereNotEq + `'"string"' )`, + }, + "NeInt32": { + filter: must.NotFail(types.NewDocument( + "v", must.NotFail(types.NewDocument("$ne", int32(42))), + )), + expected: whereNotEq + `'"int"' )`, + }, + "NeInt64": { + filter: must.NotFail(types.NewDocument( + "v", must.NotFail(types.NewDocument("$ne", int64(42))), + )), + expected: whereNotEq + `'"long"' )`, + }, + "NeFloat64": { + filter: must.NotFail(types.NewDocument( + "v", must.NotFail(types.NewDocument("$ne", float64(42.13))), + )), + expected: whereNotEq + `'"double"' )`, + }, + "NeMaxFloat64": { + filter: must.NotFail(types.NewDocument( + "v", must.NotFail(types.NewDocument("$ne", math.MaxFloat64)), + )), + args: []any{`v`, math.MaxFloat64}, + expected: whereNotEq + `'"double"' )`, + }, + "NeBool": { + filter: must.NotFail(types.NewDocument( + "v", must.NotFail(types.NewDocument("$ne", true)), + )), + expected: whereNotEq + `'"bool"' )`, + }, + "NeDatetime": { + filter: must.NotFail(types.NewDocument( + "v", must.NotFail(types.NewDocument( + "$ne", time.Date(2021, 11, 1, 10, 18, 42, 123000000, time.UTC), + )), + )), + expected: whereNotEq + `'"date"' )`, + }, + "NeObjectID": { + filter: must.NotFail(types.NewDocument( + "v", must.NotFail(types.NewDocument("$ne", objectID)), + )), + expected: whereNotEq + `'"objectId"' )`, + }, + + "Comment": { + filter: must.NotFail(types.NewDocument("$comment", "I'm comment")), + }, + } { + name, tc := name, tc + t.Run(name, func(t *testing.T) { + t.Parallel() + + if tc.skip != "" { + t.Skip(tc.skip) + } + + actual, args, err := prepareWhereClause(new(metadata.Placeholder), tc.filter) + require.NoError(t, err) + + assert.Equal(t, tc.expected, actual) + + if len(tc.args) == 0 { + return + } + + assert.Equal(t, tc.args, args) + }) + } +}