2
0
mirror of https://github.com/lightninglabs/loop synced 2024-11-04 06:00:21 +00:00
loop/sweepbatcher/sweep_batcher.go
2024-01-23 20:38:09 +02:00

658 lines
17 KiB
Go

package sweepbatcher
import (
"context"
"fmt"
"sync"
"time"
"github.com/btcsuite/btcd/btcec/v2"
"github.com/btcsuite/btcd/btcutil"
"github.com/btcsuite/btcd/chaincfg"
"github.com/btcsuite/btcd/wire"
"github.com/lightninglabs/lndclient"
"github.com/lightninglabs/loop/loopdb"
"github.com/lightninglabs/loop/utils"
"github.com/lightningnetwork/lnd/lntypes"
"github.com/lightningnetwork/lnd/lnwallet/chainfee"
)
const (
// defaultMaxTimeoutDistance is the default maximum timeout distance
// of sweeps that can appear in the same batch.
defaultMaxTimeoutDistance = 288
// batchOpen is the string representation of the state of a batch that
// is open.
batchOpen = "open"
// batchClosed is the string representation of the state of a batch
// that is closed.
batchClosed = "closed"
// batchConfirmed is the string representation of the state of a batch
// that is confirmed.
batchConfirmed = "confirmed"
// defaultMainnetPublishDelay is the default publish delay that is used
// for mainnet.
defaultMainnetPublishDelay = 5 * time.Second
// defaultTestnetPublishDelay is the default publish delay that is used
// for testnet.
defaultPublishDelay = 500 * time.Millisecond
)
type BatcherStore interface {
// FetchUnconfirmedSweepBatches fetches all the batches from the
// database that are not in a confirmed state.
FetchUnconfirmedSweepBatches(ctx context.Context) ([]*dbBatch,
error)
// InsertSweepBatch inserts a batch into the database, returning the id
// of the inserted batch.
InsertSweepBatch(ctx context.Context,
batch *dbBatch) (int32, error)
// UpdateSweepBatch updates a batch in the database.
UpdateSweepBatch(ctx context.Context,
batch *dbBatch) error
// ConfirmBatch confirms a batch by setting its state to confirmed.
ConfirmBatch(ctx context.Context, id int32) error
// FetchBatchSweeps fetches all the sweeps that belong to a batch.
FetchBatchSweeps(ctx context.Context,
id int32) ([]*dbSweep, error)
// UpsertSweep inserts a sweep into the database, or updates an existing
// sweep if it already exists.
UpsertSweep(ctx context.Context, sweep *dbSweep) error
// GetSweepStatus returns the completed status of the sweep.
GetSweepStatus(ctx context.Context, swapHash lntypes.Hash) (
bool, error)
}
// MuSig2SignSweep is a function that can be used to sign a sweep transaction
// cooperatively with the swap server.
type MuSig2SignSweep func(ctx context.Context,
protocolVersion loopdb.ProtocolVersion, swapHash lntypes.Hash,
paymentAddr [32]byte, nonce []byte, sweepTxPsbt []byte,
prevoutMap map[wire.OutPoint]*wire.TxOut) (
[]byte, []byte, error)
// VerifySchnorrSig is a function that can be used to verify a schnorr
// signature.
type VerifySchnorrSig func(pubKey *btcec.PublicKey, hash, sig []byte) error
// SweepRequest is a request to sweep a specific outpoint.
type SweepRequest struct {
// SwapHash is the hash of the swap that is being swept.
SwapHash lntypes.Hash
// Outpoint is the outpoint that is being swept.
Outpoint wire.OutPoint
// Value is the value of the outpoint that is being swept.
Value btcutil.Amount
// Notifier is a notifier that is used to notify the requester of this
// sweep that the sweep was successful.
Notifier *SpendNotifier
}
// SpendNotifier is a notifier that is used to notify the requester of a sweep
// that the sweep was successful.
type SpendNotifier struct {
// SpendChan is a channel where the spend details are received.
SpendChan chan *wire.MsgTx
// SpendErrChan is a channel where spend errors are received.
SpendErrChan chan error
// QuitChan is a channel that can be closed to stop the notifier.
QuitChan chan bool
}
var (
ErrBatcherShuttingDown = fmt.Errorf("batcher shutting down")
)
// Batcher is a system that is responsible for accepting sweep requests and
// placing them in appropriate batches. It will spin up new batches as needed.
type Batcher struct {
// batches is a map of batch IDs to the currently active batches.
batches map[int32]*batch
// sweepReqs is a channel where sweep requests are received.
sweepReqs chan SweepRequest
// errChan is a channel where errors are received.
errChan chan error
// quit signals that the batch must stop.
quit chan struct{}
// wallet is the wallet kit client that is used by batches.
wallet lndclient.WalletKitClient
// chainNotifier is the chain notifier client that is used by batches.
chainNotifier lndclient.ChainNotifierClient
// signerClient is the signer client that is used by batches.
signerClient lndclient.SignerClient
// musig2ServerKit includes all the required functionality to collect
// and verify signatures by the swap server in order to cooperatively
// sweep funds.
musig2ServerSign MuSig2SignSweep
// verifySchnorrSig is a function that can be used to verify a schnorr
// signature.
VerifySchnorrSig VerifySchnorrSig
// chainParams are the chain parameters of the chain that is used by
// batches.
chainParams *chaincfg.Params
// store includes all the database interactions that are needed by the
// batcher and the batches.
store BatcherStore
// swapStore includes all the database interactions that are needed for
// interacting with swaps.
swapStore loopdb.SwapStore
// wg is a waitgroup that is used to wait for all the goroutines to
// exit.
wg sync.WaitGroup
}
// NewBatcher creates a new Batcher instance.
func NewBatcher(wallet lndclient.WalletKitClient,
chainNotifier lndclient.ChainNotifierClient,
signerClient lndclient.SignerClient, musig2ServerSigner MuSig2SignSweep,
verifySchnorrSig VerifySchnorrSig, chainparams *chaincfg.Params,
store BatcherStore, swapStore loopdb.SwapStore) *Batcher {
return &Batcher{
batches: make(map[int32]*batch),
sweepReqs: make(chan SweepRequest),
errChan: make(chan error, 1),
quit: make(chan struct{}),
wallet: wallet,
chainNotifier: chainNotifier,
signerClient: signerClient,
musig2ServerSign: musig2ServerSigner,
VerifySchnorrSig: verifySchnorrSig,
chainParams: chainparams,
store: store,
swapStore: swapStore,
}
}
// Run starts the batcher and processes incoming sweep requests.
func (b *Batcher) Run(ctx context.Context) error {
runCtx, cancel := context.WithCancel(ctx)
defer func() {
cancel()
for _, batch := range b.batches {
batch.Wait()
}
b.wg.Wait()
}()
// First we fetch all the batches that are not in a confirmed state from
// the database. We will then resume the execution of these batches.
batches, err := b.FetchUnconfirmedBatches(runCtx)
if err != nil {
return err
}
for _, batch := range batches {
err := b.spinUpBatchFromDB(runCtx, batch)
if err != nil {
return err
}
}
for {
select {
case sweepReq := <-b.sweepReqs:
sweep, err := b.fetchSweep(runCtx, sweepReq)
if err != nil {
return err
}
err = b.handleSweep(runCtx, sweep, sweepReq.Notifier)
if err != nil {
return err
}
case err := <-b.errChan:
return err
case <-runCtx.Done():
return runCtx.Err()
}
}
}
// AddSweep adds a sweep request to the batcher for handling. This will either
// place the sweep in an existing batch or create a new one.
func (b *Batcher) AddSweep(sweepReq *SweepRequest) error {
select {
case b.sweepReqs <- *sweepReq:
return nil
case <-b.quit:
return ErrBatcherShuttingDown
}
}
// handleSweep handles a sweep request by either placing it in an existing
// batch, or by spinning up a new batch for it.
func (b *Batcher) handleSweep(ctx context.Context, sweep *sweep,
notifier *SpendNotifier) error {
completed, err := b.store.GetSweepStatus(ctx, sweep.swapHash)
if err != nil {
return err
}
log.Infof("Batcher handling sweep %x, completed=%v", sweep.swapHash[:6],
completed)
// If the sweep has already been completed in a confirmed batch then we
// can't attach its notifier to the batch as that is no longer running.
// Instead we directly detect and return the spend here.
if completed && *notifier != (SpendNotifier{}) {
go b.monitorSpendAndNotify(ctx, sweep, notifier)
return nil
}
sweep.notifier = notifier
// Check if the sweep is already in a batch. If that is the case, we
// provide the sweep to that batch and return.
for _, batch := range b.batches {
// This is a check to see if a batch is completed. In that case
// we just lazily delete it and continue our scan.
if batch.isComplete() {
delete(b.batches, batch.id)
continue
}
if batch.sweepExists(sweep.swapHash) {
accepted, err := batch.addSweep(ctx, sweep)
if err != nil {
return err
}
if !accepted {
return fmt.Errorf("existing sweep %x was not "+
"accepted by batch %d", sweep.swapHash[:6],
batch.id)
}
}
}
// If one of the batches accepts the sweep, we provide it to that batch.
for _, batch := range b.batches {
accepted, err := batch.addSweep(ctx, sweep)
if err != nil && err != ErrBatchShuttingDown {
return err
}
// If the sweep was accepted by this batch, we return, our job
// is done.
if accepted {
return nil
}
}
// If no batch is capable of accepting the sweep, we spin up a fresh
// batch and hand the sweep over to it.
batch, err := b.spinUpBatch(ctx)
if err != nil {
return err
}
// Add the sweep to the fresh batch.
accepted, err := batch.addSweep(ctx, sweep)
if err != nil {
return err
}
// If the sweep wasn't accepted by the fresh batch something is wrong,
// we should return the error.
if !accepted {
return fmt.Errorf("sweep %x was not accepted by new batch %d",
sweep.swapHash[:6], batch.id)
}
return nil
}
// spinUpBatch spins up a new batch and returns it.
func (b *Batcher) spinUpBatch(ctx context.Context) (*batch, error) {
cfg := batchConfig{
maxTimeoutDistance: defaultMaxTimeoutDistance,
batchConfTarget: defaultBatchConfTarget,
}
switch b.chainParams {
case &chaincfg.MainNetParams:
cfg.batchPublishDelay = defaultMainnetPublishDelay
default:
cfg.batchPublishDelay = defaultPublishDelay
}
batchKit := batchKit{
returnChan: b.sweepReqs,
wallet: b.wallet,
chainNotifier: b.chainNotifier,
signerClient: b.signerClient,
musig2SignSweep: b.musig2ServerSign,
verifySchnorrSig: b.VerifySchnorrSig,
purger: b.AddSweep,
store: b.store,
}
batch := NewBatch(cfg, batchKit)
id, err := batch.insertAndAcquireID(ctx)
if err != nil {
return nil, err
}
// We add the batch to our map of batches and start it.
b.batches[id] = batch
b.wg.Add(1)
go func() {
defer b.wg.Done()
err := batch.Run(ctx)
if err != nil {
_ = b.writeToErrChan(ctx, err)
}
}()
return batch, nil
}
// spinUpBatchDB spins up a batch that already existed in storage, then
// returns it.
func (b *Batcher) spinUpBatchFromDB(ctx context.Context, batch *batch) error {
cfg := batchConfig{
maxTimeoutDistance: batch.cfg.maxTimeoutDistance,
batchConfTarget: defaultBatchConfTarget,
}
rbfCache := rbfCache{
LastHeight: batch.rbfCache.LastHeight,
FeeRate: batch.rbfCache.FeeRate,
}
dbSweeps, err := b.store.FetchBatchSweeps(ctx, batch.id)
if err != nil {
return err
}
if len(dbSweeps) == 0 {
return fmt.Errorf("batch %d has no sweeps", batch.id)
}
primarySweep := dbSweeps[0]
sweeps := make(map[lntypes.Hash]sweep)
for _, dbSweep := range dbSweeps {
sweep, err := b.convertSweep(dbSweep)
if err != nil {
return err
}
sweeps[sweep.swapHash] = *sweep
}
batchKit := batchKit{
id: batch.id,
batchTxid: batch.batchTxid,
batchPkScript: batch.batchPkScript,
state: batch.state,
primaryID: primarySweep.SwapHash,
sweeps: sweeps,
rbfCache: rbfCache,
returnChan: b.sweepReqs,
wallet: b.wallet,
chainNotifier: b.chainNotifier,
signerClient: b.signerClient,
musig2SignSweep: b.musig2ServerSign,
verifySchnorrSig: b.VerifySchnorrSig,
purger: b.AddSweep,
store: b.store,
log: batchPrefixLogger(fmt.Sprintf("%d", batch.id)),
}
newBatch := NewBatchFromDB(cfg, batchKit)
// We add the batch to our map of batches and start it.
b.batches[batch.id] = newBatch
b.wg.Add(1)
go func() {
defer b.wg.Done()
err := newBatch.Run(ctx)
if err != nil {
_ = b.writeToErrChan(ctx, err)
}
}()
return nil
}
// FetchUnconfirmedBatches fetches all the batches from the database that are
// not in a confirmed state.
func (b *Batcher) FetchUnconfirmedBatches(ctx context.Context) ([]*batch,
error) {
dbBatches, err := b.store.FetchUnconfirmedSweepBatches(ctx)
if err != nil {
return nil, err
}
batches := make([]*batch, 0, len(dbBatches))
for _, bch := range dbBatches {
bch := bch
batch := batch{}
batch.id = bch.ID
switch bch.State {
case batchOpen:
batch.state = Open
case batchClosed:
batch.state = Closed
case batchConfirmed:
batch.state = Confirmed
}
batch.batchTxid = &bch.BatchTxid
batch.batchPkScript = bch.BatchPkScript
rbfCache := rbfCache{
LastHeight: bch.LastRbfHeight,
FeeRate: chainfee.SatPerKWeight(bch.LastRbfSatPerKw),
}
batch.rbfCache = rbfCache
bchCfg := batchConfig{
maxTimeoutDistance: bch.MaxTimeoutDistance,
}
batch.cfg = &bchCfg
batches = append(batches, &batch)
}
return batches, nil
}
// monitorSpendAndNotify monitors the spend of a specific outpoint and writes
// the response back to the response channel.
func (b *Batcher) monitorSpendAndNotify(ctx context.Context, sweep *sweep,
notifier *SpendNotifier) {
b.wg.Add(1)
defer b.wg.Done()
spendCtx, cancel := context.WithCancel(ctx)
defer cancel()
spendChan, spendErr, err := b.chainNotifier.RegisterSpendNtfn(
spendCtx, &sweep.outpoint, sweep.htlc.PkScript,
sweep.initiationHeight,
)
if err != nil {
select {
case notifier.SpendErrChan <- err:
case <-ctx.Done():
}
_ = b.writeToErrChan(ctx, err)
return
}
log.Infof("Batcher monitoring spend for swap %x", sweep.swapHash[:6])
for {
select {
case spend := <-spendChan:
select {
case notifier.SpendChan <- spend.SpendingTx:
case <-ctx.Done():
}
return
case err := <-spendErr:
select {
case notifier.SpendErrChan <- err:
case <-ctx.Done():
}
_ = b.writeToErrChan(ctx, err)
return
case <-notifier.QuitChan:
return
case <-ctx.Done():
return
}
}
}
func (b *Batcher) writeToErrChan(ctx context.Context, err error) error {
select {
case b.errChan <- err:
return nil
case <-ctx.Done():
return ctx.Err()
}
}
// convertSweep converts a fetched sweep from the database to a sweep that is
// ready to be processed by the batcher.
func (b *Batcher) convertSweep(dbSweep *dbSweep) (*sweep, error) {
swap := dbSweep.LoopOut
htlc, err := utils.GetHtlc(
dbSweep.SwapHash, &swap.Contract.SwapContract, b.chainParams,
)
if err != nil {
return nil, err
}
swapPaymentAddr, err := utils.ObtainSwapPaymentAddr(
swap.Contract.SwapInvoice, b.chainParams,
)
if err != nil {
return nil, err
}
return &sweep{
swapHash: swap.Hash,
outpoint: dbSweep.Outpoint,
value: dbSweep.Amount,
confTarget: swap.Contract.SweepConfTarget,
timeout: swap.Contract.CltvExpiry,
initiationHeight: swap.Contract.InitiationHeight,
htlc: *htlc,
preimage: swap.Contract.Preimage,
swapInvoicePaymentAddr: *swapPaymentAddr,
htlcKeys: swap.Contract.HtlcKeys,
htlcSuccessEstimator: htlc.AddSuccessToEstimator,
protocolVersion: swap.Contract.ProtocolVersion,
isExternalAddr: swap.Contract.IsExternalAddr,
destAddr: swap.Contract.DestAddr,
}, nil
}
// fetchSweep fetches the sweep related information from the database.
func (b *Batcher) fetchSweep(ctx context.Context,
sweepReq SweepRequest) (*sweep, error) {
swapHash, err := lntypes.MakeHash(sweepReq.SwapHash[:])
if err != nil {
return nil, fmt.Errorf("failed to parse swapHash: %v", err)
}
swap, err := b.swapStore.FetchLoopOutSwap(ctx, swapHash)
if err != nil {
return nil, fmt.Errorf("failed to fetch loop out for %x: %v",
swapHash[:6], err)
}
htlc, err := utils.GetHtlc(
swapHash, &swap.Contract.SwapContract, b.chainParams,
)
if err != nil {
return nil, fmt.Errorf("failed to get htlc: %v", err)
}
swapPaymentAddr, err := utils.ObtainSwapPaymentAddr(
swap.Contract.SwapInvoice, b.chainParams,
)
if err != nil {
return nil, fmt.Errorf("failed to get payment addr: %v", err)
}
return &sweep{
swapHash: swap.Hash,
outpoint: sweepReq.Outpoint,
value: sweepReq.Value,
confTarget: swap.Contract.SweepConfTarget,
timeout: swap.Contract.CltvExpiry,
initiationHeight: swap.Contract.InitiationHeight,
htlc: *htlc,
preimage: swap.Contract.Preimage,
swapInvoicePaymentAddr: *swapPaymentAddr,
htlcKeys: swap.Contract.HtlcKeys,
htlcSuccessEstimator: htlc.AddSuccessToEstimator,
protocolVersion: swap.Contract.ProtocolVersion,
isExternalAddr: swap.Contract.IsExternalAddr,
destAddr: swap.Contract.DestAddr,
}, nil
}