diff --git a/internal/handlers/pg/pgdb/indexes.go b/internal/handlers/pg/pgdb/indexes.go index 0ca74fb8887b..6d83f51b48c3 100644 --- a/internal/handlers/pg/pgdb/indexes.go +++ b/internal/handlers/pg/pgdb/indexes.go @@ -204,8 +204,15 @@ func createPgIndexIfNotExists(ctx context.Context, tx pgx.Tx, schema, table, ind return lazyerrors.Errorf("unknown sort order: %d", field.Order) } - // It's important to sanitize field.Field data here, as it's a user-provided value. - fieldsDef[i] = fmt.Sprintf(`((_jsonb->%s)) %s`, quoteString(field.Field), order) + // if the key is foo.bar, then need to modify it to foo -> bar + fs := strings.Split(field.Field, ".") + transformedParts := make([]string, len(fs)) + + for j, f := range fs { + // It's important to sanitize field.Field data here, as it's a user-provided value. + transformedParts[j] = quoteString(f) + } + fieldsDef[i] = fmt.Sprintf(`((_jsonb->%s)) %s`, strings.Join(transformedParts, " -> "), order) } sql := `CREATE` + unique + ` INDEX IF NOT EXISTS ` + pgx.Identifier{index}.Sanitize() + diff --git a/internal/handlers/pg/pgdb/indexes_test.go b/internal/handlers/pg/pgdb/indexes_test.go index a4431bfa0d65..f16da198a5b1 100644 --- a/internal/handlers/pg/pgdb/indexes_test.go +++ b/internal/handlers/pg/pgdb/indexes_test.go @@ -39,32 +39,68 @@ func TestCreateIndexIfNotExists(t *testing.T) { collectionName := testutil.CollectionName(t) setupDatabase(ctx, t, pool, databaseName) - indexName := "test" - err := pool.InTransaction(ctx, func(tx pgx.Tx) error { - idx := Index{ - Name: indexName, - Key: []IndexKeyPair{{Field: "foo", Order: types.Ascending}, {Field: "bar", Order: types.Descending}}, - } - return CreateIndexIfNotExists(ctx, tx, databaseName, collectionName, &idx) - }) - require.NoError(t, err) - - tableName := collectionNameToTableName(collectionName) - pgIndexName := indexNameToPgIndexName(collectionName, indexName) + for name, tc := range map[string]struct { + expectedDefinition string // the expected index definition in postgresql + index Index // the index to create + }{ + "keyWithoutNestedField": { + index: Index{ + Name: "foo_and_bar", + Key: []IndexKeyPair{ + {Field: "foo", Order: types.Ascending}, + {Field: "bar", Order: types.Descending}, + }, + }, + expectedDefinition: "((_jsonb -> 'foo'::text)), ((_jsonb -> 'bar'::text)) DESC", + }, + "keyWithNestedField_level1": { + index: Index{ + Name: "foo_dot_bar", + Key: []IndexKeyPair{ + {Field: "foo.bar", Order: types.Ascending}, + }, + }, + expectedDefinition: "(((_jsonb -> 'foo'::text) -> 'bar'::text))", + }, + "keyWithNestedField_level2": { + index: Index{ + Name: "foo_dot_bar_dot_c", + Key: []IndexKeyPair{ + {Field: "foo.bar.c", Order: types.Ascending}, + }, + }, + expectedDefinition: "((((_jsonb -> 'foo'::text) -> 'bar'::text) -> 'c'::text))", + }, + } { + tc := tc - var indexdef string - err = pool.p.QueryRow( - ctx, - "SELECT indexdef FROM pg_indexes WHERE schemaname = $1 AND tablename = $2 AND indexname = $3", - databaseName, tableName, pgIndexName, - ).Scan(&indexdef) - require.NoError(t, err) + t.Run(name, func(t *testing.T) { + t.Helper() + err := pool.InTransaction(ctx, func(tx pgx.Tx) error { + if err := CreateIndexIfNotExists(ctx, tx, databaseName, collectionName, &tc.index); err != nil { + return err + } + return nil + }) + require.NoError(t, err) + tableName := collectionNameToTableName(collectionName) + pgIndexName := indexNameToPgIndexName(collectionName, tc.index.Name) + + var indexdef string + err = pool.p.QueryRow( + ctx, + "SELECT indexdef FROM pg_indexes WHERE schemaname = $1 AND tablename = $2 AND indexname = $3", + databaseName, tableName, pgIndexName, + ).Scan(&indexdef) + require.NoError(t, err) - expectedIndexdef := fmt.Sprintf( - "CREATE INDEX %s ON \"%s\".%s USING btree (((_jsonb -> 'foo'::text)), ((_jsonb -> 'bar'::text)) DESC)", - pgIndexName, databaseName, tableName, - ) - assert.Equal(t, expectedIndexdef, indexdef) + expectedIndexdef := fmt.Sprintf( + "CREATE INDEX %s ON \"%s\".%s USING btree (%s)", + pgIndexName, databaseName, tableName, tc.expectedDefinition, + ) + assert.Equal(t, expectedIndexdef, indexdef) + }) + } } // TestDropIndexes checks that we correctly drop indexes for various combination of existing indexes. @@ -113,6 +149,15 @@ func TestDropIndexes(t *testing.T) { {Name: "_id_", Key: []IndexKeyPair{{Field: "_id", Order: types.Ascending}}, Unique: true}, }, }, + "DropNestedField": { + toCreate: []Index{ + {Name: "foo_1", Key: []IndexKeyPair{{Field: "foo.bar", Order: types.Ascending}}}, + }, + toDrop: []Index{{Key: []IndexKeyPair{{Field: "foo.bar", Order: types.Ascending}}}}, + expected: []Index{ + {Name: "_id_", Key: []IndexKeyPair{{Field: "_id", Order: types.Ascending}}, Unique: true}, + }, + }, "DropOneFromTheBeginning": { toCreate: []Index{ {Name: "foo_1", Key: []IndexKeyPair{{Field: "foo", Order: types.Ascending}}},