Skip to content

Commit

Permalink
Merge pull request #2 from ash2k/go-when-done
Browse files Browse the repository at this point in the history
Go when done
  • Loading branch information
ash2k authored May 3, 2024
2 parents 05e107b + 755e2bb commit 63189f9
Show file tree
Hide file tree
Showing 3 changed files with 117 additions and 2 deletions.
8 changes: 8 additions & 0 deletions stage.go
Original file line number Diff line number Diff line change
Expand Up @@ -9,12 +9,16 @@ type Stage interface {
// Stage context is passed to f as an argument. f should stop when context signals done.
// If f returns a non-nil error, the stager starts performing shutdown.
Go(f func(context.Context) error)
// GoWhenDone starts f in a new goroutine attached to the Stage when the stage starts shutting down.
// Stage shutdown waits for f to exit.
GoWhenDone(f func() error)
}

type stage struct {
ctx context.Context
cancelStage context.CancelFunc
cancelStagerRun context.CancelFunc
whenDone []func() error
errChan chan error
n int
}
Expand All @@ -29,3 +33,7 @@ func (s *stage) Go(f func(context.Context) error) {
s.errChan <- err
}()
}

func (s *stage) GoWhenDone(f func() error) {
s.whenDone = append(s.whenDone, f)
}
11 changes: 9 additions & 2 deletions stager.go
Original file line number Diff line number Diff line change
Expand Up @@ -39,7 +39,7 @@ func (sr *stager) NextStageWithContext(ctxParent context.Context) Stage {
ctx: ctx,
cancelStage: cancel,
cancelStagerRun: sr.runCancel,
errChan: make(chan error, 1),
errChan: make(chan error),
}
sr.stages = append(sr.stages, st)
return st
Expand All @@ -54,7 +54,14 @@ func (sr *stager) Run(ctx context.Context) error {
for i := len(sr.stages) - 1; i >= 0; i-- {
st := sr.stages[i]
st.cancelStage()
for i := 0; i < st.n; i++ {
for _, whenDone := range st.whenDone {
whenDone := whenDone
go func() {
st.errChan <- whenDone()
}()
}
n := st.n + len(st.whenDone)
for i := 0; i < n; i++ {
err := <-st.errChan
if firstErr == nil {
firstErr = err
Expand Down
100 changes: 100 additions & 0 deletions stager_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -166,3 +166,103 @@ func TestEmptyStagerStops(t *testing.T) {
}
})
}

func TestGoWhenDone_NoOtherGoroutines(t *testing.T) {
st := New()

ctx, cancel := context.WithCancel(context.Background())
cancel()

ran := 0

s := st.NextStage()
s.GoWhenDone(func() error {
ran++
return nil
})

err := st.Run(ctx)
if err != nil {
t.Fatal(err)
}
if ran != 1 {
t.Fatal(ran)
}
}

func TestGoWhenDone_WithOtherGoroutines(t *testing.T) {
st := New()

ctx, cancel := context.WithCancel(context.Background())
cancel()

ran := 0

s := st.NextStage()
s.Go(func(ctx context.Context) error {
return nil
})
s.GoWhenDone(func() error {
ran++
return nil
})

err := st.Run(ctx)
if err != nil {
t.Fatal(err)
}
if ran != 1 {
t.Fatal(ran)
}
}

func TestGoWhenDone_Error(t *testing.T) {
st := New()

ctx, cancel := context.WithCancel(context.Background())
cancel()

ran := 0

e := errors.New("boom")

s := st.NextStage()
s.GoWhenDone(func() error {
ran++
return e
})

err := st.Run(ctx)
if err != e {
t.Fatal(err)
}
if ran != 1 {
t.Fatal(ran)
}
}

func TestGoWhenDone_ErrorFromOtherGoroutine(t *testing.T) {
st := New()

ran := 0

e1 := errors.New("boom1")
e2 := errors.New("boom2")

s := st.NextStage()
s.Go(func(ctx context.Context) error {
return e1
})
s.GoWhenDone(func() error {
ran++
return e2
})

err := st.Run(context.Background())
if err != e1 {
t.Fatal(err)
}
if ran != 1 {
t.Fatal(ran)
}
}

0 comments on commit 63189f9

Please sign in to comment.