From bd5418337c81af9691b4f85e988e2649cabbc2ed Mon Sep 17 00:00:00 2001 From: Andras Banki-Horvath Date: Mon, 2 Sep 2024 13:35:38 +0200 Subject: [PATCH] loopdb: properly lock store mock for concurrent access --- loopdb/store_mock.go | 33 +++++++++++++++++++++++++++++++++ 1 file changed, 33 insertions(+) diff --git a/loopdb/store_mock.go b/loopdb/store_mock.go index 955ae5c..efaf8c4 100644 --- a/loopdb/store_mock.go +++ b/loopdb/store_mock.go @@ -4,6 +4,7 @@ import ( "context" "errors" "fmt" + "sync" "testing" "time" @@ -14,6 +15,8 @@ import ( // StoreMock implements a mock client swap store. type StoreMock struct { + sync.RWMutex + LoopOutSwaps map[lntypes.Hash]*LoopOutContract LoopOutUpdates map[lntypes.Hash][]SwapStateData loopOutStoreChan chan LoopOutContract @@ -50,6 +53,9 @@ func NewStoreMock(t *testing.T) *StoreMock { // // NOTE: Part of the SwapStore interface. func (s *StoreMock) FetchLoopOutSwaps(ctx context.Context) ([]*LoopOut, error) { + s.RLock() + defer s.RUnlock() + result := []*LoopOut{} for hash, contract := range s.LoopOutSwaps { @@ -80,6 +86,9 @@ func (s *StoreMock) FetchLoopOutSwaps(ctx context.Context) ([]*LoopOut, error) { func (s *StoreMock) FetchLoopOutSwap(ctx context.Context, hash lntypes.Hash) (*LoopOut, error) { + s.RLock() + defer s.RUnlock() + contract, ok := s.LoopOutSwaps[hash] if !ok { return nil, errors.New("swap not found") @@ -110,6 +119,9 @@ func (s *StoreMock) FetchLoopOutSwap(ctx context.Context, func (s *StoreMock) CreateLoopOut(ctx context.Context, hash lntypes.Hash, swap *LoopOutContract) error { + s.Lock() + defer s.Unlock() + _, ok := s.LoopOutSwaps[hash] if ok { return errors.New("swap already exists") @@ -126,6 +138,9 @@ func (s *StoreMock) CreateLoopOut(ctx context.Context, hash lntypes.Hash, func (s *StoreMock) FetchLoopInSwaps(ctx context.Context) ([]*LoopIn, error) { + s.RLock() + defer s.RUnlock() + result := []*LoopIn{} for hash, contract := range s.LoopInSwaps { @@ -156,6 +171,9 @@ func (s *StoreMock) FetchLoopInSwaps(ctx context.Context) ([]*LoopIn, func (s *StoreMock) CreateLoopIn(ctx context.Context, hash lntypes.Hash, swap *LoopInContract) error { + s.Lock() + defer s.Unlock() + _, ok := s.LoopInSwaps[hash] if ok { return errors.New("swap already exists") @@ -176,6 +194,9 @@ func (s *StoreMock) CreateLoopIn(ctx context.Context, hash lntypes.Hash, func (s *StoreMock) UpdateLoopOut(ctx context.Context, hash lntypes.Hash, time time.Time, state SwapStateData) error { + s.Lock() + defer s.Unlock() + updates, ok := s.LoopOutUpdates[hash] if !ok { return errors.New("swap does not exists") @@ -196,6 +217,9 @@ func (s *StoreMock) UpdateLoopOut(ctx context.Context, hash lntypes.Hash, func (s *StoreMock) UpdateLoopIn(ctx context.Context, hash lntypes.Hash, time time.Time, state SwapStateData) error { + s.Lock() + defer s.Unlock() + updates, ok := s.LoopInUpdates[hash] if !ok { return errors.New("swap does not exists") @@ -347,6 +371,9 @@ func (b *StoreMock) BatchInsertUpdate(ctx context.Context, func (s *StoreMock) BatchUpdateLoopOutSwapCosts(ctx context.Context, costs map[lntypes.Hash]SwapCost) error { + s.Lock() + defer s.Unlock() + for hash, cost := range costs { if _, ok := s.LoopOutUpdates[hash]; !ok { return fmt.Errorf("swap has no updates: %v", hash) @@ -367,6 +394,9 @@ func (s *StoreMock) BatchUpdateLoopOutSwapCosts(ctx context.Context, func (s *StoreMock) HasMigration(ctx context.Context, migrationID string) ( bool, error) { + s.RLock() + defer s.RUnlock() + _, ok := s.migrations[migrationID] return ok, nil @@ -376,6 +406,9 @@ func (s *StoreMock) HasMigration(ctx context.Context, migrationID string) ( func (s *StoreMock) SetMigration(ctx context.Context, migrationID string) error { + s.Lock() + defer s.Unlock() + if _, ok := s.migrations[migrationID]; ok { return errors.New("migration already done") }