mirror of
https://github.com/lightninglabs/loop
synced 2024-11-08 01:10:29 +00:00
Merge pull request #741 from bhandras/fsm-observer-fixup
fsm: add WaitForStateAsync to the cached observer
This commit is contained in:
commit
3843c3906d
@ -243,3 +243,82 @@ func TestExampleFSMFlow(t *testing.T) {
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
// TestObserverAsyncWait tests the observer's WaitForStateAsync function.
|
||||
func TestObserverAsyncWait(t *testing.T) {
|
||||
testCases := []struct {
|
||||
name string
|
||||
waitTime time.Duration
|
||||
blockTime time.Duration
|
||||
expectTimeout bool
|
||||
}{
|
||||
{
|
||||
name: "success",
|
||||
waitTime: time.Second,
|
||||
blockTime: time.Millisecond,
|
||||
expectTimeout: false,
|
||||
},
|
||||
{
|
||||
name: "timeout",
|
||||
waitTime: time.Millisecond,
|
||||
blockTime: time.Second,
|
||||
expectTimeout: true,
|
||||
},
|
||||
}
|
||||
|
||||
for _, tc := range testCases {
|
||||
tc := tc
|
||||
|
||||
t.Run(tc.name, func(t *testing.T) {
|
||||
service := &mockService{
|
||||
respondChan: make(chan bool),
|
||||
}
|
||||
|
||||
store := &mockStore{}
|
||||
|
||||
exampleContext := NewExampleFSMContext(service, store)
|
||||
cachedObserver := NewCachedObserver(100)
|
||||
exampleContext.RegisterObserver(cachedObserver)
|
||||
|
||||
t0 := time.Now()
|
||||
timeoutCtx, cancel := context.WithTimeout(
|
||||
context.Background(), tc.waitTime,
|
||||
)
|
||||
defer cancel()
|
||||
|
||||
// Wait for the final state.
|
||||
errChan := cachedObserver.WaitForStateAsync(
|
||||
timeoutCtx, StuffSuccess, true,
|
||||
)
|
||||
|
||||
go func() {
|
||||
err := exampleContext.SendEvent(
|
||||
OnRequestStuff,
|
||||
newInitStuffRequest(),
|
||||
)
|
||||
|
||||
require.NoError(t, err)
|
||||
|
||||
time.Sleep(tc.blockTime)
|
||||
service.respondChan <- true
|
||||
}()
|
||||
|
||||
timeout := false
|
||||
select {
|
||||
case <-timeoutCtx.Done():
|
||||
timeout = true
|
||||
|
||||
case <-errChan:
|
||||
}
|
||||
require.Equal(t, tc.expectTimeout, timeout)
|
||||
|
||||
t1 := time.Now()
|
||||
diff := t1.Sub(t0)
|
||||
if tc.expectTimeout {
|
||||
require.Less(t, diff, tc.blockTime)
|
||||
} else {
|
||||
require.Less(t, diff, tc.waitTime)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
33
fsm/fsm.go
33
fsm/fsm.go
@ -306,23 +306,36 @@ func NoOpAction(_ EventContext) EventType {
|
||||
}
|
||||
|
||||
// ErrConfigError is an error returned when the state machine is misconfigured.
|
||||
type ErrConfigError error
|
||||
type ErrConfigError struct {
|
||||
msg string
|
||||
}
|
||||
|
||||
// Error returns the error message.
|
||||
func (e ErrConfigError) Error() string {
|
||||
return fmt.Sprintf("config error: %s", e.msg)
|
||||
}
|
||||
|
||||
// NewErrConfigError creates a new ErrConfigError.
|
||||
func NewErrConfigError(msg string) ErrConfigError {
|
||||
return (ErrConfigError)(fmt.Errorf("config error: %s", msg))
|
||||
return ErrConfigError{
|
||||
msg: msg,
|
||||
}
|
||||
}
|
||||
|
||||
// ErrWaitingForStateTimeout is an error returned when the state machine times
|
||||
// out while waiting for a state.
|
||||
type ErrWaitingForStateTimeout error
|
||||
type ErrWaitingForStateTimeout struct {
|
||||
expected StateType
|
||||
}
|
||||
|
||||
// Error returns the error message.
|
||||
func (e ErrWaitingForStateTimeout) Error() string {
|
||||
return fmt.Sprintf("waiting for state timed out: %s", e.expected)
|
||||
}
|
||||
|
||||
// NewErrWaitingForStateTimeout creates a new ErrWaitingForStateTimeout.
|
||||
func NewErrWaitingForStateTimeout(expected,
|
||||
actual StateType) ErrWaitingForStateTimeout {
|
||||
|
||||
return (ErrWaitingForStateTimeout)(fmt.Errorf(
|
||||
"waiting for state timeout: expected %s, actual: %s",
|
||||
expected, actual,
|
||||
))
|
||||
func NewErrWaitingForStateTimeout(expected StateType) ErrWaitingForStateTimeout {
|
||||
return ErrWaitingForStateTimeout{
|
||||
expected: expected,
|
||||
}
|
||||
}
|
||||
|
@ -100,12 +100,11 @@ func WithAbortEarlyOnErrorOption() WaitForStateOption {
|
||||
// the given duration before checking the state. This is useful if the
|
||||
// function is called immediately after sending an event to the state machine
|
||||
// and the state machine needs some time to process the event.
|
||||
func (s *CachedObserver) WaitForState(ctx context.Context,
|
||||
func (c *CachedObserver) WaitForState(ctx context.Context,
|
||||
timeout time.Duration, state StateType,
|
||||
opts ...WaitForStateOption) error {
|
||||
|
||||
var options fsmOptions
|
||||
|
||||
for _, opt := range opts {
|
||||
opt.apply(&options)
|
||||
}
|
||||
@ -120,61 +119,77 @@ func (s *CachedObserver) WaitForState(ctx context.Context,
|
||||
}
|
||||
}
|
||||
|
||||
// Create a new context with a timeout.
|
||||
timeoutCtx, cancel := context.WithTimeout(ctx, timeout)
|
||||
defer cancel()
|
||||
|
||||
// Channel to notify when the desired state is reached
|
||||
// or an error occurred.
|
||||
ch := make(chan error)
|
||||
ch := c.WaitForStateAsync(timeoutCtx, state, options.abortEarlyOnError)
|
||||
|
||||
// Goroutine to wait on condition variable
|
||||
// Wait for either the condition to be met or for a timeout.
|
||||
select {
|
||||
case <-timeoutCtx.Done():
|
||||
return NewErrWaitingForStateTimeout(state)
|
||||
|
||||
case err := <-ch:
|
||||
return err
|
||||
}
|
||||
}
|
||||
|
||||
// WaitForStateAsync waits asynchronously until the passed context is canceled
|
||||
// or the expected state is reached. The function returns a channel that will
|
||||
// receive an error if the expected state is reached or an error occurred. If
|
||||
// the context is canceled before the expected state is reached, the channel
|
||||
// will receive an ErrWaitingForStateTimeout error.
|
||||
func (c *CachedObserver) WaitForStateAsync(ctx context.Context, state StateType,
|
||||
abortOnEarlyError bool) chan error {
|
||||
|
||||
// Channel to notify when the desired state is reached or an error
|
||||
// occurred.
|
||||
ch := make(chan error, 1)
|
||||
|
||||
// Wait on the notification condition variable asynchronously to avoid
|
||||
// blocking the caller.
|
||||
go func() {
|
||||
s.notificationMx.Lock()
|
||||
defer s.notificationMx.Unlock()
|
||||
c.notificationMx.Lock()
|
||||
defer c.notificationMx.Unlock()
|
||||
|
||||
// writeResult writes the result to the channel. If the context
|
||||
// is canceled, an ErrWaitingForStateTimeout error is written
|
||||
// to the channel.
|
||||
writeResult := func(err error) {
|
||||
select {
|
||||
case <-ctx.Done():
|
||||
ch <- NewErrWaitingForStateTimeout(
|
||||
state,
|
||||
)
|
||||
|
||||
case ch <- err:
|
||||
}
|
||||
}
|
||||
|
||||
for {
|
||||
// Check if the last state is the desired state
|
||||
if s.lastNotification.NextState == state {
|
||||
select {
|
||||
case <-timeoutCtx.Done():
|
||||
return
|
||||
// Check if the last state is the desired state.
|
||||
if c.lastNotification.NextState == state {
|
||||
writeResult(nil)
|
||||
return
|
||||
}
|
||||
|
||||
case ch <- nil:
|
||||
// Check if an error has occurred.
|
||||
if c.lastNotification.Event == OnError {
|
||||
lastErr := c.lastNotification.LastActionError
|
||||
if abortOnEarlyError {
|
||||
writeResult(lastErr)
|
||||
return
|
||||
}
|
||||
}
|
||||
|
||||
// Check if an error occurred
|
||||
if s.lastNotification.Event == OnError {
|
||||
if options.abortEarlyOnError {
|
||||
select {
|
||||
case <-timeoutCtx.Done():
|
||||
return
|
||||
|
||||
case ch <- s.lastNotification.LastActionError:
|
||||
return
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Otherwise, wait for the next notification
|
||||
s.notificationCond.Wait()
|
||||
// Otherwise use the conditional variable to wait for
|
||||
// the next notification.
|
||||
c.notificationCond.Wait()
|
||||
}
|
||||
}()
|
||||
|
||||
// Wait for either the condition to be met or for a timeout
|
||||
select {
|
||||
case <-timeoutCtx.Done():
|
||||
return NewErrWaitingForStateTimeout(
|
||||
state, s.lastNotification.NextState,
|
||||
)
|
||||
|
||||
case lastActionErr := <-ch:
|
||||
if lastActionErr != nil {
|
||||
return lastActionErr
|
||||
}
|
||||
return nil
|
||||
}
|
||||
return ch
|
||||
}
|
||||
|
||||
// FixedSizeSlice is a slice with a fixed size.
|
||||
|
Loading…
Reference in New Issue
Block a user