diff --git a/xmr-btc/src/alice.rs b/xmr-btc/src/alice.rs index 3c87b514..18d4fa83 100644 --- a/xmr-btc/src/alice.rs +++ b/xmr-btc/src/alice.rs @@ -82,7 +82,7 @@ pub async fn next_state< } #[allow(clippy::large_enum_variant)] -#[derive(Debug)] +#[derive(Debug, Deserialize, Serialize)] pub enum State { State0(State0), State1(State1), diff --git a/xmr-btc/src/bob.rs b/xmr-btc/src/bob.rs index 1dfb8b81..60da0e69 100644 --- a/xmr-btc/src/bob.rs +++ b/xmr-btc/src/bob.rs @@ -82,7 +82,7 @@ pub async fn next_state< } } -#[derive(Debug)] +#[derive(Debug, Deserialize, Serialize)] pub enum State { State0(State0), State1(State1), diff --git a/xmr-btc/tests/e2e.rs b/xmr-btc/tests/e2e.rs index 148dcbe4..b6d51217 100644 --- a/xmr-btc/tests/e2e.rs +++ b/xmr-btc/tests/e2e.rs @@ -18,6 +18,7 @@ mod tests { use monero_harness::Monero; use rand::rngs::OsRng; + use crate::harness::storage::Database; use std::{convert::TryInto, path::Path}; use testcontainers::clients::Cli; use tracing_subscriber::util::SubscriberInitExt; @@ -251,8 +252,10 @@ mod tests { let cli = Cli::default(); let (monero, _container) = Monero::new(&cli); let bitcoind = init_bitcoind(&cli).await; - let alice_db = harness::storage::Database::open(Path::new(ALICE_TEST_DB_FOLDER)).unwrap(); - let bob_db = harness::storage::Database::open(Path::new(BOB_TEST_DB_FOLDER)).unwrap(); + let alice_db: Database = + harness::storage::Database::open(Path::new(ALICE_TEST_DB_FOLDER)).unwrap(); + let bob_db: Database = + harness::storage::Database::open(Path::new(BOB_TEST_DB_FOLDER)).unwrap(); let ( alice_state0, @@ -281,29 +284,26 @@ mod tests { .await .unwrap(); - let alice_state5: alice::State5 = alice_state.try_into().unwrap(); - let bob_state3: bob::State3 = bob_state.try_into().unwrap(); - // save state to db - alice_db.insert_latest_state(&alice_state5).await.unwrap(); - bob_db.insert_latest_state(&bob_state3).await.unwrap(); + alice_db.insert_latest_state(&alice_state).await.unwrap(); + bob_db.insert_latest_state(&bob_state).await.unwrap(); }; let (alice_state6, bob_state5) = { // recover state from db - let alice_state5: alice::State5 = alice_db.get_latest_state().unwrap(); - let bob_state3: bob::State3 = bob_db.get_latest_state().unwrap(); + let alice_state = alice_db.get_latest_state().unwrap(); + let bob_state = bob_db.get_latest_state().unwrap(); let (alice_state, bob_state) = future::try_join( run_alice_until( &mut alice_node, - alice_state5.into(), + alice_state, harness::alice::is_state6, &mut OsRng, ), run_bob_until( &mut bob_node, - bob_state3.into(), + bob_state, harness::bob::is_state5, &mut OsRng, ), diff --git a/xmr-btc/tests/harness/storage.rs b/xmr-btc/tests/harness/storage.rs index cc8dd899..cfb8eceb 100644 --- a/xmr-btc/tests/harness/storage.rs +++ b/xmr-btc/tests/harness/storage.rs @@ -2,24 +2,31 @@ use anyhow::{anyhow, Context, Result}; use serde::{de::DeserializeOwned, Deserialize, Serialize}; use std::path::Path; -pub struct Database { +pub struct Database +where + T: Serialize + DeserializeOwned, +{ db: sled::Db, + _marker: std::marker::PhantomData, } -impl Database { +impl Database +where + T: Serialize + DeserializeOwned, +{ const LAST_STATE_KEY: &'static str = "latest_state"; pub fn open(path: &Path) -> Result { let db = sled::open(path).with_context(|| format!("Could not open the DB at {:?}", path))?; - Ok(Database { db }) + Ok(Database { + db, + _marker: Default::default(), + }) } - pub async fn insert_latest_state(&self, state: &T) -> Result<()> - where - T: Serialize + DeserializeOwned, - { + pub async fn insert_latest_state(&self, state: &T) -> Result<()> { let key = serialize(&Self::LAST_STATE_KEY)?; let new_value = serialize(&state).context("Could not serialize new state value")?; @@ -37,10 +44,7 @@ impl Database { .context("Could not flush db") } - pub fn get_latest_state(&self) -> anyhow::Result - where - T: DeserializeOwned, - { + pub fn get_latest_state(&self) -> anyhow::Result { let key = serialize(&Self::LAST_STATE_KEY)?; let encoded = self