mirror of https://github.com/lightninglabs/loop
Merge pull request #631 from sputn1ck/instantloopout_1
[1/?] Instant loop out: Add FSM modulepull/638/head
commit
077d702bc8
@ -0,0 +1,127 @@
|
||||
package fsm
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
)
|
||||
|
||||
// ExampleService is an example service that we want to wait for in the FSM.
|
||||
type ExampleService interface {
|
||||
WaitForStuffHappening() (<-chan bool, error)
|
||||
}
|
||||
|
||||
// ExampleStore is an example store that we want to use in our exitFunc.
|
||||
type ExampleStore interface {
|
||||
StoreStuff() error
|
||||
}
|
||||
|
||||
// ExampleFSM implements the FSM and uses the ExampleService and ExampleStore
|
||||
// to implement the actions.
|
||||
type ExampleFSM struct {
|
||||
*StateMachine
|
||||
|
||||
service ExampleService
|
||||
store ExampleStore
|
||||
}
|
||||
|
||||
// NewExampleFSMContext creates a new example FSM context.
|
||||
func NewExampleFSMContext(service ExampleService,
|
||||
store ExampleStore) *ExampleFSM {
|
||||
|
||||
exampleFSM := &ExampleFSM{
|
||||
service: service,
|
||||
store: store,
|
||||
}
|
||||
exampleFSM.StateMachine = NewStateMachine(exampleFSM.GetStates())
|
||||
|
||||
return exampleFSM
|
||||
}
|
||||
|
||||
// States.
|
||||
const (
|
||||
InitFSM = StateType("InitFSM")
|
||||
StuffSentOut = StateType("StuffSentOut")
|
||||
WaitingForStuff = StateType("WaitingForStuff")
|
||||
StuffFailed = StateType("StuffFailed")
|
||||
StuffSuccess = StateType("StuffSuccess")
|
||||
)
|
||||
|
||||
// Events.
|
||||
var (
|
||||
OnRequestStuff = EventType("OnRequestStuff")
|
||||
OnStuffSentOut = EventType("OnStuffSentOut")
|
||||
OnStuffSuccess = EventType("OnStuffSuccess")
|
||||
)
|
||||
|
||||
// GetStates returns the states for the example FSM.
|
||||
func (e *ExampleFSM) GetStates() States {
|
||||
return States{
|
||||
Default: State{
|
||||
Transitions: Transitions{
|
||||
OnRequestStuff: InitFSM,
|
||||
},
|
||||
},
|
||||
InitFSM: State{
|
||||
Action: e.initFSM,
|
||||
Transitions: Transitions{
|
||||
OnStuffSentOut: StuffSentOut,
|
||||
OnError: StuffFailed,
|
||||
},
|
||||
},
|
||||
StuffSentOut: State{
|
||||
Action: e.waitForStuff,
|
||||
Transitions: Transitions{
|
||||
OnStuffSuccess: StuffSuccess,
|
||||
OnError: StuffFailed,
|
||||
},
|
||||
},
|
||||
StuffFailed: State{
|
||||
Action: NoOpAction,
|
||||
},
|
||||
StuffSuccess: State{
|
||||
Action: NoOpAction,
|
||||
},
|
||||
}
|
||||
}
|
||||
|
||||
// InitStuffRequest is the event context for the InitFSM state.
|
||||
type InitStuffRequest struct {
|
||||
Stuff string
|
||||
respondChan chan<- string
|
||||
}
|
||||
|
||||
// initFSM is the action for the InitFSM state.
|
||||
func (e *ExampleFSM) initFSM(eventCtx EventContext) EventType {
|
||||
req, ok := eventCtx.(*InitStuffRequest)
|
||||
if !ok {
|
||||
return e.HandleError(
|
||||
fmt.Errorf("invalid event context type: %T", eventCtx),
|
||||
)
|
||||
}
|
||||
|
||||
err := e.store.StoreStuff()
|
||||
if err != nil {
|
||||
return e.HandleError(err)
|
||||
}
|
||||
|
||||
req.respondChan <- req.Stuff
|
||||
|
||||
return OnStuffSentOut
|
||||
}
|
||||
|
||||
// waitForStuff is an action that waits for stuff to happen.
|
||||
func (e *ExampleFSM) waitForStuff(eventCtx EventContext) EventType {
|
||||
waitChan, err := e.service.WaitForStuffHappening()
|
||||
if err != nil {
|
||||
return e.HandleError(err)
|
||||
}
|
||||
|
||||
go func() {
|
||||
<-waitChan
|
||||
err := e.SendEvent(OnStuffSuccess, nil)
|
||||
if err != nil {
|
||||
log.Errorf("unable to send event: %v", err)
|
||||
}
|
||||
}()
|
||||
|
||||
return NoOp
|
||||
}
|
@ -0,0 +1,12 @@
|
||||
```mermaid
|
||||
stateDiagram-v2
|
||||
[*] --> InitFSM: OnRequestStuff
|
||||
InitFSM
|
||||
InitFSM --> StuffFailed: OnError
|
||||
InitFSM --> StuffSentOut: OnStuffSentOut
|
||||
StuffFailed
|
||||
StuffSentOut
|
||||
StuffSentOut --> StuffFailed: OnError
|
||||
StuffSentOut --> StuffSuccess: OnStuffSuccess
|
||||
StuffSuccess
|
||||
```
|
@ -0,0 +1,245 @@
|
||||
package fsm
|
||||
|
||||
import (
|
||||
"context"
|
||||
"errors"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
|
||||
var (
|
||||
errService = errors.New("service error")
|
||||
errStore = errors.New("store error")
|
||||
)
|
||||
|
||||
type mockStore struct {
|
||||
storeErr error
|
||||
}
|
||||
|
||||
func (m *mockStore) StoreStuff() error {
|
||||
return m.storeErr
|
||||
}
|
||||
|
||||
type mockService struct {
|
||||
respondChan chan bool
|
||||
respondErr error
|
||||
}
|
||||
|
||||
func (m *mockService) WaitForStuffHappening() (<-chan bool, error) {
|
||||
return m.respondChan, m.respondErr
|
||||
}
|
||||
|
||||
func newInitStuffRequest() *InitStuffRequest {
|
||||
return &InitStuffRequest{
|
||||
Stuff: "stuff",
|
||||
respondChan: make(chan<- string, 1),
|
||||
}
|
||||
}
|
||||
|
||||
func TestExampleFSM(t *testing.T) {
|
||||
testCases := []struct {
|
||||
name string
|
||||
expectedState StateType
|
||||
eventCtx EventContext
|
||||
expectedLastActionError error
|
||||
|
||||
sendEvent EventType
|
||||
sendEventErr error
|
||||
|
||||
serviceErr error
|
||||
storeErr error
|
||||
}{
|
||||
{
|
||||
name: "success",
|
||||
expectedState: StuffSuccess,
|
||||
eventCtx: newInitStuffRequest(),
|
||||
sendEvent: OnRequestStuff,
|
||||
},
|
||||
{
|
||||
name: "service error",
|
||||
expectedState: StuffFailed,
|
||||
eventCtx: newInitStuffRequest(),
|
||||
sendEvent: OnRequestStuff,
|
||||
serviceErr: errService,
|
||||
expectedLastActionError: errService,
|
||||
},
|
||||
{
|
||||
name: "store error",
|
||||
expectedLastActionError: errStore,
|
||||
storeErr: errStore,
|
||||
sendEvent: OnRequestStuff,
|
||||
expectedState: StuffFailed,
|
||||
eventCtx: newInitStuffRequest(),
|
||||
},
|
||||
}
|
||||
|
||||
for _, tc := range testCases {
|
||||
tc := tc
|
||||
|
||||
t.Run(tc.name, func(t *testing.T) {
|
||||
respondChan := make(chan string, 1)
|
||||
if req, ok := tc.eventCtx.(*InitStuffRequest); ok {
|
||||
req.respondChan = respondChan
|
||||
}
|
||||
|
||||
serviceResponseChan := make(chan bool, 1)
|
||||
serviceResponseChan <- true
|
||||
|
||||
service := &mockService{
|
||||
respondChan: serviceResponseChan,
|
||||
respondErr: tc.serviceErr,
|
||||
}
|
||||
|
||||
store := &mockStore{
|
||||
storeErr: tc.storeErr,
|
||||
}
|
||||
|
||||
exampleContext := NewExampleFSMContext(service, store)
|
||||
cachedObserver := NewCachedObserver(100)
|
||||
|
||||
exampleContext.RegisterObserver(cachedObserver)
|
||||
|
||||
err := exampleContext.SendEvent(
|
||||
tc.sendEvent, tc.eventCtx,
|
||||
)
|
||||
require.Equal(t, tc.sendEventErr, err)
|
||||
|
||||
require.Equal(
|
||||
t,
|
||||
tc.expectedLastActionError,
|
||||
exampleContext.LastActionError,
|
||||
)
|
||||
|
||||
err = cachedObserver.WaitForState(
|
||||
context.Background(),
|
||||
time.Second,
|
||||
tc.expectedState,
|
||||
)
|
||||
require.NoError(t, err)
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
// getTestContext returns a test context for the example FSM and a cached
|
||||
// observer that can be used to verify the state transitions.
|
||||
func getTestContext() (*ExampleFSM, *CachedObserver) {
|
||||
service := &mockService{
|
||||
respondChan: make(chan bool, 1),
|
||||
}
|
||||
service.respondChan <- true
|
||||
|
||||
store := &mockStore{}
|
||||
|
||||
exampleContext := NewExampleFSMContext(service, store)
|
||||
cachedObserver := NewCachedObserver(100)
|
||||
|
||||
exampleContext.RegisterObserver(cachedObserver)
|
||||
|
||||
return exampleContext, cachedObserver
|
||||
}
|
||||
|
||||
// TestExampleFSMFlow tests different flows that the example FSM can go through.
|
||||
func TestExampleFSMFlow(t *testing.T) {
|
||||
testCases := []struct {
|
||||
name string
|
||||
expectedStateFlow []StateType
|
||||
expectedEventFlow []EventType
|
||||
storeError error
|
||||
serviceError error
|
||||
}{
|
||||
{
|
||||
name: "success",
|
||||
expectedStateFlow: []StateType{
|
||||
InitFSM,
|
||||
StuffSentOut,
|
||||
StuffSuccess,
|
||||
},
|
||||
expectedEventFlow: []EventType{
|
||||
OnRequestStuff,
|
||||
OnStuffSentOut,
|
||||
OnStuffSuccess,
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "failure on store",
|
||||
expectedStateFlow: []StateType{
|
||||
InitFSM,
|
||||
StuffFailed,
|
||||
},
|
||||
expectedEventFlow: []EventType{
|
||||
OnRequestStuff,
|
||||
OnError,
|
||||
},
|
||||
storeError: errStore,
|
||||
},
|
||||
{
|
||||
name: "failure on service",
|
||||
expectedStateFlow: []StateType{
|
||||
InitFSM,
|
||||
StuffSentOut,
|
||||
StuffFailed,
|
||||
},
|
||||
expectedEventFlow: []EventType{
|
||||
OnRequestStuff,
|
||||
OnStuffSentOut,
|
||||
OnError,
|
||||
},
|
||||
serviceError: errService,
|
||||
},
|
||||
}
|
||||
|
||||
for _, tc := range testCases {
|
||||
tc := tc
|
||||
|
||||
t.Run(tc.name, func(t *testing.T) {
|
||||
exampleContext, cachedObserver := getTestContext()
|
||||
|
||||
if tc.storeError != nil {
|
||||
exampleContext.store.(*mockStore).
|
||||
storeErr = tc.storeError
|
||||
}
|
||||
|
||||
if tc.serviceError != nil {
|
||||
exampleContext.service.(*mockService).
|
||||
respondErr = tc.serviceError
|
||||
}
|
||||
|
||||
go func() {
|
||||
err := exampleContext.SendEvent(
|
||||
OnRequestStuff,
|
||||
newInitStuffRequest(),
|
||||
)
|
||||
|
||||
require.NoError(t, err)
|
||||
}()
|
||||
|
||||
// Wait for the final state.
|
||||
err := cachedObserver.WaitForState(
|
||||
context.Background(),
|
||||
time.Second,
|
||||
tc.expectedStateFlow[len(
|
||||
tc.expectedStateFlow,
|
||||
)-1],
|
||||
)
|
||||
require.NoError(t, err)
|
||||
|
||||
allNotifications := cachedObserver.
|
||||
GetCachedNotifications()
|
||||
|
||||
for index, notification := range allNotifications {
|
||||
require.Equal(
|
||||
t,
|
||||
tc.expectedStateFlow[index],
|
||||
notification.NextState,
|
||||
)
|
||||
require.Equal(
|
||||
t,
|
||||
tc.expectedEventFlow[index],
|
||||
notification.Event,
|
||||
)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
@ -0,0 +1,296 @@
|
||||
package fsm
|
||||
|
||||
import (
|
||||
"errors"
|
||||
"fmt"
|
||||
"sync"
|
||||
)
|
||||
|
||||
// ErrEventRejected is the error returned when the state machine cannot process
|
||||
// an event in the state that it is in.
|
||||
var (
|
||||
ErrEventRejected = errors.New("event rejected")
|
||||
ErrWaitForStateTimedOut = errors.New(
|
||||
"timed out while waiting for event",
|
||||
)
|
||||
ErrInvalidContextType = errors.New("invalid context")
|
||||
)
|
||||
|
||||
const (
|
||||
// Default represents the default state of the system.
|
||||
Default StateType = ""
|
||||
|
||||
// NoOp represents a no-op event.
|
||||
NoOp EventType = "NoOp"
|
||||
|
||||
// OnError can be used when an action returns a generic error.
|
||||
OnError EventType = "OnError"
|
||||
|
||||
// ContextValidationFailed can be when the passed context if
|
||||
// not of the expected type.
|
||||
ContextValidationFailed EventType = "ContextValidationFailed"
|
||||
)
|
||||
|
||||
// StateType represents an extensible state type in the state machine.
|
||||
type StateType string
|
||||
|
||||
// EventType represents an extensible event type in the state machine.
|
||||
type EventType string
|
||||
|
||||
// EventContext represents the context to be passed to the action
|
||||
// implementation.
|
||||
type EventContext interface{}
|
||||
|
||||
// Action represents the action to be executed in a given state.
|
||||
type Action func(eventCtx EventContext) EventType
|
||||
|
||||
// Transitions represents a mapping of events and states.
|
||||
type Transitions map[EventType]StateType
|
||||
|
||||
// State binds a state with an action and a set of events it can handle.
|
||||
type State struct {
|
||||
// EntryFunc is a function that is called when the state is entered.
|
||||
EntryFunc func()
|
||||
// ExitFunc is a function that is called when the state is exited.
|
||||
ExitFunc func()
|
||||
// Action is the action to be executed in the state.
|
||||
Action Action
|
||||
// Transitions is a mapping of events and states.
|
||||
Transitions Transitions
|
||||
}
|
||||
|
||||
// States represents a mapping of states and their implementations.
|
||||
type States map[StateType]State
|
||||
|
||||
// Notification represents a notification sent to the state machine's
|
||||
// notification channel.
|
||||
type Notification struct {
|
||||
// PreviousState is the state the state machine was in before the event
|
||||
// was processed.
|
||||
PreviousState StateType
|
||||
// NextState is the state the state machine is in after the event was
|
||||
// processed.
|
||||
NextState StateType
|
||||
// Event is the event that was processed.
|
||||
Event EventType
|
||||
}
|
||||
|
||||
// Observer is an interface that can be implemented by types that want to
|
||||
// observe the state machine.
|
||||
type Observer interface {
|
||||
Notify(Notification)
|
||||
}
|
||||
|
||||
// StateMachine represents the state machine.
|
||||
type StateMachine struct {
|
||||
// Context represents the state machine context.
|
||||
States States
|
||||
|
||||
// ActionEntryFunc is a function that is called before an action is
|
||||
// executed.
|
||||
ActionEntryFunc func()
|
||||
|
||||
// ActionExitFunc is a function that is called after an action is
|
||||
// executed.
|
||||
ActionExitFunc func()
|
||||
|
||||
// mutex ensures that only 1 event is processed by the state machine at
|
||||
// any given time.
|
||||
mutex sync.Mutex
|
||||
|
||||
// LastActionError is an error set by the last action executed.
|
||||
LastActionError error
|
||||
|
||||
// previous represents the previous state.
|
||||
previous StateType
|
||||
|
||||
// current represents the current state.
|
||||
current StateType
|
||||
|
||||
// observers is a slice of observers that are notified when the state
|
||||
// machine transitions between states.
|
||||
observers []Observer
|
||||
|
||||
// observerMutex ensures that observers are only added or removed
|
||||
// safely.
|
||||
observerMutex sync.Mutex
|
||||
}
|
||||
|
||||
// NewStateMachine creates a new state machine.
|
||||
func NewStateMachine(states States) *StateMachine {
|
||||
return &StateMachine{
|
||||
States: states,
|
||||
observers: make([]Observer, 0),
|
||||
}
|
||||
}
|
||||
|
||||
// getNextState returns the next state for the event given the machine's current
|
||||
// state, or an error if the event can't be handled in the given state.
|
||||
func (s *StateMachine) getNextState(event EventType) (State, error) {
|
||||
var (
|
||||
state State
|
||||
ok bool
|
||||
)
|
||||
|
||||
stateMap := s.States
|
||||
|
||||
if state, ok = stateMap[s.current]; !ok {
|
||||
return State{}, NewErrConfigError("current state not found")
|
||||
}
|
||||
|
||||
if state.Transitions == nil {
|
||||
return State{}, NewErrConfigError(
|
||||
"current state has no transitions",
|
||||
)
|
||||
}
|
||||
|
||||
var next StateType
|
||||
if next, ok = state.Transitions[event]; !ok {
|
||||
return State{}, NewErrConfigError(
|
||||
"event not found in current transitions",
|
||||
)
|
||||
}
|
||||
|
||||
// Identify the state definition for the next state.
|
||||
state, ok = stateMap[next]
|
||||
if !ok {
|
||||
return State{}, NewErrConfigError("next state not found")
|
||||
}
|
||||
|
||||
if state.Action == nil {
|
||||
return State{}, NewErrConfigError("next state has no action")
|
||||
}
|
||||
|
||||
// Transition over to the next state.
|
||||
s.previous = s.current
|
||||
s.current = next
|
||||
|
||||
return state, nil
|
||||
}
|
||||
|
||||
// SendEvent sends an event to the state machine. It returns an error if the
|
||||
// event cannot be processed in the current state. Otherwise, it only returns
|
||||
// nil if the event for the last action is a no-op.
|
||||
func (s *StateMachine) SendEvent(event EventType, eventCtx EventContext) error {
|
||||
s.mutex.Lock()
|
||||
defer s.mutex.Unlock()
|
||||
|
||||
if s.States == nil {
|
||||
return NewErrConfigError("state machine config is nil")
|
||||
}
|
||||
|
||||
for {
|
||||
// Determine the next state for the event given the machine's
|
||||
// current state.
|
||||
state, err := s.getNextState(event)
|
||||
if err != nil {
|
||||
return ErrEventRejected
|
||||
}
|
||||
|
||||
// Notify the state machine's observers.
|
||||
s.observerMutex.Lock()
|
||||
for _, observer := range s.observers {
|
||||
observer.Notify(Notification{
|
||||
PreviousState: s.previous,
|
||||
NextState: s.current,
|
||||
Event: event,
|
||||
})
|
||||
}
|
||||
s.observerMutex.Unlock()
|
||||
|
||||
// Execute the state machines ActionEntryFunc.
|
||||
if s.ActionEntryFunc != nil {
|
||||
s.ActionEntryFunc()
|
||||
}
|
||||
|
||||
// Execute the current state's entry function
|
||||
if state.EntryFunc != nil {
|
||||
state.EntryFunc()
|
||||
}
|
||||
|
||||
// Execute the next state's action and loop over again if the
|
||||
// event returned is not a no-op.
|
||||
nextEvent := state.Action(eventCtx)
|
||||
|
||||
// Execute the current state's exit function
|
||||
if state.ExitFunc != nil {
|
||||
state.ExitFunc()
|
||||
}
|
||||
|
||||
// Execute the state machines ActionExitFunc.
|
||||
if s.ActionExitFunc != nil {
|
||||
s.ActionExitFunc()
|
||||
}
|
||||
|
||||
// If the next event is a no-op, we're done.
|
||||
if nextEvent == NoOp {
|
||||
return nil
|
||||
}
|
||||
|
||||
event = nextEvent
|
||||
}
|
||||
}
|
||||
|
||||
// RegisterObserver registers an observer with the state machine.
|
||||
func (s *StateMachine) RegisterObserver(observer Observer) {
|
||||
s.observerMutex.Lock()
|
||||
defer s.observerMutex.Unlock()
|
||||
|
||||
if observer != nil {
|
||||
s.observers = append(s.observers, observer)
|
||||
}
|
||||
}
|
||||
|
||||
// RemoveObserver removes an observer from the state machine. It returns true
|
||||
// if the observer was removed, false otherwise.
|
||||
func (s *StateMachine) RemoveObserver(observer Observer) bool {
|
||||
s.observerMutex.Lock()
|
||||
defer s.observerMutex.Unlock()
|
||||
|
||||
for i, o := range s.observers {
|
||||
if o == observer {
|
||||
s.observers = append(
|
||||
s.observers[:i], s.observers[i+1:]...,
|
||||
)
|
||||
return true
|
||||
}
|
||||
}
|
||||
|
||||
return false
|
||||
}
|
||||
|
||||
// HandleError is a helper function that can be used by actions to handle
|
||||
// errors.
|
||||
func (s *StateMachine) HandleError(err error) EventType {
|
||||
log.Errorf("StateMachine error: %s", err)
|
||||
s.LastActionError = err
|
||||
return OnError
|
||||
}
|
||||
|
||||
// NoOpAction is a no-op action that can be used by states that don't need to
|
||||
// execute any action.
|
||||
func NoOpAction(_ EventContext) EventType {
|
||||
return NoOp
|
||||
}
|
||||
|
||||
// ErrConfigError is an error returned when the state machine is misconfigured.
|
||||
type ErrConfigError error
|
||||
|
||||
// NewErrConfigError creates a new ErrConfigError.
|
||||
func NewErrConfigError(msg string) ErrConfigError {
|
||||
return (ErrConfigError)(fmt.Errorf("config error: %s", msg))
|
||||
}
|
||||
|
||||
// ErrWaitingForStateTimeout is an error returned when the state machine times
|
||||
// out while waiting for a state.
|
||||
type ErrWaitingForStateTimeout error
|
||||
|
||||
// NewErrWaitingForStateTimeout creates a new ErrWaitingForStateTimeout.
|
||||
func NewErrWaitingForStateTimeout(expected,
|
||||
actual StateType) ErrWaitingForStateTimeout {
|
||||
|
||||
return (ErrWaitingForStateTimeout)(fmt.Errorf(
|
||||
"waiting for state timeout: expected %s, actual: %s",
|
||||
expected, actual,
|
||||
))
|
||||
}
|
@ -0,0 +1,139 @@
|
||||
# Finite State Machine Module
|
||||
|
||||
This module provides a simple golang finite state machine (FSM) implementation.
|
||||
|
||||
|
||||
## Introduction
|
||||
|
||||
The state machine uses events and actions to transition between states. The
|
||||
events are used to trigger a transition and the actions are used to perform
|
||||
some work when entering a state. Actions return new events which are then
|
||||
used to trigger the next transition.
|
||||
|
||||
## Usage
|
||||
|
||||
A simple way to use the FSM is to embed it into a struct:
|
||||
|
||||
```go
|
||||
type LightSwitchFSM struct {
|
||||
*StateMachine
|
||||
}
|
||||
```
|
||||
|
||||
In order to use the FSM you need to define the events, actions and statemaps
|
||||
for the FSM. events are defined as constants, actions are defined as functions
|
||||
on the `LightSwitchFSM` struct and statemaps are in a map of `State` to `StateMap`
|
||||
where `StateMap` is a map of `Event` to `Action`.
|
||||
|
||||
For the `LightSwitchFSM` we can first define the states
|
||||
```go
|
||||
const (
|
||||
OffState = StateType("Off")
|
||||
OnState = StateType("On")
|
||||
)
|
||||
|
||||
const (
|
||||
SwitchOff = EventType("SwitchOff")
|
||||
SwitchOn = EventType("SwitchOn")
|
||||
)
|
||||
```
|
||||
|
||||
Next we define the actions, here we're simply going to log from the action.
|
||||
```go
|
||||
func (a *LightSwitchFSM) OffAction(_ EventContext) EventType {
|
||||
fmt.Println("The light has been switched off")
|
||||
return NoOp
|
||||
}
|
||||
|
||||
func (a *LightSwitchFSM) OnAction(_ EventContext) EventType {
|
||||
fmt.Println("The light has been switched on")
|
||||
return NoOp
|
||||
}
|
||||
```
|
||||
|
||||
Next we define the statemap, here we're going to implement a getStates()
|
||||
function that returns the statemap.
|
||||
```go
|
||||
func (l *LightSwitchFSM) getStates() States {
|
||||
return States{
|
||||
OffState: State{
|
||||
Action: l.OffAction,
|
||||
Transitions: Transitions{
|
||||
SwitchOn: OnState,
|
||||
},
|
||||
},
|
||||
OnState: State{
|
||||
Action: l.OnAction,
|
||||
Transitions: Transitions{
|
||||
SwitchOff: OffState,
|
||||
},
|
||||
},
|
||||
}
|
||||
}
|
||||
```
|
||||
|
||||
Finally, we can create the FSM and use it.
|
||||
|
||||
```go
|
||||
func NewLightSwitchFSM() *LightSwitchFSM {
|
||||
fsm := &LightSwitchFSM{}
|
||||
fsm.StateMachine = &StateMachine{
|
||||
States: fsm.getStates(),
|
||||
Current: OffState,
|
||||
}
|
||||
return fsm
|
||||
}
|
||||
```
|
||||
|
||||
This is what it would look like to use the FSM:
|
||||
```go
|
||||
func TestLightSwitchFSM(t *testing.T) {
|
||||
// Create a new light switch FSM.
|
||||
lightSwitch := NewLightSwitchFSM()
|
||||
|
||||
// Expect the light to be off
|
||||
require.Equal(t, lightSwitch.Current, OffState)
|
||||
|
||||
// Send the On Event
|
||||
err := lightSwitch.SendEvent(SwitchOn, nil)
|
||||
require.NoError(t, err)
|
||||
|
||||
// Expect the light to be on
|
||||
require.Equal(t, lightSwitch.Current, OnState)
|
||||
|
||||
// Send the Off Event
|
||||
err = lightSwitch.SendEvent(SwitchOff, nil)
|
||||
require.NoError(t, err)
|
||||
|
||||
// Expect the light to be off
|
||||
require.Equal(t, lightSwitch.Current, OffState)
|
||||
}
|
||||
```
|
||||
|
||||
## Observing the state machine
|
||||
The state machine can be observed by registering an observer. The observer
|
||||
will be called when the state machine transitions between states. The observer
|
||||
is called with the old state, the new state and the event that triggered the
|
||||
transition.
|
||||
|
||||
An observer can be registered by calling the `RegisterObserver` function on
|
||||
the state machine. The observer must implement the `Observer` interface.
|
||||
|
||||
```go
|
||||
type Observer interface {
|
||||
Notify(Notification)
|
||||
}
|
||||
```
|
||||
|
||||
An example of a cached observer can be found in [observer.go](./observer.go).
|
||||
|
||||
|
||||
## More Examples
|
||||
A more elaborate example that uses error handling, event context and more
|
||||
elaborate actions can be found in here [examples_fsm.go](./example_fsm.go).
|
||||
With the tests in [examples_fsm_test.go](./example_fsm_test.go) showing how to
|
||||
use the FSM.
|
||||
|
||||
## Visualizing the FSM
|
||||
The FSM can be visualized to mermaid markdown using the [stateparser.go](./stateparser/stateparser.go)
|
||||
tool. The visualization for the exampleFSM can be found in [example_fsm.md](./example_fsm.md).
|
@ -0,0 +1,117 @@
|
||||
package fsm
|
||||
|
||||
import (
|
||||
"errors"
|
||||
"testing"
|
||||
|
||||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
|
||||
var (
|
||||
errAction = errors.New("action error")
|
||||
)
|
||||
|
||||
// TestStateMachineContext is a test context for the state machine.
|
||||
type TestStateMachineContext struct {
|
||||
*StateMachine
|
||||
}
|
||||
|
||||
// GetStates returns the states for the test state machine.
|
||||
// The StateMap looks like this:
|
||||
// State1 -> Event1 -> State2 .
|
||||
func (c *TestStateMachineContext) GetStates() States {
|
||||
return States{
|
||||
"State1": State{
|
||||
Action: func(ctx EventContext) EventType {
|
||||
return "Event1"
|
||||
},
|
||||
Transitions: Transitions{
|
||||
"Event1": "State2",
|
||||
},
|
||||
},
|
||||
"State2": State{
|
||||
Action: func(ctx EventContext) EventType {
|
||||
return "NoOp"
|
||||
},
|
||||
Transitions: Transitions{},
|
||||
},
|
||||
}
|
||||
}
|
||||
|
||||
// errorAction returns an error.
|
||||
func (c *TestStateMachineContext) errorAction(eventCtx EventContext) EventType {
|
||||
return c.StateMachine.HandleError(errAction)
|
||||
}
|
||||
|
||||
func setupTestStateMachineContext() *TestStateMachineContext {
|
||||
ctx := &TestStateMachineContext{}
|
||||
|
||||
ctx.StateMachine = &StateMachine{
|
||||
States: ctx.GetStates(),
|
||||
current: "State1",
|
||||
previous: "",
|
||||
}
|
||||
|
||||
return ctx
|
||||
}
|
||||
|
||||
// TestStateMachine_Success tests the state machine with a successful event.
|
||||
func TestStateMachine_Success(t *testing.T) {
|
||||
ctx := setupTestStateMachineContext()
|
||||
|
||||
// Send an event to the state machine.
|
||||
err := ctx.SendEvent("Event1", nil)
|
||||
require.NoError(t, err)
|
||||
|
||||
// Check that the state machine has transitioned to the next state.
|
||||
require.Equal(t, StateType("State2"), ctx.current)
|
||||
}
|
||||
|
||||
// TestStateMachine_ConfigurationError tests the state machine with a
|
||||
// configuration error.
|
||||
func TestStateMachine_ConfigurationError(t *testing.T) {
|
||||
ctx := setupTestStateMachineContext()
|
||||
ctx.StateMachine.States = nil
|
||||
|
||||
err := ctx.SendEvent("Event1", nil)
|
||||
require.EqualError(
|
||||
t, err,
|
||||
NewErrConfigError("state machine config is nil").Error(),
|
||||
)
|
||||
}
|
||||
|
||||
// TestStateMachine_ActionError tests the state machine with an action error.
|
||||
func TestStateMachine_ActionError(t *testing.T) {
|
||||
ctx := setupTestStateMachineContext()
|
||||
|
||||
states := ctx.StateMachine.States
|
||||
|
||||
// Add a Transition to State2 if the Action on Stat2 fails.
|
||||
// The new StateMap looks like this:
|
||||
// State1 -> Event1 -> State2
|
||||
// State2 -> OnError -> ErrorState
|
||||
states["State2"] = State{
|
||||
Action: ctx.errorAction,
|
||||
Transitions: Transitions{
|
||||
OnError: "ErrorState",
|
||||
},
|
||||
}
|
||||
|
||||
states["ErrorState"] = State{
|
||||
Action: func(ctx EventContext) EventType {
|
||||
return "NoOp"
|
||||
},
|
||||
Transitions: Transitions{},
|
||||
}
|
||||
|
||||
err := ctx.SendEvent("Event1", nil)
|
||||
|
||||
// Sending an event to the state machine should not return an error.
|
||||
require.NoError(t, err)
|
||||
|
||||
// Ensure that the last error is set.
|
||||
require.Equal(t, errAction, ctx.StateMachine.LastActionError)
|
||||
|
||||
// Expect the state machine to have transitioned to the ErrorState.
|
||||
require.Equal(t, StateType("ErrorState"), ctx.StateMachine.current)
|
||||
}
|
@ -0,0 +1,26 @@
|
||||
package fsm
|
||||
|
||||
import (
|
||||
"github.com/btcsuite/btclog"
|
||||
"github.com/lightningnetwork/lnd/build"
|
||||
)
|
||||
|
||||
// Subsystem defines the sub system name of this package.
|
||||
const Subsystem = "FSM"
|
||||
|
||||
// log is a logger that is initialized with no output filters. This
|
||||
// means the package will not perform any logging by default until the caller
|
||||
// requests it.
|
||||
var log btclog.Logger
|
||||
|
||||
// The default amount of logging is none.
|
||||
func init() {
|
||||
UseLogger(build.NewSubLogger(Subsystem, nil))
|
||||
}
|
||||
|
||||
// UseLogger uses a specified Logger to output package logging info.
|
||||
// This should be used in preference to SetLogWriter if the caller is also
|
||||
// using btclog.
|
||||
func UseLogger(logger btclog.Logger) {
|
||||
log = logger
|
||||
}
|
@ -0,0 +1,134 @@
|
||||
package fsm
|
||||
|
||||
import (
|
||||
"context"
|
||||
"sync"
|
||||
"time"
|
||||
)
|
||||
|
||||
// CachedObserver is an observer that caches all states and transitions of
|
||||
// the observed state machine.
|
||||
type CachedObserver struct {
|
||||
lastNotification Notification
|
||||
cachedNotifications *FixedSizeSlice[Notification]
|
||||
|
||||
notificationCond *sync.Cond
|
||||
notificationMx sync.Mutex
|
||||
}
|
||||
|
||||
// NewCachedObserver creates a new cached observer with the given maximum
|
||||
// number of cached notifications.
|
||||
func NewCachedObserver(maxElements int) *CachedObserver {
|
||||
fixedSizeSlice := NewFixedSizeSlice[Notification](maxElements)
|
||||
observer := &CachedObserver{
|
||||
cachedNotifications: fixedSizeSlice,
|
||||
}
|
||||
observer.notificationCond = sync.NewCond(&observer.notificationMx)
|
||||
|
||||
return observer
|
||||
}
|
||||
|
||||
// Notify implements the Observer interface.
|
||||
func (c *CachedObserver) Notify(notification Notification) {
|
||||
c.notificationMx.Lock()
|
||||
defer c.notificationMx.Unlock()
|
||||
|
||||
c.cachedNotifications.Add(notification)
|
||||
c.lastNotification = notification
|
||||
c.notificationCond.Broadcast()
|
||||
}
|
||||
|
||||
// GetCachedNotifications returns a copy of the cached notifications.
|
||||
func (c *CachedObserver) GetCachedNotifications() []Notification {
|
||||
c.notificationMx.Lock()
|
||||
defer c.notificationMx.Unlock()
|
||||
|
||||
return c.cachedNotifications.Get()
|
||||
}
|
||||
|
||||
// WaitForState waits for the state machine to reach the given state.
|
||||
func (s *CachedObserver) WaitForState(ctx context.Context,
|
||||
timeout time.Duration, state StateType) error {
|
||||
|
||||
timeoutCtx, cancel := context.WithTimeout(ctx, timeout)
|
||||
defer cancel()
|
||||
|
||||
// Channel to notify when the desired state is reached
|
||||
ch := make(chan struct{})
|
||||
|
||||
// Goroutine to wait on condition variable
|
||||
go func() {
|
||||
s.notificationMx.Lock()
|
||||
defer s.notificationMx.Unlock()
|
||||
|
||||
for {
|
||||
// Check if the last state is the desired state
|
||||
if s.lastNotification.NextState == state {
|
||||
ch <- struct{}{}
|
||||
return
|
||||
}
|
||||
|
||||
// Otherwise, wait for the next notification
|
||||
s.notificationCond.Wait()
|
||||
}
|
||||
}()
|
||||
|
||||
// Wait for either the condition to be met or for a timeout
|
||||
select {
|
||||
case <-timeoutCtx.Done():
|
||||
return NewErrWaitingForStateTimeout(
|
||||
state, s.lastNotification.NextState,
|
||||
)
|
||||
case <-ch:
|
||||
return nil
|
||||
}
|
||||
}
|
||||
|
||||
// FixedSizeSlice is a slice with a fixed size.
|
||||
type FixedSizeSlice[T any] struct {
|
||||
data []T
|
||||
maxLen int
|
||||
|
||||
sync.Mutex
|
||||
}
|
||||
|
||||
// NewFixedSlice initializes a new FixedSlice with a given maximum length.
|
||||
func NewFixedSizeSlice[T any](maxLen int) *FixedSizeSlice[T] {
|
||||
return &FixedSizeSlice[T]{
|
||||
data: make([]T, 0, maxLen),
|
||||
maxLen: maxLen,
|
||||
}
|
||||
}
|
||||
|
||||
// Add appends a new element to the slice. If the slice reaches its maximum
|
||||
// length, the first element is removed.
|
||||
func (fs *FixedSizeSlice[T]) Add(element T) {
|
||||
fs.Lock()
|
||||
defer fs.Unlock()
|
||||
|
||||
if len(fs.data) == fs.maxLen {
|
||||
// Remove the first element
|
||||
fs.data = fs.data[1:]
|
||||
}
|
||||
// Add the new element
|
||||
fs.data = append(fs.data, element)
|
||||
}
|
||||
|
||||
// Get returns a copy of the slice.
|
||||
func (fs *FixedSizeSlice[T]) Get() []T {
|
||||
fs.Lock()
|
||||
defer fs.Unlock()
|
||||
|
||||
data := make([]T, len(fs.data))
|
||||
copy(data, fs.data)
|
||||
|
||||
return data
|
||||
}
|
||||
|
||||
// GetElement returns the element at the given index.
|
||||
func (fs *FixedSizeSlice[T]) GetElement(index int) T {
|
||||
fs.Lock()
|
||||
defer fs.Unlock()
|
||||
|
||||
return fs.data[index]
|
||||
}
|
@ -0,0 +1,96 @@
|
||||
package main
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"errors"
|
||||
"flag"
|
||||
"fmt"
|
||||
"os"
|
||||
"path/filepath"
|
||||
"sort"
|
||||
|
||||
"github.com/lightninglabs/loop/fsm"
|
||||
)
|
||||
|
||||
func main() {
|
||||
if err := run(); err != nil {
|
||||
fmt.Println(err)
|
||||
os.Exit(1)
|
||||
}
|
||||
}
|
||||
|
||||
func run() error {
|
||||
out := flag.String("out", "", "outfile")
|
||||
stateMachine := flag.String("fsm", "", "the swap state machine to parse")
|
||||
flag.Parse()
|
||||
|
||||
if filepath.Ext(*out) != ".md" {
|
||||
return errors.New("wrong argument: out must be a .md file")
|
||||
}
|
||||
|
||||
fp, err := filepath.Abs(*out)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
switch *stateMachine {
|
||||
case "example":
|
||||
exampleFSM := &fsm.ExampleFSM{}
|
||||
err = writeMermaidFile(fp, exampleFSM.GetStates())
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
default:
|
||||
fmt.Println("Missing or wrong argument: fsm must be one of:")
|
||||
fmt.Println("\treservations")
|
||||
fmt.Println("\texample")
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func writeMermaidFile(filename string, states fsm.States) error {
|
||||
f, err := os.Create(filename)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
defer f.Close()
|
||||
|
||||
var b bytes.Buffer
|
||||
fmt.Fprint(&b, "```mermaid\nstateDiagram-v2\n")
|
||||
|
||||
sortedStates := sortedKeys(states)
|
||||
for _, state := range sortedStates {
|
||||
edges := states[fsm.StateType(state)]
|
||||
// write state name
|
||||
if len(state) > 0 {
|
||||
fmt.Fprintf(&b, "%s\n", state)
|
||||
} else {
|
||||
state = "[*]"
|
||||
}
|
||||
// write transitions
|
||||
for edge, target := range edges.Transitions {
|
||||
fmt.Fprintf(&b, "%s --> %s: %s\n", state, target, edge)
|
||||
}
|
||||
}
|
||||
|
||||
fmt.Fprint(&b, "```")
|
||||
_, err = f.Write(b.Bytes())
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func sortedKeys(m fsm.States) []string {
|
||||
keys := make([]string, len(m))
|
||||
i := 0
|
||||
for k := range m {
|
||||
keys[i] = string(k)
|
||||
i++
|
||||
}
|
||||
sort.Strings(keys)
|
||||
return keys
|
||||
}
|
@ -0,0 +1,2 @@
|
||||
#!/usr/bin/env bash
|
||||
go run ./fsm/stateparser/stateparser.go --out ./fsm/example_fsm.md --fsm example
|
Loading…
Reference in New Issue