2
0
mirror of https://github.com/lightninglabs/loop synced 2024-11-04 06:00:21 +00:00
loop/testcontext_test.go
2024-01-23 20:38:11 +02:00

290 lines
6.9 KiB
Go

package loop
import (
"context"
"fmt"
"testing"
"time"
"github.com/btcsuite/btcd/btcec/v2"
"github.com/btcsuite/btcd/btcutil"
"github.com/btcsuite/btcd/wire"
"github.com/lightninglabs/lndclient"
"github.com/lightninglabs/loop/loopdb"
"github.com/lightninglabs/loop/swap"
"github.com/lightninglabs/loop/sweep"
"github.com/lightninglabs/loop/sweepbatcher"
"github.com/lightninglabs/loop/test"
"github.com/lightningnetwork/lnd/chainntnfs"
"github.com/lightningnetwork/lnd/lnrpc"
"github.com/lightningnetwork/lnd/lntypes"
"github.com/stretchr/testify/require"
)
var (
testPreimage = lntypes.Preimage([32]byte{
1, 1, 1, 1, 2, 2, 2, 2,
3, 3, 3, 3, 4, 4, 4, 4,
1, 1, 1, 1, 2, 2, 2, 2,
3, 3, 3, 3, 4, 4, 4, 4,
})
)
// testContext contains functionality to support client unit tests.
type testContext struct {
test.Context
serverMock *serverMock
swapClient *Client
statusChan chan SwapInfo
store *loopdb.StoreMock
expiryChan chan time.Time
runErr chan error
stop func()
}
// mockVerifySchnorrSigFail is used to simulate failed taproot keyspend
// signature verification. If passed to the executeConfig we'll test an
// uncooperative server and will fall back to scriptspend sweep.
func mockVerifySchnorrSigFail(pubKey *btcec.PublicKey, hash,
sig []byte) error {
return fmt.Errorf("invalid sig")
}
// mockVerifySchnorrSigSuccess is used to simulate successful taproot keyspend
// signature verification. If passed to the executeConfig we'll test an
// uncooperative server and will fall back to scriptspend sweep.
func mockVerifySchnorrSigSuccess(pubKey *btcec.PublicKey, hash,
sig []byte) error {
return fmt.Errorf("invalid sig")
}
func mockMuSig2SignSweep(ctx context.Context,
protocolVersion loopdb.ProtocolVersion, swapHash lntypes.Hash,
paymentAddr [32]byte, nonce []byte, sweepTxPsbt []byte,
prevoutMap map[wire.OutPoint]*wire.TxOut) (
[]byte, []byte, error) {
return nil, nil, nil
}
func newSwapClient(config *clientConfig) *Client {
sweeper := &sweep.Sweeper{
Lnd: config.LndServices,
}
lndServices := config.LndServices
batcherStore := sweepbatcher.NewStoreMock()
batcher := sweepbatcher.NewBatcher(
config.LndServices.WalletKit, config.LndServices.ChainNotifier,
config.LndServices.Signer, mockMuSig2SignSweep,
mockVerifySchnorrSigSuccess, config.LndServices.ChainParams,
batcherStore, config.Store,
)
executor := newExecutor(&executorConfig{
lnd: lndServices,
store: config.Store,
sweeper: sweeper,
batcher: batcher,
createExpiryTimer: config.CreateExpiryTimer,
cancelSwap: config.Server.CancelLoopOutSwap,
verifySchnorrSig: mockVerifySchnorrSigFail,
})
return &Client{
errChan: make(chan error),
clientConfig: *config,
lndServices: lndServices,
sweeper: sweeper,
executor: executor,
resumeReady: make(chan struct{}),
}
}
func createClientTestContext(t *testing.T,
pendingSwaps []*loopdb.LoopOut) *testContext {
clientLnd := test.NewMockLnd()
serverMock := newServerMock(clientLnd)
store := loopdb.NewStoreMock(t)
for _, s := range pendingSwaps {
store.LoopOutSwaps[s.Hash] = s.Contract
updates := []loopdb.SwapStateData{}
for _, e := range s.Events {
updates = append(updates, e.SwapStateData)
}
store.LoopOutUpdates[s.Hash] = updates
}
expiryChan := make(chan time.Time)
timerFactory := func(expiry time.Duration) <-chan time.Time {
return expiryChan
}
swapClient := newSwapClient(&clientConfig{
LndServices: &clientLnd.LndServices,
Server: serverMock,
Store: store,
CreateExpiryTimer: timerFactory,
})
statusChan := make(chan SwapInfo)
ctx := &testContext{
Context: test.NewContext(
t,
clientLnd,
),
swapClient: swapClient,
statusChan: statusChan,
expiryChan: expiryChan,
store: store,
serverMock: serverMock,
}
ctx.runErr = make(chan error)
runCtx, stop := context.WithCancel(context.Background())
ctx.stop = stop
go func() {
err := swapClient.Run(runCtx, statusChan)
log.Errorf("client run: %v", err)
ctx.runErr <- err
}()
return ctx
}
func (ctx *testContext) finish() {
ctx.stop()
select {
case err := <-ctx.runErr:
require.NoError(ctx.Context.T, err)
case <-time.After(test.Timeout):
ctx.Context.T.Fatal("client not stopping")
}
ctx.assertIsDone()
}
func (ctx *testContext) assertIsDone() {
require.NoError(ctx.Context.T, ctx.Context.Lnd.IsDone())
require.NoError(ctx.Context.T, ctx.store.IsDone())
select {
case <-ctx.statusChan:
ctx.Context.T.Fatalf("not all status updates read")
default:
}
}
func (ctx *testContext) assertStored() {
ctx.Context.T.Helper()
ctx.store.AssertLoopOutStored()
}
func (ctx *testContext) assertStorePreimageReveal() {
ctx.Context.T.Helper()
ctx.store.AssertStorePreimageReveal()
}
func (ctx *testContext) assertStoreFinished(expectedResult loopdb.SwapState) {
ctx.Context.T.Helper()
ctx.store.AssertStoreFinished(expectedResult)
}
func (ctx *testContext) assertStatus(expectedState loopdb.SwapState) {
ctx.Context.T.Helper()
for {
select {
case update := <-ctx.statusChan:
if update.SwapType != swap.TypeOut {
continue
}
if update.State == expectedState {
return
}
case <-time.After(test.Timeout):
ctx.Context.T.Fatalf("expected status %v not "+
"received in time", expectedState)
}
}
}
func (ctx *testContext) publishHtlc(script []byte,
amt btcutil.Amount) wire.OutPoint {
// Create the htlc tx.
htlcTx := wire.MsgTx{}
htlcTx.AddTxIn(&wire.TxIn{
PreviousOutPoint: wire.OutPoint{},
})
htlcTx.AddTxOut(&wire.TxOut{
PkScript: script,
Value: int64(amt),
})
htlcTxHash := htlcTx.TxHash()
// Signal client that script has been published.
select {
case ctx.Lnd.ConfChannel <- &chainntnfs.TxConfirmation{
Tx: &htlcTx,
}:
case <-time.After(test.Timeout):
ctx.Context.T.Fatalf("htlc confirmed not consumed")
}
return wire.OutPoint{
Hash: htlcTxHash,
Index: 0,
}
}
// trackPayment asserts that a call to track payment was sent and sends the
// status provided into the updates channel.
func (ctx *testContext) trackPayment(status lnrpc.Payment_PaymentStatus) {
trackPayment := ctx.Context.AssertTrackPayment()
select {
case trackPayment.Updates <- lndclient.PaymentStatus{
State: status,
}:
case <-time.After(test.Timeout):
ctx.Context.T.Fatalf("could not send payment update")
}
}
// assertPreimagePush asserts that we made an attempt to push our preimage to
// the server.
func (ctx *testContext) assertPreimagePush(preimage lntypes.Preimage) {
select {
case pushedPreimage := <-ctx.serverMock.preimagePush:
require.Equal(ctx.Context.T, preimage, pushedPreimage)
case <-time.After(test.Timeout):
ctx.Context.T.Fatalf("preimage not pushed")
}
}
func (ctx *testContext) AssertEpochListeners(numListeners int32) {
ctx.Context.T.Helper()
require.Eventually(ctx.Context.T, func() bool {
return ctx.Lnd.EpochSubscribers() == numListeners
}, test.Timeout, time.Millisecond*250)
}