mirror of https://github.com/lightninglabs/loop
loopdb: store protocol version alongside with swaps
This commit adds the protocol version to each stored swap. This will be used to ensure that when swaps are resumed after a restart, they're correctly handled given any breaking protocol changes.pull/291/head
parent
a41b7c8ddd
commit
86db43a2cb
@ -1,8 +1,42 @@
|
||||
package loopdb
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
)
|
||||
|
||||
// itob returns an 8-byte big endian representation of v.
|
||||
func itob(v uint64) []byte {
|
||||
b := make([]byte, 8)
|
||||
byteOrder.PutUint64(b, v)
|
||||
return b
|
||||
}
|
||||
|
||||
// UnmarshalProtocolVersion attempts to unmarshal a byte slice to a
|
||||
// ProtocolVersion value. If the unmarshal fails, ProtocolVersionUnrecorded is
|
||||
// returned along with an error.
|
||||
func UnmarshalProtocolVersion(b []byte) (ProtocolVersion, error) {
|
||||
if b == nil {
|
||||
return ProtocolVersionUnrecorded, nil
|
||||
}
|
||||
|
||||
if len(b) != 4 {
|
||||
return ProtocolVersionUnrecorded,
|
||||
fmt.Errorf("invalid size: %v", len(b))
|
||||
}
|
||||
|
||||
version := ProtocolVersion(byteOrder.Uint32(b))
|
||||
if !version.Valid() {
|
||||
return ProtocolVersionUnrecorded,
|
||||
fmt.Errorf("invalid protocol version: %v", version)
|
||||
}
|
||||
|
||||
return version, nil
|
||||
}
|
||||
|
||||
// MarshalProtocolVersion marshals a ProtocolVersion value to a byte slice.
|
||||
func MarshalProtocolVersion(version ProtocolVersion) []byte {
|
||||
var versionBytes [4]byte
|
||||
byteOrder.PutUint32(versionBytes[:], uint32(version))
|
||||
|
||||
return versionBytes[:]
|
||||
}
|
||||
|
@ -0,0 +1,53 @@
|
||||
package loopdb
|
||||
|
||||
import (
|
||||
"testing"
|
||||
|
||||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
|
||||
// TestProtocolVersionMarshalUnMarshal tests that marshalling and unmarshalling
|
||||
// looprpc.ProtocolVersion works correctly.
|
||||
func TestProtocolVersionMarshalUnMarshal(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
testVersions := [...]ProtocolVersion{
|
||||
ProtocolVersionLegacy,
|
||||
ProtocolVersionMultiLoopOut,
|
||||
ProtocolVersionSegwitLoopIn,
|
||||
ProtocolVersionPreimagePush,
|
||||
ProtocolVersionUserExpiryLoopOut,
|
||||
}
|
||||
|
||||
bogusVersion := []byte{0xFF, 0xFF, 0xFF, 0xFF}
|
||||
invalidSlice := []byte{0xFF, 0xFF, 0xFF}
|
||||
|
||||
for i := 0; i < len(testVersions); i++ {
|
||||
testVersion := testVersions[i]
|
||||
|
||||
// Test that unmarshal(marshal(v)) == v.
|
||||
version, err := UnmarshalProtocolVersion(
|
||||
MarshalProtocolVersion(testVersion),
|
||||
)
|
||||
require.NoError(t, err)
|
||||
require.Equal(t, testVersion, version)
|
||||
|
||||
// Test that unmarshalling a nil slice returns the default
|
||||
// version along with no error.
|
||||
version, err = UnmarshalProtocolVersion(nil)
|
||||
require.NoError(t, err)
|
||||
require.Equal(t, ProtocolVersionUnrecorded, version)
|
||||
|
||||
// Test that unmarshalling an unknown version returns the
|
||||
// default version along with an error.
|
||||
version, err = UnmarshalProtocolVersion(bogusVersion)
|
||||
require.Error(t, err, "expected invalid version")
|
||||
require.Equal(t, ProtocolVersionUnrecorded, version)
|
||||
|
||||
// Test that unmarshalling an invalid slice returns the
|
||||
// default version along with an error.
|
||||
version, err = UnmarshalProtocolVersion(invalidSlice)
|
||||
require.Error(t, err, "expected invalid size")
|
||||
require.Equal(t, ProtocolVersionUnrecorded, version)
|
||||
}
|
||||
}
|
@ -0,0 +1,76 @@
|
||||
package loopdb
|
||||
|
||||
import (
|
||||
"math"
|
||||
|
||||
"github.com/lightninglabs/loop/looprpc"
|
||||
)
|
||||
|
||||
// ProtocolVersion represents the protocol version (declared on rpc level) that
|
||||
// the client declared to us.
|
||||
type ProtocolVersion uint32
|
||||
|
||||
const (
|
||||
// ProtocolVersionLegacy indicates that the client is a legacy version
|
||||
// that did not report its protocol version.
|
||||
ProtocolVersionLegacy ProtocolVersion = 0
|
||||
|
||||
// ProtocolVersionMultiLoopOut indicates that the client supports multi
|
||||
// loop out.
|
||||
ProtocolVersionMultiLoopOut ProtocolVersion = 1
|
||||
|
||||
// ProtocolVersionSegwitLoopIn indicates that the client supports segwit
|
||||
// loop in.
|
||||
ProtocolVersionSegwitLoopIn ProtocolVersion = 2
|
||||
|
||||
// ProtocolVersionPreimagePush indicates that the client will push loop
|
||||
// out preimages to the sever to speed up claim.
|
||||
ProtocolVersionPreimagePush ProtocolVersion = 3
|
||||
|
||||
// ProtocolVersionUserExpiryLoopOut indicates that the client will
|
||||
// propose a cltv expiry height for loop out.
|
||||
ProtocolVersionUserExpiryLoopOut ProtocolVersion = 4
|
||||
|
||||
// ProtocolVersionUnrecorded is set for swaps were created before we
|
||||
// started saving protocol version with swaps.
|
||||
ProtocolVersionUnrecorded ProtocolVersion = math.MaxUint32
|
||||
|
||||
// CurrentRpcProtocolVersion defines the version of the RPC protocol
|
||||
// that is currently supported by the loop client.
|
||||
CurrentRPCProtocolVersion = looprpc.ProtocolVersion_USER_EXPIRY_LOOP_OUT
|
||||
|
||||
// CurrentInteranlProtocolVersionInternal defines the RPC current
|
||||
// protocol in the internal representation.
|
||||
CurrentInternalProtocolVersion = ProtocolVersion(CurrentRPCProtocolVersion)
|
||||
)
|
||||
|
||||
// Valid returns true if the value of the ProtocolVersion is valid.
|
||||
func (p ProtocolVersion) Valid() bool {
|
||||
return p <= CurrentInternalProtocolVersion
|
||||
}
|
||||
|
||||
// String returns the string representation of a protocol version.
|
||||
func (p ProtocolVersion) String() string {
|
||||
switch p {
|
||||
case ProtocolVersionUnrecorded:
|
||||
return "Unrecorded"
|
||||
|
||||
case ProtocolVersionLegacy:
|
||||
return "Legacy"
|
||||
|
||||
case ProtocolVersionMultiLoopOut:
|
||||
return "Multi Loop Out"
|
||||
|
||||
case ProtocolVersionSegwitLoopIn:
|
||||
return "Segwit Loop In"
|
||||
|
||||
case ProtocolVersionPreimagePush:
|
||||
return "Preimage Push"
|
||||
|
||||
case ProtocolVersionUserExpiryLoopOut:
|
||||
return "User Expiry Loop Out"
|
||||
|
||||
default:
|
||||
return "Unknown"
|
||||
}
|
||||
}
|
@ -0,0 +1,47 @@
|
||||
package loopdb
|
||||
|
||||
import (
|
||||
"testing"
|
||||
|
||||
"github.com/lightninglabs/loop/looprpc"
|
||||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
|
||||
// TestProtocolVersionSanity tests that protocol versions are sane, meaning
|
||||
// we always keep our stored protocol version in sync with the RPC protocol
|
||||
// version except for the unrecorded version.
|
||||
func TestProtocolVersionSanity(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
versions := [...]ProtocolVersion{
|
||||
ProtocolVersionLegacy,
|
||||
ProtocolVersionMultiLoopOut,
|
||||
ProtocolVersionSegwitLoopIn,
|
||||
ProtocolVersionPreimagePush,
|
||||
ProtocolVersionUserExpiryLoopOut,
|
||||
}
|
||||
|
||||
rpcVersions := [...]looprpc.ProtocolVersion{
|
||||
looprpc.ProtocolVersion_LEGACY,
|
||||
looprpc.ProtocolVersion_MULTI_LOOP_OUT,
|
||||
looprpc.ProtocolVersion_NATIVE_SEGWIT_LOOP_IN,
|
||||
looprpc.ProtocolVersion_PREIMAGE_PUSH_LOOP_OUT,
|
||||
looprpc.ProtocolVersion_USER_EXPIRY_LOOP_OUT,
|
||||
}
|
||||
|
||||
require.Equal(t, len(versions), len(rpcVersions))
|
||||
for i, version := range versions {
|
||||
require.Equal(t, uint32(version), uint32(rpcVersions[i]))
|
||||
}
|
||||
|
||||
// Finally test that the current version contants are up to date
|
||||
require.Equal(t,
|
||||
CurrentInternalProtocolVersion,
|
||||
versions[len(versions)-1],
|
||||
)
|
||||
|
||||
require.Equal(t,
|
||||
uint32(CurrentInternalProtocolVersion),
|
||||
uint32(CurrentRPCProtocolVersion),
|
||||
)
|
||||
}
|
Loading…
Reference in New Issue