2
0
mirror of https://github.com/lightninglabs/loop synced 2024-11-09 19:10:47 +00:00
loop/test/context.go

268 lines
6.3 KiB
Go

package test
import (
"bytes"
"crypto/sha256"
"testing"
"time"
"github.com/btcsuite/btcd/wire"
"github.com/lightninglabs/lndclient"
"github.com/lightningnetwork/lnd/chainntnfs"
"github.com/lightningnetwork/lnd/lnrpc"
"github.com/lightningnetwork/lnd/lntypes"
"github.com/lightningnetwork/lnd/zpay32"
"github.com/stretchr/testify/require"
)
// Context contains shared test context functions.
type Context struct {
T *testing.T
Lnd *LndMockServices
FailedInvoices map[lntypes.Hash]struct{}
PaidInvoices map[string]func(error)
}
// NewContext instanties a new common test context.
func NewContext(t *testing.T,
lnd *LndMockServices) Context {
return Context{
T: t,
Lnd: lnd,
FailedInvoices: make(map[lntypes.Hash]struct{}),
PaidInvoices: make(map[string]func(error)),
}
}
// ReceiveTx receives and decodes a published tx.
func (ctx *Context) ReceiveTx() *wire.MsgTx {
ctx.T.Helper()
select {
case tx := <-ctx.Lnd.TxPublishChannel:
return tx
case <-time.After(Timeout):
ctx.T.Fatalf("sweep not published")
return nil
}
}
// NotifySpend simulates a spend.
func (ctx *Context) NotifySpend(tx *wire.MsgTx, inputIndex uint32) {
ctx.T.Helper()
txHash := tx.TxHash()
select {
case ctx.Lnd.SpendChannel <- &chainntnfs.SpendDetail{
SpendingTx: tx,
SpenderTxHash: &txHash,
SpenderInputIndex: inputIndex,
}:
case <-time.After(Timeout):
ctx.T.Fatalf("htlc spend not consumed")
}
}
// NotifyConf simulates a conf.
func (ctx *Context) NotifyConf(tx *wire.MsgTx) {
ctx.T.Helper()
select {
case ctx.Lnd.ConfChannel <- &chainntnfs.TxConfirmation{
Tx: tx,
}:
case <-time.After(Timeout):
ctx.T.Fatalf("htlc spend not consumed")
}
}
// AssertRegisterSpendNtfn asserts that a register for spend has been received.
func (ctx *Context) AssertRegisterSpendNtfn(script []byte) {
ctx.T.Helper()
select {
case spendIntent := <-ctx.Lnd.RegisterSpendChannel:
require.Equal(
ctx.T, script, spendIntent.PkScript,
"server not listening for published htlc script",
)
case <-time.After(Timeout):
DumpGoroutines()
ctx.T.Fatalf("spend not subscribed to")
}
}
// AssertTrackPayment asserts that a call was made to track payment, and
// returns the track payment message so that it can be used to send updates
// to the test.
func (ctx *Context) AssertTrackPayment() TrackPaymentMessage {
ctx.T.Helper()
var msg TrackPaymentMessage
select {
case msg = <-ctx.Lnd.TrackPaymentChannel:
case <-time.After(Timeout):
DumpGoroutines()
ctx.T.Fatalf("payment not tracked")
}
return msg
}
// AssertRegisterConf asserts that a register for conf has been received.
func (ctx *Context) AssertRegisterConf(expectTxHash bool, confs int32) *ConfRegistration {
ctx.T.Helper()
// Expect client to register for conf
var confIntent *ConfRegistration
select {
case confIntent = <-ctx.Lnd.RegisterConfChannel:
switch {
case expectTxHash && confIntent.TxID == nil:
ctx.T.Fatalf("expected tx id for registration")
case !expectTxHash && confIntent.TxID != nil:
ctx.T.Fatalf("expected script only registration")
}
// Require that we registered for the number of confirmations
// the test expects.
require.Equal(ctx.T, confs, confIntent.NumConfs)
case <-time.After(Timeout):
ctx.T.Fatalf("htlc confirmed not subscribed to")
}
return confIntent
}
// AssertPaid asserts that the expected payment request has been paid. This
// function returns a complete function to signal the final payment result.
func (ctx *Context) AssertPaid(
expectedMemo string) func(error) {
ctx.T.Helper()
if done, ok := ctx.PaidInvoices[expectedMemo]; ok {
return done
}
// Assert that client pays swap invoice.
for {
var swapPayment RouterPaymentChannelMessage
select {
case swapPayment = <-ctx.Lnd.RouterSendPaymentChannel:
case <-time.After(Timeout):
ctx.T.Fatalf("no payment sent for invoice: %v",
expectedMemo)
}
payReq := ctx.DecodeInvoice(swapPayment.SendPaymentRequest.Invoice)
_, ok := ctx.PaidInvoices[*payReq.Description]
require.False(
ctx.T, ok,
"duplicate invoice paid: %v", *payReq.Description,
)
done := func(result error) {
if result != nil {
swapPayment.Errors <- result
return
}
swapPayment.Updates <- lndclient.PaymentStatus{
State: lnrpc.Payment_SUCCEEDED,
}
}
ctx.PaidInvoices[*payReq.Description] = done
if *payReq.Description == expectedMemo {
return done
}
}
}
// AssertSettled asserts that an invoice with the given hash is settled.
func (ctx *Context) AssertSettled(
expectedHash lntypes.Hash) lntypes.Preimage {
ctx.T.Helper()
select {
case preimage := <-ctx.Lnd.SettleInvoiceChannel:
hash := sha256.Sum256(preimage[:])
require.Equal(
ctx.T, expectedHash, lntypes.Hash(hash),
"server claims with wrong preimage",
)
return preimage
case <-time.After(Timeout):
}
ctx.T.Fatalf("invoice not settled")
return lntypes.Preimage{}
}
// AssertFailed asserts that an invoice with the given hash is failed.
func (ctx *Context) AssertFailed(expectedHash lntypes.Hash) {
ctx.T.Helper()
if _, ok := ctx.FailedInvoices[expectedHash]; ok {
return
}
for {
select {
case hash := <-ctx.Lnd.FailInvoiceChannel:
ctx.FailedInvoices[expectedHash] = struct{}{}
if expectedHash == hash {
return
}
case <-time.After(Timeout):
ctx.T.Fatalf("invoice not failed")
}
}
}
// DecodeInvoice decodes a payment request string.
func (ctx *Context) DecodeInvoice(request string) *zpay32.Invoice {
ctx.T.Helper()
payReq, err := ctx.Lnd.DecodeInvoice(request)
require.NoError(ctx.T, err)
return payReq
}
// GetOutputIndex returns the index in the tx outs of the given script hash.
func (ctx *Context) GetOutputIndex(tx *wire.MsgTx,
script []byte) int {
for idx, out := range tx.TxOut {
if bytes.Equal(out.PkScript, script) {
return idx
}
}
ctx.T.Fatal("htlc not present in tx")
return 0
}
// NotifyServerHeight notifies the server of the arrival of a new block and
// waits for the notification to be processed by selecting on a
// dedicated test channel.
func (ctx *Context) NotifyServerHeight(height int32) {
require.NoError(ctx.T, ctx.Lnd.NotifyHeight(height))
}
func (ctx *Context) AssertEpochListeners(numListeners int32) {
require.Eventually(ctx.T, func() bool {
return ctx.Lnd.EpochSubscribers() == numListeners
}, Timeout, time.Millisecond*250)
}