Skip to content

Commit

Permalink
perf: merge nested preload query when using join (#6990)
Browse files Browse the repository at this point in the history
* pref: merge nest preload query

* fix: preload test
  • Loading branch information
a631807682 authored Apr 25, 2024
1 parent 5553ff3 commit 85299bf
Show file tree
Hide file tree
Showing 2 changed files with 88 additions and 12 deletions.
14 changes: 12 additions & 2 deletions callbacks/preload.go
Original file line number Diff line number Diff line change
Expand Up @@ -123,8 +123,18 @@ func preloadEntryPoint(db *gorm.DB, joins []string, relationships *schema.Relati
if joined, nestedJoins := isJoined(name); joined {
switch rv := db.Statement.ReflectValue; rv.Kind() {
case reflect.Slice, reflect.Array:
for i := 0; i < rv.Len(); i++ {
reflectValue := rel.Field.ReflectValueOf(db.Statement.Context, rv.Index(i))
if rv.Len() > 0 {
reflectValue := rel.FieldSchema.MakeSlice().Elem()
reflectValue.SetLen(rv.Len())
for i := 0; i < rv.Len(); i++ {
frv := rel.Field.ReflectValueOf(db.Statement.Context, rv.Index(i))
if frv.Kind() != reflect.Ptr {
reflectValue.Index(i).Set(frv.Addr())
} else {
reflectValue.Index(i).Set(frv)
}
}

tx := preloadDB(db, reflectValue, reflectValue.Interface())
if err := preloadEntryPoint(tx, nestedJoins, &tx.Statement.Schema.Relationships, preloadMap[name], associationsConds); err != nil {
return err
Expand Down
86 changes: 76 additions & 10 deletions tests/preload_test.go
Original file line number Diff line number Diff line change
@@ -1,14 +1,14 @@
package tests_test

import (
"context"
"encoding/json"
"regexp"
"sort"
"strconv"
"sync"
"testing"

"github.com/stretchr/testify/require"
"time"

"gorm.io/gorm"
"gorm.io/gorm/clause"
Expand Down Expand Up @@ -337,7 +337,7 @@ func TestNestedPreloadWithNestedJoin(t *testing.T) {
DB.Migrator().DropTable(&Preload{}, &Join{}, &Nested{}, &Value{})
DB.Migrator().AutoMigrate(&Preload{}, &Join{}, &Nested{}, &Value{})

value := Value{
value1 := Value{
Name: "value",
Nested: Nested{
Preloads: []*Preload{
Expand All @@ -346,32 +346,98 @@ func TestNestedPreloadWithNestedJoin(t *testing.T) {
Join: Join{Value: "j1"},
},
}
if err := DB.Create(&value).Error; err != nil {
value2 := Value{
Name: "value2",
Nested: Nested{
Preloads: []*Preload{
{Value: "p3"}, {Value: "p4"}, {Value: "p5"},
},
Join: Join{Value: "j2"},
},
}

values := []*Value{&value1, &value2}
if err := DB.Create(&values).Error; err != nil {
t.Errorf("failed to create value, got err: %v", err)
}

var find1 Value
err := DB.Joins("Nested").Joins("Nested.Join").Preload("Nested.Preloads").First(&find1).Error
err := DB.Joins("Nested").Joins("Nested.Join").Preload("Nested.Preloads").First(&find1, value1.ID).Error
if err != nil {
t.Errorf("failed to find value, got err: %v", err)
}
AssertEqual(t, find1, value)
AssertEqual(t, find1, value1)

var find2 Value
// Joins will automatically add Nested queries.
err = DB.Joins("Nested.Join").Preload("Nested.Preloads").First(&find2).Error
err = DB.Joins("Nested.Join").Preload("Nested.Preloads").First(&find2, value2.ID).Error
if err != nil {
t.Errorf("failed to find value, got err: %v", err)
}
AssertEqual(t, find2, value)
AssertEqual(t, find2, value2)

var finds []Value
err = DB.Joins("Nested.Join").Joins("Nested").Preload("Nested.Preloads").Find(&finds).Error
if err != nil {
t.Errorf("failed to find value, got err: %v", err)
}
require.Len(t, finds, 1)
AssertEqual(t, finds[0], value)
AssertEqual(t, len(finds), 2)
AssertEqual(t, finds[0], value1)
AssertEqual(t, finds[1], value2)
}

func TestMergeNestedPreloadWithNestedJoin(t *testing.T) {
users := []User{
{
Name: "TestMergeNestedPreloadWithNestedJoin-1",
Manager: &User{
Name: "Alexis Manager",
Tools: []Tools{
{Name: "Alexis Tool 1"},
{Name: "Alexis Tool 2"},
},
},
},
{
Name: "TestMergeNestedPreloadWithNestedJoin-2",
Manager: &User{
Name: "Jinzhu Manager",
Tools: []Tools{
{Name: "Jinzhu Tool 1"},
{Name: "Jinzhu Tool 2"},
},
},
},
}

DB.Create(&users)

query := make([]string, 0)
sess := DB.Session(&gorm.Session{Logger: Tracer{
Logger: DB.Config.Logger,
Test: func(ctx context.Context, begin time.Time, fc func() (sql string, rowsAffected int64), err error) {
sql, _ := fc()
query = append(query, sql)
},
}})

var result []User
err := sess.
Joins("Manager").
Preload("Manager.Tools").
Where("users.name Like ?", "TestMergeNestedPreloadWithNestedJoin%").
Find(&result).Error

if err != nil {
t.Fatalf("failed to preload and find users: %v", err)
}

AssertEqual(t, result, users)
AssertEqual(t, len(query), 2) // Check preload queries are merged

if !regexp.MustCompile(`SELECT \* FROM .*tools.* WHERE .*IN.*`).MatchString(query[0]) {
t.Fatalf("Expected first query to preload manager tools, got: %s", query[0])
}
}

func TestEmbedPreload(t *testing.T) {
Expand Down

0 comments on commit 85299bf

Please sign in to comment.