From da9d09aa5ed97ea6b6e9961b167760f0c23e0037 Mon Sep 17 00:00:00 2001 From: rishflab Date: Tue, 28 Sep 2021 10:15:31 +1000 Subject: [PATCH] Create Database trait Use domain types in database API to prevent leaking of database types. This trait will allow us to smoothly introduce the sqlite database. --- swap/src/asb/event_loop.rs | 27 +-- swap/src/asb/recovery/cancel.rs | 10 +- swap/src/asb/recovery/punish.rs | 10 +- swap/src/asb/recovery/redeem.rs | 17 +- swap/src/asb/recovery/refund.rs | 12 +- swap/src/asb/recovery/safely_abort.rs | 10 +- swap/src/bin/asb.rs | 8 +- swap/src/bin/swap.rs | 46 +++-- swap/src/cli/cancel.rs | 11 +- swap/src/cli/refund.rs | 12 +- swap/src/database.rs | 19 ++ swap/src/database/alice.rs | 40 ++-- swap/src/database/sled.rs | 258 ++++++++++++-------------- swap/src/protocol.rs | 79 ++++++++ swap/src/protocol/alice.rs | 4 +- swap/src/protocol/alice/state.rs | 2 +- swap/src/protocol/alice/swap.rs | 7 +- swap/src/protocol/bob.rs | 13 +- swap/src/protocol/bob/state.rs | 2 +- swap/src/protocol/bob/swap.rs | 6 +- swap/tests/harness/mod.rs | 9 +- 21 files changed, 350 insertions(+), 252 deletions(-) diff --git a/swap/src/asb/event_loop.rs b/swap/src/asb/event_loop.rs index 13b85de8..c843ea9e 100644 --- a/swap/src/asb/event_loop.rs +++ b/swap/src/asb/event_loop.rs @@ -1,9 +1,9 @@ use crate::asb::{Behaviour, OutEvent, Rate}; -use crate::database::SledDatabase; use crate::network::quote::BidQuote; use crate::network::swap_setup::alice::WalletSnapshot; use crate::network::transfer_proof; use crate::protocol::alice::{AliceState, State3, Swap}; +use crate::protocol::{Database, State}; use crate::{bitcoin, env, kraken, monero}; use anyhow::{Context, Result}; use futures::future; @@ -14,7 +14,7 @@ use libp2p::swarm::SwarmEvent; use libp2p::{PeerId, Swarm}; use rust_decimal::Decimal; use std::collections::HashMap; -use std::convert::Infallible; +use std::convert::{Infallible, TryInto}; use std::fmt::Debug; use std::sync::Arc; use tokio::sync::mpsc; @@ -39,7 +39,7 @@ where env_config: env::Config, bitcoin_wallet: Arc, monero_wallet: Arc, - db: Arc, + db: Arc, latest_rate: LR, min_buy: bitcoin::Amount, max_buy: bitcoin::Amount, @@ -71,7 +71,7 @@ where env_config: env::Config, bitcoin_wallet: Arc, monero_wallet: Arc, - db: Arc, + db: Arc, latest_rate: LR, min_buy: bitcoin::Amount, max_buy: bitcoin::Amount, @@ -108,16 +108,21 @@ where self.inflight_encrypted_signatures .push(future::pending().boxed()); - let unfinished_swaps = match self.db.unfinished_alice() { - Ok(unfinished_swaps) => unfinished_swaps, - Err(_) => { - tracing::error!("Failed to load unfinished swaps"); + let swaps = match self.db.all().await { + Ok(swaps) => swaps, + Err(e) => { + tracing::error!("Failed to load swaps from database: {}", e); return; } }; + let unfinished_swaps = swaps + .into_iter() + .filter(|(_swap_id, state)| !state.swap_finished()) + .collect::>(); + for (swap_id, state) in unfinished_swaps { - let peer_id = match self.db.get_peer_id(swap_id) { + let peer_id = match self.db.get_peer_id(swap_id).await { Ok(peer_id) => peer_id, Err(_) => { tracing::warn!(%swap_id, "Resuming swap skipped because no peer-id found for swap in database"); @@ -133,7 +138,7 @@ where monero_wallet: self.monero_wallet.clone(), env_config: self.env_config, db: self.db.clone(), - state: state.into(), + state: state.try_into().expect("Alice state loaded from db"), swap_id, }; @@ -197,7 +202,7 @@ where } SwarmEvent::Behaviour(OutEvent::EncryptedSignatureReceived{ msg, channel, peer }) => { let swap_id = msg.swap_id; - let swap_peer = self.db.get_peer_id(swap_id); + let swap_peer = self.db.get_peer_id(swap_id).await; // Ensure that an incoming encrypted signature is sent by the peer-id associated with the swap let swap_peer = match swap_peer { diff --git a/swap/src/asb/recovery/cancel.rs b/swap/src/asb/recovery/cancel.rs index e571a0f1..ec70a8df 100644 --- a/swap/src/asb/recovery/cancel.rs +++ b/swap/src/asb/recovery/cancel.rs @@ -1,16 +1,17 @@ use crate::bitcoin::{parse_rpc_error_code, RpcErrorCode, Txid, Wallet}; -use crate::database::{SledDatabase, Swap}; use crate::protocol::alice::AliceState; +use crate::protocol::Database; use anyhow::{bail, Result}; +use std::convert::TryInto; use std::sync::Arc; use uuid::Uuid; pub async fn cancel( swap_id: Uuid, bitcoin_wallet: Arc, - db: Arc, + db: Arc, ) -> Result<(Txid, AliceState)> { - let state = db.get_state(swap_id)?.try_into_alice()?.into(); + let state = db.get_state(swap_id).await?.try_into()?; let (monero_wallet_restore_blockheight, transfer_proof, state3) = match state { @@ -58,8 +59,7 @@ pub async fn cancel( transfer_proof, state3, }; - let db_state = (&state).into(); - db.insert_latest_state(swap_id, Swap::Alice(db_state)) + db.insert_latest_state(swap_id, state.clone().into()) .await?; Ok((txid, state)) diff --git a/swap/src/asb/recovery/punish.rs b/swap/src/asb/recovery/punish.rs index e797717e..e94abac8 100644 --- a/swap/src/asb/recovery/punish.rs +++ b/swap/src/asb/recovery/punish.rs @@ -1,7 +1,8 @@ use crate::bitcoin::{self, Txid}; -use crate::database::{SledDatabase, Swap}; use crate::protocol::alice::AliceState; +use crate::protocol::Database; use anyhow::{bail, Result}; +use std::convert::TryInto; use std::sync::Arc; use uuid::Uuid; @@ -14,9 +15,9 @@ pub enum Error { pub async fn punish( swap_id: Uuid, bitcoin_wallet: Arc, - db: Arc, + db: Arc, ) -> Result<(Txid, AliceState)> { - let state = db.get_state(swap_id)?.try_into_alice()?.into(); + let state = db.get_state(swap_id).await?.try_into()?; let state3 = match state { // Punish potentially possible (no knowledge of cancel transaction) @@ -46,8 +47,7 @@ pub async fn punish( let txid = state3.punish_btc(&bitcoin_wallet).await?; let state = AliceState::BtcPunished; - let db_state = (&state).into(); - db.insert_latest_state(swap_id, Swap::Alice(db_state)) + db.insert_latest_state(swap_id, state.clone().into()) .await?; Ok((txid, state)) diff --git a/swap/src/asb/recovery/redeem.rs b/swap/src/asb/recovery/redeem.rs index c4729ae0..e4642feb 100644 --- a/swap/src/asb/recovery/redeem.rs +++ b/swap/src/asb/recovery/redeem.rs @@ -1,7 +1,8 @@ use crate::bitcoin::{Txid, Wallet}; -use crate::database::{SledDatabase, Swap}; use crate::protocol::alice::AliceState; +use crate::protocol::Database; use anyhow::{bail, Result}; +use std::convert::TryInto; use std::sync::Arc; use uuid::Uuid; @@ -23,10 +24,10 @@ impl Finality { pub async fn redeem( swap_id: Uuid, bitcoin_wallet: Arc, - db: Arc, + db: Arc, finality: Finality, ) -> Result<(Txid, AliceState)> { - let state = db.get_state(swap_id)?.try_into_alice()?.into(); + let state = db.get_state(swap_id).await?.try_into()?; match state { AliceState::EncSigLearned { @@ -42,17 +43,14 @@ pub async fn redeem( subscription.wait_until_seen().await?; let state = AliceState::BtcRedeemTransactionPublished { state3 }; - let db_state = (&state).into(); - db.insert_latest_state(swap_id, Swap::Alice(db_state)) - .await?; + db.insert_latest_state(swap_id, state.into()).await?; if let Finality::Await = finality { subscription.wait_until_final().await?; } let state = AliceState::BtcRedeemed; - let db_state = (&state).into(); - db.insert_latest_state(swap_id, Swap::Alice(db_state)) + db.insert_latest_state(swap_id, state.clone().into()) .await?; Ok((txid, state)) @@ -64,8 +62,7 @@ pub async fn redeem( } let state = AliceState::BtcRedeemed; - let db_state = (&state).into(); - db.insert_latest_state(swap_id, Swap::Alice(db_state)) + db.insert_latest_state(swap_id, state.clone().into()) .await?; let txid = state3.tx_redeem().txid(); diff --git a/swap/src/asb/recovery/refund.rs b/swap/src/asb/recovery/refund.rs index 89dc30f6..64e5c3f3 100644 --- a/swap/src/asb/recovery/refund.rs +++ b/swap/src/asb/recovery/refund.rs @@ -1,9 +1,10 @@ use crate::bitcoin::{self}; -use crate::database::{SledDatabase, Swap}; use crate::monero; use crate::protocol::alice::AliceState; +use crate::protocol::Database; use anyhow::{bail, Result}; use libp2p::PeerId; +use std::convert::TryInto; use std::sync::Arc; use uuid::Uuid; @@ -26,9 +27,9 @@ pub async fn refund( swap_id: Uuid, bitcoin_wallet: Arc, monero_wallet: Arc, - db: Arc, + db: Arc, ) -> Result { - let state = db.get_state(swap_id)?.try_into_alice()?.into(); + let state = db.get_state(swap_id).await?.try_into()?; let (monero_wallet_restore_blockheight, transfer_proof, state3) = match state { // In case no XMR has been locked, move to Safely Aborted @@ -66,7 +67,7 @@ pub async fn refund( tracing::debug!(%swap_id, "Bitcoin refund transaction found, extracting key to refund Monero"); state3.extract_monero_private_key(published_refund_tx)? } else { - let bob_peer_id = db.get_peer_id(swap_id)?; + let bob_peer_id = db.get_peer_id(swap_id).await?; bail!(Error::RefundTransactionNotPublishedYet(bob_peer_id),); }; @@ -81,8 +82,7 @@ pub async fn refund( .await?; let state = AliceState::XmrRefunded; - let db_state = (&state).into(); - db.insert_latest_state(swap_id, Swap::Alice(db_state)) + db.insert_latest_state(swap_id, state.clone().into()) .await?; Ok(state) diff --git a/swap/src/asb/recovery/safely_abort.rs b/swap/src/asb/recovery/safely_abort.rs index f8336c8c..ad162f8d 100644 --- a/swap/src/asb/recovery/safely_abort.rs +++ b/swap/src/asb/recovery/safely_abort.rs @@ -1,11 +1,12 @@ -use crate::database::{SledDatabase, Swap}; use crate::protocol::alice::AliceState; +use crate::protocol::Database; use anyhow::{bail, Result}; +use std::convert::TryInto; use std::sync::Arc; use uuid::Uuid; -pub async fn safely_abort(swap_id: Uuid, db: Arc) -> Result { - let state = db.get_state(swap_id)?.try_into_alice()?.into(); +pub async fn safely_abort(swap_id: Uuid, db: Arc) -> Result { + let state = db.get_state(swap_id).await?.try_into()?; match state { AliceState::Started { .. } @@ -13,8 +14,7 @@ pub async fn safely_abort(swap_id: Uuid, db: Arc) -> Result { let state = AliceState::SafelyAborted; - let db_state = (&state).into(); - db.insert_latest_state(swap_id, Swap::Alice(db_state)) + db.insert_latest_state(swap_id, state.clone().into()) .await?; Ok(state) diff --git a/swap/src/bin/asb.rs b/swap/src/bin/asb.rs index 10801c59..46660500 100644 --- a/swap/src/bin/asb.rs +++ b/swap/src/bin/asb.rs @@ -18,6 +18,7 @@ use libp2p::core::multiaddr::Protocol; use libp2p::core::Multiaddr; use libp2p::swarm::AddressScore; use libp2p::Swarm; +use std::convert::TryInto; use std::env; use std::net::{IpAddr, Ipv4Addr, SocketAddr}; use std::sync::Arc; @@ -32,7 +33,8 @@ use swap::database::SledDatabase; use swap::monero::Amount; use swap::network::rendezvous::XmrBtcNamespace; use swap::network::swarm; -use swap::protocol::alice::run; +use swap::protocol::alice::{run, AliceState}; +use swap::protocol::Database; use swap::seed::Seed; use swap::tor::AuthenticatedClient; use swap::{asb, bitcoin, kraken, monero, tor}; @@ -93,6 +95,7 @@ async fn main() -> Result<()> { let db_path = config.data.dir.join("database"); let db = SledDatabase::open(config.data.dir.join(db_path).as_path()) + .await .context("Could not open database")?; let seed = @@ -208,7 +211,8 @@ async fn main() -> Result<()> { table.set_header(vec!["SWAP ID", "STATE"]); - for (swap_id, state) in db.all_alice()? { + for (swap_id, state) in db.all().await? { + let state: AliceState = state.try_into()?; table.add_row(vec![swap_id.to_string(), state.to_string()]); } diff --git a/swap/src/bin/swap.rs b/swap/src/bin/swap.rs index 3da42926..2c40aa7c 100644 --- a/swap/src/bin/swap.rs +++ b/swap/src/bin/swap.rs @@ -17,6 +17,7 @@ use comfy_table::Table; use qrcode::render::unicode; use qrcode::QrCode; use std::cmp::min; +use std::convert::TryInto; use std::env; use std::future::Future; use std::path::PathBuf; @@ -30,8 +31,8 @@ use swap::env::Config; use swap::libp2p_ext::MultiAddrExt; use swap::network::quote::BidQuote; use swap::network::swarm; -use swap::protocol::bob; -use swap::protocol::bob::Swap; +use swap::protocol::bob::{BobState, Swap}; +use swap::protocol::{bob, Database}; use swap::seed::Seed; use swap::{bitcoin, cli, monero}; use url::Url; @@ -66,8 +67,11 @@ async fn main() -> Result<()> { let swap_id = Uuid::new_v4(); cli::tracing::init(debug, json, data_dir.join("logs"), Some(swap_id))?; - let db = SledDatabase::open(data_dir.join("database").as_path()) - .context("Failed to open database")?; + let db = Arc::new( + SledDatabase::open(data_dir.join("database").as_path()) + .await + .context("Failed to open database")?, + ); let seed = Seed::from_file_or_generate(data_dir.as_path()) .context("Failed to read in seed file")?; @@ -140,13 +144,15 @@ async fn main() -> Result<()> { } Command::History => { let db = SledDatabase::open(data_dir.join("database").as_path()) + .await .context("Failed to open database")?; let mut table = Table::new(); table.set_header(vec!["SWAP ID", "STATE"]); - for (swap_id, state) in db.all_bob()? { + for (swap_id, state) in db.all().await? { + let state: BobState = state.try_into()?; table.add_row(vec![swap_id.to_string(), state.to_string()]); } @@ -215,8 +221,11 @@ async fn main() -> Result<()> { tor_socks5_port, } => { cli::tracing::init(debug, json, data_dir.join("logs"), Some(swap_id))?; - let db = SledDatabase::open(data_dir.join("database").as_path()) - .context("Failed to open database")?; + let db = Arc::new( + SledDatabase::open(data_dir.join("database").as_path()) + .await + .context("Failed to open database")?, + ); let seed = Seed::from_file_or_generate(data_dir.as_path()) .context("Failed to read in seed file")?; @@ -232,8 +241,8 @@ async fn main() -> Result<()> { init_monero_wallet(data_dir, monero_daemon_address, env_config).await?; let bitcoin_wallet = Arc::new(bitcoin_wallet); - let seller_peer_id = db.get_peer_id(swap_id)?; - let seller_addresses = db.get_addresses(seller_peer_id)?; + let seller_peer_id = db.get_peer_id(swap_id).await?; + let seller_addresses = db.get_addresses(seller_peer_id).await?; let behaviour = cli::Behaviour::new(seller_peer_id, env_config, bitcoin_wallet.clone()); let mut swarm = @@ -251,7 +260,7 @@ async fn main() -> Result<()> { EventLoop::new(swap_id, swarm, seller_peer_id, env_config)?; let handle = tokio::spawn(event_loop.run()); - let monero_receive_address = db.get_monero_address(swap_id)?; + let monero_receive_address = db.get_monero_address(swap_id).await?; let swap = Swap::from_db( db, swap_id, @@ -260,7 +269,8 @@ async fn main() -> Result<()> { env_config, event_loop_handle, monero_receive_address, - )?; + ) + .await?; tokio::select! { event_loop_result = handle => { @@ -277,8 +287,11 @@ async fn main() -> Result<()> { bitcoin_target_block, } => { cli::tracing::init(debug, json, data_dir.join("logs"), Some(swap_id))?; - let db = SledDatabase::open(data_dir.join("database").as_path()) - .context("Failed to open database")?; + let db = Arc::new( + SledDatabase::open(data_dir.join("database").as_path()) + .await + .context("Failed to open database")?, + ); let seed = Seed::from_file_or_generate(data_dir.as_path()) .context("Failed to read in seed file")?; @@ -300,8 +313,11 @@ async fn main() -> Result<()> { bitcoin_target_block, } => { cli::tracing::init(debug, json, data_dir.join("logs"), Some(swap_id))?; - let db = SledDatabase::open(data_dir.join("database").as_path()) - .context("Failed to open database")?; + let db = Arc::new( + SledDatabase::open(data_dir.join("database").as_path()) + .await + .context("Failed to open database")?, + ); let seed = Seed::from_file_or_generate(data_dir.as_path()) .context("Failed to read in seed file")?; diff --git a/swap/src/cli/cancel.rs b/swap/src/cli/cancel.rs index 42381eb0..bf973caa 100644 --- a/swap/src/cli/cancel.rs +++ b/swap/src/cli/cancel.rs @@ -1,16 +1,17 @@ use crate::bitcoin::{parse_rpc_error_code, RpcErrorCode, Txid, Wallet}; -use crate::database::{SledDatabase, Swap}; use crate::protocol::bob::BobState; +use crate::protocol::Database; use anyhow::{bail, Result}; +use std::convert::TryInto; use std::sync::Arc; use uuid::Uuid; pub async fn cancel( swap_id: Uuid, bitcoin_wallet: Arc, - db: SledDatabase, + db: Arc, ) -> Result<(Txid, BobState)> { - let state = db.get_state(swap_id)?.try_into_bob()?.into(); + let state = db.get_state(swap_id).await?.try_into()?; let state6 = match state { BobState::BtcLocked(state3) => state3.cancel(), @@ -48,8 +49,8 @@ pub async fn cancel( }; let state = BobState::BtcCancelled(state6); - let db_state = state.clone().into(); - db.insert_latest_state(swap_id, Swap::Bob(db_state)).await?; + db.insert_latest_state(swap_id, state.clone().into()) + .await?; Ok((txid, state)) } diff --git a/swap/src/cli/refund.rs b/swap/src/cli/refund.rs index 7b2244f0..25ccd0e0 100644 --- a/swap/src/cli/refund.rs +++ b/swap/src/cli/refund.rs @@ -1,16 +1,17 @@ use crate::bitcoin::Wallet; -use crate::database::{SledDatabase, Swap}; use crate::protocol::bob::BobState; +use crate::protocol::Database; use anyhow::{bail, Result}; +use std::convert::TryInto; use std::sync::Arc; use uuid::Uuid; pub async fn refund( swap_id: Uuid, bitcoin_wallet: Arc, - db: SledDatabase, + db: Arc, ) -> Result { - let state = db.get_state(swap_id)?.try_into_bob()?.into(); + let state = db.get_state(swap_id).await?.try_into()?; let state6 = match state { BobState::BtcLocked(state3) => state3.cancel(), @@ -35,9 +36,8 @@ pub async fn refund( state6.publish_refund_btc(bitcoin_wallet.as_ref()).await?; let state = BobState::BtcRefunded(state6); - let db_state = state.clone().into(); - - db.insert_latest_state(swap_id, Swap::Bob(db_state)).await?; + db.insert_latest_state(swap_id, state.clone().into()) + .await?; Ok(state) } diff --git a/swap/src/database.rs b/swap/src/database.rs index 1b59d4c8..c03a4e74 100644 --- a/swap/src/database.rs +++ b/swap/src/database.rs @@ -2,6 +2,7 @@ pub use self::sled::SledDatabase; pub use alice::Alice; pub use bob::Bob; +use crate::protocol::State; use anyhow::{bail, Result}; use serde::{Deserialize, Serialize}; use std::fmt::Display; @@ -16,6 +17,24 @@ pub enum Swap { Bob(Bob), } +impl From for Swap { + fn from(state: State) -> Self { + match state { + State::Alice(state) => Swap::Alice(state.into()), + State::Bob(state) => Swap::Bob(state.into()), + } + } +} + +impl From for State { + fn from(value: Swap) -> Self { + match value { + Swap::Alice(alice) => State::Alice(alice.into()), + Swap::Bob(bob) => State::Bob(bob.into()), + } + } +} + impl From for Swap { fn from(from: Alice) -> Self { Swap::Alice(from) diff --git a/swap/src/database/alice.rs b/swap/src/database/alice.rs index 3358238f..1262d04a 100644 --- a/swap/src/database/alice.rs +++ b/swap/src/database/alice.rs @@ -78,8 +78,8 @@ pub enum AliceEndState { BtcPunished, } -impl From<&AliceState> for Alice { - fn from(alice_state: &AliceState) -> Self { +impl From for Alice { + fn from(alice_state: AliceState) -> Self { match alice_state { AliceState::Started { state3 } => Alice::Started { state3: state3.as_ref().clone(), @@ -95,8 +95,8 @@ impl From<&AliceState> for Alice { transfer_proof, state3, } => Alice::XmrLockTransactionSent { - monero_wallet_restore_blockheight: *monero_wallet_restore_blockheight, - transfer_proof: transfer_proof.clone(), + monero_wallet_restore_blockheight, + transfer_proof, state3: state3.as_ref().clone(), }, AliceState::XmrLocked { @@ -104,8 +104,8 @@ impl From<&AliceState> for Alice { transfer_proof, state3, } => Alice::XmrLocked { - monero_wallet_restore_blockheight: *monero_wallet_restore_blockheight, - transfer_proof: transfer_proof.clone(), + monero_wallet_restore_blockheight, + transfer_proof, state3: state3.as_ref().clone(), }, AliceState::XmrLockTransferProofSent { @@ -113,8 +113,8 @@ impl From<&AliceState> for Alice { transfer_proof, state3, } => Alice::XmrLockTransferProofSent { - monero_wallet_restore_blockheight: *monero_wallet_restore_blockheight, - transfer_proof: transfer_proof.clone(), + monero_wallet_restore_blockheight, + transfer_proof, state3: state3.as_ref().clone(), }, AliceState::EncSigLearned { @@ -123,10 +123,10 @@ impl From<&AliceState> for Alice { state3, encrypted_signature, } => Alice::EncSigLearned { - monero_wallet_restore_blockheight: *monero_wallet_restore_blockheight, - transfer_proof: transfer_proof.clone(), + monero_wallet_restore_blockheight, + transfer_proof, state3: state3.as_ref().clone(), - encrypted_signature: *encrypted_signature.clone(), + encrypted_signature: encrypted_signature.as_ref().clone(), }, AliceState::BtcRedeemTransactionPublished { state3 } => { Alice::BtcRedeemTransactionPublished { @@ -139,8 +139,8 @@ impl From<&AliceState> for Alice { transfer_proof, state3, } => Alice::BtcCancelled { - monero_wallet_restore_blockheight: *monero_wallet_restore_blockheight, - transfer_proof: transfer_proof.clone(), + monero_wallet_restore_blockheight, + transfer_proof, state3: state3.as_ref().clone(), }, AliceState::BtcRefunded { @@ -149,9 +149,9 @@ impl From<&AliceState> for Alice { spend_key, state3, } => Alice::BtcRefunded { - monero_wallet_restore_blockheight: *monero_wallet_restore_blockheight, - transfer_proof: transfer_proof.clone(), - spend_key: *spend_key, + monero_wallet_restore_blockheight, + transfer_proof, + spend_key, state3: state3.as_ref().clone(), }, AliceState::BtcPunishable { @@ -159,8 +159,8 @@ impl From<&AliceState> for Alice { transfer_proof, state3, } => Alice::BtcPunishable { - monero_wallet_restore_blockheight: *monero_wallet_restore_blockheight, - transfer_proof: transfer_proof.clone(), + monero_wallet_restore_blockheight, + transfer_proof, state3: state3.as_ref().clone(), }, AliceState::XmrRefunded => Alice::Done(AliceEndState::XmrRefunded), @@ -169,8 +169,8 @@ impl From<&AliceState> for Alice { transfer_proof, state3, } => Alice::CancelTimelockExpired { - monero_wallet_restore_blockheight: *monero_wallet_restore_blockheight, - transfer_proof: transfer_proof.clone(), + monero_wallet_restore_blockheight, + transfer_proof, state3: state3.as_ref().clone(), }, AliceState::BtcPunished => Alice::Done(AliceEndState::BtcPunished), diff --git a/swap/src/database/sled.rs b/swap/src/database/sled.rs index 512b5039..b1e2d9ab 100644 --- a/swap/src/database/sled.rs +++ b/swap/src/database/sled.rs @@ -1,6 +1,7 @@ -use crate::database::{Alice, Bob, Swap}; +use crate::database::Swap; +use crate::protocol::{Database, State}; use anyhow::{anyhow, Context, Result}; -use itertools::Itertools; +use async_trait::async_trait; use libp2p::{Multiaddr, PeerId}; use serde::de::DeserializeOwned; use serde::Serialize; @@ -8,6 +9,9 @@ use std::path::Path; use std::str::FromStr; use uuid::Uuid; +pub use crate::database::alice::Alice; +pub use crate::database::bob::Bob; + pub struct SledDatabase { swaps: sled::Tree, peers: sled::Tree, @@ -15,27 +19,9 @@ pub struct SledDatabase { monero_addresses: sled::Tree, } -impl SledDatabase { - pub fn open(path: &Path) -> Result { - tracing::debug!("Opening database at {}", path.display()); - - let db = - sled::open(path).with_context(|| format!("Could not open the DB at {:?}", path))?; - - let swaps = db.open_tree("swaps")?; - let peers = db.open_tree("peers")?; - let addresses = db.open_tree("addresses")?; - let monero_addresses = db.open_tree("monero_addresses")?; - - Ok(SledDatabase { - swaps, - peers, - addresses, - monero_addresses, - }) - } - - pub async fn insert_peer_id(&self, swap_id: Uuid, peer_id: PeerId) -> Result<()> { +#[async_trait] +impl Database for SledDatabase { + async fn insert_peer_id(&self, swap_id: Uuid, peer_id: PeerId) -> Result<()> { let peer_id_str = peer_id.to_string(); let key = serialize(&swap_id)?; @@ -50,7 +36,7 @@ impl SledDatabase { .context("Could not flush db") } - pub fn get_peer_id(&self, swap_id: Uuid) -> Result { + async fn get_peer_id(&self, swap_id: Uuid) -> Result { let key = serialize(&swap_id)?; let encoded = self @@ -62,11 +48,7 @@ impl SledDatabase { Ok(PeerId::from_str(peer_id.as_str())?) } - pub async fn insert_monero_address( - &self, - swap_id: Uuid, - address: monero::Address, - ) -> Result<()> { + async fn insert_monero_address(&self, swap_id: Uuid, address: monero::Address) -> Result<()> { let key = swap_id.as_bytes(); let value = serialize(&address)?; @@ -79,7 +61,7 @@ impl SledDatabase { .context("Could not flush db") } - pub fn get_monero_address(&self, swap_id: Uuid) -> Result { + async fn get_monero_address(&self, swap_id: Uuid) -> Result { let encoded = self .monero_addresses .get(swap_id.as_bytes())? @@ -95,7 +77,7 @@ impl SledDatabase { Ok(monero_address) } - pub async fn insert_address(&self, peer_id: PeerId, address: Multiaddr) -> Result<()> { + async fn insert_address(&self, peer_id: PeerId, address: Multiaddr) -> Result<()> { let key = peer_id.to_bytes(); let existing_addresses = self.addresses.get(&key)?; @@ -124,7 +106,7 @@ impl SledDatabase { .context("Could not flush db") } - pub fn get_addresses(&self, peer_id: PeerId) -> Result> { + async fn get_addresses(&self, peer_id: PeerId) -> Result> { let key = peer_id.to_bytes(); let addresses = match self.addresses.get(&key)? { @@ -135,9 +117,10 @@ impl SledDatabase { Ok(addresses) } - pub async fn insert_latest_state(&self, swap_id: Uuid, state: Swap) -> Result<()> { + async fn insert_latest_state(&self, swap_id: Uuid, state: State) -> Result<()> { let key = serialize(&swap_id)?; - let new_value = serialize(&state).context("Could not serialize new state value")?; + let swap = Swap::from(state); + let new_value = serialize(&swap).context("Could not serialize new state value")?; let old_value = self.swaps.get(&key)?; @@ -153,7 +136,7 @@ impl SledDatabase { .context("Could not flush db") } - pub fn get_state(&self, swap_id: Uuid) -> Result { + async fn get_state(&self, swap_id: Uuid) -> Result { let key = serialize(&swap_id)?; let encoded = self @@ -161,47 +144,91 @@ impl SledDatabase { .get(&key)? .ok_or_else(|| anyhow!("Swap with id {} not found in database", swap_id))?; - let state = deserialize(&encoded).context("Could not deserialize state")?; + let swap = deserialize::(&encoded).context("Could not deserialize state")?; + + let state = State::from(swap); + Ok(state) } - pub fn all_alice(&self) -> Result> { - self.all_alice_iter().collect() + async fn all(&self) -> Result> { + self.all_iter().collect() + } +} + +impl SledDatabase { + pub async fn open(path: &Path) -> Result { + tracing::debug!("Opening database at {}", path.display()); + + let db = + sled::open(path).with_context(|| format!("Could not open the DB at {:?}", path))?; + + let swaps = db.open_tree("swaps")?; + let peers = db.open_tree("peers")?; + let addresses = db.open_tree("addresses")?; + let monero_addresses = db.open_tree("monero_addresses")?; + + Ok(SledDatabase { + swaps, + peers, + addresses, + monero_addresses, + }) } - fn all_alice_iter(&self) -> impl Iterator> { - self.all_swaps_iter().map(|item| { - let (swap_id, swap) = item?; - Ok((swap_id, swap.try_into_alice()?)) + pub fn get_all_peers(&self) -> impl Iterator> { + self.peers.iter().map(|item| { + let (key, value) = item.context("Failed to retrieve peer id from DB")?; + + let swap_id = deserialize::(&key)?; + let peer_id_bytes = + deserialize::>(&value).context("Failed to deserialize swap")?; + + let peer_id = PeerId::from_bytes(&peer_id_bytes)?; + + Ok((swap_id, peer_id)) }) } - pub fn all_bob(&self) -> Result> { - self.all_bob_iter().collect() + pub fn get_all_addresses(&self) -> impl Iterator)>> { + self.addresses.iter().map(|item| { + let (key, value) = item.context("Failed to retrieve peer address from DB")?; + + let peer_id_bytes = deserialize::>(&key)?; + let addr = + deserialize::>(&value).context("Failed to deserialize swap")?; + + let peer_id = PeerId::from_bytes(&peer_id_bytes)?; + + Ok((peer_id, addr)) + }) } - fn all_bob_iter(&self) -> impl Iterator> { - self.all_swaps_iter().map(|item| { - let (swap_id, swap) = item?; - Ok((swap_id, swap.try_into_bob()?)) + pub fn get_all_monero_addresses( + &self, + ) -> impl Iterator> { + self.monero_addresses.iter().map(|item| { + let (key, value) = item.context("Failed to retrieve monero address from DB")?; + + let swap_id = deserialize::(&key)?; + let addr = + deserialize::(&value).context("Failed to deserialize swap")?; + + Ok((swap_id, addr)) }) } - fn all_swaps_iter(&self) -> impl Iterator> { + fn all_iter(&self) -> impl Iterator> { self.swaps.iter().map(|item| { let (key, value) = item.context("Failed to retrieve swap from DB")?; let swap_id = deserialize::(&key)?; let swap = deserialize::(&value).context("Failed to deserialize swap")?; - Ok((swap_id, swap)) - }) - } + let state = State::from(swap); - pub fn unfinished_alice(&self) -> Result> { - self.all_alice_iter() - .filter_ok(|(_swap_id, alice)| !matches!(alice, Alice::Done(_))) - .collect() + Ok((swap_id, state)) + }) } } @@ -222,22 +249,20 @@ where #[cfg(test)] mod tests { use super::*; - use crate::database::alice::{Alice, AliceEndState}; - use crate::database::bob::{Bob, BobEndState}; - use crate::database::{NotAlice, NotBob}; + use crate::protocol::alice::AliceState; #[tokio::test] async fn can_write_and_read_to_multiple_keys() { let db_dir = tempfile::tempdir().unwrap(); - let db = SledDatabase::open(db_dir.path()).unwrap(); + let db = SledDatabase::open(db_dir.path()).await.unwrap(); - let state_1 = Swap::Alice(Alice::Done(AliceEndState::BtcRedeemed)); + let state_1 = State::from(AliceState::BtcRedeemed); let swap_id_1 = Uuid::new_v4(); db.insert_latest_state(swap_id_1, state_1.clone()) .await .expect("Failed to save second state"); - let state_2 = Swap::Bob(Bob::Done(BobEndState::SafelyAborted)); + let state_2 = State::from(AliceState::BtcPunished); let swap_id_2 = Uuid::new_v4(); db.insert_latest_state(swap_id_2, state_2.clone()) .await @@ -245,10 +270,12 @@ mod tests { let recovered_1 = db .get_state(swap_id_1) + .await .expect("Failed to recover first state"); let recovered_2 = db .get_state(swap_id_2) + .await .expect("Failed to recover second state"); assert_eq!(recovered_1, state_1); @@ -258,9 +285,9 @@ mod tests { #[tokio::test] async fn can_write_twice_to_one_key() { let db_dir = tempfile::tempdir().unwrap(); - let db = SledDatabase::open(db_dir.path()).unwrap(); + let db = SledDatabase::open(db_dir.path()).await.unwrap(); - let state = Swap::Alice(Alice::Done(AliceEndState::SafelyAborted)); + let state = State::from(AliceState::SafelyAborted); let swap_id = Uuid::new_v4(); db.insert_latest_state(swap_id, state.clone()) @@ -268,6 +295,7 @@ mod tests { .expect("Failed to save state the first time"); let recovered = db .get_state(swap_id) + .await .expect("Failed to recover state the first time"); // We insert and recover twice to ensure database implementation allows the @@ -277,84 +305,29 @@ mod tests { .expect("Failed to save state the second time"); let recovered = db .get_state(swap_id) + .await .expect("Failed to recover state the second time"); assert_eq!(recovered, state); } - #[tokio::test] - async fn all_swaps_as_alice() { - let db_dir = tempfile::tempdir().unwrap(); - let db = SledDatabase::open(db_dir.path()).unwrap(); - - let alice_state = Alice::Done(AliceEndState::BtcPunished); - let alice_swap = Swap::Alice(alice_state.clone()); - let alice_swap_id = Uuid::new_v4(); - db.insert_latest_state(alice_swap_id, alice_swap) - .await - .expect("Failed to save alice state 1"); - - let alice_swaps = db.all_alice().unwrap(); - assert_eq!(alice_swaps.len(), 1); - assert!(alice_swaps.contains(&(alice_swap_id, alice_state))); - - let bob_state = Bob::Done(BobEndState::SafelyAborted); - let bob_swap = Swap::Bob(bob_state); - let bob_swap_id = Uuid::new_v4(); - db.insert_latest_state(bob_swap_id, bob_swap) - .await - .expect("Failed to save bob state 1"); - - let err = db.all_alice().unwrap_err(); - - assert_eq!(err.downcast_ref::().unwrap(), &NotAlice); - } - - #[tokio::test] - async fn all_swaps_as_bob() { - let db_dir = tempfile::tempdir().unwrap(); - let db = SledDatabase::open(db_dir.path()).unwrap(); - - let bob_state = Bob::Done(BobEndState::SafelyAborted); - let bob_swap = Swap::Bob(bob_state.clone()); - let bob_swap_id = Uuid::new_v4(); - db.insert_latest_state(bob_swap_id, bob_swap) - .await - .expect("Failed to save bob state 1"); - - let bob_swaps = db.all_bob().unwrap(); - assert_eq!(bob_swaps.len(), 1); - assert!(bob_swaps.contains(&(bob_swap_id, bob_state))); - - let alice_state = Alice::Done(AliceEndState::BtcPunished); - let alice_swap = Swap::Alice(alice_state); - let alice_swap_id = Uuid::new_v4(); - db.insert_latest_state(alice_swap_id, alice_swap) - .await - .expect("Failed to save alice state 1"); - - let err = db.all_bob().unwrap_err(); - - assert_eq!(err.downcast_ref::().unwrap(), &NotBob); - } - #[tokio::test] async fn can_save_swap_state_and_peer_id_with_same_swap_id() -> Result<()> { let db_dir = tempfile::tempdir().unwrap(); - let db = SledDatabase::open(db_dir.path()).unwrap(); + let db = SledDatabase::open(db_dir.path()).await.unwrap(); let alice_id = Uuid::new_v4(); - let alice_state = Alice::Done(AliceEndState::BtcPunished); - let alice_swap = Swap::Alice(alice_state); + let alice_state = State::from(AliceState::BtcPunished); let peer_id = PeerId::random(); - db.insert_latest_state(alice_id, alice_swap.clone()).await?; + db.insert_latest_state(alice_id, alice_state.clone()) + .await?; db.insert_peer_id(alice_id, peer_id).await?; - let loaded_swap = db.get_state(alice_id)?; - let loaded_peer_id = db.get_peer_id(alice_id)?; + let loaded_swap = db.get_state(alice_id).await?; + let loaded_peer_id = db.get_peer_id(alice_id).await?; - assert_eq!(alice_swap, loaded_swap); + assert_eq!(alice_state, loaded_swap); assert_eq!(peer_id, loaded_peer_id); Ok(()) @@ -364,23 +337,23 @@ mod tests { async fn test_reopen_db() -> Result<()> { let db_dir = tempfile::tempdir().unwrap(); let alice_id = Uuid::new_v4(); - let alice_state = Alice::Done(AliceEndState::BtcPunished); - let alice_swap = Swap::Alice(alice_state); + let alice_state = State::from(AliceState::BtcPunished); let peer_id = PeerId::random(); { - let db = SledDatabase::open(db_dir.path()).unwrap(); - db.insert_latest_state(alice_id, alice_swap.clone()).await?; + let db = SledDatabase::open(db_dir.path()).await.unwrap(); + db.insert_latest_state(alice_id, alice_state.clone()) + .await?; db.insert_peer_id(alice_id, peer_id).await?; } - let db = SledDatabase::open(db_dir.path()).unwrap(); + let db = SledDatabase::open(db_dir.path()).await.unwrap(); - let loaded_swap = db.get_state(alice_id)?; - let loaded_peer_id = db.get_peer_id(alice_id)?; + let loaded_swap = db.get_state(alice_id).await?; + let loaded_peer_id = db.get_peer_id(alice_id).await?; - assert_eq!(alice_swap, loaded_swap); + assert_eq!(alice_state, loaded_swap); assert_eq!(peer_id, loaded_peer_id); Ok(()) @@ -394,12 +367,15 @@ mod tests { let home2 = "/ip4/127.0.0.1/tcp/2".parse::()?; { - let db = SledDatabase::open(db_dir.path())?; + let db = SledDatabase::open(db_dir.path()).await?; db.insert_address(peer_id, home1.clone()).await?; db.insert_address(peer_id, home2.clone()).await?; } - let addresses = SledDatabase::open(db_dir.path())?.get_addresses(peer_id)?; + let addresses = SledDatabase::open(db_dir.path()) + .await? + .get_addresses(peer_id) + .await?; assert_eq!(addresses, vec![home1, home2]); @@ -411,9 +387,11 @@ mod tests { let db_dir = tempfile::tempdir()?; let swap_id = Uuid::new_v4(); - SledDatabase::open(db_dir.path())?.insert_monero_address(swap_id, "53gEuGZUhP9JMEBZoGaFNzhwEgiG7hwQdMCqFxiyiTeFPmkbt1mAoNybEUvYBKHcnrSgxnVWgZsTvRBaHBNXPa8tHiCU51a".parse()?).await?; - let loaded_monero_address = - SledDatabase::open(db_dir.path())?.get_monero_address(swap_id)?; + SledDatabase::open(db_dir.path()).await?.insert_monero_address(swap_id, "53gEuGZUhP9JMEBZoGaFNzhwEgiG7hwQdMCqFxiyiTeFPmkbt1mAoNybEUvYBKHcnrSgxnVWgZsTvRBaHBNXPa8tHiCU51a".parse()?).await?; + let loaded_monero_address = SledDatabase::open(db_dir.path()) + .await? + .get_monero_address(swap_id) + .await?; assert_eq!(loaded_monero_address.to_string(), "53gEuGZUhP9JMEBZoGaFNzhwEgiG7hwQdMCqFxiyiTeFPmkbt1mAoNybEUvYBKHcnrSgxnVWgZsTvRBaHBNXPa8tHiCU51a"); diff --git a/swap/src/protocol.rs b/swap/src/protocol.rs index 59b100c0..88077bc3 100644 --- a/swap/src/protocol.rs +++ b/swap/src/protocol.rs @@ -1,10 +1,18 @@ +use crate::protocol::alice::swap::is_complete as alice_is_complete; +use crate::protocol::alice::AliceState; +use crate::protocol::bob::swap::is_complete as bob_is_complete; +use crate::protocol::bob::BobState; use crate::{bitcoin, monero}; +use anyhow::Result; +use async_trait::async_trait; use conquer_once::Lazy; use ecdsa_fun::fun::marker::Mark; +use libp2p::{Multiaddr, PeerId}; use serde::{Deserialize, Serialize}; use sha2::Sha256; use sigma_fun::ext::dl_secp256k1_ed25519_eq::{CrossCurveDLEQ, CrossCurveDLEQProof}; use sigma_fun::HashTranscript; +use std::convert::TryInto; use uuid::Uuid; pub mod alice; @@ -65,3 +73,74 @@ pub struct Message4 { tx_punish_sig: bitcoin::Signature, tx_cancel_sig: bitcoin::Signature, } + +#[allow(clippy::large_enum_variant)] +#[derive(Clone, Debug, PartialEq)] +pub enum State { + Alice(AliceState), + Bob(BobState), +} + +impl State { + pub fn swap_finished(&self) -> bool { + match self { + State::Alice(state) => alice_is_complete(state), + State::Bob(state) => bob_is_complete(state), + } + } +} + +impl From for State { + fn from(alice: AliceState) -> Self { + Self::Alice(alice) + } +} + +impl From for State { + fn from(bob: BobState) -> Self { + Self::Bob(bob) + } +} + +#[derive(thiserror::Error, Debug, Clone, Copy, PartialEq)] +#[error("Not in the role of Alice")] +pub struct NotAlice; + +#[derive(thiserror::Error, Debug, Clone, Copy, PartialEq)] +#[error("Not in the role of Bob")] +pub struct NotBob; + +impl TryInto for State { + type Error = NotBob; + + fn try_into(self) -> std::result::Result { + match self { + State::Alice(_) => Err(NotBob), + State::Bob(state) => Ok(state), + } + } +} + +impl TryInto for State { + type Error = NotAlice; + + fn try_into(self) -> std::result::Result { + match self { + State::Alice(state) => Ok(state), + State::Bob(_) => Err(NotAlice), + } + } +} + +#[async_trait] +pub trait Database { + async fn insert_peer_id(&self, swap_id: Uuid, peer_id: PeerId) -> Result<()>; + async fn get_peer_id(&self, swap_id: Uuid) -> Result; + async fn insert_monero_address(&self, swap_id: Uuid, address: monero::Address) -> Result<()>; + async fn get_monero_address(&self, swap_id: Uuid) -> Result; + async fn insert_address(&self, peer_id: PeerId, address: Multiaddr) -> Result<()>; + async fn get_addresses(&self, peer_id: PeerId) -> Result>; + async fn insert_latest_state(&self, swap_id: Uuid, state: State) -> Result<()>; + async fn get_state(&self, swap_id: Uuid) -> Result; + async fn all(&self) -> Result>; +} diff --git a/swap/src/protocol/alice.rs b/swap/src/protocol/alice.rs index b0d610b9..c6c21512 100644 --- a/swap/src/protocol/alice.rs +++ b/swap/src/protocol/alice.rs @@ -1,7 +1,7 @@ //! Run an XMR/BTC swap in the role of Alice. //! Alice holds XMR and wishes receive BTC. -use crate::database::SledDatabase; use crate::env::Config; +use crate::protocol::Database; use crate::{asb, bitcoin, monero}; use std::sync::Arc; use uuid::Uuid; @@ -19,5 +19,5 @@ pub struct Swap { pub monero_wallet: Arc, pub env_config: Config, pub swap_id: Uuid, - pub db: Arc, + pub db: Arc, } diff --git a/swap/src/protocol/alice/state.rs b/swap/src/protocol/alice/state.rs index 18c2a18d..9dcf3f46 100644 --- a/swap/src/protocol/alice/state.rs +++ b/swap/src/protocol/alice/state.rs @@ -16,7 +16,7 @@ use sigma_fun::ext::dl_secp256k1_ed25519_eq::CrossCurveDLEQProof; use std::fmt; use uuid::Uuid; -#[derive(Debug)] +#[derive(Debug, Clone, PartialEq)] pub enum AliceState { Started { state3: Box, diff --git a/swap/src/protocol/alice/swap.rs b/swap/src/protocol/alice/swap.rs index 76324667..77516d11 100644 --- a/swap/src/protocol/alice/swap.rs +++ b/swap/src/protocol/alice/swap.rs @@ -4,7 +4,7 @@ use crate::asb::{EventLoopHandle, LatestRate}; use crate::bitcoin::ExpiredTimelocks; use crate::env::Config; use crate::protocol::alice::{AliceState, Swap}; -use crate::{bitcoin, database, monero}; +use crate::{bitcoin, monero}; use anyhow::{bail, Context, Result}; use tokio::select; use tokio::time::timeout; @@ -40,9 +40,8 @@ where ) .await?; - let db_state = (¤t_state).into(); swap.db - .insert_latest_state(swap.swap_id, database::Swap::Alice(db_state)) + .insert_latest_state(swap.swap_id, current_state.clone().into()) .await?; } @@ -398,7 +397,7 @@ where }) } -fn is_complete(state: &AliceState) -> bool { +pub(crate) fn is_complete(state: &AliceState) -> bool { matches!( state, AliceState::XmrRefunded diff --git a/swap/src/protocol/bob.rs b/swap/src/protocol/bob.rs index bbc35e7f..0ef3e241 100644 --- a/swap/src/protocol/bob.rs +++ b/swap/src/protocol/bob.rs @@ -3,11 +3,12 @@ use std::sync::Arc; use anyhow::Result; use uuid::Uuid; -use crate::database::SledDatabase; +use crate::protocol::Database; use crate::{bitcoin, cli, env, monero}; pub use self::state::*; pub use self::swap::{run, run_until}; +use std::convert::TryInto; pub mod state; pub mod swap; @@ -15,7 +16,7 @@ pub mod swap; pub struct Swap { pub state: BobState, pub event_loop_handle: cli::EventLoopHandle, - pub db: SledDatabase, + pub db: Arc, pub bitcoin_wallet: Arc, pub monero_wallet: Arc, pub env_config: env::Config, @@ -26,7 +27,7 @@ pub struct Swap { impl Swap { #[allow(clippy::too_many_arguments)] pub fn new( - db: SledDatabase, + db: Arc, id: Uuid, bitcoin_wallet: Arc, monero_wallet: Arc, @@ -52,8 +53,8 @@ impl Swap { } #[allow(clippy::too_many_arguments)] - pub fn from_db( - db: SledDatabase, + pub async fn from_db( + db: Arc, id: Uuid, bitcoin_wallet: Arc, monero_wallet: Arc, @@ -61,7 +62,7 @@ impl Swap { event_loop_handle: cli::EventLoopHandle, monero_receive_address: monero::Address, ) -> Result { - let state = db.get_state(id)?.try_into_bob()?.into(); + let state = db.get_state(id).await?.try_into()?; Ok(Self { state, diff --git a/swap/src/protocol/bob/state.rs b/swap/src/protocol/bob/state.rs index 8e0bac9a..effd05ee 100644 --- a/swap/src/protocol/bob/state.rs +++ b/swap/src/protocol/bob/state.rs @@ -21,7 +21,7 @@ use sigma_fun::ext::dl_secp256k1_ed25519_eq::CrossCurveDLEQProof; use std::fmt; use uuid::Uuid; -#[derive(Debug, Clone)] +#[derive(Debug, Clone, PartialEq)] pub enum BobState { Started { btc_amount: bitcoin::Amount, diff --git a/swap/src/protocol/bob/swap.rs b/swap/src/protocol/bob/swap.rs index e1e9f849..c0bccc69 100644 --- a/swap/src/protocol/bob/swap.rs +++ b/swap/src/protocol/bob/swap.rs @@ -1,6 +1,5 @@ use crate::bitcoin::{ExpiredTimelocks, TxCancel, TxRefund}; use crate::cli::EventLoopHandle; -use crate::database::Swap; use crate::network::swap_setup::bob::NewSwap; use crate::protocol::bob; use crate::protocol::bob::state::*; @@ -33,7 +32,7 @@ pub async fn run_until( while !is_target_state(¤t_state) { current_state = next_state( swap.id, - current_state, + current_state.clone(), &mut swap.event_loop_handle, swap.bitcoin_wallet.as_ref(), swap.monero_wallet.as_ref(), @@ -41,9 +40,8 @@ pub async fn run_until( ) .await?; - let db_state = current_state.clone().into(); swap.db - .insert_latest_state(swap.id, Swap::Bob(db_state)) + .insert_latest_state(swap.id, current_state.clone().into()) .await?; } diff --git a/swap/tests/harness/mod.rs b/swap/tests/harness/mod.rs index 6e56e22f..fbf1115a 100644 --- a/swap/tests/harness/mod.rs +++ b/swap/tests/harness/mod.rs @@ -222,7 +222,7 @@ async fn start_alice( bitcoin_wallet: Arc, monero_wallet: Arc, ) -> (AliceApplicationHandle, Receiver) { - let db = Arc::new(SledDatabase::open(db_path.as_path()).unwrap()); + let db = Arc::new(SledDatabase::open(db_path.as_path()).await.unwrap()); let min_buy = bitcoin::Amount::from_sat(u64::MIN); let max_buy = bitcoin::Amount::from_sat(u64::MAX); @@ -402,7 +402,7 @@ struct BobParams { impl BobParams { pub async fn new_swap_from_db(&self, swap_id: Uuid) -> Result<(bob::Swap, cli::EventLoop)> { let (event_loop, handle) = self.new_eventloop(swap_id).await?; - let db = SledDatabase::open(&self.db_path)?; + let db = Arc::new(SledDatabase::open(&self.db_path).await?); let swap = bob::Swap::from_db( db, @@ -412,7 +412,8 @@ impl BobParams { self.env_config, handle, self.monero_wallet.get_main_address(), - )?; + ) + .await?; Ok((swap, event_loop)) } @@ -424,7 +425,7 @@ impl BobParams { let swap_id = Uuid::new_v4(); let (event_loop, handle) = self.new_eventloop(swap_id).await?; - let db = SledDatabase::open(&self.db_path)?; + let db = Arc::new(SledDatabase::open(&self.db_path).await?); let swap = bob::Swap::new( db,