diff --git a/sweepbatcher/sweep_batcher.go b/sweepbatcher/sweep_batcher.go index 744afd8..04b3c6b 100644 --- a/sweepbatcher/sweep_batcher.go +++ b/sweepbatcher/sweep_batcher.go @@ -701,7 +701,7 @@ func (b *Batcher) convertSweep(ctx context.Context, dbSweep *dbSweep) ( s, err := b.sweepStore.FetchSweep(ctx, dbSweep.SwapHash) if err != nil { - return nil, fmt.Errorf("failed to fetch sweep data for %x: %v", + return nil, fmt.Errorf("failed to fetch sweep data for %x: %w", dbSweep.SwapHash[:6], err) } @@ -748,7 +748,7 @@ func (f *SwapStoreWrapper) FetchSweep(ctx context.Context, swap, err := f.swapStore.FetchLoopOutSwap(ctx, swapHash) if err != nil { - return nil, fmt.Errorf("failed to fetch loop out for %x: %v", + return nil, fmt.Errorf("failed to fetch loop out for %x: %w", swapHash[:6], err) } @@ -756,14 +756,14 @@ func (f *SwapStoreWrapper) FetchSweep(ctx context.Context, swapHash, &swap.Contract.SwapContract, f.chainParams, ) if err != nil { - return nil, fmt.Errorf("failed to get htlc: %v", err) + return nil, fmt.Errorf("failed to get htlc: %w", err) } swapPaymentAddr, err := utils.ObtainSwapPaymentAddr( swap.Contract.SwapInvoice, f.chainParams, ) if err != nil { - return nil, fmt.Errorf("failed to get payment addr: %v", err) + return nil, fmt.Errorf("failed to get payment addr: %w", err) } return &SweepInfo{ @@ -798,7 +798,7 @@ func (b *Batcher) fetchSweep(ctx context.Context, s, err := b.sweepStore.FetchSweep(ctx, sweepReq.SwapHash) if err != nil { - return nil, fmt.Errorf("failed to fetch sweep data for %x: %v", + return nil, fmt.Errorf("failed to fetch sweep data for %x: %w", sweepReq.SwapHash[:6], err) } diff --git a/sweepbatcher/sweep_batcher_test.go b/sweepbatcher/sweep_batcher_test.go index d211e88..e7ece3e 100644 --- a/sweepbatcher/sweep_batcher_test.go +++ b/sweepbatcher/sweep_batcher_test.go @@ -3,6 +3,7 @@ package sweepbatcher import ( "context" "errors" + "fmt" "sync" "testing" "time" @@ -30,6 +31,17 @@ const ( ntfnBufferSize = 1024 ) +// destAddr is a dummy p2wkh address to use as the destination address for +// the swaps. +var destAddr = func() btcutil.Address { + p2wkhAddr := "bcrt1qq68r6ff4k4pjx39efs44gcyccf7unqnu5qtjjz" + addr, err := btcutil.DecodeAddress(p2wkhAddr, nil) + if err != nil { + panic(err) + } + return addr +}() + func testMuSig2SignSweep(ctx context.Context, protocolVersion loopdb.ProtocolVersion, swapHash lntypes.Hash, paymentAddr [32]byte, nonce []byte, sweepTxPsbt []byte, @@ -54,21 +66,34 @@ func checkBatcherError(t *testing.T, err error) { } } -// TestSweepBatcherBatchCreation tests that sweep requests enter the expected +// getOnlyBatch makes sure the batcher has exactly one batch and returns it. +func getOnlyBatch(batcher *Batcher) *batch { + if len(batcher.batches) != 1 { + panic(fmt.Sprintf("getOnlyBatch called on a batcher having "+ + "%d batches", len(batcher.batches))) + } + + for _, batch := range batcher.batches { + return batch + } + + panic("unreachable") +} + +// testSweepBatcherBatchCreation tests that sweep requests enter the expected // batch based on their timeout distance. -func TestSweepBatcherBatchCreation(t *testing.T) { +func testSweepBatcherBatchCreation(t *testing.T, store testStore, + batcherStore testBatcherStore) { + defer test.Guard(t)() lnd := test.NewMockLnd() ctx, cancel := context.WithCancel(context.Background()) defer cancel() - store := loopdb.NewStoreMock(t) sweepStore, err := NewSweepFetcherFromSwapStore(store, lnd.ChainParams) require.NoError(t, err) - batcherStore := NewStoreMock() - batcher := NewBatcher(lnd.WalletKit, lnd.ChainNotifier, lnd.Signer, testMuSig2SignSweep, nil, lnd.ChainParams, batcherStore, sweepStore) @@ -94,6 +119,7 @@ func TestSweepBatcherBatchCreation(t *testing.T) { AmountRequested: 111, }, + DestAddr: destAddr, SwapInvoice: swapInvoice, } @@ -133,7 +159,11 @@ func TestSweepBatcherBatchCreation(t *testing.T) { SwapContract: loopdb.SwapContract{ CltvExpiry: 111 + defaultMaxTimeoutDistance - 1, AmountRequested: 222, + + // Make preimage unique to pass SQL constraints. + Preimage: lntypes.Preimage{2}, }, + DestAddr: destAddr, SwapInvoice: swapInvoice, } @@ -165,7 +195,11 @@ func TestSweepBatcherBatchCreation(t *testing.T) { SwapContract: loopdb.SwapContract{ CltvExpiry: 111 + defaultMaxTimeoutDistance + 1, AmountRequested: 333, + + // Make preimage unique to pass SQL constraints. + Preimage: lntypes.Preimage{3}, }, + DestAddr: destAddr, SwapInvoice: swapInvoice, } @@ -211,21 +245,20 @@ func TestSweepBatcherBatchCreation(t *testing.T) { require.True(t, batcherStore.AssertSweepStored(sweepReq3.SwapHash)) } -// TestSweepBatcherSimpleLifecycle tests the simple lifecycle of the batches +// testSweepBatcherSimpleLifecycle tests the simple lifecycle of the batches // that are created and run by the batcher. -func TestSweepBatcherSimpleLifecycle(t *testing.T) { +func testSweepBatcherSimpleLifecycle(t *testing.T, store testStore, + batcherStore testBatcherStore) { + defer test.Guard(t)() lnd := test.NewMockLnd() ctx, cancel := context.WithCancel(context.Background()) defer cancel() - store := loopdb.NewStoreMock(t) sweepStore, err := NewSweepFetcherFromSwapStore(store, lnd.ChainParams) require.NoError(t, err) - batcherStore := NewStoreMock() - batcher := NewBatcher(lnd.WalletKit, lnd.ChainNotifier, lnd.Signer, testMuSig2SignSweep, nil, lnd.ChainParams, batcherStore, sweepStore) @@ -250,6 +283,7 @@ func TestSweepBatcherSimpleLifecycle(t *testing.T) { CltvExpiry: 111, AmountRequested: 111, }, + DestAddr: destAddr, SwapInvoice: swapInvoice, SweepConfTarget: 111, } @@ -351,21 +385,20 @@ func TestSweepBatcherSimpleLifecycle(t *testing.T) { }, test.Timeout, eventuallyCheckFrequency) } -// TestSweepBatcherSweepReentry tests that when an old version of the batch tx +// testSweepBatcherSweepReentry tests that when an old version of the batch tx // gets confirmed the sweep leftovers are sent back to the batcher. -func TestSweepBatcherSweepReentry(t *testing.T) { +func testSweepBatcherSweepReentry(t *testing.T, store testStore, + batcherStore testBatcherStore) { + defer test.Guard(t)() lnd := test.NewMockLnd() ctx, cancel := context.WithCancel(context.Background()) defer cancel() - store := loopdb.NewStoreMock(t) sweepStore, err := NewSweepFetcherFromSwapStore(store, lnd.ChainParams) require.NoError(t, err) - batcherStore := NewStoreMock() - batcher := NewBatcher(lnd.WalletKit, lnd.ChainNotifier, lnd.Signer, testMuSig2SignSweep, nil, lnd.ChainParams, batcherStore, sweepStore) @@ -391,6 +424,7 @@ func TestSweepBatcherSweepReentry(t *testing.T) { CltvExpiry: 111, AmountRequested: 111, }, + DestAddr: destAddr, SwapInvoice: swapInvoice, SweepConfTarget: 111, } @@ -413,7 +447,11 @@ func TestSweepBatcherSweepReentry(t *testing.T) { SwapContract: loopdb.SwapContract{ CltvExpiry: 111, AmountRequested: 222, + + // Make preimage unique to pass SQL constraints. + Preimage: lntypes.Preimage{2}, }, + DestAddr: destAddr, SwapInvoice: swapInvoice, SweepConfTarget: 111, } @@ -436,7 +474,11 @@ func TestSweepBatcherSweepReentry(t *testing.T) { SwapContract: loopdb.SwapContract{ CltvExpiry: 111, AmountRequested: 333, + + // Make preimage unique to pass SQL constraints. + Preimage: lntypes.Preimage{3}, }, + DestAddr: destAddr, SwapInvoice: swapInvoice, SweepConfTarget: 111, } @@ -561,21 +603,20 @@ func TestSweepBatcherSweepReentry(t *testing.T) { require.Equal(t, b.state, Open) } -// TestSweepBatcherNonWalletAddr tests that sweep requests that sweep to a non +// testSweepBatcherNonWalletAddr tests that sweep requests that sweep to a non // wallet address enter individual batches. -func TestSweepBatcherNonWalletAddr(t *testing.T) { +func testSweepBatcherNonWalletAddr(t *testing.T, store testStore, + batcherStore testBatcherStore) { + defer test.Guard(t)() lnd := test.NewMockLnd() ctx, cancel := context.WithCancel(context.Background()) defer cancel() - store := loopdb.NewStoreMock(t) sweepStore, err := NewSweepFetcherFromSwapStore(store, lnd.ChainParams) require.NoError(t, err) - batcherStore := NewStoreMock() - batcher := NewBatcher(lnd.WalletKit, lnd.ChainNotifier, lnd.Signer, testMuSig2SignSweep, nil, lnd.ChainParams, batcherStore, sweepStore) @@ -601,6 +642,7 @@ func TestSweepBatcherNonWalletAddr(t *testing.T) { AmountRequested: 111, }, IsExternalAddr: true, + DestAddr: destAddr, SwapInvoice: swapInvoice, } @@ -640,7 +682,11 @@ func TestSweepBatcherNonWalletAddr(t *testing.T) { SwapContract: loopdb.SwapContract{ CltvExpiry: 111 + defaultMaxTimeoutDistance - 1, AmountRequested: 222, + + // Make preimage unique to pass SQL constraints. + Preimage: lntypes.Preimage{2}, }, + DestAddr: destAddr, SwapInvoice: swapInvoice, IsExternalAddr: true, } @@ -677,7 +723,11 @@ func TestSweepBatcherNonWalletAddr(t *testing.T) { SwapContract: loopdb.SwapContract{ CltvExpiry: 111 + defaultMaxTimeoutDistance + 1, AmountRequested: 333, + + // Make preimage unique to pass SQL constraints. + Preimage: lntypes.Preimage{3}, }, + DestAddr: destAddr, SwapInvoice: swapInvoice, IsExternalAddr: true, } @@ -729,21 +779,20 @@ func TestSweepBatcherNonWalletAddr(t *testing.T) { require.True(t, batcherStore.AssertSweepStored(sweepReq3.SwapHash)) } -// TestSweepBatcherComposite tests that sweep requests that sweep to both wallet +// testSweepBatcherComposite tests that sweep requests that sweep to both wallet // addresses and non-wallet addresses enter the correct batches. -func TestSweepBatcherComposite(t *testing.T) { +func testSweepBatcherComposite(t *testing.T, store testStore, + batcherStore testBatcherStore) { + defer test.Guard(t)() lnd := test.NewMockLnd() ctx, cancel := context.WithCancel(context.Background()) defer cancel() - store := loopdb.NewStoreMock(t) sweepStore, err := NewSweepFetcherFromSwapStore(store, lnd.ChainParams) require.NoError(t, err) - batcherStore := NewStoreMock() - batcher := NewBatcher(lnd.WalletKit, lnd.ChainNotifier, lnd.Signer, testMuSig2SignSweep, nil, lnd.ChainParams, batcherStore, sweepStore) @@ -769,6 +818,7 @@ func TestSweepBatcherComposite(t *testing.T) { AmountRequested: 111, }, + DestAddr: destAddr, SwapInvoice: swapInvoice, } @@ -792,7 +842,11 @@ func TestSweepBatcherComposite(t *testing.T) { SwapContract: loopdb.SwapContract{ CltvExpiry: 111 + defaultMaxTimeoutDistance - 1, AmountRequested: 222, + + // Make preimage unique to pass SQL constraints. + Preimage: lntypes.Preimage{2}, }, + DestAddr: destAddr, SwapInvoice: swapInvoice, } @@ -816,7 +870,11 @@ func TestSweepBatcherComposite(t *testing.T) { SwapContract: loopdb.SwapContract{ CltvExpiry: 111 + defaultMaxTimeoutDistance - 3, AmountRequested: 333, + + // Make preimage unique to pass SQL constraints. + Preimage: lntypes.Preimage{3}, }, + DestAddr: destAddr, SwapInvoice: swapInvoice, IsExternalAddr: true, } @@ -841,7 +899,11 @@ func TestSweepBatcherComposite(t *testing.T) { SwapContract: loopdb.SwapContract{ CltvExpiry: 111 + defaultMaxTimeoutDistance + 1, AmountRequested: 444, + + // Make preimage unique to pass SQL constraints. + Preimage: lntypes.Preimage{4}, }, + DestAddr: destAddr, SwapInvoice: swapInvoice, } @@ -865,7 +927,11 @@ func TestSweepBatcherComposite(t *testing.T) { SwapContract: loopdb.SwapContract{ CltvExpiry: 111 + defaultMaxTimeoutDistance + 5, AmountRequested: 555, + + // Make preimage unique to pass SQL constraints. + Preimage: lntypes.Preimage{5}, }, + DestAddr: destAddr, SwapInvoice: swapInvoice, } @@ -889,7 +955,11 @@ func TestSweepBatcherComposite(t *testing.T) { SwapContract: loopdb.SwapContract{ CltvExpiry: 111 + defaultMaxTimeoutDistance + 6, AmountRequested: 666, + + // Make preimage unique to pass SQL constraints. + Preimage: lntypes.Preimage{6}, }, + DestAddr: destAddr, SwapInvoice: swapInvoice, IsExternalAddr: true, } @@ -1013,9 +1083,11 @@ func makeTestTx(value int64) *wire.MsgTx { return tx } -// TestGetFeePortionForSweep tests that the fee portion for a sweep is correctly +// testGetFeePortionForSweep tests that the fee portion for a sweep is correctly // calculated. -func TestGetFeePortionForSweep(t *testing.T) { +func testGetFeePortionForSweep(t *testing.T, store testStore, + batcherStore testBatcherStore) { + tests := []struct { name string spendTxValue int64 @@ -1050,19 +1122,19 @@ func TestGetFeePortionForSweep(t *testing.T) { } } -// TestRestoringEmptyBatch tests that the batcher can be restored with an empty +// testRestoringEmptyBatch tests that the batcher can be restored with an empty // batch. -func TestRestoringEmptyBatch(t *testing.T) { +func testRestoringEmptyBatch(t *testing.T, store testStore, + batcherStore testBatcherStore) { + defer test.Guard(t)() lnd := test.NewMockLnd() ctx, cancel := context.WithCancel(context.Background()) - store := loopdb.NewStoreMock(t) sweepStore, err := NewSweepFetcherFromSwapStore(store, lnd.ChainParams) require.NoError(t, err) - batcherStore := NewStoreMock() _, err = batcherStore.InsertSweepBatch(ctx, &dbBatch{}) require.NoError(t, err) @@ -1099,6 +1171,7 @@ func TestRestoringEmptyBatch(t *testing.T) { AmountRequested: 111, }, + DestAddr: destAddr, SwapInvoice: swapInvoice, } @@ -1138,11 +1211,20 @@ func TestRestoringEmptyBatch(t *testing.T) { type loopStoreMock struct { loops map[lntypes.Hash]*loopdb.LoopOut mu sync.Mutex + + // backend is the store passed to the test. An empty swap with the ID + // passed is stored to this place to satisfy SQL foreign key constraint. + backend testStore + + // preimage is last preimage first byte used in fake swap in backend. + // It has to be unique to satisfy SQL constraint. + preimage byte } -func newLoopStoreMock() *loopStoreMock { +func newLoopStoreMock(backend testStore) *loopStoreMock { return &loopStoreMock{ - loops: make(map[lntypes.Hash]*loopdb.LoopOut), + loops: make(map[lntypes.Hash]*loopdb.LoopOut), + backend: backend, } } @@ -1164,23 +1246,67 @@ func (s *loopStoreMock) putLoopOutSwap(hash lntypes.Hash, out *loopdb.LoopOut) { s.mu.Lock() defer s.mu.Unlock() + _, existed := s.loops[hash] s.loops[hash] = out + + if existed { + // The swap exists, no need to create one in backend, since it + // stores fake data anyway. + return + } + + if _, ok := s.backend.(*loopdb.StoreMock); ok { + // Do not create a fake loop in loopdb.StoreMock, because it + // blocks on notification channels and this is not needed. + return + } + + // Put a swap with the same ID to backend store to satisfy SQL foreign + // key constraint. Don't store the data to ensure it is not used. + err := s.backend.CreateLoopOut(context.Background(), hash, + &loopdb.LoopOutContract{ + SwapContract: loopdb.SwapContract{ + CltvExpiry: 999, + AmountRequested: 999, + + // Make preimage unique to pass SQL constraints. + Preimage: lntypes.Preimage{s.preimage}, + }, + + DestAddr: destAddr, + SwapInvoice: swapInvoice, + }, + ) + + s.backend.AssertLoopOutStored() + + // Make preimage unique to pass SQL constraints. + s.preimage++ + + if err != nil { + panic(err) + } } -// TestHandleSweepTwice tests that handing the same sweep twice must not +// AssertLoopOutStored asserts that a swap is stored. +func (s *loopStoreMock) AssertLoopOutStored() { + s.backend.AssertLoopOutStored() +} + +// testHandleSweepTwice tests that handing the same sweep twice must not // add it to different batches. -func TestHandleSweepTwice(t *testing.T) { +func testHandleSweepTwice(t *testing.T, backend testStore, + batcherStore testBatcherStore) { + defer test.Guard(t)() lnd := test.NewMockLnd() ctx, cancel := context.WithCancel(context.Background()) - store := newLoopStoreMock() + store := newLoopStoreMock(backend) sweepStore, err := NewSweepFetcherFromSwapStore(store, lnd.ChainParams) require.NoError(t, err) - batcherStore := NewStoreMock() - batcher := NewBatcher(lnd.WalletKit, lnd.ChainNotifier, lnd.Signer, testMuSig2SignSweep, nil, lnd.ChainParams, batcherStore, sweepStore) @@ -1221,6 +1347,7 @@ func TestHandleSweepTwice(t *testing.T) { CltvExpiry: shortCltv, AmountRequested: 111, }, + DestAddr: destAddr, SwapInvoice: swapInvoice, }, } @@ -1244,6 +1371,7 @@ func TestHandleSweepTwice(t *testing.T) { CltvExpiry: longCltv, AmountRequested: 222, }, + DestAddr: destAddr, SwapInvoice: swapInvoice, }, } @@ -1284,6 +1412,7 @@ func TestHandleSweepTwice(t *testing.T) { CltvExpiry: shortCltv, AmountRequested: 222, }, + DestAddr: destAddr, SwapInvoice: swapInvoice, }, } @@ -1300,24 +1429,31 @@ func TestHandleSweepTwice(t *testing.T) { return false } + // Find the batch with largest ID. It must be the second batch. + // Variable batches is a map, not a slice, so we have to visit + // all the items and find the one with maximum id. + var secondBatch *batch + for _, batch := range batches { + if secondBatch == nil || batch.id > secondBatch.id { + secondBatch = batch + } + } + // Make sure the second batch has the second sweep. - sweep2, has := batches[1].sweeps[sweepReq2.SwapHash] + sweep2, has := secondBatch.sweeps[sweepReq2.SwapHash] if !has { return false } // Make sure the second sweep's timeout has been updated. - if sweep2.timeout != shortCltv { - return false - } - - return true + return sweep2.timeout == shortCltv }, test.Timeout, eventuallyCheckFrequency) // Make sure each batch has one sweep. If the second sweep was added to // both batches, the following check won't pass. - require.Equal(t, 1, len(batcher.batches[0].sweeps)) - require.Equal(t, 1, len(batcher.batches[1].sweeps)) + for _, batch := range batcher.batches { + require.Equal(t, 1, len(batch.sweeps)) + } // Now make the batcher quit by canceling the context. cancel() @@ -1326,20 +1462,19 @@ func TestHandleSweepTwice(t *testing.T) { checkBatcherError(t, runErr) } -// TestRestoringPreservesConfTarget tests that after the batch is written to DB +// testRestoringPreservesConfTarget tests that after the batch is written to DB // and loaded back, its batchConfTarget value is preserved. -func TestRestoringPreservesConfTarget(t *testing.T) { +func testRestoringPreservesConfTarget(t *testing.T, store testStore, + batcherStore testBatcherStore) { + defer test.Guard(t)() lnd := test.NewMockLnd() ctx, cancel := context.WithCancel(context.Background()) - store := loopdb.NewStoreMock(t) sweepStore, err := NewSweepFetcherFromSwapStore(store, lnd.ChainParams) require.NoError(t, err) - batcherStore := NewStoreMock() - batcher := NewBatcher(lnd.WalletKit, lnd.ChainNotifier, lnd.Signer, testMuSig2SignSweep, nil, lnd.ChainParams, batcherStore, sweepStore) @@ -1373,6 +1508,7 @@ func TestRestoringPreservesConfTarget(t *testing.T) { AmountRequested: 111, }, + DestAddr: destAddr, SwapInvoice: swapInvoice, SweepConfTarget: 123, } @@ -1391,12 +1527,26 @@ func TestRestoringPreservesConfTarget(t *testing.T) { // Once batcher receives sweep request it will eventually spin up a // batch. require.Eventually(t, func() bool { - // Make sure that the sweep was stored and we have exactly one - // active batch, with one sweep and proper batchConfTarget. - return batcherStore.AssertSweepStored(sweepReq.SwapHash) && - len(batcher.batches) == 1 && - len(batcher.batches[0].sweeps) == 1 && - batcher.batches[0].cfg.batchConfTarget == 123 + // Make sure that the sweep was stored + if !batcherStore.AssertSweepStored(sweepReq.SwapHash) { + return false + } + + // Make sure there is exactly one active batch. + if len(batcher.batches) != 1 { + return false + } + + // Get the batch. + batch := getOnlyBatch(batcher) + + // Make sure the batch has one sweep. + if len(batch.sweeps) != 1 { + return false + } + + // Make sure the batch has proper batchConfTarget. + return batch.cfg.batchConfTarget == 123 }, test.Timeout, eventuallyCheckFrequency) // Make sure we have stored the batch. @@ -1427,13 +1577,25 @@ func TestRestoringPreservesConfTarget(t *testing.T) { // Wait for batch to load. require.Eventually(t, func() bool { - return batcherStore.AssertSweepStored(sweepReq.SwapHash) && - len(batcher.batches) == 1 && - len(batcher.batches[0].sweeps) == 1 + // Make sure that the sweep was stored + if !batcherStore.AssertSweepStored(sweepReq.SwapHash) { + return false + } + + // Make sure there is exactly one active batch. + if len(batcher.batches) != 1 { + return false + } + + // Get the batch. + batch := getOnlyBatch(batcher) + + // Make sure the batch has one sweep. + return len(batch.sweeps) == 1 }, test.Timeout, eventuallyCheckFrequency) // Make sure batchConfTarget was preserved. - require.Equal(t, 123, int(batcher.batches[0].cfg.batchConfTarget)) + require.Equal(t, 123, int(getOnlyBatch(batcher).cfg.batchConfTarget)) // Expect registration for spend notification. <-lnd.RegisterSpendChannel @@ -1456,8 +1618,10 @@ func (f *sweepFetcherMock) FetchSweep(ctx context.Context, hash lntypes.Hash) ( return f.store[hash], nil } -// TestSweepFetcher tests providing custom sweep fetcher to Batcher. -func TestSweepFetcher(t *testing.T) { +// testSweepFetcher tests providing custom sweep fetcher to Batcher. +func testSweepFetcher(t *testing.T, store testStore, + batcherStore testBatcherStore) { + defer test.Guard(t)() lnd := test.NewMockLnd() @@ -1479,7 +1643,33 @@ func TestSweepFetcher(t *testing.T) { }, } - batcherStore := NewStoreMock() + // Create a sweep request. + sweepReq := SweepRequest{ + SwapHash: lntypes.Hash{1, 1, 1}, + Value: 111, + Outpoint: wire.OutPoint{ + Hash: chainhash.Hash{1, 1}, + Index: 1, + }, + Notifier: &dummyNotifier, + } + + // Create a swap in the DB. It is needed to satisfy SQL constraints in + // case of SQL test. The data is not actually used, since we pass sweep + // fetcher, so put different conf target to make sure it is not used. + swap := &loopdb.LoopOutContract{ + SwapContract: loopdb.SwapContract{ + CltvExpiry: 222, + AmountRequested: 222, + }, + + DestAddr: destAddr, + SwapInvoice: swapInvoice, + SweepConfTarget: 321, + } + err = store.CreateLoopOut(ctx, sweepReq.SwapHash, swap) + require.NoError(t, err) + store.AssertLoopOutStored() batcher := NewBatcher(lnd.WalletKit, lnd.ChainNotifier, lnd.Signer, testMuSig2SignSweep, nil, lnd.ChainParams, batcherStore, @@ -1497,17 +1687,6 @@ func TestSweepFetcher(t *testing.T) { // Wait for the batcher to be initialized. <-batcher.initDone - // Create a sweep request. - sweepReq := SweepRequest{ - SwapHash: lntypes.Hash{1, 1, 1}, - Value: 111, - Outpoint: wire.OutPoint{ - Hash: chainhash.Hash{1, 1}, - Index: 1, - }, - Notifier: &dummyNotifier, - } - // Deliver sweep request to batcher. require.NoError(t, batcher.AddSweep(&sweepReq)) @@ -1518,12 +1697,26 @@ func TestSweepFetcher(t *testing.T) { // Once batcher receives sweep request it will eventually spin up a // batch. require.Eventually(t, func() bool { - // Make sure that the sweep was stored and we have exactly one - // active batch, with one sweep and proper batchConfTarget. - return batcherStore.AssertSweepStored(sweepReq.SwapHash) && - len(batcher.batches) == 1 && - len(batcher.batches[0].sweeps) == 1 && - batcher.batches[0].cfg.batchConfTarget == 123 + // Make sure that the sweep was stored + if !batcherStore.AssertSweepStored(sweepReq.SwapHash) { + return false + } + + // Make sure there is exactly one active batch. + if len(batcher.batches) != 1 { + return false + } + + // Get the batch. + batch := getOnlyBatch(batcher) + + // Make sure the batch has one sweep. + if len(batch.sweeps) != 1 { + return false + } + + // Make sure the batch has proper batchConfTarget. + return batch.cfg.batchConfTarget == 123 }, test.Timeout, eventuallyCheckFrequency) // Make sure we have stored the batch. @@ -1539,21 +1732,20 @@ func TestSweepFetcher(t *testing.T) { checkBatcherError(t, runErr) } -// TestSweepBatcherCloseDuringAdding tests that sweep batcher works correctly +// testSweepBatcherCloseDuringAdding tests that sweep batcher works correctly // if it is closed (stops running) during AddSweep call. -func TestSweepBatcherCloseDuringAdding(t *testing.T) { +func testSweepBatcherCloseDuringAdding(t *testing.T, store testStore, + batcherStore testBatcherStore) { + defer test.Guard(t)() lnd := test.NewMockLnd() ctx, cancel := context.WithCancel(context.Background()) defer cancel() - store := loopdb.NewStoreMock(t) sweepStore, err := NewSweepFetcherFromSwapStore(store, lnd.ChainParams) require.NoError(t, err) - batcherStore := NewStoreMock() - batcher := NewBatcher(lnd.WalletKit, lnd.ChainNotifier, lnd.Signer, testMuSig2SignSweep, nil, lnd.ChainParams, batcherStore, sweepStore) @@ -1571,8 +1763,12 @@ func TestSweepBatcherCloseDuringAdding(t *testing.T) { SwapContract: loopdb.SwapContract{ CltvExpiry: 111, AmountRequested: 111, + + // Make preimage unique to pass SQL constraints. + Preimage: lntypes.Preimage{i}, }, + DestAddr: destAddr, SwapInvoice: swapInvoice, } @@ -1634,3 +1830,173 @@ func TestSweepBatcherCloseDuringAdding(t *testing.T) { close(quit) <-registrationChan } + +// TestSweepBatcherBatchCreation tests that sweep requests enter the expected +// batch based on their timeout distance. +func TestSweepBatcherBatchCreation(t *testing.T) { + runTests(t, testSweepBatcherBatchCreation) +} + +// TestSweepBatcherSimpleLifecycle tests the simple lifecycle of the batches +// that are created and run by the batcher. +func TestSweepBatcherSimpleLifecycle(t *testing.T) { + runTests(t, testSweepBatcherSimpleLifecycle) +} + +// TestSweepBatcherSweepReentry tests that when an old version of the batch tx +// gets confirmed the sweep leftovers are sent back to the batcher. +func TestSweepBatcherSweepReentry(t *testing.T) { + runTests(t, testSweepBatcherSweepReentry) +} + +// TestSweepBatcherNonWalletAddr tests that sweep requests that sweep to a non +// wallet address enter individual batches. +func TestSweepBatcherNonWalletAddr(t *testing.T) { + runTests(t, testSweepBatcherNonWalletAddr) +} + +// TestSweepBatcherComposite tests that sweep requests that sweep to both wallet +// addresses and non-wallet addresses enter the correct batches. +func TestSweepBatcherComposite(t *testing.T) { + runTests(t, testSweepBatcherComposite) +} + +// TestGetFeePortionForSweep tests that the fee portion for a sweep is correctly +// calculated. +func TestGetFeePortionForSweep(t *testing.T) { + runTests(t, testGetFeePortionForSweep) +} + +// TestRestoringEmptyBatch tests that the batcher can be restored with an empty +// batch. +func TestRestoringEmptyBatch(t *testing.T) { + runTests(t, testRestoringEmptyBatch) +} + +// TestHandleSweepTwice tests that handing the same sweep twice must not +// add it to different batches. +func TestHandleSweepTwice(t *testing.T) { + runTests(t, testHandleSweepTwice) +} + +// TestRestoringPreservesConfTarget tests that after the batch is written to DB +// and loaded back, its batchConfTarget value is preserved. +func TestRestoringPreservesConfTarget(t *testing.T) { + runTests(t, testRestoringPreservesConfTarget) +} + +// TestSweepFetcher tests providing custom sweep fetcher to Batcher. +func TestSweepFetcher(t *testing.T) { + runTests(t, testSweepFetcher) +} + +// TestSweepBatcherCloseDuringAdding tests that sweep batcher works correctly +// if it is closed (stops running) during AddSweep call. +func TestSweepBatcherCloseDuringAdding(t *testing.T) { + runTests(t, testSweepBatcherCloseDuringAdding) +} + +// testBatcherStore is BatcherStore used in tests. +type testBatcherStore interface { + BatcherStore + + // AssertSweepStored asserts that a sweep is stored. + AssertSweepStored(id lntypes.Hash) bool +} + +type loopdbBatcherStore struct { + BatcherStore + + sweepsSet map[lntypes.Hash]struct{} +} + +// UpsertSweep inserts a sweep into the database, or updates an existing sweep +// if it already exists. This wrapper was added to update sweepsSet. +func (s *loopdbBatcherStore) UpsertSweep(ctx context.Context, + sweep *dbSweep) error { + + err := s.BatcherStore.UpsertSweep(ctx, sweep) + if err == nil { + s.sweepsSet[sweep.SwapHash] = struct{}{} + } + return err +} + +// AssertSweepStored asserts that a sweep is stored. +func (s *loopdbBatcherStore) AssertSweepStored(id lntypes.Hash) bool { + _, has := s.sweepsSet[id] + return has +} + +// testStore is loopdb used in tests. +type testStore interface { + loopdb.SwapStore + + // AssertLoopOutStored asserts that a swap is stored. + AssertLoopOutStored() +} + +// loopdbStore wraps loopdb.SwapStore and implements testStore interface. +type loopdbStore struct { + loopdb.SwapStore + + t *testing.T + + loopOutStoreChan chan struct{} +} + +// newLoopdbStore creates new loopdbStore instance. +func newLoopdbStore(t *testing.T, swapStore loopdb.SwapStore) *loopdbStore { + return &loopdbStore{ + SwapStore: swapStore, + t: t, + loopOutStoreChan: make(chan struct{}, 1), + } +} + +// CreateLoopOut adds an initiated swap to the store. +func (s *loopdbStore) CreateLoopOut(ctx context.Context, hash lntypes.Hash, + swap *loopdb.LoopOutContract) error { + + err := s.SwapStore.CreateLoopOut(ctx, hash, swap) + if err == nil { + s.loopOutStoreChan <- struct{}{} + } + + return err +} + +// AssertLoopOutStored asserts that a swap is stored. +func (s *loopdbStore) AssertLoopOutStored() { + s.t.Helper() + + select { + case <-s.loopOutStoreChan: + case <-time.After(test.Timeout): + s.t.Fatalf("expected swap to be stored") + } +} + +// runTests runs a test with both mock and loopdb. +func runTests(t *testing.T, testFn func(t *testing.T, store testStore, + batcherStore testBatcherStore)) { + + t.Run("mocks", func(t *testing.T) { + store := loopdb.NewStoreMock(t) + batcherStore := NewStoreMock() + testFn(t, store, batcherStore) + }) + + t.Run("loopdb", func(t *testing.T) { + sqlDB := loopdb.NewTestDB(t) + typedSqlDB := loopdb.NewTypedStore[Querier](sqlDB) + lnd := test.NewMockLnd() + batcherStore := NewSQLStore(typedSqlDB, lnd.ChainParams) + testStore := newLoopdbStore(t, sqlDB) + testBatcherStore := &loopdbBatcherStore{ + BatcherStore: batcherStore, + sweepsSet: make(map[lntypes.Hash]struct{}), + } + testFn(t, testStore, testBatcherStore) + }) +}