mirror of
https://github.com/lightninglabs/loop
synced 2024-11-09 19:10:47 +00:00
811e9dff99
By adding WaitForStateAsync to the observer we can always observe state changes in an atomic way without relying on the observer's internal cache.
342 lines
9.0 KiB
Go
342 lines
9.0 KiB
Go
package fsm
|
|
|
|
import (
|
|
"errors"
|
|
"fmt"
|
|
"sync"
|
|
)
|
|
|
|
// ErrEventRejected is the error returned when the state machine cannot process
|
|
// an event in the state that it is in.
|
|
var (
|
|
ErrEventRejected = errors.New("event rejected")
|
|
ErrWaitForStateTimedOut = errors.New(
|
|
"timed out while waiting for event",
|
|
)
|
|
ErrInvalidContextType = errors.New("invalid context")
|
|
ErrWaitingForStateEarlyAbortError = errors.New(
|
|
"waiting for state early abort",
|
|
)
|
|
)
|
|
|
|
const (
|
|
// EmptyState represents the default state of the system.
|
|
EmptyState StateType = ""
|
|
|
|
// NoOp represents a no-op event.
|
|
NoOp EventType = "NoOp"
|
|
|
|
// OnError can be used when an action returns a generic error.
|
|
OnError EventType = "OnError"
|
|
|
|
// ContextValidationFailed can be when the passed context if
|
|
// not of the expected type.
|
|
ContextValidationFailed EventType = "ContextValidationFailed"
|
|
)
|
|
|
|
// StateType represents an extensible state type in the state machine.
|
|
type StateType string
|
|
|
|
// EventType represents an extensible event type in the state machine.
|
|
type EventType string
|
|
|
|
// EventContext represents the context to be passed to the action
|
|
// implementation.
|
|
type EventContext interface{}
|
|
|
|
// Action represents the action to be executed in a given state.
|
|
type Action func(eventCtx EventContext) EventType
|
|
|
|
// Transitions represents a mapping of events and states.
|
|
type Transitions map[EventType]StateType
|
|
|
|
// State binds a state with an action and a set of events it can handle.
|
|
type State struct {
|
|
// EntryFunc is a function that is called when the state is entered.
|
|
EntryFunc func()
|
|
// ExitFunc is a function that is called when the state is exited.
|
|
ExitFunc func()
|
|
// Action is the action to be executed in the state.
|
|
Action Action
|
|
// Transitions is a mapping of events and states.
|
|
Transitions Transitions
|
|
}
|
|
|
|
// States represents a mapping of states and their implementations.
|
|
type States map[StateType]State
|
|
|
|
// Notification represents a notification sent to the state machine's
|
|
// notification channel.
|
|
type Notification struct {
|
|
// PreviousState is the state the state machine was in before the event
|
|
// was processed.
|
|
PreviousState StateType
|
|
// NextState is the state the state machine is in after the event was
|
|
// processed.
|
|
NextState StateType
|
|
// Event is the event that was processed.
|
|
Event EventType
|
|
// LastActionError is the error returned by the last action executed.
|
|
LastActionError error
|
|
}
|
|
|
|
// Observer is an interface that can be implemented by types that want to
|
|
// observe the state machine.
|
|
type Observer interface {
|
|
Notify(Notification)
|
|
}
|
|
|
|
// StateMachine represents the state machine.
|
|
type StateMachine struct {
|
|
// Context represents the state machine context.
|
|
States States
|
|
|
|
// ActionEntryFunc is a function that is called before an action is
|
|
// executed.
|
|
ActionEntryFunc func(Notification)
|
|
|
|
// ActionExitFunc is a function that is called after an action is
|
|
// executed, it is called with the EventType returned by the action.
|
|
ActionExitFunc func(NextEvent EventType)
|
|
|
|
// LastActionError is an error set by the last action executed.
|
|
LastActionError error
|
|
|
|
// DefaultObserver is the default observer that is notified when the
|
|
// state machine transitions between states.
|
|
DefaultObserver *CachedObserver
|
|
|
|
// previous represents the previous state.
|
|
previous StateType
|
|
|
|
// current represents the current state.
|
|
current StateType
|
|
|
|
// observers is a slice of observers that are notified when the state
|
|
// machine transitions between states.
|
|
observers []Observer
|
|
|
|
// observerMutex ensures that observers are only added or removed
|
|
// safely.
|
|
observerMutex sync.Mutex
|
|
|
|
// mutex ensures that only 1 event is processed by the state machine at
|
|
// any given time.
|
|
mutex sync.Mutex
|
|
}
|
|
|
|
// NewStateMachine creates a new state machine.
|
|
func NewStateMachine(states States, observerSize int) *StateMachine {
|
|
return NewStateMachineWithState(states, EmptyState, observerSize)
|
|
}
|
|
|
|
// NewStateMachineWithState creates a new state machine and sets the initial
|
|
// state.
|
|
func NewStateMachineWithState(states States, current StateType,
|
|
observerSize int) *StateMachine {
|
|
|
|
observers := []Observer{}
|
|
var defaultObserver *CachedObserver
|
|
|
|
if observerSize > 0 {
|
|
defaultObserver = NewCachedObserver(observerSize)
|
|
observers = append(observers, defaultObserver)
|
|
}
|
|
|
|
return &StateMachine{
|
|
States: states,
|
|
current: current,
|
|
DefaultObserver: defaultObserver,
|
|
observers: observers,
|
|
}
|
|
}
|
|
|
|
// getNextState returns the next state for the event given the machine's current
|
|
// state, or an error if the event can't be handled in the given state.
|
|
func (s *StateMachine) getNextState(event EventType) (State, error) {
|
|
var (
|
|
state State
|
|
ok bool
|
|
)
|
|
|
|
stateMap := s.States
|
|
|
|
if state, ok = stateMap[s.current]; !ok {
|
|
return State{}, NewErrConfigError("current state not found")
|
|
}
|
|
|
|
if state.Transitions == nil {
|
|
return State{}, NewErrConfigError(
|
|
"current state has no transitions",
|
|
)
|
|
}
|
|
|
|
var next StateType
|
|
if next, ok = state.Transitions[event]; !ok {
|
|
return State{}, NewErrConfigError(
|
|
"event not found in current transitions",
|
|
)
|
|
}
|
|
|
|
// Identify the state definition for the next state.
|
|
state, ok = stateMap[next]
|
|
if !ok {
|
|
return State{}, NewErrConfigError("next state not found")
|
|
}
|
|
|
|
if state.Action == nil {
|
|
return State{}, NewErrConfigError("next state has no action")
|
|
}
|
|
|
|
// Transition over to the next state.
|
|
s.previous = s.current
|
|
s.current = next
|
|
|
|
return state, nil
|
|
}
|
|
|
|
// SendEvent sends an event to the state machine. It returns an error if the
|
|
// event cannot be processed in the current state. Otherwise, it only returns
|
|
// nil if the event for the last action is a no-op.
|
|
func (s *StateMachine) SendEvent(event EventType, eventCtx EventContext) error {
|
|
s.mutex.Lock()
|
|
defer s.mutex.Unlock()
|
|
|
|
if s.States == nil {
|
|
return NewErrConfigError("state machine config is nil")
|
|
}
|
|
|
|
for {
|
|
// Determine the next state for the event given the machine's
|
|
// current state.
|
|
state, err := s.getNextState(event)
|
|
if err != nil {
|
|
log.Errorf("unable to get next state: %v from event: "+
|
|
"%v, current state: %v", err, event, s.current)
|
|
return ErrEventRejected
|
|
}
|
|
|
|
// Notify the state machine's observers.
|
|
s.observerMutex.Lock()
|
|
notification := Notification{
|
|
PreviousState: s.previous,
|
|
NextState: s.current,
|
|
Event: event,
|
|
LastActionError: s.LastActionError,
|
|
}
|
|
|
|
for _, observer := range s.observers {
|
|
observer.Notify(notification)
|
|
}
|
|
s.observerMutex.Unlock()
|
|
|
|
// Execute the state machines ActionEntryFunc.
|
|
if s.ActionEntryFunc != nil {
|
|
s.ActionEntryFunc(notification)
|
|
}
|
|
|
|
// Execute the current state's entry function
|
|
if state.EntryFunc != nil {
|
|
state.EntryFunc()
|
|
}
|
|
|
|
// Execute the next state's action and loop over again if the
|
|
// event returned is not a no-op.
|
|
nextEvent := state.Action(eventCtx)
|
|
|
|
// Execute the current state's exit function
|
|
if state.ExitFunc != nil {
|
|
state.ExitFunc()
|
|
}
|
|
|
|
// Execute the state machines ActionExitFunc.
|
|
if s.ActionExitFunc != nil {
|
|
s.ActionExitFunc(nextEvent)
|
|
}
|
|
|
|
// If the next event is a no-op, we're done.
|
|
if nextEvent == NoOp {
|
|
return nil
|
|
}
|
|
|
|
event = nextEvent
|
|
}
|
|
}
|
|
|
|
// RegisterObserver registers an observer with the state machine.
|
|
func (s *StateMachine) RegisterObserver(observer Observer) {
|
|
s.observerMutex.Lock()
|
|
defer s.observerMutex.Unlock()
|
|
|
|
if observer != nil {
|
|
s.observers = append(s.observers, observer)
|
|
}
|
|
}
|
|
|
|
// RemoveObserver removes an observer from the state machine. It returns true
|
|
// if the observer was removed, false otherwise.
|
|
func (s *StateMachine) RemoveObserver(observer Observer) bool {
|
|
s.observerMutex.Lock()
|
|
defer s.observerMutex.Unlock()
|
|
|
|
for i, o := range s.observers {
|
|
if o == observer {
|
|
s.observers = append(
|
|
s.observers[:i], s.observers[i+1:]...,
|
|
)
|
|
return true
|
|
}
|
|
}
|
|
|
|
return false
|
|
}
|
|
|
|
// HandleError is a helper function that can be used by actions to handle
|
|
// errors.
|
|
func (s *StateMachine) HandleError(err error) EventType {
|
|
log.Errorf("StateMachine error: %s", err)
|
|
s.LastActionError = err
|
|
return OnError
|
|
}
|
|
|
|
// NoOpAction is a no-op action that can be used by states that don't need to
|
|
// execute any action.
|
|
func NoOpAction(_ EventContext) EventType {
|
|
return NoOp
|
|
}
|
|
|
|
// ErrConfigError is an error returned when the state machine is misconfigured.
|
|
type ErrConfigError struct {
|
|
msg string
|
|
}
|
|
|
|
// Error returns the error message.
|
|
func (e ErrConfigError) Error() string {
|
|
return fmt.Sprintf("config error: %s", e.msg)
|
|
}
|
|
|
|
// NewErrConfigError creates a new ErrConfigError.
|
|
func NewErrConfigError(msg string) ErrConfigError {
|
|
return ErrConfigError{
|
|
msg: msg,
|
|
}
|
|
}
|
|
|
|
// ErrWaitingForStateTimeout is an error returned when the state machine times
|
|
// out while waiting for a state.
|
|
type ErrWaitingForStateTimeout struct {
|
|
expected StateType
|
|
}
|
|
|
|
// Error returns the error message.
|
|
func (e ErrWaitingForStateTimeout) Error() string {
|
|
return fmt.Sprintf("waiting for state timed out: %s", e.expected)
|
|
}
|
|
|
|
// NewErrWaitingForStateTimeout creates a new ErrWaitingForStateTimeout.
|
|
func NewErrWaitingForStateTimeout(expected StateType) ErrWaitingForStateTimeout {
|
|
return ErrWaitingForStateTimeout{
|
|
expected: expected,
|
|
}
|
|
}
|