Merge pull request #741 from bhandras/fsm-observer-fixup

fsm: add WaitForStateAsync to the cached observer
pull/745/head
András Bánki-Horváth 3 weeks ago committed by GitHub
commit 3843c3906d
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194

@ -243,3 +243,82 @@ func TestExampleFSMFlow(t *testing.T) {
})
}
}
// TestObserverAsyncWait tests the observer's WaitForStateAsync function.
func TestObserverAsyncWait(t *testing.T) {
testCases := []struct {
name string
waitTime time.Duration
blockTime time.Duration
expectTimeout bool
}{
{
name: "success",
waitTime: time.Second,
blockTime: time.Millisecond,
expectTimeout: false,
},
{
name: "timeout",
waitTime: time.Millisecond,
blockTime: time.Second,
expectTimeout: true,
},
}
for _, tc := range testCases {
tc := tc
t.Run(tc.name, func(t *testing.T) {
service := &mockService{
respondChan: make(chan bool),
}
store := &mockStore{}
exampleContext := NewExampleFSMContext(service, store)
cachedObserver := NewCachedObserver(100)
exampleContext.RegisterObserver(cachedObserver)
t0 := time.Now()
timeoutCtx, cancel := context.WithTimeout(
context.Background(), tc.waitTime,
)
defer cancel()
// Wait for the final state.
errChan := cachedObserver.WaitForStateAsync(
timeoutCtx, StuffSuccess, true,
)
go func() {
err := exampleContext.SendEvent(
OnRequestStuff,
newInitStuffRequest(),
)
require.NoError(t, err)
time.Sleep(tc.blockTime)
service.respondChan <- true
}()
timeout := false
select {
case <-timeoutCtx.Done():
timeout = true
case <-errChan:
}
require.Equal(t, tc.expectTimeout, timeout)
t1 := time.Now()
diff := t1.Sub(t0)
if tc.expectTimeout {
require.Less(t, diff, tc.blockTime)
} else {
require.Less(t, diff, tc.waitTime)
}
})
}
}

@ -306,23 +306,36 @@ func NoOpAction(_ EventContext) EventType {
}
// ErrConfigError is an error returned when the state machine is misconfigured.
type ErrConfigError error
type ErrConfigError struct {
msg string
}
// Error returns the error message.
func (e ErrConfigError) Error() string {
return fmt.Sprintf("config error: %s", e.msg)
}
// NewErrConfigError creates a new ErrConfigError.
func NewErrConfigError(msg string) ErrConfigError {
return (ErrConfigError)(fmt.Errorf("config error: %s", msg))
return ErrConfigError{
msg: msg,
}
}
// ErrWaitingForStateTimeout is an error returned when the state machine times
// out while waiting for a state.
type ErrWaitingForStateTimeout error
type ErrWaitingForStateTimeout struct {
expected StateType
}
// NewErrWaitingForStateTimeout creates a new ErrWaitingForStateTimeout.
func NewErrWaitingForStateTimeout(expected,
actual StateType) ErrWaitingForStateTimeout {
// Error returns the error message.
func (e ErrWaitingForStateTimeout) Error() string {
return fmt.Sprintf("waiting for state timed out: %s", e.expected)
}
return (ErrWaitingForStateTimeout)(fmt.Errorf(
"waiting for state timeout: expected %s, actual: %s",
expected, actual,
))
// NewErrWaitingForStateTimeout creates a new ErrWaitingForStateTimeout.
func NewErrWaitingForStateTimeout(expected StateType) ErrWaitingForStateTimeout {
return ErrWaitingForStateTimeout{
expected: expected,
}
}

@ -100,12 +100,11 @@ func WithAbortEarlyOnErrorOption() WaitForStateOption {
// the given duration before checking the state. This is useful if the
// function is called immediately after sending an event to the state machine
// and the state machine needs some time to process the event.
func (s *CachedObserver) WaitForState(ctx context.Context,
func (c *CachedObserver) WaitForState(ctx context.Context,
timeout time.Duration, state StateType,
opts ...WaitForStateOption) error {
var options fsmOptions
for _, opt := range opts {
opt.apply(&options)
}
@ -120,61 +119,77 @@ func (s *CachedObserver) WaitForState(ctx context.Context,
}
}
// Create a new context with a timeout.
timeoutCtx, cancel := context.WithTimeout(ctx, timeout)
defer cancel()
// Channel to notify when the desired state is reached
// or an error occurred.
ch := make(chan error)
ch := c.WaitForStateAsync(timeoutCtx, state, options.abortEarlyOnError)
// Goroutine to wait on condition variable
go func() {
s.notificationMx.Lock()
defer s.notificationMx.Unlock()
// Wait for either the condition to be met or for a timeout.
select {
case <-timeoutCtx.Done():
return NewErrWaitingForStateTimeout(state)
for {
// Check if the last state is the desired state
if s.lastNotification.NextState == state {
select {
case <-timeoutCtx.Done():
return
case err := <-ch:
return err
}
}
case ch <- nil:
return
}
// WaitForStateAsync waits asynchronously until the passed context is canceled
// or the expected state is reached. The function returns a channel that will
// receive an error if the expected state is reached or an error occurred. If
// the context is canceled before the expected state is reached, the channel
// will receive an ErrWaitingForStateTimeout error.
func (c *CachedObserver) WaitForStateAsync(ctx context.Context, state StateType,
abortOnEarlyError bool) chan error {
// Channel to notify when the desired state is reached or an error
// occurred.
ch := make(chan error, 1)
// Wait on the notification condition variable asynchronously to avoid
// blocking the caller.
go func() {
c.notificationMx.Lock()
defer c.notificationMx.Unlock()
// writeResult writes the result to the channel. If the context
// is canceled, an ErrWaitingForStateTimeout error is written
// to the channel.
writeResult := func(err error) {
select {
case <-ctx.Done():
ch <- NewErrWaitingForStateTimeout(
state,
)
case ch <- err:
}
}
// Check if an error occurred
if s.lastNotification.Event == OnError {
if options.abortEarlyOnError {
select {
case <-timeoutCtx.Done():
return
for {
// Check if the last state is the desired state.
if c.lastNotification.NextState == state {
writeResult(nil)
return
}
case ch <- s.lastNotification.LastActionError:
return
}
// Check if an error has occurred.
if c.lastNotification.Event == OnError {
lastErr := c.lastNotification.LastActionError
if abortOnEarlyError {
writeResult(lastErr)
return
}
}
// Otherwise, wait for the next notification
s.notificationCond.Wait()
// Otherwise use the conditional variable to wait for
// the next notification.
c.notificationCond.Wait()
}
}()
// Wait for either the condition to be met or for a timeout
select {
case <-timeoutCtx.Done():
return NewErrWaitingForStateTimeout(
state, s.lastNotification.NextState,
)
case lastActionErr := <-ch:
if lastActionErr != nil {
return lastActionErr
}
return nil
}
return ch
}
// FixedSizeSlice is a slice with a fixed size.

Loading…
Cancel
Save