Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

/go/libraries/doltcore/sql/dsess: parallelize sql.NewDatabase work #8740

Merged
merged 10 commits into from
Jan 14, 2025
Merged
48 changes: 27 additions & 21 deletions go/libraries/doltcore/sqle/dsess/autoincrement_tracker.go
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@ import (

"github.com/dolthub/go-mysql-server/sql"
gmstypes "github.com/dolthub/go-mysql-server/sql/types"
"golang.org/x/sync/errgroup"

"github.com/dolthub/dolt/go/libraries/doltcore/doltdb"
"github.com/dolthub/dolt/go/libraries/doltcore/doltdb/durable"
Expand Down Expand Up @@ -398,34 +399,39 @@ func (a *AutoIncrementTracker) AcquireTableLock(ctx *sql.Context, tableName stri
}

func (a *AutoIncrementTracker) InitWithRoots(ctx context.Context, roots ...doltdb.Rootish) error {
for _, root := range roots {
r, err := root.ResolveRootValue(ctx)
if err != nil {
return err
}
eg, egCtx := errgroup.WithContext(ctx)
eg.SetLimit(128)

err = r.IterTables(ctx, func(tableName doltdb.TableName, table *doltdb.Table, sch schema.Schema) (bool, error) {
if !schema.HasAutoIncrement(sch) {
return false, nil
for _, root := range roots {
eg.Go(func() error {
if egCtx.Err() != nil {
return egCtx.Err()
}

seq, err := table.GetAutoIncrementValue(ctx)
if err != nil {
return true, err
r, rerr := root.ResolveRootValue(egCtx)
if rerr != nil {
return rerr
}

tableNameStr := tableName.ToLower().Name
if oldValue, loaded := a.sequences.LoadOrStore(tableNameStr, seq); loaded && seq > oldValue.(uint64) {
a.sequences.Store(tableNameStr, seq)
}
return r.IterTables(egCtx, func(tableName doltdb.TableName, table *doltdb.Table, sch schema.Schema) (bool, error) {
if !schema.HasAutoIncrement(sch) {
return false, nil
}

return false, nil
})
seq, iErr := table.GetAutoIncrementValue(egCtx)
if iErr != nil {
return true, iErr
}

if err != nil {
return err
}
tableNameStr := tableName.ToLower().Name
if oldValue, loaded := a.sequences.LoadOrStore(tableNameStr, seq); loaded && seq > oldValue.(uint64) {
a.sequences.Store(tableNameStr, seq)
}

return false, nil
})
})
}

return nil
return eg.Wait()
}
73 changes: 50 additions & 23 deletions go/libraries/doltcore/sqle/dsess/globalstate.go
100755 → 100644
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@ import (
"sync"

"github.com/dolthub/go-mysql-server/sql"
"golang.org/x/sync/errgroup"

"github.com/dolthub/dolt/go/libraries/doltcore/doltdb"
"github.com/dolthub/dolt/go/libraries/doltcore/ref"
Expand All @@ -40,35 +41,61 @@ func NewGlobalStateStoreForDb(ctx context.Context, dbName string, db *doltdb.Dol
rootRefs = append(rootRefs, branches...)
rootRefs = append(rootRefs, remotes...)

var roots []doltdb.Rootish
rootRefsChan := make(chan doltdb.Rootish, len(rootRefs))
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Instead of a channel here, just allocate roots here:

roots := make([]doltdb.Rootish, len(rootRefs))

and then have the loop capture the iteration variable and write the root into the right place:

for i, b := range rootRefs {
    eg.Go(func() error {
        ...
            roots[i] = cm
        ...
            roots[i] = ws
        ...
    })
}

Probably best to eg.SetLimit() on this errgroup as well.

eg, egCtx := errgroup.WithContext(ctx)
wg := sync.WaitGroup{}

for _, b := range rootRefs {
switch b.GetType() {
case ref.BranchRefType:
wsRef, err := ref.WorkingSetRefForHead(b)
if err != nil {
return GlobalStateImpl{}, err
wg.Add(1)
eg.Go(func() error {
defer wg.Done()
if egCtx.Err() != nil {
return egCtx.Err()
}

ws, err := db.ResolveWorkingSet(ctx, wsRef)
if err == doltdb.ErrWorkingSetNotFound {
// use the branch head if there isn't a working set for it
cm, err := db.ResolveCommitRef(ctx, b)
switch b.GetType() {
case ref.BranchRefType:
wsRef, err := ref.WorkingSetRefForHead(b)
if err != nil {
return GlobalStateImpl{}, err
return err
}
roots = append(roots, cm)
} else if err != nil {
return GlobalStateImpl{}, err
} else {
roots = append(roots, ws)
}
case ref.RemoteRefType:
cm, err := db.ResolveCommitRef(ctx, b)
if err != nil {
return GlobalStateImpl{}, err

ws, err := db.ResolveWorkingSet(egCtx, wsRef)
if err == doltdb.ErrWorkingSetNotFound {
// use the branch head if there isn't a working set for it
cm, err := db.ResolveCommitRef(egCtx, b)
if err != nil {
return err
}
rootRefsChan <- cm
} else if err != nil {
return err
} else {
rootRefsChan <- ws
}
case ref.RemoteRefType:
cm, err := db.ResolveCommitRef(egCtx, b)
if err != nil {
return err
}
rootRefsChan <- cm
}
roots = append(roots, cm)
}
return nil
})
}

// prevent sending on closed channel
wg.Wait()
close(rootRefsChan)

err = eg.Wait()
if err != nil {
return GlobalStateImpl{}, err
}

var roots []doltdb.Rootish
for rootRef := range rootRefsChan {
roots = append(roots, rootRef)
}

tracker, err := NewAutoIncrementTracker(ctx, dbName, roots...)
Expand Down
Loading