diff --git a/Cargo.lock b/Cargo.lock index 99880a98..fbbec2b2 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -1552,6 +1552,15 @@ dependencies = [ "either", ] +[[package]] +name = "itertools" +version = "0.10.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "37d572918e350e82412fe766d24b15e6682fb2ed2bbe018280caa810397cb319" +dependencies = [ + "either", +] + [[package]] name = "itoa" version = "0.4.7" @@ -2547,7 +2556,7 @@ checksum = "32d3ebd75ac2679c2af3a92246639f9fcc8a442ee420719cc4fe195b98dd5fa3" dependencies = [ "bytes 1.0.1", "heck", - "itertools", + "itertools 0.9.0", "log 0.4.14", "multimap", "petgraph", @@ -2564,7 +2573,7 @@ source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "169a15f3008ecb5160cba7d37bcd690a7601b6d30cfb87a117d45e59d52af5d4" dependencies = [ "anyhow", - "itertools", + "itertools 0.9.0", "proc-macro2", "quote", "syn", @@ -3522,6 +3531,7 @@ dependencies = [ "futures", "get-port", "hyper 0.14.5", + "itertools 0.10.0", "libp2p", "libp2p-async-await", "miniscript", diff --git a/swap/Cargo.toml b/swap/Cargo.toml index 09014eba..d36589dc 100644 --- a/swap/Cargo.toml +++ b/swap/Cargo.toml @@ -25,6 +25,7 @@ dialoguer = "0.8" directories-next = "2" ecdsa_fun = { git = "https://github.com/LLFourn/secp256kfun", features = ["libsecp_compat", "serde"] } futures = { version = "0.3", default-features = false } +itertools = "0.10" libp2p = { version = "0.36", default-features = false, features = ["tcp-tokio", "yamux", "mplex", "dns-tokio", "noise", "request-response"] } libp2p-async-await = { git = "https://github.com/comit-network/rust-libp2p-async-await" } miniscript = { version = "5", features = ["serde"] } diff --git a/swap/src/bin/asb.rs b/swap/src/bin/asb.rs index fb48a148..d77bf271 100644 --- a/swap/src/bin/asb.rs +++ b/swap/src/bin/asb.rs @@ -130,7 +130,7 @@ async fn main() -> Result<()> { table.add_row(row!["SWAP ID", "STATE"]); - for (swap_id, state) in db.all()? { + for (swap_id, state) in db.all_alice()? { table.add_row(row![swap_id, state]); } diff --git a/swap/src/bin/swap.rs b/swap/src/bin/swap.rs index c1d6721f..a4b877c4 100644 --- a/swap/src/bin/swap.rs +++ b/swap/src/bin/swap.rs @@ -158,7 +158,7 @@ async fn main() -> Result<()> { table.add_row(row!["SWAP ID", "STATE"]); - for (swap_id, state) in db.all()? { + for (swap_id, state) in db.all_bob()? { table.add_row(row![swap_id, state]); } diff --git a/swap/src/database.rs b/swap/src/database.rs index 98b7f092..4e62af59 100644 --- a/swap/src/database.rs +++ b/swap/src/database.rs @@ -2,6 +2,7 @@ pub use alice::Alice; pub use bob::Bob; use anyhow::{anyhow, bail, Context, Result}; +use itertools::Itertools; use serde::de::DeserializeOwned; use serde::{Deserialize, Serialize}; use std::fmt::Display; @@ -38,11 +39,26 @@ impl Display for Swap { } } +#[derive(thiserror::Error, Debug, Clone, Copy, PartialEq)] +#[error("Not in the role of Alice")] +struct NotAlice; + +#[derive(thiserror::Error, Debug, Clone, Copy, PartialEq)] +#[error("Not in the role of Bob")] +struct NotBob; + impl Swap { + pub fn try_into_alice(self) -> Result { + match self { + Swap::Alice(alice) => Ok(alice), + Swap::Bob(_) => bail!(NotAlice), + } + } + pub fn try_into_bob(self) -> Result { match self { Swap::Bob(bob) => Ok(bob), - Swap::Alice(_) => bail!("Swap instance is not Bob"), + Swap::Alice(_) => bail!(NotBob), } } } @@ -90,22 +106,42 @@ impl Database { Ok(state) } - pub fn all(&self) -> Result> { - self.0 - .iter() - .map(|item| match item { - Ok((key, value)) => { - let swap_id = deserialize::(&key); - let swap = deserialize::(&value).context("Failed to deserialize swap"); - - match (swap_id, swap) { - (Ok(swap_id), Ok(swap)) => Ok((swap_id, swap)), - (Ok(_), Err(err)) => Err(err), - _ => bail!("Failed to deserialize swap"), - } - } - Err(err) => Err(err).context("Failed to retrieve swap from DB"), - }) + pub fn all_alice(&self) -> Result> { + self.all_alice_iter().collect() + } + + fn all_alice_iter(&self) -> impl Iterator> { + self.all_swaps_iter().map(|item| { + let (swap_id, swap) = item?; + Ok((swap_id, swap.try_into_alice()?)) + }) + } + + pub fn all_bob(&self) -> Result> { + self.all_bob_iter().collect() + } + + fn all_bob_iter(&self) -> impl Iterator> { + self.all_swaps_iter().map(|item| { + let (swap_id, swap) = item?; + Ok((swap_id, swap.try_into_bob()?)) + }) + } + + fn all_swaps_iter(&self) -> impl Iterator> { + self.0.iter().map(|item| { + let (key, value) = item.context("Failed to retrieve swap from DB")?; + + let swap_id = deserialize::(&key)?; + let swap = deserialize::(&value).context("Failed to deserialize swap")?; + + Ok((swap_id, swap)) + }) + } + + pub fn unfinished_alice(&self) -> Result> { + self.all_alice_iter() + .filter_ok(|(_swap_id, alice)| !matches!(alice, Alice::Done(_))) .collect() } } @@ -187,26 +223,58 @@ mod tests { } #[tokio::test] - async fn can_fetch_all_keys() { + async fn all_swaps_as_alice() { let db_dir = tempfile::tempdir().unwrap(); let db = Database::open(db_dir.path()).unwrap(); - let state_1 = Swap::Alice(Alice::Done(AliceEndState::BtcPunished)); - let swap_id_1 = Uuid::new_v4(); - db.insert_latest_state(swap_id_1, state_1.clone()) + let alice_state = Alice::Done(AliceEndState::BtcPunished); + let alice_swap = Swap::Alice(alice_state.clone()); + let alice_swap_id = Uuid::new_v4(); + db.insert_latest_state(alice_swap_id, alice_swap) .await - .expect("Failed to save second state"); + .expect("Failed to save alice state 1"); - let state_2 = Swap::Bob(Bob::Done(BobEndState::SafelyAborted)); - let swap_id_2 = Uuid::new_v4(); - db.insert_latest_state(swap_id_2, state_2.clone()) + let alice_swaps = db.all_alice().unwrap(); + assert_eq!(alice_swaps.len(), 1); + assert!(alice_swaps.contains(&(alice_swap_id, alice_state))); + + let bob_state = Bob::Done(BobEndState::SafelyAborted); + let bob_swap = Swap::Bob(bob_state); + let bob_swap_id = Uuid::new_v4(); + db.insert_latest_state(bob_swap_id, bob_swap) .await - .expect("Failed to save first state"); + .expect("Failed to save bob state 1"); + + let err = db.all_alice().unwrap_err(); + + assert_eq!(err.downcast_ref::().unwrap(), &NotAlice); + } + + #[tokio::test] + async fn all_swaps_as_bob() { + let db_dir = tempfile::tempdir().unwrap(); + let db = Database::open(db_dir.path()).unwrap(); + + let bob_state = Bob::Done(BobEndState::SafelyAborted); + let bob_swap = Swap::Bob(bob_state.clone()); + let bob_swap_id = Uuid::new_v4(); + db.insert_latest_state(bob_swap_id, bob_swap) + .await + .expect("Failed to save bob state 1"); + + let bob_swaps = db.all_bob().unwrap(); + assert_eq!(bob_swaps.len(), 1); + assert!(bob_swaps.contains(&(bob_swap_id, bob_state))); + + let alice_state = Alice::Done(AliceEndState::BtcPunished); + let alice_swap = Swap::Alice(alice_state); + let alice_swap_id = Uuid::new_v4(); + db.insert_latest_state(alice_swap_id, alice_swap) + .await + .expect("Failed to save alice state 1"); - let swaps = db.all().unwrap(); + let err = db.all_bob().unwrap_err(); - assert_eq!(swaps.len(), 2); - assert!(swaps.contains(&(swap_id_1, state_1))); - assert!(swaps.contains(&(swap_id_2, state_2))); + assert_eq!(err.downcast_ref::().unwrap(), &NotBob); } }