mirror of
https://github.com/lightninglabs/loop
synced 2024-11-04 06:00:21 +00:00
Merge pull request #548 from GeorgeTsagk/autoloop-amount-backoff
Autoloop amount backoff
This commit is contained in:
commit
55845ff8ca
@ -199,16 +199,30 @@ func TestAutoLoopEnabled(t *testing.T) {
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
singleLoopOut = &loopdb.LoopOut{
|
||||
Loop: loopdb.Loop{
|
||||
Events: []*loopdb.LoopEvent{
|
||||
{
|
||||
SwapStateData: loopdb.SwapStateData{
|
||||
State: loopdb.StateInitiated,
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
}
|
||||
)
|
||||
|
||||
// Tick our autolooper with no existing swaps, we expect a loop out
|
||||
// swap to be dispatched for each channel.
|
||||
step := &autoloopStep{
|
||||
minAmt: 1,
|
||||
maxAmt: amt + 1,
|
||||
quotesOut: quotes,
|
||||
expectedOut: loopOuts,
|
||||
minAmt: 1,
|
||||
maxAmt: amt + 1,
|
||||
quotesOut: quotes,
|
||||
expectedOut: loopOuts,
|
||||
existingOutSingle: singleLoopOut,
|
||||
}
|
||||
|
||||
c.autoloop(step)
|
||||
|
||||
// Tick again with both of our swaps in progress. We haven't shifted our
|
||||
@ -220,9 +234,10 @@ func TestAutoLoopEnabled(t *testing.T) {
|
||||
}
|
||||
|
||||
step = &autoloopStep{
|
||||
minAmt: 1,
|
||||
maxAmt: amt + 1,
|
||||
existingOut: existing,
|
||||
minAmt: 1,
|
||||
maxAmt: amt + 1,
|
||||
existingOut: existing,
|
||||
existingOutSingle: singleLoopOut,
|
||||
}
|
||||
c.autoloop(step)
|
||||
|
||||
@ -278,11 +293,12 @@ func TestAutoLoopEnabled(t *testing.T) {
|
||||
// still has balances which reflect that we need to swap), but nothing
|
||||
// for channel 2, since it has had a failure.
|
||||
step = &autoloopStep{
|
||||
minAmt: 1,
|
||||
maxAmt: amt + 1,
|
||||
existingOut: existing,
|
||||
quotesOut: quotes,
|
||||
expectedOut: loopOuts,
|
||||
minAmt: 1,
|
||||
maxAmt: amt + 1,
|
||||
existingOut: existing,
|
||||
quotesOut: quotes,
|
||||
expectedOut: loopOuts,
|
||||
existingOutSingle: singleLoopOut,
|
||||
}
|
||||
c.autoloop(step)
|
||||
|
||||
@ -299,10 +315,11 @@ func TestAutoLoopEnabled(t *testing.T) {
|
||||
}
|
||||
|
||||
step = &autoloopStep{
|
||||
minAmt: 1,
|
||||
maxAmt: amt + 1,
|
||||
existingOut: existing,
|
||||
quotesOut: quotes,
|
||||
minAmt: 1,
|
||||
maxAmt: amt + 1,
|
||||
existingOut: existing,
|
||||
quotesOut: quotes,
|
||||
existingOutSingle: singleLoopOut,
|
||||
}
|
||||
c.autoloop(step)
|
||||
|
||||
@ -446,13 +463,27 @@ func TestAutoloopAddress(t *testing.T) {
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
singleLoopOut = &loopdb.LoopOut{
|
||||
Loop: loopdb.Loop{
|
||||
Events: []*loopdb.LoopEvent{
|
||||
{
|
||||
SwapStateData: loopdb.SwapStateData{
|
||||
State: loopdb.StateHtlcPublished,
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
}
|
||||
)
|
||||
|
||||
step := &autoloopStep{
|
||||
minAmt: 1,
|
||||
maxAmt: amt + 1,
|
||||
quotesOut: quotes,
|
||||
expectedOut: loopOuts,
|
||||
minAmt: 1,
|
||||
maxAmt: amt + 1,
|
||||
quotesOut: quotes,
|
||||
expectedOut: loopOuts,
|
||||
existingOutSingle: singleLoopOut,
|
||||
keepDestAddr: true,
|
||||
}
|
||||
c.autoloop(step)
|
||||
|
||||
@ -606,6 +637,18 @@ func TestCompositeRules(t *testing.T) {
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
singleLoopOut = &loopdb.LoopOut{
|
||||
Loop: loopdb.Loop{
|
||||
Events: []*loopdb.LoopEvent{
|
||||
{
|
||||
SwapStateData: loopdb.SwapStateData{
|
||||
State: loopdb.StateHtlcPublished,
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
}
|
||||
)
|
||||
|
||||
// Tick our autolooper with no existing swaps, we expect a loop out
|
||||
@ -613,10 +656,11 @@ func TestCompositeRules(t *testing.T) {
|
||||
// maximum to be greater than the swap amount for our peer swap (which
|
||||
// is the larger of the two swaps).
|
||||
step := &autoloopStep{
|
||||
minAmt: 1,
|
||||
maxAmt: peerAmount + 1,
|
||||
quotesOut: quotes,
|
||||
expectedOut: loopOuts,
|
||||
minAmt: 1,
|
||||
maxAmt: peerAmount + 1,
|
||||
quotesOut: quotes,
|
||||
expectedOut: loopOuts,
|
||||
existingOutSingle: singleLoopOut,
|
||||
}
|
||||
c.autoloop(step)
|
||||
|
||||
@ -928,6 +972,18 @@ func TestAutoloopBothTypes(t *testing.T) {
|
||||
Label: labels.AutoloopLabel(swap.TypeIn),
|
||||
Initiator: autoloopSwapInitiator,
|
||||
}
|
||||
|
||||
singleLoopOut = &loopdb.LoopOut{
|
||||
Loop: loopdb.Loop{
|
||||
Events: []*loopdb.LoopEvent{
|
||||
{
|
||||
SwapStateData: loopdb.SwapStateData{
|
||||
State: loopdb.StateHtlcPublished,
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
}
|
||||
)
|
||||
|
||||
step := &autoloopStep{
|
||||
@ -961,6 +1017,7 @@ func TestAutoloopBothTypes(t *testing.T) {
|
||||
},
|
||||
},
|
||||
},
|
||||
existingOutSingle: singleLoopOut,
|
||||
}
|
||||
c.autoloop(step)
|
||||
c.stop()
|
||||
|
@ -2,7 +2,9 @@ package liquidity
|
||||
|
||||
import (
|
||||
"context"
|
||||
"reflect"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/btcsuite/btcd/btcutil"
|
||||
"github.com/lightninglabs/lndclient"
|
||||
@ -11,8 +13,10 @@ import (
|
||||
"github.com/lightninglabs/loop/swap"
|
||||
"github.com/lightninglabs/loop/test"
|
||||
"github.com/lightningnetwork/lnd/clock"
|
||||
"github.com/lightningnetwork/lnd/lntypes"
|
||||
"github.com/lightningnetwork/lnd/ticker"
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
|
||||
type autoloopTestCtx struct {
|
||||
@ -45,9 +49,17 @@ type autoloopTestCtx struct {
|
||||
// loopOuts is a channel that we get existing loop out swaps on.
|
||||
loopOuts chan []*loopdb.LoopOut
|
||||
|
||||
// loopOutSingle is the single loop out returned from fetching a single
|
||||
// swap from store.
|
||||
loopOutSingle *loopdb.LoopOut
|
||||
|
||||
// loopIns is a channel that we get existing loop in swaps on.
|
||||
loopIns chan []*loopdb.LoopIn
|
||||
|
||||
// loopInSingle is the single loop in returned from fetching a single
|
||||
// swap from store.
|
||||
loopInSingle *loopdb.LoopIn
|
||||
|
||||
// restrictions is a channel that we get swap restrictions on.
|
||||
restrictions chan *Restrictions
|
||||
|
||||
@ -131,6 +143,9 @@ func newAutoloopTestCtx(t *testing.T, parameters Parameters,
|
||||
ListLoopOut: func() ([]*loopdb.LoopOut, error) {
|
||||
return <-testCtx.loopOuts, nil
|
||||
},
|
||||
GetLoopOut: func(hash lntypes.Hash) (*loopdb.LoopOut, error) {
|
||||
return testCtx.loopOutSingle, nil
|
||||
},
|
||||
ListLoopIn: func() ([]*loopdb.LoopIn, error) {
|
||||
return <-testCtx.loopIns, nil
|
||||
},
|
||||
@ -188,6 +203,10 @@ func newAutoloopTestCtx(t *testing.T, parameters Parameters,
|
||||
testCtx.manager = NewManager(cfg)
|
||||
err := testCtx.manager.setParameters(context.Background(), parameters)
|
||||
assert.NoError(t, err)
|
||||
// Override the payments check interval for the tests in order to not
|
||||
// timeout.
|
||||
testCtx.manager.params.CustomPaymentCheckInterval =
|
||||
150 * time.Millisecond
|
||||
<-done
|
||||
return testCtx
|
||||
}
|
||||
@ -241,14 +260,17 @@ type loopInRequestResp struct {
|
||||
// autoloopStep contains all of the information to required to step
|
||||
// through an autoloop tick.
|
||||
type autoloopStep struct {
|
||||
minAmt btcutil.Amount
|
||||
maxAmt btcutil.Amount
|
||||
existingOut []*loopdb.LoopOut
|
||||
existingIn []*loopdb.LoopIn
|
||||
quotesOut []quoteRequestResp
|
||||
quotesIn []quoteInRequestResp
|
||||
expectedOut []loopOutRequestResp
|
||||
expectedIn []loopInRequestResp
|
||||
minAmt btcutil.Amount
|
||||
maxAmt btcutil.Amount
|
||||
existingOut []*loopdb.LoopOut
|
||||
existingOutSingle *loopdb.LoopOut
|
||||
existingIn []*loopdb.LoopIn
|
||||
existingInSingle *loopdb.LoopIn
|
||||
quotesOut []quoteRequestResp
|
||||
quotesIn []quoteInRequestResp
|
||||
expectedOut []loopOutRequestResp
|
||||
expectedIn []loopInRequestResp
|
||||
keepDestAddr bool
|
||||
}
|
||||
|
||||
// autoloop walks our test context through the process of triggering our
|
||||
@ -269,6 +291,9 @@ func (c *autoloopTestCtx) autoloop(step *autoloopStep) {
|
||||
c.loopOuts <- step.existingOut
|
||||
c.loopIns <- step.existingIn
|
||||
|
||||
c.loopOutSingle = step.existingOutSingle
|
||||
c.loopInSingle = step.existingInSingle
|
||||
|
||||
// Assert that we query the server for a quote for each of our
|
||||
// recommended swaps. Note that this differs from our set of expected
|
||||
// swaps because we may get quotes for suggested swaps but then just
|
||||
@ -299,25 +324,77 @@ func (c *autoloopTestCtx) autoloop(step *autoloopStep) {
|
||||
c.quotes <- expected.quote
|
||||
}
|
||||
|
||||
// Assert that we dispatch the expected set of swaps.
|
||||
for _, expected := range step.expectedOut {
|
||||
require.True(c.t, c.matchLoopOuts(step.expectedOut, step.keepDestAddr))
|
||||
require.True(c.t, c.matchLoopIns(step.expectedIn))
|
||||
}
|
||||
|
||||
// matchLoopOuts checks that the actual loop out requests we got match the
|
||||
// expected ones. The argument keepDestAddr is used to indicate whether we keep
|
||||
// the actual loops destination address for the comparison. This is useful
|
||||
// because we don't want to compare the destination address generated by the
|
||||
// wallet mock. We want to compare the destination address when testing the
|
||||
// autoloop DestAddr parameter for loop outs.
|
||||
func (c *autoloopTestCtx) matchLoopOuts(swaps []loopOutRequestResp,
|
||||
keepDestAddr bool) bool {
|
||||
|
||||
swapsCopy := make([]loopOutRequestResp, len(swaps))
|
||||
copy(swapsCopy, swaps)
|
||||
|
||||
length := len(swapsCopy)
|
||||
|
||||
for i := 0; i < length; i++ {
|
||||
actual := <-c.outRequest
|
||||
|
||||
// Set our destination address to nil so that we do not need to
|
||||
// provide the address that is obtained by the mock wallet kit.
|
||||
if expected.request.DestAddr == nil {
|
||||
if !keepDestAddr {
|
||||
actual.DestAddr = nil
|
||||
}
|
||||
|
||||
assert.Equal(c.t, expected.request, actual)
|
||||
c.loopOut <- expected.response
|
||||
inner:
|
||||
for index, swap := range swapsCopy {
|
||||
equal := reflect.DeepEqual(swap.request, actual)
|
||||
|
||||
if equal {
|
||||
c.loopOut <- swap.response
|
||||
|
||||
swapsCopy = append(
|
||||
swapsCopy[:index],
|
||||
swapsCopy[index+1:]...,
|
||||
)
|
||||
|
||||
break inner
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
for _, expected := range step.expectedIn {
|
||||
return len(swapsCopy) == 0
|
||||
}
|
||||
|
||||
// matchLoopIns checks that the actual loop in requests we got match the
|
||||
// expected ones.
|
||||
func (c *autoloopTestCtx) matchLoopIns(
|
||||
swaps []loopInRequestResp) bool {
|
||||
|
||||
swapsCopy := make([]loopInRequestResp, len(swaps))
|
||||
copy(swapsCopy, swaps)
|
||||
|
||||
for i := 0; i < len(swapsCopy); i++ {
|
||||
actual := <-c.inRequest
|
||||
|
||||
assert.Equal(c.t, expected.request, actual)
|
||||
inner:
|
||||
for i, swap := range swapsCopy {
|
||||
equal := reflect.DeepEqual(swap.request, actual)
|
||||
|
||||
c.loopIn <- expected.response
|
||||
if equal {
|
||||
c.loopIn <- swap.response
|
||||
|
||||
swapsCopy = append(
|
||||
swapsCopy[:i], swapsCopy[i+1:]...,
|
||||
)
|
||||
|
||||
break inner
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
return len(swapsCopy) == 0
|
||||
}
|
||||
|
@ -48,6 +48,7 @@ import (
|
||||
"github.com/lightninglabs/loop/swap"
|
||||
"github.com/lightningnetwork/lnd/clock"
|
||||
"github.com/lightningnetwork/lnd/funding"
|
||||
"github.com/lightningnetwork/lnd/lntypes"
|
||||
"github.com/lightningnetwork/lnd/lnwallet/chainfee"
|
||||
"github.com/lightningnetwork/lnd/lnwire"
|
||||
"github.com/lightningnetwork/lnd/routing/route"
|
||||
@ -62,6 +63,22 @@ const (
|
||||
// a channel is part of a temporarily failed swap.
|
||||
defaultFailureBackoff = time.Hour * 24
|
||||
|
||||
// defaultAmountBackoff is the default backoff we apply to the amount
|
||||
// of a loop out swap that failed the off-chain payments.
|
||||
defaultAmountBackoff = float64(0.25)
|
||||
|
||||
// defaultAmountBackoffRetry is the default number of times we will
|
||||
// perform an amount backoff to a loop out swap before we give up.
|
||||
defaultAmountBackoffRetry = 5
|
||||
|
||||
// defaultSwapWaitTimeout is the default maximum amount of time we
|
||||
// wait for a swap to reach a terminal state.
|
||||
defaultSwapWaitTimeout = time.Hour * 24
|
||||
|
||||
// defaultPaymentCheckInterval is the default time that passes between
|
||||
// checks for loop out payments status.
|
||||
defaultPaymentCheckInterval = time.Second * 2
|
||||
|
||||
// defaultConfTarget is the default sweep target we use for loop outs.
|
||||
// We get our inbound liquidity quickly using preimage push, so we can
|
||||
// use a long conf target without worrying about ux impact.
|
||||
@ -78,7 +95,7 @@ const (
|
||||
|
||||
// DefaultAutoloopTicker is the default amount of time between automated
|
||||
// swap checks.
|
||||
DefaultAutoloopTicker = time.Minute * 10
|
||||
DefaultAutoloopTicker = time.Minute * 20
|
||||
|
||||
// autoloopSwapInitiator is the value we send in the initiator field of
|
||||
// a swap request when issuing an automatic swap.
|
||||
@ -164,6 +181,10 @@ type Config struct {
|
||||
// ListLoopOut returns all of the loop our swaps stored on disk.
|
||||
ListLoopOut func() ([]*loopdb.LoopOut, error)
|
||||
|
||||
// GetLoopOut returns a single loop out swap based on the provided swap
|
||||
// hash.
|
||||
GetLoopOut func(hash lntypes.Hash) (*loopdb.LoopOut, error)
|
||||
|
||||
// ListLoopIn returns all of the loop in swaps stored on disk.
|
||||
ListLoopIn func() ([]*loopdb.LoopIn, error)
|
||||
|
||||
@ -399,13 +420,10 @@ func (m *Manager) autoloop(ctx context.Context) error {
|
||||
swap.DestAddr = m.params.DestAddr
|
||||
}
|
||||
|
||||
loopOut, err := m.cfg.LoopOut(ctx, &swap)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
log.Infof("loop out automatically dispatched: hash: %v, "+
|
||||
"address: %v", loopOut.SwapHash, loopOut.HtlcAddress)
|
||||
go m.dispatchStickyLoopOut(
|
||||
ctx, swap, defaultAmountBackoffRetry,
|
||||
defaultAmountBackoff,
|
||||
)
|
||||
}
|
||||
|
||||
for _, in := range suggestion.InSwaps {
|
||||
@ -1044,6 +1062,143 @@ func (m *Manager) refreshAutoloopBudget(ctx context.Context) {
|
||||
}
|
||||
}
|
||||
|
||||
// dispatchStickyLoopOut attempts to dispatch a loop out swap that will
|
||||
// automatically retry its execution with an amount based backoff.
|
||||
func (m *Manager) dispatchStickyLoopOut(ctx context.Context,
|
||||
out loop.OutRequest, retryCount uint16, amountBackoff float64) {
|
||||
|
||||
for i := 0; i < int(retryCount); i++ {
|
||||
// Dispatch the swap.
|
||||
swap, err := m.cfg.LoopOut(ctx, &out)
|
||||
if err != nil {
|
||||
log.Errorf("unable to dispatch loop out, hash: %v, "+
|
||||
"err: %v", swap.SwapHash, err)
|
||||
}
|
||||
|
||||
log.Infof("loop out automatically dispatched: hash: %v, "+
|
||||
"address: %v, amount %v", swap.SwapHash,
|
||||
swap.HtlcAddress, out.Amount)
|
||||
|
||||
updates := make(chan *loopdb.SwapState, 1)
|
||||
|
||||
// Monitor the swap state and write the desired update to the
|
||||
// update channel. We do not want to read all of the swap state
|
||||
// updates, just the one that will help us assume the state of
|
||||
// the off-chain payment.
|
||||
go m.waitForSwapPayment(
|
||||
ctx, swap.SwapHash, updates, defaultSwapWaitTimeout,
|
||||
)
|
||||
|
||||
select {
|
||||
case <-ctx.Done():
|
||||
return
|
||||
|
||||
case update := <-updates:
|
||||
if update == nil {
|
||||
// If update is nil then no update occurred
|
||||
// within the defined timeout period. It's
|
||||
// better to return and not attempt a retry.
|
||||
log.Debug(
|
||||
"No payment update received for swap "+
|
||||
"%v, skipping amount backoff",
|
||||
swap.SwapHash,
|
||||
)
|
||||
|
||||
return
|
||||
}
|
||||
|
||||
if *update == loopdb.StateFailOffchainPayments {
|
||||
// Save the old amount so we can log it.
|
||||
oldAmt := out.Amount
|
||||
|
||||
// If we failed to pay the server, we will
|
||||
// decrease the amount of the swap and try
|
||||
// again.
|
||||
out.Amount -= btcutil.Amount(
|
||||
float64(out.Amount) * amountBackoff,
|
||||
)
|
||||
|
||||
log.Infof("swap %v: amount backoff old amount="+
|
||||
"%v, new amount=%v", swap.SwapHash,
|
||||
oldAmt, out.Amount)
|
||||
|
||||
continue
|
||||
} else {
|
||||
// If the update channel did not return an
|
||||
// off-chain payment failure we won't retry.
|
||||
return
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// waitForSwapPayment waits for a swap to progress beyond the stage of
|
||||
// forwarding the payment to the server through the network. It returns the
|
||||
// final update on the outcome through a channel.
|
||||
func (m *Manager) waitForSwapPayment(ctx context.Context, swapHash lntypes.Hash,
|
||||
updateChan chan *loopdb.SwapState, timeout time.Duration) {
|
||||
|
||||
startTime := time.Now()
|
||||
var (
|
||||
swap *loopdb.LoopOut
|
||||
err error
|
||||
interval time.Duration
|
||||
)
|
||||
|
||||
if m.params.CustomPaymentCheckInterval != 0 {
|
||||
interval = m.params.CustomPaymentCheckInterval
|
||||
} else {
|
||||
interval = defaultPaymentCheckInterval
|
||||
}
|
||||
|
||||
for time.Since(startTime) < timeout {
|
||||
select {
|
||||
case <-ctx.Done():
|
||||
return
|
||||
case <-time.After(interval):
|
||||
}
|
||||
|
||||
swap, err = m.cfg.GetLoopOut(swapHash)
|
||||
if err != nil {
|
||||
log.Errorf(
|
||||
"Error getting swap with hash %x: %v", swapHash,
|
||||
err,
|
||||
)
|
||||
continue
|
||||
}
|
||||
|
||||
// If no update has occurred yet, continue in order to wait.
|
||||
update := swap.LastUpdate()
|
||||
if update == nil {
|
||||
continue
|
||||
}
|
||||
|
||||
// Write the update if the swap has reached a state the helps
|
||||
// us determine whether the off-chain payment successfully
|
||||
// reached the destination.
|
||||
switch update.State {
|
||||
case loopdb.StateFailInsufficientValue:
|
||||
fallthrough
|
||||
case loopdb.StateSuccess:
|
||||
fallthrough
|
||||
case loopdb.StateFailSweepTimeout:
|
||||
fallthrough
|
||||
case loopdb.StateFailTimeout:
|
||||
fallthrough
|
||||
case loopdb.StatePreimageRevealed:
|
||||
fallthrough
|
||||
case loopdb.StateFailOffchainPayments:
|
||||
updateChan <- &update.State
|
||||
return
|
||||
}
|
||||
}
|
||||
|
||||
// If no update occurred within the defined timeout we return an empty
|
||||
// update to the channel, causing the sticky loop out to not retry
|
||||
// anymore.
|
||||
updateChan <- nil
|
||||
}
|
||||
|
||||
// swapTraffic contains a summary of our current and previously failed swaps.
|
||||
type swapTraffic struct {
|
||||
ongoingLoopOut map[lnwire.ShortChannelID]bool
|
||||
|
@ -87,6 +87,10 @@ type Parameters struct {
|
||||
// ChannelRules are exclusively set to prevent overlap between peer
|
||||
// and channel rules map to avoid ambiguity.
|
||||
PeerRules map[route.Vertex]*SwapRule
|
||||
|
||||
// CustomPaymentCheckInterval is an optional custom interval to use when
|
||||
// checking an autoloop loop out payments' payment status.
|
||||
CustomPaymentCheckInterval time.Duration
|
||||
}
|
||||
|
||||
// String returns the string representation of our parameters.
|
||||
|
@ -72,6 +72,7 @@ func getLiquidityManager(client *loop.Client) *liquidity.Manager {
|
||||
LoopOutQuote: client.LoopOutQuote,
|
||||
LoopInQuote: client.LoopInQuote,
|
||||
ListLoopOut: client.Store.FetchLoopOutSwaps,
|
||||
GetLoopOut: client.Store.FetchLoopOutSwap,
|
||||
ListLoopIn: client.Store.FetchLoopInSwaps,
|
||||
MinimumConfirmations: minConfTarget,
|
||||
PutLiquidityParams: client.Store.PutLiquidityParams,
|
||||
|
@ -12,6 +12,9 @@ type SwapStore interface {
|
||||
// FetchLoopOutSwaps returns all swaps currently in the store.
|
||||
FetchLoopOutSwaps() ([]*LoopOut, error)
|
||||
|
||||
// FetchLoopOutSwap returns the loop out swap with the given hash.
|
||||
FetchLoopOutSwap(hash lntypes.Hash) (*LoopOut, error)
|
||||
|
||||
// CreateLoopOut adds an initiated swap to the store.
|
||||
CreateLoopOut(hash lntypes.Hash, swap *LoopOutContract) error
|
||||
|
||||
|
452
loopdb/store.go
452
loopdb/store.go
@ -255,111 +255,12 @@ func (s *boltSwapStore) FetchLoopOutSwaps() ([]*LoopOut, error) {
|
||||
return nil
|
||||
}
|
||||
|
||||
// From the root bucket, we'll grab the next swap
|
||||
// bucket for this swap from its swaphash.
|
||||
swapBucket := rootBucket.Bucket(swapHash)
|
||||
if swapBucket == nil {
|
||||
return fmt.Errorf("swap bucket %x not found",
|
||||
swapHash)
|
||||
}
|
||||
|
||||
// With the main swap bucket obtained, we'll grab the
|
||||
// raw swap contract bytes and decode it.
|
||||
contractBytes := swapBucket.Get(contractKey)
|
||||
if contractBytes == nil {
|
||||
return errors.New("contract not found")
|
||||
}
|
||||
|
||||
contract, err := deserializeLoopOutContract(
|
||||
contractBytes, s.chainParams,
|
||||
)
|
||||
loop, err := s.fetchLoopOutSwap(rootBucket, swapHash)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
// Get our label for this swap, if it is present.
|
||||
contract.Label = getLabel(swapBucket)
|
||||
|
||||
// Read the list of concatenated outgoing channel ids
|
||||
// that form the outgoing set.
|
||||
setBytes := swapBucket.Get(outgoingChanSetKey)
|
||||
if outgoingChanSetKey != nil {
|
||||
r := bytes.NewReader(setBytes)
|
||||
readLoop:
|
||||
for {
|
||||
var chanID uint64
|
||||
err := binary.Read(r, byteOrder, &chanID)
|
||||
switch {
|
||||
case err == io.EOF:
|
||||
break readLoop
|
||||
case err != nil:
|
||||
return err
|
||||
}
|
||||
|
||||
contract.OutgoingChanSet = append(
|
||||
contract.OutgoingChanSet,
|
||||
chanID,
|
||||
)
|
||||
}
|
||||
}
|
||||
|
||||
// Set our default number of confirmations for the swap.
|
||||
contract.HtlcConfirmations = DefaultLoopOutHtlcConfirmations
|
||||
|
||||
// If we have the number of confirmations stored for
|
||||
// this swap, we overwrite our default with the stored
|
||||
// value.
|
||||
confBytes := swapBucket.Get(confirmationsKey)
|
||||
if confBytes != nil {
|
||||
r := bytes.NewReader(confBytes)
|
||||
err := binary.Read(
|
||||
r, byteOrder, &contract.HtlcConfirmations,
|
||||
)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
|
||||
updates, err := deserializeUpdates(swapBucket)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
// Try to unmarshal the protocol version for the swap.
|
||||
// If the protocol version is not stored (which is
|
||||
// the case for old clients), we'll assume the
|
||||
// ProtocolVersionUnrecorded instead.
|
||||
contract.ProtocolVersion, err =
|
||||
UnmarshalProtocolVersion(
|
||||
swapBucket.Get(protocolVersionKey),
|
||||
)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
// Try to unmarshal the key locator.
|
||||
if contract.ProtocolVersion >= ProtocolVersionHtlcV3 {
|
||||
contract.ClientKeyLocator, err = UnmarshalKeyLocator(
|
||||
swapBucket.Get(keyLocatorKey),
|
||||
)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
|
||||
loop := LoopOut{
|
||||
Loop: Loop{
|
||||
Events: updates,
|
||||
},
|
||||
Contract: contract,
|
||||
}
|
||||
|
||||
loop.Hash, err = lntypes.MakeHash(swapHash)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
swaps = append(swaps, &loop)
|
||||
swaps = append(swaps, loop)
|
||||
|
||||
return nil
|
||||
})
|
||||
@ -371,53 +272,33 @@ func (s *boltSwapStore) FetchLoopOutSwaps() ([]*LoopOut, error) {
|
||||
return swaps, nil
|
||||
}
|
||||
|
||||
// deserializeUpdates deserializes the list of swap updates that are stored as a
|
||||
// key of the given bucket.
|
||||
func deserializeUpdates(swapBucket *bbolt.Bucket) ([]*LoopEvent, error) {
|
||||
// Once we have the raw swap, we'll also need to decode
|
||||
// each of the past updates to the swap itself.
|
||||
stateBucket := swapBucket.Bucket(updatesBucketKey)
|
||||
if stateBucket == nil {
|
||||
return nil, errors.New("updates bucket not found")
|
||||
}
|
||||
// FetchLoopOutSwap returns the loop out swap with the given hash.
|
||||
//
|
||||
// NOTE: Part of the loopdb.SwapStore interface.
|
||||
func (s *boltSwapStore) FetchLoopOutSwap(hash lntypes.Hash) (*LoopOut, error) {
|
||||
var swap *LoopOut
|
||||
|
||||
// Deserialize and collect each swap update into our slice of swap
|
||||
// events.
|
||||
var updates []*LoopEvent
|
||||
err := stateBucket.ForEach(func(k, v []byte) error {
|
||||
updateBucket := stateBucket.Bucket(k)
|
||||
if updateBucket == nil {
|
||||
return fmt.Errorf("expected state sub-bucket for %x", k)
|
||||
err := s.db.View(func(tx *bbolt.Tx) error {
|
||||
// First, we'll grab our main loop out bucket key.
|
||||
rootBucket := tx.Bucket(loopOutBucketKey)
|
||||
if rootBucket == nil {
|
||||
return errors.New("bucket does not exist")
|
||||
}
|
||||
|
||||
basicState := updateBucket.Get(basicStateKey)
|
||||
if basicState == nil {
|
||||
return errors.New("no basic state for update")
|
||||
}
|
||||
|
||||
event, err := deserializeLoopEvent(basicState)
|
||||
loop, err := s.fetchLoopOutSwap(rootBucket, hash[:])
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
// Deserialize htlc tx hash if this updates contains one.
|
||||
htlcTxHashBytes := updateBucket.Get(htlcTxHashKey)
|
||||
if htlcTxHashBytes != nil {
|
||||
htlcTxHash, err := chainhash.NewHash(htlcTxHashBytes)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
event.HtlcTxHash = htlcTxHash
|
||||
}
|
||||
swap = loop
|
||||
|
||||
updates = append(updates, event)
|
||||
return nil
|
||||
})
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return updates, nil
|
||||
return swap, nil
|
||||
}
|
||||
|
||||
// FetchLoopInSwaps returns all loop in swaps currently in the store.
|
||||
@ -442,71 +323,12 @@ func (s *boltSwapStore) FetchLoopInSwaps() ([]*LoopIn, error) {
|
||||
return nil
|
||||
}
|
||||
|
||||
// From the root bucket, we'll grab the next swap
|
||||
// bucket for this swap from its swaphash.
|
||||
swapBucket := rootBucket.Bucket(swapHash)
|
||||
if swapBucket == nil {
|
||||
return fmt.Errorf("swap bucket %x not found",
|
||||
swapHash)
|
||||
}
|
||||
|
||||
// With the main swap bucket obtained, we'll grab the
|
||||
// raw swap contract bytes and decode it.
|
||||
contractBytes := swapBucket.Get(contractKey)
|
||||
if contractBytes == nil {
|
||||
return errors.New("contract not found")
|
||||
}
|
||||
|
||||
contract, err := deserializeLoopInContract(
|
||||
contractBytes,
|
||||
)
|
||||
loop, err := s.fetchLoopInSwap(rootBucket, swapHash)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
// Get our label for this swap, if it is present.
|
||||
contract.Label = getLabel(swapBucket)
|
||||
|
||||
updates, err := deserializeUpdates(swapBucket)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
// Try to unmarshal the protocol version for the swap.
|
||||
// If the protocol version is not stored (which is
|
||||
// the case for old clients), we'll assume the
|
||||
// ProtocolVersionUnrecorded instead.
|
||||
contract.ProtocolVersion, err =
|
||||
UnmarshalProtocolVersion(
|
||||
swapBucket.Get(protocolVersionKey),
|
||||
)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
// Try to unmarshal the key locator.
|
||||
if contract.ProtocolVersion >= ProtocolVersionHtlcV3 {
|
||||
contract.ClientKeyLocator, err = UnmarshalKeyLocator(
|
||||
swapBucket.Get(keyLocatorKey),
|
||||
)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
|
||||
loop := LoopIn{
|
||||
Loop: Loop{
|
||||
Events: updates,
|
||||
},
|
||||
Contract: contract,
|
||||
}
|
||||
|
||||
loop.Hash, err = lntypes.MakeHash(swapHash)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
swaps = append(swaps, &loop)
|
||||
swaps = append(swaps, loop)
|
||||
|
||||
return nil
|
||||
})
|
||||
@ -824,3 +646,243 @@ func (s *boltSwapStore) FetchLiquidityParams() ([]byte, error) {
|
||||
|
||||
return params, err
|
||||
}
|
||||
|
||||
// fetchUpdates deserializes the list of swap updates that are stored as a
|
||||
// key of the given bucket.
|
||||
func fetchUpdates(swapBucket *bbolt.Bucket) ([]*LoopEvent, error) {
|
||||
// Once we have the raw swap, we'll also need to decode
|
||||
// each of the past updates to the swap itself.
|
||||
stateBucket := swapBucket.Bucket(updatesBucketKey)
|
||||
if stateBucket == nil {
|
||||
return nil, errors.New("updates bucket not found")
|
||||
}
|
||||
|
||||
// Deserialize and collect each swap update into our slice of swap
|
||||
// events.
|
||||
var updates []*LoopEvent
|
||||
err := stateBucket.ForEach(func(k, v []byte) error {
|
||||
updateBucket := stateBucket.Bucket(k)
|
||||
if updateBucket == nil {
|
||||
return fmt.Errorf("expected state sub-bucket for %x", k)
|
||||
}
|
||||
|
||||
basicState := updateBucket.Get(basicStateKey)
|
||||
if basicState == nil {
|
||||
return errors.New("no basic state for update")
|
||||
}
|
||||
|
||||
event, err := deserializeLoopEvent(basicState)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
// Deserialize htlc tx hash if this updates contains one.
|
||||
htlcTxHashBytes := updateBucket.Get(htlcTxHashKey)
|
||||
if htlcTxHashBytes != nil {
|
||||
htlcTxHash, err := chainhash.NewHash(htlcTxHashBytes)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
event.HtlcTxHash = htlcTxHash
|
||||
}
|
||||
|
||||
updates = append(updates, event)
|
||||
return nil
|
||||
})
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return updates, nil
|
||||
}
|
||||
|
||||
// fetchLoopOutSwap fetches and deserializes the raw swap bytes into a LoopOut
|
||||
// struct.
|
||||
func (s *boltSwapStore) fetchLoopOutSwap(rootBucket *bbolt.Bucket,
|
||||
swapHash []byte) (*LoopOut, error) {
|
||||
|
||||
// From the root bucket, we'll grab the next swap
|
||||
// bucket for this swap from its swaphash.
|
||||
swapBucket := rootBucket.Bucket(swapHash)
|
||||
if swapBucket == nil {
|
||||
return nil, fmt.Errorf("swap bucket %x not found",
|
||||
swapHash)
|
||||
}
|
||||
|
||||
hash, err := lntypes.MakeHash(swapHash)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
// With the main swap bucket obtained, we'll grab the
|
||||
// raw swap contract bytes and decode it.
|
||||
contractBytes := swapBucket.Get(contractKey)
|
||||
if contractBytes == nil {
|
||||
return nil, errors.New("contract not found")
|
||||
}
|
||||
|
||||
contract, err := deserializeLoopOutContract(
|
||||
contractBytes, s.chainParams,
|
||||
)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
// Get our label for this swap, if it is present.
|
||||
contract.Label = getLabel(swapBucket)
|
||||
|
||||
// Read the list of concatenated outgoing channel ids
|
||||
// that form the outgoing set.
|
||||
setBytes := swapBucket.Get(outgoingChanSetKey)
|
||||
if outgoingChanSetKey != nil {
|
||||
r := bytes.NewReader(setBytes)
|
||||
readLoop:
|
||||
for {
|
||||
var chanID uint64
|
||||
err := binary.Read(r, byteOrder, &chanID)
|
||||
switch {
|
||||
case err == io.EOF:
|
||||
break readLoop
|
||||
case err != nil:
|
||||
return nil, err
|
||||
}
|
||||
|
||||
contract.OutgoingChanSet = append(
|
||||
contract.OutgoingChanSet,
|
||||
chanID,
|
||||
)
|
||||
}
|
||||
}
|
||||
|
||||
// Set our default number of confirmations for the swap.
|
||||
contract.HtlcConfirmations = DefaultLoopOutHtlcConfirmations
|
||||
|
||||
// If we have the number of confirmations stored for
|
||||
// this swap, we overwrite our default with the stored
|
||||
// value.
|
||||
confBytes := swapBucket.Get(confirmationsKey)
|
||||
if confBytes != nil {
|
||||
r := bytes.NewReader(confBytes)
|
||||
err := binary.Read(
|
||||
r, byteOrder, &contract.HtlcConfirmations,
|
||||
)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
}
|
||||
|
||||
updates, err := fetchUpdates(swapBucket)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
// Try to unmarshal the protocol version for the swap.
|
||||
// If the protocol version is not stored (which is
|
||||
// the case for old clients), we'll assume the
|
||||
// ProtocolVersionUnrecorded instead.
|
||||
contract.ProtocolVersion, err =
|
||||
UnmarshalProtocolVersion(
|
||||
swapBucket.Get(protocolVersionKey),
|
||||
)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
// Try to unmarshal the key locator.
|
||||
if contract.ProtocolVersion >= ProtocolVersionHtlcV3 {
|
||||
contract.ClientKeyLocator, err = UnmarshalKeyLocator(
|
||||
swapBucket.Get(keyLocatorKey),
|
||||
)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
}
|
||||
|
||||
loop := LoopOut{
|
||||
Loop: Loop{
|
||||
Events: updates,
|
||||
},
|
||||
Contract: contract,
|
||||
}
|
||||
|
||||
loop.Hash, err = lntypes.MakeHash(hash[:])
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return &loop, nil
|
||||
}
|
||||
|
||||
// fetchLoopInSwap fetches and deserializes the raw swap bytes into a LoopIn
|
||||
// struct.
|
||||
func (s *boltSwapStore) fetchLoopInSwap(rootBucket *bbolt.Bucket,
|
||||
swapHash []byte) (*LoopIn, error) {
|
||||
|
||||
// From the root bucket, we'll grab the next swap
|
||||
// bucket for this swap from its swaphash.
|
||||
swapBucket := rootBucket.Bucket(swapHash)
|
||||
if swapBucket == nil {
|
||||
return nil, fmt.Errorf("swap bucket %x not found",
|
||||
swapHash)
|
||||
}
|
||||
|
||||
hash, err := lntypes.MakeHash(swapHash)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
// With the main swap bucket obtained, we'll grab the
|
||||
// raw swap contract bytes and decode it.
|
||||
contractBytes := swapBucket.Get(contractKey)
|
||||
if contractBytes == nil {
|
||||
return nil, errors.New("contract not found")
|
||||
}
|
||||
|
||||
contract, err := deserializeLoopInContract(
|
||||
contractBytes,
|
||||
)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
// Get our label for this swap, if it is present.
|
||||
contract.Label = getLabel(swapBucket)
|
||||
|
||||
updates, err := fetchUpdates(swapBucket)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
// Try to unmarshal the protocol version for the swap.
|
||||
// If the protocol version is not stored (which is
|
||||
// the case for old clients), we'll assume the
|
||||
// ProtocolVersionUnrecorded instead.
|
||||
contract.ProtocolVersion, err =
|
||||
UnmarshalProtocolVersion(
|
||||
swapBucket.Get(protocolVersionKey),
|
||||
)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
// Try to unmarshal the key locator.
|
||||
if contract.ProtocolVersion >= ProtocolVersionHtlcV3 {
|
||||
contract.ClientKeyLocator, err = UnmarshalKeyLocator(
|
||||
swapBucket.Get(keyLocatorKey),
|
||||
)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
}
|
||||
|
||||
loop := LoopIn{
|
||||
Loop: Loop{
|
||||
Events: updates,
|
||||
},
|
||||
Contract: contract,
|
||||
}
|
||||
|
||||
loop.Hash = hash
|
||||
|
||||
return &loop, nil
|
||||
}
|
||||
|
@ -5,7 +5,6 @@ import (
|
||||
"io/ioutil"
|
||||
"os"
|
||||
"path/filepath"
|
||||
"reflect"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
@ -96,24 +95,20 @@ func TestLoopOutStore(t *testing.T) {
|
||||
// swap store for specific swap parameters.
|
||||
func testLoopOutStore(t *testing.T, pendingSwap *LoopOutContract) {
|
||||
tempDirName, err := ioutil.TempDir("", "clientstore")
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
require.NoError(t, err)
|
||||
|
||||
defer os.RemoveAll(tempDirName)
|
||||
|
||||
store, err := NewBoltSwapStore(tempDirName, &chaincfg.MainNetParams)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
require.NoError(t, err)
|
||||
|
||||
// First, verify that an empty database has no active swaps.
|
||||
swaps, err := store.FetchLoopOutSwaps()
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
if len(swaps) != 0 {
|
||||
t.Fatal("expected empty store")
|
||||
}
|
||||
|
||||
require.NoError(t, err)
|
||||
require.Empty(t, swaps)
|
||||
|
||||
hash := pendingSwap.Preimage.Hash()
|
||||
|
||||
// checkSwap is a test helper function that'll assert the state of a
|
||||
// swap.
|
||||
@ -121,43 +116,37 @@ func testLoopOutStore(t *testing.T, pendingSwap *LoopOutContract) {
|
||||
t.Helper()
|
||||
|
||||
swaps, err := store.FetchLoopOutSwaps()
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
require.NoError(t, err)
|
||||
|
||||
if len(swaps) != 1 {
|
||||
t.Fatal("expected pending swap in store")
|
||||
}
|
||||
require.Len(t, swaps, 1)
|
||||
|
||||
swap := swaps[0].Contract
|
||||
if !reflect.DeepEqual(swap, pendingSwap) {
|
||||
t.Fatal("invalid pending swap data")
|
||||
}
|
||||
swap, err := store.FetchLoopOutSwap(hash)
|
||||
require.NoError(t, err)
|
||||
|
||||
if swaps[0].State().State != expectedState {
|
||||
t.Fatalf("expected state %v, but got %v",
|
||||
expectedState, swaps[0].State(),
|
||||
)
|
||||
}
|
||||
require.Equal(t, hash, swap.Hash)
|
||||
require.Equal(t, hash, swaps[0].Hash)
|
||||
|
||||
swapContract := swap.Contract
|
||||
|
||||
require.Equal(t, swapContract, pendingSwap)
|
||||
|
||||
require.Equal(t, expectedState, swap.State().State)
|
||||
|
||||
if expectedState == StatePreimageRevealed {
|
||||
require.NotNil(t, swaps[0].State().HtlcTxHash)
|
||||
require.NotNil(t, swap.State().HtlcTxHash)
|
||||
}
|
||||
}
|
||||
|
||||
hash := pendingSwap.Preimage.Hash()
|
||||
|
||||
// If we create a new swap, then it should show up as being initialized
|
||||
// right after.
|
||||
if err := store.CreateLoopOut(hash, pendingSwap); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
err = store.CreateLoopOut(hash, pendingSwap)
|
||||
require.NoError(t, err)
|
||||
|
||||
checkSwap(StateInitiated)
|
||||
|
||||
// Trying to make the same swap again should result in an error.
|
||||
if err := store.CreateLoopOut(hash, pendingSwap); err == nil {
|
||||
t.Fatal("expected error on storing duplicate")
|
||||
}
|
||||
err = store.CreateLoopOut(hash, pendingSwap)
|
||||
require.Error(t, err)
|
||||
checkSwap(StateInitiated)
|
||||
|
||||
// Next, we'll update to the next state of the pre-image being
|
||||
@ -169,9 +158,8 @@ func testLoopOutStore(t *testing.T, pendingSwap *LoopOutContract) {
|
||||
HtlcTxHash: &chainhash.Hash{1, 6, 2},
|
||||
},
|
||||
)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
require.NoError(t, err)
|
||||
|
||||
checkSwap(StatePreimageRevealed)
|
||||
|
||||
// Next, we'll update to the final state to ensure that the state is
|
||||
@ -182,21 +170,17 @@ func testLoopOutStore(t *testing.T, pendingSwap *LoopOutContract) {
|
||||
State: StateFailInsufficientValue,
|
||||
},
|
||||
)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
require.NoError(t, err)
|
||||
checkSwap(StateFailInsufficientValue)
|
||||
|
||||
if err := store.Close(); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
err = store.Close()
|
||||
require.NoError(t, err)
|
||||
|
||||
// If we re-open the same store, then the state of the current swap
|
||||
// should be the same.
|
||||
store, err = NewBoltSwapStore(tempDirName, &chaincfg.MainNetParams)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
require.NoError(t, err)
|
||||
|
||||
checkSwap(StateFailInsufficientValue)
|
||||
}
|
||||
|
||||
@ -242,24 +226,18 @@ func TestLoopInStore(t *testing.T) {
|
||||
|
||||
func testLoopInStore(t *testing.T, pendingSwap LoopInContract) {
|
||||
tempDirName, err := ioutil.TempDir("", "clientstore")
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
require.NoError(t, err)
|
||||
defer os.RemoveAll(tempDirName)
|
||||
|
||||
store, err := NewBoltSwapStore(tempDirName, &chaincfg.MainNetParams)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
require.NoError(t, err)
|
||||
|
||||
// First, verify that an empty database has no active swaps.
|
||||
swaps, err := store.FetchLoopInSwaps()
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
if len(swaps) != 0 {
|
||||
t.Fatal("expected empty store")
|
||||
}
|
||||
require.NoError(t, err)
|
||||
require.Empty(t, swaps)
|
||||
|
||||
hash := sha256.Sum256(testPreimage[:])
|
||||
|
||||
// checkSwap is a test helper function that'll assert the state of a
|
||||
// swap.
|
||||
@ -267,39 +245,27 @@ func testLoopInStore(t *testing.T, pendingSwap LoopInContract) {
|
||||
t.Helper()
|
||||
|
||||
swaps, err := store.FetchLoopInSwaps()
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
if len(swaps) != 1 {
|
||||
t.Fatal("expected pending swap in store")
|
||||
}
|
||||
require.NoError(t, err)
|
||||
require.Len(t, swaps, 1)
|
||||
|
||||
swap := swaps[0].Contract
|
||||
if !reflect.DeepEqual(swap, &pendingSwap) {
|
||||
t.Fatal("invalid pending swap data")
|
||||
}
|
||||
|
||||
if swaps[0].State().State != expectedState {
|
||||
t.Fatalf("expected state %v, but got %v",
|
||||
expectedState, swaps[0].State(),
|
||||
)
|
||||
}
|
||||
require.Equal(t, swap, &pendingSwap)
|
||||
|
||||
require.Equal(t, swaps[0].State().State, expectedState)
|
||||
}
|
||||
|
||||
hash := sha256.Sum256(testPreimage[:])
|
||||
|
||||
// If we create a new swap, then it should show up as being initialized
|
||||
// right after.
|
||||
if err := store.CreateLoopIn(hash, &pendingSwap); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
err = store.CreateLoopIn(hash, &pendingSwap)
|
||||
require.NoError(t, err)
|
||||
|
||||
checkSwap(StateInitiated)
|
||||
|
||||
// Trying to make the same swap again should result in an error.
|
||||
if err := store.CreateLoopIn(hash, &pendingSwap); err == nil {
|
||||
t.Fatal("expected error on storing duplicate")
|
||||
}
|
||||
err = store.CreateLoopIn(hash, &pendingSwap)
|
||||
require.Error(t, err)
|
||||
|
||||
checkSwap(StateInitiated)
|
||||
|
||||
// Next, we'll update to the next state of the pre-image being
|
||||
@ -310,9 +276,8 @@ func testLoopInStore(t *testing.T, pendingSwap LoopInContract) {
|
||||
State: StatePreimageRevealed,
|
||||
},
|
||||
)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
require.NoError(t, err)
|
||||
|
||||
checkSwap(StatePreimageRevealed)
|
||||
|
||||
// Next, we'll update to the final state to ensure that the state is
|
||||
@ -323,21 +288,17 @@ func testLoopInStore(t *testing.T, pendingSwap LoopInContract) {
|
||||
State: StateFailInsufficientValue,
|
||||
},
|
||||
)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
require.NoError(t, err)
|
||||
checkSwap(StateFailInsufficientValue)
|
||||
|
||||
if err := store.Close(); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
err = store.Close()
|
||||
require.NoError(t, err)
|
||||
|
||||
// If we re-open the same store, then the state of the current swap
|
||||
// should be the same.
|
||||
store, err = NewBoltSwapStore(tempDirName, &chaincfg.MainNetParams)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
require.NoError(t, err)
|
||||
|
||||
checkSwap(StateFailInsufficientValue)
|
||||
}
|
||||
|
||||
@ -467,9 +428,8 @@ func TestLegacyOutgoingChannel(t *testing.T) {
|
||||
|
||||
// Assert that the outgoing channel is read properly.
|
||||
expectedChannelSet := ChannelSet{5}
|
||||
if !reflect.DeepEqual(swaps[0].Contract.OutgoingChanSet, expectedChannelSet) {
|
||||
t.Fatal("invalid outgoing channel")
|
||||
}
|
||||
|
||||
require.Equal(t, expectedChannelSet, swaps[0].Contract.OutgoingChanSet)
|
||||
}
|
||||
|
||||
// TestLiquidityParams checks that reading and writing to liquidty bucket are
|
||||
|
@ -70,6 +70,36 @@ func (s *storeMock) FetchLoopOutSwaps() ([]*loopdb.LoopOut, error) {
|
||||
return result, nil
|
||||
}
|
||||
|
||||
// FetchLoopOutSwaps returns all swaps currently in the store.
|
||||
//
|
||||
// NOTE: Part of the loopdb.SwapStore interface.
|
||||
func (s *storeMock) FetchLoopOutSwap(
|
||||
hash lntypes.Hash) (*loopdb.LoopOut, error) {
|
||||
|
||||
contract, ok := s.loopOutSwaps[hash]
|
||||
if !ok {
|
||||
return nil, errors.New("swap not found")
|
||||
}
|
||||
|
||||
updates := s.loopOutUpdates[hash]
|
||||
events := make([]*loopdb.LoopEvent, len(updates))
|
||||
for i, u := range updates {
|
||||
events[i] = &loopdb.LoopEvent{
|
||||
SwapStateData: u,
|
||||
}
|
||||
}
|
||||
|
||||
swap := &loopdb.LoopOut{
|
||||
Loop: loopdb.Loop{
|
||||
Hash: hash,
|
||||
Events: events,
|
||||
},
|
||||
Contract: contract,
|
||||
}
|
||||
|
||||
return swap, nil
|
||||
}
|
||||
|
||||
// CreateLoopOut adds an initiated swap to the store.
|
||||
//
|
||||
// NOTE: Part of the loopdb.SwapStore interface.
|
||||
|
Loading…
Reference in New Issue
Block a user