2
0
mirror of https://github.com/lightninglabs/loop synced 2024-11-08 01:10:29 +00:00

sweepbatcher: close the quit channel when the batcher is shutting down

This commit is contained in:
Andras Banki-Horvath 2024-05-24 10:52:26 +02:00
parent c01e8014e1
commit e5ade6a0b1
No known key found for this signature in database
GPG Key ID: 80E5375C094198D8
3 changed files with 32 additions and 20 deletions

View File

@ -197,7 +197,11 @@ type batch struct {
// main event loop. // main event loop.
callLeave chan struct{} callLeave chan struct{}
// quit signals that the batch must stop. // stopped signals that the batch has stopped.
stopped chan struct{}
// quit is owned by the parent batcher and signals that the batch must
// stop.
quit chan struct{} quit chan struct{}
// wallet is the wallet client used to create and publish the batch // wallet is the wallet client used to create and publish the batch
@ -261,6 +265,7 @@ type batchKit struct {
purger Purger purger Purger
store BatcherStore store BatcherStore
log btclog.Logger log btclog.Logger
quit chan struct{}
} }
// scheduleNextCall schedules the next call to the batch handler's main event // scheduleNextCall schedules the next call to the batch handler's main event
@ -270,6 +275,9 @@ func (b *batch) scheduleNextCall() (func(), error) {
case b.callEnter <- struct{}{}: case b.callEnter <- struct{}{}:
case <-b.quit: case <-b.quit:
return func() {}, ErrBatcherShuttingDown
case <-b.stopped:
return func() {}, ErrBatchShuttingDown return func() {}, ErrBatchShuttingDown
} }
@ -293,7 +301,8 @@ func NewBatch(cfg batchConfig, bk batchKit) *batch {
errChan: make(chan error, 1), errChan: make(chan error, 1),
callEnter: make(chan struct{}), callEnter: make(chan struct{}),
callLeave: make(chan struct{}), callLeave: make(chan struct{}),
quit: make(chan struct{}), stopped: make(chan struct{}),
quit: bk.quit,
batchTxid: bk.batchTxid, batchTxid: bk.batchTxid,
wallet: bk.wallet, wallet: bk.wallet,
chainNotifier: bk.chainNotifier, chainNotifier: bk.chainNotifier,
@ -320,7 +329,8 @@ func NewBatchFromDB(cfg batchConfig, bk batchKit) *batch {
errChan: make(chan error, 1), errChan: make(chan error, 1),
callEnter: make(chan struct{}), callEnter: make(chan struct{}),
callLeave: make(chan struct{}), callLeave: make(chan struct{}),
quit: make(chan struct{}), stopped: make(chan struct{}),
quit: bk.quit,
batchTxid: bk.batchTxid, batchTxid: bk.batchTxid,
batchPkScript: bk.batchPkScript, batchPkScript: bk.batchPkScript,
rbfCache: bk.rbfCache, rbfCache: bk.rbfCache,
@ -447,7 +457,7 @@ func (b *batch) Run(ctx context.Context) error {
runCtx, cancel := context.WithCancel(ctx) runCtx, cancel := context.WithCancel(ctx)
defer func() { defer func() {
cancel() cancel()
close(b.quit) close(b.stopped)
b.wg.Wait() b.wg.Wait()
}() }()

View File

@ -216,6 +216,7 @@ func (b *Batcher) Run(ctx context.Context) error {
runCtx, cancel := context.WithCancel(ctx) runCtx, cancel := context.WithCancel(ctx)
defer func() { defer func() {
cancel() cancel()
close(b.quit)
for _, batch := range b.batches { for _, batch := range b.batches {
batch.Wait() batch.Wait()
@ -379,6 +380,7 @@ func (b *Batcher) spinUpBatch(ctx context.Context) (*batch, error) {
verifySchnorrSig: b.VerifySchnorrSig, verifySchnorrSig: b.VerifySchnorrSig,
purger: b.AddSweep, purger: b.AddSweep,
store: b.store, store: b.store,
quit: b.quit,
} }
batch := NewBatch(cfg, batchKit) batch := NewBatch(cfg, batchKit)
@ -461,6 +463,7 @@ func (b *Batcher) spinUpBatchFromDB(ctx context.Context, batch *batch) error {
purger: b.AddSweep, purger: b.AddSweep,
store: b.store, store: b.store,
log: batchPrefixLogger(fmt.Sprintf("%d", batch.id)), log: batchPrefixLogger(fmt.Sprintf("%d", batch.id)),
quit: b.quit,
} }
cfg := batchConfig{ cfg := batchConfig{

View File

@ -2,7 +2,7 @@ package sweepbatcher
import ( import (
"context" "context"
"strings" "errors"
"testing" "testing"
"time" "time"
@ -43,6 +43,15 @@ var dummyNotifier = SpendNotifier{
QuitChan: make(chan bool, ntfnBufferSize), QuitChan: make(chan bool, ntfnBufferSize),
} }
func checkBatcherError(t *testing.T, err error) {
if !errors.Is(err, context.Canceled) &&
!errors.Is(err, ErrBatcherShuttingDown) &&
!errors.Is(err, ErrBatchShuttingDown) {
require.NoError(t, err)
}
}
// TestSweepBatcherBatchCreation tests that sweep requests enter the expected // TestSweepBatcherBatchCreation tests that sweep requests enter the expected
// batch based on their timeout distance. // batch based on their timeout distance.
func TestSweepBatcherBatchCreation(t *testing.T) { func TestSweepBatcherBatchCreation(t *testing.T) {
@ -60,9 +69,7 @@ func TestSweepBatcherBatchCreation(t *testing.T) {
testMuSig2SignSweep, nil, lnd.ChainParams, batcherStore, store) testMuSig2SignSweep, nil, lnd.ChainParams, batcherStore, store)
go func() { go func() {
err := batcher.Run(ctx) err := batcher.Run(ctx)
if !strings.Contains(err.Error(), "context canceled") { checkBatcherError(t, err)
require.NoError(t, err)
}
}() }()
// Create a sweep request. // Create a sweep request.
@ -215,9 +222,7 @@ func TestSweepBatcherSimpleLifecycle(t *testing.T) {
testMuSig2SignSweep, nil, lnd.ChainParams, batcherStore, store) testMuSig2SignSweep, nil, lnd.ChainParams, batcherStore, store)
go func() { go func() {
err := batcher.Run(ctx) err := batcher.Run(ctx)
if !strings.Contains(err.Error(), "context canceled") { checkBatcherError(t, err)
require.NoError(t, err)
}
}() }()
// Create a sweep request. // Create a sweep request.
@ -354,9 +359,7 @@ func TestSweepBatcherSweepReentry(t *testing.T) {
testMuSig2SignSweep, nil, lnd.ChainParams, batcherStore, store) testMuSig2SignSweep, nil, lnd.ChainParams, batcherStore, store)
go func() { go func() {
err := batcher.Run(ctx) err := batcher.Run(ctx)
if !strings.Contains(err.Error(), "context canceled") { checkBatcherError(t, err)
require.NoError(t, err)
}
}() }()
// Create some sweep requests with timeouts not too far away, in order // Create some sweep requests with timeouts not too far away, in order
@ -561,9 +564,7 @@ func TestSweepBatcherNonWalletAddr(t *testing.T) {
testMuSig2SignSweep, nil, lnd.ChainParams, batcherStore, store) testMuSig2SignSweep, nil, lnd.ChainParams, batcherStore, store)
go func() { go func() {
err := batcher.Run(ctx) err := batcher.Run(ctx)
if !strings.Contains(err.Error(), "context canceled") { checkBatcherError(t, err)
require.NoError(t, err)
}
}() }()
// Create a sweep request. // Create a sweep request.
@ -727,9 +728,7 @@ func TestSweepBatcherComposite(t *testing.T) {
testMuSig2SignSweep, nil, lnd.ChainParams, batcherStore, store) testMuSig2SignSweep, nil, lnd.ChainParams, batcherStore, store)
go func() { go func() {
err := batcher.Run(ctx) err := batcher.Run(ctx)
if !strings.Contains(err.Error(), "context canceled") { checkBatcherError(t, err)
require.NoError(t, err)
}
}() }()
// Create a sweep request. // Create a sweep request.