loopdb: add helper methods to update swap costs

This commit adds the necessary sqlc code and SwapStore function to
update swap costs for all swaps in one transaction.
pull/764/head
Andras Banki-Horvath 4 weeks ago
parent 08aa4db35d
commit 4f5c806ba5
No known key found for this signature in database
GPG Key ID: 80E5375C094198D8

@ -65,6 +65,11 @@ type SwapStore interface {
// it's decoding using the proto package's `Unmarshal` method.
FetchLiquidityParams(ctx context.Context) ([]byte, error)
// BatchUpdateLoopOutSwapCosts updates the swap costs for a batch of
// loop out swaps.
BatchUpdateLoopOutSwapCosts(ctx context.Context,
swaps map[lntypes.Hash]SwapCost) error
// Close closes the underlying database.
Close() error
}

@ -407,6 +407,38 @@ func (s *BaseDB) BatchInsertUpdate(ctx context.Context,
})
}
// BatchUpdateLoopOutSwapCosts updates the swap costs for a batch of loop out
// swaps.
func (b *BaseDB) BatchUpdateLoopOutSwapCosts(ctx context.Context,
costs map[lntypes.Hash]SwapCost) error {
writeOpts := &SqliteTxOptions{}
return b.ExecTx(ctx, writeOpts, func(tx *sqlc.Queries) error {
for swapHash, cost := range costs {
lastUpdateID, err := tx.GetLastUpdateID(
ctx, swapHash[:],
)
if err != nil {
return err
}
err = tx.OverrideSwapCosts(
ctx, sqlc.OverrideSwapCostsParams{
ID: lastUpdateID,
ServerCost: int64(cost.Server),
OnchainCost: int64(cost.Onchain),
OffchainCost: int64(cost.Offchain),
},
)
if err != nil {
return err
}
}
return nil
})
}
// loopToInsertArgs converts a SwapContract struct to the arguments needed to
// insert it into the database.
func loopToInsertArgs(hash lntypes.Hash,

@ -13,6 +13,7 @@ import (
"github.com/lightninglabs/loop/loopdb/sqlc"
"github.com/lightninglabs/loop/test"
"github.com/lightningnetwork/lnd/keychain"
"github.com/lightningnetwork/lnd/lntypes"
"github.com/lightningnetwork/lnd/routing/route"
"github.com/stretchr/testify/require"
)
@ -396,6 +397,124 @@ func TestIssue615(t *testing.T) {
require.NoError(t, err)
}
// TestBatchUpdateCost tests that we can batch update the cost of multiple swaps
// at once.
func TestBatchUpdateCost(t *testing.T) {
// Create a new sqlite store for testing.
store := NewTestDB(t)
destAddr := test.GetDestAddr(t, 0)
initiationTime := time.Date(2018, 11, 1, 0, 0, 0, 0, time.UTC)
testContract := LoopOutContract{
SwapContract: SwapContract{
AmountRequested: 100,
CltvExpiry: 144,
HtlcKeys: HtlcKeys{
SenderScriptKey: senderKey,
ReceiverScriptKey: receiverKey,
SenderInternalPubKey: senderInternalKey,
ReceiverInternalPubKey: receiverInternalKey,
ClientScriptKeyLocator: keychain.KeyLocator{
Family: 1,
Index: 2,
},
},
MaxMinerFee: 10,
MaxSwapFee: 20,
InitiationHeight: 99,
InitiationTime: initiationTime,
ProtocolVersion: ProtocolVersionMuSig2,
},
MaxPrepayRoutingFee: 40,
PrepayInvoice: "prepayinvoice",
DestAddr: destAddr,
SwapInvoice: "swapinvoice",
MaxSwapRoutingFee: 30,
SweepConfTarget: 2,
HtlcConfirmations: 2,
SwapPublicationDeadline: initiationTime,
PaymentTimeout: time.Second * 11,
}
makeSwap := func(preimage lntypes.Preimage) *LoopOutContract {
contract := testContract
contract.Preimage = preimage
return &contract
}
// Next, we'll add two swaps to the database.
preimage1 := testPreimage
preimage2 := lntypes.Preimage{4, 4, 4}
ctxb := context.Background()
swap1 := makeSwap(preimage1)
swap2 := makeSwap(preimage2)
hash1 := swap1.Preimage.Hash()
err := store.CreateLoopOut(ctxb, hash1, swap1)
require.NoError(t, err)
hash2 := swap2.Preimage.Hash()
err = store.CreateLoopOut(ctxb, hash2, swap2)
require.NoError(t, err)
// Add an update to both swaps containing the cost.
err = store.UpdateLoopOut(
ctxb, hash1, testTime,
SwapStateData{
State: StateSuccess,
Cost: SwapCost{
Server: 1,
Onchain: 2,
Offchain: 3,
},
},
)
require.NoError(t, err)
err = store.UpdateLoopOut(
ctxb, hash2, testTime,
SwapStateData{
State: StateSuccess,
Cost: SwapCost{
Server: 4,
Onchain: 5,
Offchain: 6,
},
},
)
require.NoError(t, err)
updateMap := map[lntypes.Hash]SwapCost{
hash1: {
Server: 2,
Onchain: 3,
Offchain: 4,
},
hash2: {
Server: 6,
Onchain: 7,
Offchain: 8,
},
}
require.NoError(t, store.BatchUpdateLoopOutSwapCosts(ctxb, updateMap))
swaps, err := store.FetchLoopOutSwaps(ctxb)
require.NoError(t, err)
require.Len(t, swaps, 2)
swapsMap := make(map[lntypes.Hash]*LoopOut)
swapsMap[swaps[0].Hash] = swaps[0]
swapsMap[swaps[1].Hash] = swaps[1]
require.Equal(t, updateMap[hash1], swapsMap[hash1].State().Cost)
require.Equal(t, updateMap[hash2], swapsMap[hash2].State().Cost)
}
const charset = "abcdefghijklmnopqrstuvwxyzABCDEFGHIJKLMNOPQRSTUVWXYZ0123456789"
func randomString(length int) string {

@ -18,6 +18,7 @@ type Querier interface {
GetInstantOutSwap(ctx context.Context, swapHash []byte) (GetInstantOutSwapRow, error)
GetInstantOutSwapUpdates(ctx context.Context, swapHash []byte) ([]InstantoutUpdate, error)
GetInstantOutSwaps(ctx context.Context) ([]GetInstantOutSwapsRow, error)
GetLastUpdateID(ctx context.Context, swapHash []byte) (int32, error)
GetLoopInSwap(ctx context.Context, swapHash []byte) (GetLoopInSwapRow, error)
GetLoopInSwaps(ctx context.Context) ([]GetLoopInSwapsRow, error)
GetLoopOutSwap(ctx context.Context, swapHash []byte) (GetLoopOutSwapRow, error)
@ -38,6 +39,7 @@ type Querier interface {
InsertReservationUpdate(ctx context.Context, arg InsertReservationUpdateParams) error
InsertSwap(ctx context.Context, arg InsertSwapParams) error
InsertSwapUpdate(ctx context.Context, arg InsertSwapUpdateParams) error
OverrideSwapCosts(ctx context.Context, arg OverrideSwapCostsParams) error
UpdateBatch(ctx context.Context, arg UpdateBatchParams) error
UpdateInstantOut(ctx context.Context, arg UpdateInstantOutParams) error
UpdateReservation(ctx context.Context, arg UpdateReservationParams) error

@ -133,3 +133,19 @@ INSERT INTO htlc_keys(
) VALUES (
$1, $2, $3, $4, $5, $6, $7
);
-- name: GetLastUpdateID :one
SELECT id
FROM swap_updates
WHERE swap_hash = $1
ORDER BY update_timestamp DESC
LIMIT 1;
-- name: OverrideSwapCosts :exec
UPDATE swap_updates
SET
server_cost = $2,
onchain_cost = $3,
offchain_cost = $4
WHERE id = $1;

@ -10,6 +10,21 @@ import (
"time"
)
const getLastUpdateID = `-- name: GetLastUpdateID :one
SELECT id
FROM swap_updates
WHERE swap_hash = $1
ORDER BY update_timestamp DESC
LIMIT 1
`
func (q *Queries) GetLastUpdateID(ctx context.Context, swapHash []byte) (int32, error) {
row := q.db.QueryRowContext(ctx, getLastUpdateID, swapHash)
var id int32
err := row.Scan(&id)
return id, err
}
const getLoopInSwap = `-- name: GetLoopInSwap :one
SELECT
swaps.id, swaps.swap_hash, swaps.preimage, swaps.initiation_time, swaps.amount_requested, swaps.cltv_expiry, swaps.max_miner_fee, swaps.max_swap_fee, swaps.initiation_height, swaps.protocol_version, swaps.label,
@ -596,3 +611,29 @@ func (q *Queries) InsertSwapUpdate(ctx context.Context, arg InsertSwapUpdatePara
)
return err
}
const overrideSwapCosts = `-- name: OverrideSwapCosts :exec
UPDATE swap_updates
SET
server_cost = $2,
onchain_cost = $3,
offchain_cost = $4
WHERE id = $1
`
type OverrideSwapCostsParams struct {
ID int32
ServerCost int64
OnchainCost int64
OffchainCost int64
}
func (q *Queries) OverrideSwapCosts(ctx context.Context, arg OverrideSwapCostsParams) error {
_, err := q.db.ExecContext(ctx, overrideSwapCosts,
arg.ID,
arg.ServerCost,
arg.OnchainCost,
arg.OffchainCost,
)
return err
}

@ -1009,3 +1009,11 @@ func (b *boltSwapStore) BatchInsertUpdate(ctx context.Context,
return errUnimplemented
}
// BatchUpdateLoopOutSwapCosts updates the swap costs for a batch of loop out
// swaps.
func (b *boltSwapStore) BatchUpdateLoopOutSwapCosts(ctx context.Context,
costs map[lntypes.Hash]SwapCost) error {
return errUnimplemented
}

@ -3,6 +3,7 @@ package loopdb
import (
"context"
"errors"
"fmt"
"testing"
"time"
@ -337,3 +338,24 @@ func (b *StoreMock) BatchInsertUpdate(ctx context.Context,
return errors.New("not implemented")
}
// BatchUpdateLoopOutSwapCosts updates the swap costs for a batch of loop out
// swaps.
func (s *StoreMock) BatchUpdateLoopOutSwapCosts(ctx context.Context,
costs map[lntypes.Hash]SwapCost) error {
for hash, cost := range costs {
if _, ok := s.LoopOutUpdates[hash]; !ok {
return fmt.Errorf("swap has no updates: %v", hash)
}
updates, ok := s.LoopOutUpdates[hash]
if !ok {
return fmt.Errorf("swap has no updates: %v", hash)
}
updates[len(updates)-1].Cost = cost
}
return nil
}

Loading…
Cancel
Save