Save and recover protocol state from disk
NOTE: This implementation saves secrets to disk! It is not secure. The storage API allows the caller to atomically record the state of the protocol. The user can retrieve this recorded state and re-commence the protocol from that point. The state is recorded using a hard coded key, causing it to overwrite the previously recorded state. This limitation means that this recovery mechanism should not be used in a program that simultaneously manages the execution of multiple swaps. An e2e test was added to show how to save, recover and resume protocol execution. This logic could also be integrated into the run_until functions to automate saving but was not included at this stage as protocol execution is currently under development. Serialisation and deserialisation was implemented on the states to allow the to be stored using the database. Currently the secret's are also being stored to disk but should be recovered from a seed or wallets.pull/12/head
parent
ea064c95b4
commit
39afb4196b
@ -0,0 +1,210 @@
|
||||
pub mod ecdsa_fun_signature {
|
||||
use serde::{de, de::Visitor, Deserializer, Serializer};
|
||||
use std::{convert::TryFrom, fmt};
|
||||
|
||||
struct Bytes64Visitor;
|
||||
|
||||
impl<'de> Visitor<'de> for Bytes64Visitor {
|
||||
type Value = ecdsa_fun::Signature;
|
||||
|
||||
fn expecting(&self, formatter: &mut fmt::Formatter<'_>) -> fmt::Result {
|
||||
write!(formatter, "a string containing 64 bytes")
|
||||
}
|
||||
|
||||
fn visit_bytes<E>(self, s: &[u8]) -> Result<Self::Value, E>
|
||||
where
|
||||
E: de::Error,
|
||||
{
|
||||
if let Ok(value) = <[u8; 64]>::try_from(s) {
|
||||
let sig = ecdsa_fun::Signature::from_bytes(value)
|
||||
.expect("bytes represent an integer greater than or equal to the curve order");
|
||||
Ok(sig)
|
||||
} else {
|
||||
Err(de::Error::invalid_length(s.len(), &self))
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
pub fn serialize<S>(x: &ecdsa_fun::Signature, s: S) -> Result<S::Ok, S::Error>
|
||||
where
|
||||
S: Serializer,
|
||||
{
|
||||
s.serialize_bytes(&x.to_bytes())
|
||||
}
|
||||
|
||||
pub fn deserialize<'de, D>(
|
||||
deserializer: D,
|
||||
) -> Result<ecdsa_fun::Signature, <D as Deserializer<'de>>::Error>
|
||||
where
|
||||
D: Deserializer<'de>,
|
||||
{
|
||||
let sig = deserializer.deserialize_bytes(Bytes64Visitor)?;
|
||||
Ok(sig)
|
||||
}
|
||||
}
|
||||
|
||||
pub mod cross_curve_dleq_scalar {
|
||||
use serde::{de, de::Visitor, Deserializer, Serializer};
|
||||
use std::{convert::TryFrom, fmt};
|
||||
|
||||
struct Bytes32Visitor;
|
||||
|
||||
impl<'de> Visitor<'de> for Bytes32Visitor {
|
||||
type Value = cross_curve_dleq::Scalar;
|
||||
|
||||
fn expecting(&self, formatter: &mut fmt::Formatter<'_>) -> fmt::Result {
|
||||
write!(formatter, "a string containing 32 bytes")
|
||||
}
|
||||
|
||||
fn visit_bytes<E>(self, s: &[u8]) -> Result<Self::Value, E>
|
||||
where
|
||||
E: de::Error,
|
||||
{
|
||||
if let Ok(value) = <[u8; 32]>::try_from(s) {
|
||||
Ok(cross_curve_dleq::Scalar::from(value))
|
||||
} else {
|
||||
Err(de::Error::invalid_length(s.len(), &self))
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
pub fn serialize<S>(x: &cross_curve_dleq::Scalar, s: S) -> Result<S::Ok, S::Error>
|
||||
where
|
||||
S: Serializer,
|
||||
{
|
||||
// Serialise as ed25519 because the inner bytes are private
|
||||
// TODO: Open PR in cross_curve_dleq to allow accessing the inner bytes
|
||||
s.serialize_bytes(&x.into_ed25519().to_bytes())
|
||||
}
|
||||
|
||||
pub fn deserialize<'de, D>(
|
||||
deserializer: D,
|
||||
) -> Result<cross_curve_dleq::Scalar, <D as Deserializer<'de>>::Error>
|
||||
where
|
||||
D: Deserializer<'de>,
|
||||
{
|
||||
let dleq = deserializer.deserialize_bytes(Bytes32Visitor)?;
|
||||
Ok(dleq)
|
||||
}
|
||||
}
|
||||
|
||||
pub mod monero_private_key {
|
||||
use serde::{de, de::Visitor, Deserializer, Serializer};
|
||||
use std::fmt;
|
||||
|
||||
struct BytesVisitor;
|
||||
|
||||
impl<'de> Visitor<'de> for BytesVisitor {
|
||||
type Value = monero::PrivateKey;
|
||||
|
||||
fn expecting(&self, formatter: &mut fmt::Formatter<'_>) -> fmt::Result {
|
||||
write!(formatter, "a string containing 32 bytes")
|
||||
}
|
||||
|
||||
fn visit_bytes<E>(self, s: &[u8]) -> Result<Self::Value, E>
|
||||
where
|
||||
E: de::Error,
|
||||
{
|
||||
if let Ok(key) = monero::PrivateKey::from_slice(s) {
|
||||
Ok(key)
|
||||
} else {
|
||||
Err(de::Error::invalid_length(s.len(), &self))
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
pub fn serialize<S>(x: &monero::PrivateKey, s: S) -> Result<S::Ok, S::Error>
|
||||
where
|
||||
S: Serializer,
|
||||
{
|
||||
s.serialize_bytes(x.as_bytes())
|
||||
}
|
||||
|
||||
pub fn deserialize<'de, D>(
|
||||
deserializer: D,
|
||||
) -> Result<monero::PrivateKey, <D as Deserializer<'de>>::Error>
|
||||
where
|
||||
D: Deserializer<'de>,
|
||||
{
|
||||
let key = deserializer.deserialize_bytes(BytesVisitor)?;
|
||||
Ok(key)
|
||||
}
|
||||
}
|
||||
|
||||
pub mod bitcoin_amount {
|
||||
use serde::{Deserialize, Deserializer, Serializer};
|
||||
|
||||
pub fn serialize<S>(value: &bitcoin::Amount, serializer: S) -> Result<S::Ok, S::Error>
|
||||
where
|
||||
S: Serializer,
|
||||
{
|
||||
serializer.serialize_u64(value.as_sat())
|
||||
}
|
||||
|
||||
pub fn deserialize<'de, D>(
|
||||
deserializer: D,
|
||||
) -> Result<bitcoin::Amount, <D as Deserializer<'de>>::Error>
|
||||
where
|
||||
D: Deserializer<'de>,
|
||||
{
|
||||
let value = u64::deserialize(deserializer)?;
|
||||
let amount = bitcoin::Amount::from_sat(value);
|
||||
|
||||
Ok(amount)
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
use ::bitcoin::SigHash;
|
||||
use curve25519_dalek::scalar::Scalar;
|
||||
use rand::rngs::OsRng;
|
||||
use serde::{Deserialize, Serialize};
|
||||
|
||||
#[derive(Debug, Serialize, Deserialize, PartialEq)]
|
||||
pub struct CrossCurveDleqScalar(
|
||||
#[serde(with = "cross_curve_dleq_scalar")] cross_curve_dleq::Scalar,
|
||||
);
|
||||
|
||||
#[derive(Debug, Serialize, Deserialize, PartialEq)]
|
||||
pub struct ECDSAFunSignature(#[serde(with = "ecdsa_fun_signature")] ecdsa_fun::Signature);
|
||||
|
||||
#[derive(Debug, Serialize, Deserialize, PartialEq)]
|
||||
pub struct MoneroPrivateKey(#[serde(with = "monero_private_key")] crate::monero::PrivateKey);
|
||||
|
||||
#[derive(Debug, Serialize, Deserialize, PartialEq)]
|
||||
pub struct BitcoinAmount(#[serde(with = "bitcoin_amount")] ::bitcoin::Amount);
|
||||
|
||||
#[test]
|
||||
fn serde_cross_curv_dleq_scalar() {
|
||||
let scalar = CrossCurveDleqScalar(cross_curve_dleq::Scalar::random(&mut OsRng));
|
||||
let encoded = serde_cbor::to_vec(&scalar).unwrap();
|
||||
let decoded: CrossCurveDleqScalar = serde_cbor::from_slice(&encoded).unwrap();
|
||||
assert_eq!(scalar, decoded);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn serde_ecdsa_fun_sig() {
|
||||
let secret_key = crate::bitcoin::SecretKey::new_random(&mut OsRng);
|
||||
let sig = ECDSAFunSignature(secret_key.sign(SigHash::default()));
|
||||
let encoded = serde_cbor::to_vec(&sig).unwrap();
|
||||
let decoded: ECDSAFunSignature = serde_cbor::from_slice(&encoded).unwrap();
|
||||
assert_eq!(sig, decoded);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn serde_monero_private_key() {
|
||||
let key = MoneroPrivateKey(monero::PrivateKey::from_scalar(Scalar::random(&mut OsRng)));
|
||||
let encoded = serde_cbor::to_vec(&key).unwrap();
|
||||
let decoded: MoneroPrivateKey = serde_cbor::from_slice(&encoded).unwrap();
|
||||
assert_eq!(key, decoded);
|
||||
}
|
||||
#[test]
|
||||
fn serde_bitcoin_amount() {
|
||||
let amount = BitcoinAmount(::bitcoin::Amount::from_sat(100));
|
||||
let encoded = serde_cbor::to_vec(&amount).unwrap();
|
||||
let decoded: BitcoinAmount = serde_cbor::from_slice(&encoded).unwrap();
|
||||
assert_eq!(amount, decoded);
|
||||
}
|
||||
}
|
@ -0,0 +1,159 @@
|
||||
use anyhow::{anyhow, Context, Result};
|
||||
use serde::{de::DeserializeOwned, Deserialize, Serialize};
|
||||
use std::path::Path;
|
||||
|
||||
pub struct Database {
|
||||
db: sled::Db,
|
||||
}
|
||||
|
||||
impl Database {
|
||||
const LAST_STATE_KEY: &'static str = "latest_state";
|
||||
|
||||
pub fn open(path: &Path) -> Result<Self> {
|
||||
let path = path
|
||||
.to_str()
|
||||
.ok_or_else(|| anyhow!("The path is not utf-8 valid: {:?}", path))?;
|
||||
let db = sled::open(path).with_context(|| format!("Could not open the DB at {}", path))?;
|
||||
|
||||
Ok(Database { db })
|
||||
}
|
||||
|
||||
pub async fn insert_latest_state<T>(&self, state: &T) -> Result<()>
|
||||
where
|
||||
T: Serialize + DeserializeOwned,
|
||||
{
|
||||
let key = serialize(&Self::LAST_STATE_KEY)?;
|
||||
let new_value = serialize(&state).context("Could not serialize new state value")?;
|
||||
|
||||
let old_value = self.db.get(&key)?;
|
||||
|
||||
self.db
|
||||
.compare_and_swap(key, old_value, Some(new_value))
|
||||
.context("Could not write in the DB")?
|
||||
.context("Stored swap somehow changed, aborting saving")?; // let _ =
|
||||
|
||||
self.db
|
||||
.flush_async()
|
||||
.await
|
||||
.map(|_| ())
|
||||
.context("Could not flush db")
|
||||
}
|
||||
|
||||
pub fn get_latest_state<T>(&self) -> anyhow::Result<T>
|
||||
where
|
||||
T: DeserializeOwned,
|
||||
{
|
||||
let key = serialize(&Self::LAST_STATE_KEY)?;
|
||||
|
||||
let encoded = self
|
||||
.db
|
||||
.get(&key)?
|
||||
.ok_or_else(|| anyhow!("State does not exist {:?}", key))?;
|
||||
|
||||
let state = deserialize(&encoded).context("Could not deserialize state")?;
|
||||
Ok(state)
|
||||
}
|
||||
}
|
||||
|
||||
pub fn serialize<T>(t: &T) -> anyhow::Result<Vec<u8>>
|
||||
where
|
||||
T: Serialize,
|
||||
{
|
||||
Ok(serde_cbor::to_vec(t)?)
|
||||
}
|
||||
|
||||
pub fn deserialize<T>(v: &[u8]) -> anyhow::Result<T>
|
||||
where
|
||||
T: DeserializeOwned,
|
||||
{
|
||||
Ok(serde_cbor::from_slice(&v)?)
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
#![allow(non_snake_case)]
|
||||
use super::*;
|
||||
use bitcoin::SigHash;
|
||||
use curve25519_dalek::scalar::Scalar;
|
||||
use ecdsa_fun::fun::rand_core::OsRng;
|
||||
use std::str::FromStr;
|
||||
use xmr_btc::serde::{
|
||||
bitcoin_amount, cross_curve_dleq_scalar, ecdsa_fun_signature, monero_private_key,
|
||||
};
|
||||
|
||||
#[derive(Debug, Serialize, Deserialize, PartialEq)]
|
||||
pub struct TestState {
|
||||
A: xmr_btc::bitcoin::PublicKey,
|
||||
a: xmr_btc::bitcoin::SecretKey,
|
||||
#[serde(with = "cross_curve_dleq_scalar")]
|
||||
s_a: ::cross_curve_dleq::Scalar,
|
||||
#[serde(with = "monero_private_key")]
|
||||
s_b: monero::PrivateKey,
|
||||
S_a_monero: ::monero::PublicKey,
|
||||
S_a_bitcoin: xmr_btc::bitcoin::PublicKey,
|
||||
v: xmr_btc::monero::PrivateViewKey,
|
||||
#[serde(with = "bitcoin_amount")]
|
||||
btc: ::bitcoin::Amount,
|
||||
xmr: xmr_btc::monero::Amount,
|
||||
refund_timelock: u32,
|
||||
refund_address: ::bitcoin::Address,
|
||||
transaction: ::bitcoin::Transaction,
|
||||
#[serde(with = "ecdsa_fun_signature")]
|
||||
tx_punish_sig: xmr_btc::bitcoin::Signature,
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn recover_state_from_db() {
|
||||
let db = Database::open(Path::new("../target/test_recover.db")).unwrap();
|
||||
|
||||
let a = crate::bitcoin::SecretKey::new_random(&mut OsRng);
|
||||
let s_a = cross_curve_dleq::Scalar::random(&mut OsRng);
|
||||
let s_b = monero::PrivateKey::from_scalar(Scalar::random(&mut OsRng));
|
||||
let v_a = xmr_btc::monero::PrivateViewKey::new_random(&mut OsRng);
|
||||
let S_a_monero = monero::PublicKey::from_private_key(&monero::PrivateKey {
|
||||
scalar: s_a.into_ed25519(),
|
||||
});
|
||||
let S_a_bitcoin = s_a.into_secp256k1().into();
|
||||
let tx_punish_sig = a.sign(SigHash::default());
|
||||
|
||||
let state = TestState {
|
||||
A: a.public(),
|
||||
a,
|
||||
s_b,
|
||||
s_a,
|
||||
S_a_monero,
|
||||
S_a_bitcoin,
|
||||
v: v_a,
|
||||
btc: ::bitcoin::Amount::from_sat(100),
|
||||
xmr: crate::monero::Amount::from_piconero(1000),
|
||||
refund_timelock: 0,
|
||||
refund_address: ::bitcoin::Address::from_str("1L5wSMgerhHg8GZGcsNmAx5EXMRXSKR3He")
|
||||
.unwrap(),
|
||||
transaction: ::bitcoin::Transaction {
|
||||
version: 0,
|
||||
lock_time: 0,
|
||||
input: vec![::bitcoin::TxIn::default()],
|
||||
output: vec![::bitcoin::TxOut::default()],
|
||||
},
|
||||
tx_punish_sig,
|
||||
};
|
||||
|
||||
db.insert_latest_state(&state)
|
||||
.await
|
||||
.expect("Failed to save state the first time");
|
||||
let recovered: TestState = db
|
||||
.get_latest_state()
|
||||
.expect("Failed to recover state the first time");
|
||||
|
||||
// We insert and recover twice to ensure database implementation allows the
|
||||
// caller to write to an existing key
|
||||
db.insert_latest_state(&recovered)
|
||||
.await
|
||||
.expect("Failed to save state the second time");
|
||||
let recovered: TestState = db
|
||||
.get_latest_state()
|
||||
.expect("Failed to recover state the second time");
|
||||
|
||||
assert_eq!(state, recovered);
|
||||
}
|
||||
}
|
Loading…
Reference in New Issue