2
0
mirror of https://github.com/lightninglabs/loop synced 2024-11-08 01:10:29 +00:00
loop/testcontext_test.go
2023-03-29 15:56:40 +02:00

256 lines
5.7 KiB
Go

package loop
import (
"context"
"fmt"
"testing"
"time"
"github.com/btcsuite/btcd/btcec/v2"
"github.com/btcsuite/btcd/btcutil"
"github.com/btcsuite/btcd/wire"
"github.com/lightninglabs/lndclient"
"github.com/lightninglabs/loop/loopdb"
"github.com/lightninglabs/loop/swap"
"github.com/lightninglabs/loop/sweep"
"github.com/lightninglabs/loop/test"
"github.com/lightningnetwork/lnd/chainntnfs"
"github.com/lightningnetwork/lnd/lnrpc"
"github.com/lightningnetwork/lnd/lntypes"
"github.com/stretchr/testify/require"
)
var (
testPreimage = lntypes.Preimage([32]byte{
1, 1, 1, 1, 2, 2, 2, 2,
3, 3, 3, 3, 4, 4, 4, 4,
1, 1, 1, 1, 2, 2, 2, 2,
3, 3, 3, 3, 4, 4, 4, 4,
})
)
// testContext contains functionality to support client unit tests.
type testContext struct {
test.Context
serverMock *serverMock
swapClient *Client
statusChan chan SwapInfo
store *storeMock
expiryChan chan time.Time
runErr chan error
stop func()
}
// mockVerifySchnorrSigFail is used to simulate failed taproot keyspend
// signature verification. If passed to the executeConfig we'll test an
// uncooperative server and will fall back to scriptspend sweep.
func mockVerifySchnorrSigFail(pubKey *btcec.PublicKey, hash,
sig []byte) error {
return fmt.Errorf("invalid sig")
}
func newSwapClient(config *clientConfig) *Client {
sweeper := &sweep.Sweeper{
Lnd: config.LndServices,
}
lndServices := config.LndServices
executor := newExecutor(&executorConfig{
lnd: lndServices,
store: config.Store,
sweeper: sweeper,
createExpiryTimer: config.CreateExpiryTimer,
cancelSwap: config.Server.CancelLoopOutSwap,
verifySchnorrSig: mockVerifySchnorrSigFail,
})
return &Client{
errChan: make(chan error),
clientConfig: *config,
lndServices: lndServices,
sweeper: sweeper,
executor: executor,
resumeReady: make(chan struct{}),
}
}
func createClientTestContext(t *testing.T,
pendingSwaps []*loopdb.LoopOut) *testContext {
clientLnd := test.NewMockLnd()
serverMock := newServerMock(clientLnd)
store := newStoreMock(t)
for _, s := range pendingSwaps {
store.loopOutSwaps[s.Hash] = s.Contract
updates := []loopdb.SwapStateData{}
for _, e := range s.Events {
updates = append(updates, e.SwapStateData)
}
store.loopOutUpdates[s.Hash] = updates
}
expiryChan := make(chan time.Time)
timerFactory := func(expiry time.Duration) <-chan time.Time {
return expiryChan
}
swapClient := newSwapClient(&clientConfig{
LndServices: &clientLnd.LndServices,
Server: serverMock,
Store: store,
CreateExpiryTimer: timerFactory,
})
statusChan := make(chan SwapInfo)
ctx := &testContext{
Context: test.NewContext(
t,
clientLnd,
),
swapClient: swapClient,
statusChan: statusChan,
expiryChan: expiryChan,
store: store,
serverMock: serverMock,
}
ctx.runErr = make(chan error)
runCtx, stop := context.WithCancel(context.Background())
ctx.stop = stop
go func() {
err := swapClient.Run(
runCtx,
statusChan,
)
log.Errorf("client run: %v", err)
ctx.runErr <- err
}()
return ctx
}
func (ctx *testContext) finish() {
ctx.stop()
select {
case err := <-ctx.runErr:
require.NoError(ctx.Context.T, err)
case <-time.After(test.Timeout):
ctx.Context.T.Fatal("client not stopping")
}
ctx.assertIsDone()
}
func (ctx *testContext) assertIsDone() {
require.NoError(ctx.Context.T, ctx.Context.Lnd.IsDone())
require.NoError(ctx.Context.T, ctx.store.isDone())
select {
case <-ctx.statusChan:
ctx.Context.T.Fatalf("not all status updates read")
default:
}
}
func (ctx *testContext) assertStored() {
ctx.Context.T.Helper()
ctx.store.assertLoopOutStored()
}
func (ctx *testContext) assertStorePreimageReveal() {
ctx.Context.T.Helper()
ctx.store.assertStorePreimageReveal()
}
func (ctx *testContext) assertStoreFinished(expectedResult loopdb.SwapState) {
ctx.Context.T.Helper()
ctx.store.assertStoreFinished(expectedResult)
}
func (ctx *testContext) assertStatus(expectedState loopdb.SwapState) {
ctx.Context.T.Helper()
for {
select {
case update := <-ctx.statusChan:
if update.SwapType != swap.TypeOut {
continue
}
if update.State == expectedState {
return
}
case <-time.After(test.Timeout):
ctx.Context.T.Fatalf("expected status %v not "+
"received in time", expectedState)
}
}
}
func (ctx *testContext) publishHtlc(script []byte,
amt btcutil.Amount) wire.OutPoint {
// Create the htlc tx.
htlcTx := wire.MsgTx{}
htlcTx.AddTxIn(&wire.TxIn{
PreviousOutPoint: wire.OutPoint{},
})
htlcTx.AddTxOut(&wire.TxOut{
PkScript: script,
Value: int64(amt),
})
htlcTxHash := htlcTx.TxHash()
// Signal client that script has been published.
select {
case ctx.Lnd.ConfChannel <- &chainntnfs.TxConfirmation{
Tx: &htlcTx,
}:
case <-time.After(test.Timeout):
ctx.Context.T.Fatalf("htlc confirmed not consumed")
}
return wire.OutPoint{
Hash: htlcTxHash,
Index: 0,
}
}
// trackPayment asserts that a call to track payment was sent and sends the
// status provided into the updates channel.
func (ctx *testContext) trackPayment(status lnrpc.Payment_PaymentStatus) {
trackPayment := ctx.Context.AssertTrackPayment()
select {
case trackPayment.Updates <- lndclient.PaymentStatus{
State: status,
}:
case <-time.After(test.Timeout):
ctx.Context.T.Fatalf("could not send payment update")
}
}
// assertPreimagePush asserts that we made an attempt to push our preimage to
// the server.
func (ctx *testContext) assertPreimagePush(preimage lntypes.Preimage) {
select {
case pushedPreimage := <-ctx.serverMock.preimagePush:
require.Equal(ctx.Context.T, preimage, pushedPreimage)
case <-time.After(test.Timeout):
ctx.Context.T.Fatalf("preimage not pushed")
}
}