Skip to content

Commit

Permalink
Allow transitions in callbacks (looplab#88)
Browse files Browse the repository at this point in the history
* Allow state transitions in callbacks

This adds the possibility of "starting" a state machine and have it
execute multiple state transitions in succession, given that no errors
occur.

The equivalent code without this change:
```go
var errTransition error

for errTransition == nil {
	transitions := request.FSM.AvailableTransitions()
	if len(transitions) == 0 {
		break
	}
	if len(transitions) > 1 {
		errTransition = errors.New("only 1 transition should be available")
	}
	errTransition = request.FSM.Event(transitions[0])
}

if errTransition != nil {
	fmt.Println(errTransition)
}
```

Arguably, that’s bad because of several reasons:
1. The state machine is used like a puppet.
2. The state transitions that make up the "happy path" are encoded
   outside the state machine.
3. The code really isn’t good.
4. There’s no way to intervene or make different decisions on which
   state to transition to next (reinforces bullet point 2).
5. There’s no way to add proper error handling.

It is possible to fix a certain number of those problems but not all
of them, especially 2 and 4 but also 1.

The added test is green and uses both an enter state and an after event
callback.

No other test case was touched in any way (besides enhancing the
context one that was added in the previous commit).

* Allow async state transition to be canceled

This adds a context and cancelation facility to the type `AsyncError`.
Async state transitions can now be canceled by calling `CancelTransition`
on the AsyncError returned by `fsm.Event`. The context on that error can
also be handed off as described in looplab#77 (comment).

* Add example for triggering transitions in callbacks

* Add example for canceling an async transition
  • Loading branch information
annismckenzie authored Oct 6, 2022
1 parent 54bbb61 commit 3637340
Show file tree
Hide file tree
Showing 7 changed files with 362 additions and 10 deletions.
7 changes: 7 additions & 0 deletions errors.go
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,10 @@

package fsm

import (
"context"
)

// InvalidEventError is returned by FSM.Event() when the event cannot be called
// in the current state.
type InvalidEventError struct {
Expand Down Expand Up @@ -82,6 +86,9 @@ func (e CanceledError) Error() string {
// asynchronous state transition.
type AsyncError struct {
Err error

Ctx context.Context
CancelTransition func()
}

func (e AsyncError) Error() string {
Expand Down
54 changes: 54 additions & 0 deletions examples/cancel_async_transition.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,54 @@
//go:build ignore
// +build ignore

package main

import (
"context"
"fmt"
"time"

"github.com/looplab/fsm"
)

func main() {
f := fsm.NewFSM(
"start",
fsm.Events{
{Name: "run", Src: []string{"start"}, Dst: "end"},
},
fsm.Callbacks{
"leave_start": func(_ context.Context, e *fsm.Event) {
e.Async()
},
},
)

err := f.Event(context.Background(), "run")
asyncError, ok := err.(fsm.AsyncError)
if !ok {
panic(fmt.Sprintf("expected error to be 'AsyncError', got %v", err))
}
var asyncStateTransitionWasCanceled bool
go func() {
<-asyncError.Ctx.Done()
asyncStateTransitionWasCanceled = true
if asyncError.Ctx.Err() != context.Canceled {
panic(fmt.Sprintf("Expected error to be '%v' but was '%v'", context.Canceled, asyncError.Ctx.Err()))
}
}()
asyncError.CancelTransition()
time.Sleep(20 * time.Millisecond)

if err = f.Transition(); err != nil {
panic(fmt.Sprintf("Error encountered when transitioning: %v", err))
}
if !asyncStateTransitionWasCanceled {
panic("expected async state transition cancelation to have propagated")
}
if f.Current() != "start" {
panic("expected state to be 'start'")
}

fmt.Println("Successfully ran state machine.")
}
54 changes: 54 additions & 0 deletions examples/transition_callbacks.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,54 @@
//go:build ignore
// +build ignore

package main

import (
"context"
"fmt"

"github.com/looplab/fsm"
)

func main() {
var afterFinishCalled bool
fsm := fsm.NewFSM(
"start",
fsm.Events{
{Name: "run", Src: []string{"start"}, Dst: "end"},
{Name: "finish", Src: []string{"end"}, Dst: "finished"},
{Name: "reset", Src: []string{"end", "finished"}, Dst: "start"},
},
fsm.Callbacks{
"enter_end": func(ctx context.Context, e *fsm.Event) {
if err := e.FSM.Event(ctx, "finish"); err != nil {
fmt.Println(err)
}
},
"after_finish": func(ctx context.Context, e *fsm.Event) {
afterFinishCalled = true
if e.Src != "end" {
panic(fmt.Sprintf("source should have been 'end' but was '%s'", e.Src))
}
if err := e.FSM.Event(ctx, "reset"); err != nil {
fmt.Println(err)
}
},
},
)

if err := fsm.Event(context.Background(), "run"); err != nil {
panic(fmt.Sprintf("Error encountered when triggering the run event: %v", err))
}

if !afterFinishCalled {
panic(fmt.Sprintf("After finish callback should have run, current state: '%s'", fsm.Current()))
}

currentState := fsm.Current()
if currentState != "start" {
panic(fmt.Sprintf("expected state to be 'start', was '%s'", currentState))
}

fmt.Println("Successfully ran state machine.")
}
57 changes: 48 additions & 9 deletions fsm.go
Original file line number Diff line number Diff line change
Expand Up @@ -289,7 +289,16 @@ func (f *FSM) SetMetadata(key string, dataValue interface{}) {
// internal bug.
func (f *FSM) Event(ctx context.Context, event string, args ...interface{}) error {
f.eventMu.Lock()
defer f.eventMu.Unlock()
// in order to always unlock the event mutex, the defer is added
// in case the state transition goes through and enter/after callbacks
// are called; because these must be able to trigger new state
// transitions, it is explicitly unlocked in the code below
var unlocked bool
defer func() {
if !unlocked {
f.eventMu.Unlock()
}
}()

f.stateMu.RLock()
defer f.stateMu.RUnlock()
Expand Down Expand Up @@ -323,18 +332,48 @@ func (f *FSM) Event(ctx context.Context, event string, args ...interface{}) erro
}

// Setup the transition, call it later.
f.transition = func() {
f.stateMu.Lock()
f.current = dst
f.stateMu.Unlock()
transitionFunc := func(ctx context.Context, async bool) func() {
return func() {
if ctx.Err() != nil {
if e.Err == nil {
e.Err = ctx.Err()
}
return
}

f.enterStateCallbacks(ctx, e)
f.afterEventCallbacks(ctx, e)
f.stateMu.Lock()
f.current = dst
f.stateMu.Unlock()

// at this point, we unlock the event mutex in order to allow
// enter state callbacks to trigger another transition
// for aynchronous state transitions this doesn't happen because
// the event mutex has already been unlocked
if !async {
f.eventMu.Unlock()
unlocked = true
}
f.transition = nil // treat the state transition as done
f.enterStateCallbacks(ctx, e)
f.afterEventCallbacks(ctx, e)
}
}

f.transition = transitionFunc(ctx, false)

if err = f.leaveStateCallbacks(ctx, e); err != nil {
if _, ok := err.(CanceledError); ok {
f.transition = nil
} else if asyncError, ok := err.(AsyncError); ok {
// setup a new context in order for async state transitions to work correctly
// this "uncancels" the original context which ignores its cancelation
// but keeps the values of the original context available to callers
ctx, cancel := uncancelContext(ctx)
e.cancelFunc = cancel
asyncError.Ctx = ctx
asyncError.CancelTransition = cancel
f.transition = transitionFunc(ctx, true)
return asyncError
}
return err
}
Expand Down Expand Up @@ -405,15 +444,15 @@ func (f *FSM) leaveStateCallbacks(ctx context.Context, e *Event) error {
if e.canceled {
return CanceledError{e.Err}
} else if e.async {
return AsyncError{e.Err}
return AsyncError{Err: e.Err}
}
}
if fn, ok := f.callbacks[cKey{"", callbackLeaveState}]; ok {
fn(ctx, e)
if e.canceled {
return CanceledError{e.Err}
} else if e.async {
return AsyncError{e.Err}
return AsyncError{Err: e.Err}
}
}
return nil
Expand Down
88 changes: 87 additions & 1 deletion fsm_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@ package fsm

import (
"context"
"errors"
"fmt"
"sort"
"sync"
Expand Down Expand Up @@ -526,6 +527,42 @@ func TestAsyncTransitionNotInProgress(t *testing.T) {
}
}

func TestCancelAsyncTransition(t *testing.T) {
fsm := NewFSM(
"start",
Events{
{Name: "run", Src: []string{"start"}, Dst: "end"},
},
Callbacks{
"leave_start": func(_ context.Context, e *Event) {
e.Async()
},
},
)
err := fsm.Event(context.Background(), "run")
asyncError, ok := err.(AsyncError)
if !ok {
t.Errorf("expected error to be 'AsyncError', got %v", err)
}
var asyncStateTransitionWasCanceled bool
go func() {
<-asyncError.Ctx.Done()
asyncStateTransitionWasCanceled = true
}()
asyncError.CancelTransition()
time.Sleep(20 * time.Millisecond)

if err = fsm.Transition(); err != nil {
t.Errorf("expected no error, got %v", err)
}
if !asyncStateTransitionWasCanceled {
t.Error("expected async state transition cancelation to have propagated")
}
if fsm.Current() != "start" {
t.Error("expected state to be 'start'")
}
}

func TestCallbackNoError(t *testing.T) {
fsm := NewFSM(
"start",
Expand Down Expand Up @@ -695,6 +732,47 @@ func TestDoubleTransition(t *testing.T) {
wg.Wait()
}

func TestTransitionInCallbacks(t *testing.T) {
var fsm *FSM
var afterFinishCalled bool
fsm = NewFSM(
"start",
Events{
{Name: "run", Src: []string{"start"}, Dst: "end"},
{Name: "finish", Src: []string{"end"}, Dst: "finished"},
{Name: "reset", Src: []string{"end", "finished"}, Dst: "start"},
},
Callbacks{
"enter_end": func(ctx context.Context, e *Event) {
if err := e.FSM.Event(ctx, "finish"); err != nil {
fmt.Println(err)
}
},
"after_finish": func(ctx context.Context, e *Event) {
afterFinishCalled = true
if e.Src != "end" {
panic(fmt.Sprintf("source should have been 'end' but was '%s'", e.Src))
}
if err := e.FSM.Event(ctx, "reset"); err != nil {
fmt.Println(err)
}
},
},
)

if err := fsm.Event(context.Background(), "run"); err != nil {
t.Errorf("expected no error, got %v", err)
}
if !afterFinishCalled {
t.Error("expected after_finish callback to have been executed but it wasn't")
}

currentState := fsm.Current()
if currentState != "start" {
t.Errorf("expected state to be 'start', was '%s'", currentState)
}
}

func TestContextInCallbacks(t *testing.T) {
var fsm *FSM
var enterEndAsyncWorkDone bool
Expand All @@ -711,6 +789,11 @@ func TestContextInCallbacks(t *testing.T) {
<-ctx.Done()
enterEndAsyncWorkDone = true
}()

<-ctx.Done()
if err := e.FSM.Event(ctx, "finish"); err != nil {
e.Err = fmt.Errorf("transitioning to the finished state failed: %w", err)
}
},
},
)
Expand All @@ -719,7 +802,10 @@ func TestContextInCallbacks(t *testing.T) {
go func() {
cancel()
}()
fsm.Event(ctx, "run")
err := fsm.Event(ctx, "run")
if !errors.Is(err, context.Canceled) {
t.Errorf("expected 'context canceled' error, got %v", err)
}
time.Sleep(20 * time.Millisecond)

if !enterEndAsyncWorkDone {
Expand Down
21 changes: 21 additions & 0 deletions uncancel_context.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,21 @@
package fsm

import (
"context"
"time"
)

type uncancel struct {
context.Context
}

func (*uncancel) Deadline() (deadline time.Time, ok bool) { return }
func (*uncancel) Done() <-chan struct{} { return nil }
func (*uncancel) Err() error { return nil }

// uncancelContext returns a context which ignores the cancellation of the parent and only keeps the values.
// Also returns a new cancel function.
// This is useful to keep a background task running while the initial request is finished.
func uncancelContext(ctx context.Context) (context.Context, context.CancelFunc) {
return context.WithCancel(&uncancel{ctx})
}
Loading

0 comments on commit 3637340

Please sign in to comment.