use super::{ authentication::{AuthHandler, Authenticate, Keychain, KeychainResult, Verifier}, Backup, FramedTransport, HeapSecretKey, Reconnectable, Transport, }; use async_trait::async_trait; use log::*; use serde::{Deserialize, Serialize}; use std::io; use std::ops::{Deref, DerefMut}; use tokio::sync::oneshot; #[cfg(test)] use super::InmemoryTransport; /// Id of the connection pub type ConnectionId = u32; /// Represents a connection from either the client or server side #[derive(Debug)] pub enum Connection { /// Connection from the client side Client { /// Unique id associated with the connection id: ConnectionId, /// One-time password (OTP) for use in reauthenticating with the server reauth_otp: HeapSecretKey, /// Underlying transport used to communicate transport: FramedTransport, }, /// Connection from the server side Server { /// Unique id associated with the connection id: ConnectionId, /// Used to send the backup into storage when the connection is dropped tx: oneshot::Sender, /// Underlying transport used to communicate transport: FramedTransport, }, } impl Deref for Connection { type Target = FramedTransport; fn deref(&self) -> &Self::Target { match self { Self::Client { transport, .. } => transport, Self::Server { transport, .. } => transport, } } } impl DerefMut for Connection { fn deref_mut(&mut self) -> &mut Self::Target { match self { Self::Client { transport, .. } => transport, Self::Server { transport, .. } => transport, } } } impl Drop for Connection { /// On drop for a server connection, the connection's backup will be sent via `tx`. For a /// client connection, nothing happens. fn drop(&mut self) { match self { Self::Client { .. } => (), Self::Server { tx, transport, .. } => { // NOTE: We grab the current backup state and store it using the tx, replacing // the backup with a default and the tx with a disconnected one let backup = std::mem::take(&mut transport.backup); let tx = std::mem::replace(tx, oneshot::channel().0); let _ = tx.send(backup); } } } } #[async_trait] impl Reconnectable for Connection where T: Transport, { /// Attempts to re-establish a connection. /// /// ### Client /// /// For a client, this means performing an actual [`reconnect`] on the underlying /// [`Transport`], re-establishing an encrypted codec, submitting a request to the server to /// reauthenticate using a previously-derived OTP, and refreshing the connection id and OTP for /// use in a future reauthentication. /// /// ### Server /// /// For a server, this will fail as unsupported. /// /// [`reconnect`]: Reconnectable::reconnect async fn reconnect(&mut self) -> io::Result<()> { async fn reconnect_client( id: &mut ConnectionId, reauth_otp: &mut HeapSecretKey, transport: &mut FramedTransport, ) -> io::Result<()> { // Re-establish a raw connection debug!("[Conn {id}] Re-establishing connection"); Reconnectable::reconnect(transport).await?; // Perform a handshake to ensure that the connection is properly established and encrypted debug!("[Conn {id}] Performing handshake"); transport.client_handshake().await?; // Communicate that we are an existing connection debug!("[Conn {id}] Performing re-authentication"); transport .write_frame_for(&ConnectType::Reconnect { id: *id, otp: reauth_otp.unprotected_as_bytes().to_vec(), }) .await?; // Receive the new id for the connection // NOTE: If we fail re-authentication above, // this will fail as the connection is dropped debug!("[Conn {id}] Receiving new connection id"); let new_id = transport .read_frame_as::() .await? .ok_or_else(|| { io::Error::new(io::ErrorKind::Other, "Missing connection id frame") })?; debug!("[Conn {id}] Resetting id to {new_id}"); *id = new_id; // Derive an OTP for reauthentication debug!("[Conn {id}] Deriving future OTP for reauthentication"); *reauth_otp = transport.exchange_keys().await?.into_heap_secret_key(); Ok(()) } match self { Self::Client { id, transport, reauth_otp, } => { // Freeze our backup as we don't want the connection logic to alter it transport.backup.freeze(); // Attempt to perform the reconnection and unfreeze our backup regardless of the // result let result = reconnect_client(id, reauth_otp, transport).await; transport.backup.unfreeze(); result?; // Perform synchronization debug!("[Conn {id}] Synchronizing frame state"); transport.synchronize().await?; Ok(()) } Self::Server { .. } => Err(io::Error::new( io::ErrorKind::Unsupported, "Server connection cannot reconnect", )), } } } /// Type of connection to perform #[derive(Debug, Serialize, Deserialize)] enum ConnectType { /// Indicates that the connection from client to server is no and not a reconnection Connect, /// Indicates that the connection from client to server is a reconnection and should attempt to /// use the connection id and OTP to authenticate Reconnect { /// Id of the connection to reauthenticate id: ConnectionId, /// Raw bytes of the OTP #[serde(with = "serde_bytes")] otp: Vec, }, } impl Connection where T: Transport, { /// Transforms a raw [`Transport`] into an established [`Connection`] from the client-side by /// performing the following: /// /// 1. Handshakes to derive the appropriate [`Codec`](crate::Codec) to use /// 2. Authenticates the established connection to ensure it is valid /// 3. Restores pre-existing state using the provided backup, replaying any missing frames and /// receiving any frames from the other side pub async fn client(transport: T, handler: H) -> io::Result { let id: ConnectionId = rand::random(); // Perform a handshake to ensure that the connection is properly established and encrypted debug!("[Conn {id}] Performing handshake"); let mut transport: FramedTransport = FramedTransport::from_client_handshake(transport).await?; // Communicate that we are a new connection debug!("[Conn {id}] Communicating that this is a new connection"); transport.write_frame_for(&ConnectType::Connect).await?; // Receive the new id for the connection let id = { debug!("[Conn {id}] Receiving new connection id"); let new_id = transport .read_frame_as::() .await? .ok_or_else(|| { io::Error::new(io::ErrorKind::Other, "Missing connection id frame") })?; debug!("[Conn {id}] Resetting id to {new_id}"); new_id }; // Authenticate the transport with the server-side debug!("[Conn {id}] Performing authentication"); transport.authenticate(handler).await?; // Derive an OTP for reauthentication debug!("[Conn {id}] Deriving future OTP for reauthentication"); let reauth_otp = transport.exchange_keys().await?.into_heap_secret_key(); Ok(Self::Client { id, reauth_otp, transport, }) } /// Transforms a raw [`Transport`] into an established [`Connection`] from the server-side by /// performing the following: /// /// 1. Handshakes to derive the appropriate [`Codec`](crate::Codec) to use /// 2. Authenticates the established connection to ensure it is valid by either using the /// given `verifier` or, if working with an existing client connection, will validate an OTP /// from our database /// 3. Restores pre-existing state using the provided backup, replaying any missing frames and /// receiving any frames from the other side pub async fn server( transport: T, verifier: &Verifier, keychain: Keychain>, ) -> io::Result { let id: ConnectionId = rand::random(); // Perform a handshake to ensure that the connection is properly established and encrypted debug!("[Conn {id}] Performing handshake"); let mut transport: FramedTransport = FramedTransport::from_server_handshake(transport).await?; // Receive a client id, look up to see if the client id exists already // // 1. If it already exists, wait for a password to follow, which is a one-time password used by // the client. If the password is correct, then generate a new one-time client id and // password for a future connection (only updating if the connection fully completes) and // send it to the client, and then perform a replay situation // // 2. If it does not exist, ignore the client id and password. Generate a new client id to send // to the client. Perform verification like usual. Then generate a one-time password and // send it to the client. debug!("[Conn {id}] Waiting for connection type"); let connection_type = transport .read_frame_as::() .await? .ok_or_else(|| io::Error::new(io::ErrorKind::Other, "Missing connection type frame"))?; // Create a oneshot channel used to relay the backup when the connection is dropped let (tx, rx) = oneshot::channel(); // Based on the connection type, we either try to find and validate an existing connection // or we perform normal verification match connection_type { ConnectType::Connect => { // Communicate the connection id debug!("[Conn {id}] Telling other side to change connection id"); transport.write_frame_for(&id).await?; // Perform authentication to ensure the connection is valid debug!("[Conn {id}] Verifying connection"); verifier.verify(&mut transport).await?; // Derive an OTP for reauthentication debug!("[Conn {id}] Deriving future OTP for reauthentication"); let reauth_otp = transport.exchange_keys().await?.into_heap_secret_key(); // Store the id, OTP, and backup retrieval in our database keychain.insert(id.to_string(), reauth_otp, rx).await; } ConnectType::Reconnect { id: other_id, otp } => { let reauth_otp = HeapSecretKey::from(otp); debug!("[Conn {id}] Checking if {other_id} exists and has matching OTP"); match keychain .remove_if_has_key(other_id.to_string(), reauth_otp) .await { KeychainResult::Ok(x) => { // Communicate the connection id debug!("[Conn {id}] Telling other side to change connection id"); transport.write_frame_for(&id).await?; // Derive an OTP for reauthentication debug!("[Conn {id}] Deriving future OTP for reauthentication"); let reauth_otp = transport.exchange_keys().await?.into_heap_secret_key(); // Grab the old backup and swap it into our transport debug!("[Conn {id}] Acquiring backup for existing connection"); match x.await { Ok(backup) => { transport.backup = backup; } Err(_) => { warn!("[Conn {id}] Missing backup"); } } // Synchronize using the provided backup debug!("[Conn {id}] Synchronizing frame state"); transport.synchronize().await?; // Store the id, OTP, and backup retrieval in our database keychain.insert(id.to_string(), reauth_otp, rx).await; } KeychainResult::InvalidPassword => { return Err(io::Error::new( io::ErrorKind::PermissionDenied, "Invalid OTP for reconnect", )); } KeychainResult::InvalidId => { return Err(io::Error::new( io::ErrorKind::PermissionDenied, "Invalid id for reconnect", )); } } } } Ok(Self::Server { id, tx, transport }) } } #[cfg(test)] impl Connection { /// Establishes a pair of [`Connection`]s using [`InmemoryTransport`] underneath, returning /// them in the form (client, server). /// /// ### Note /// /// This skips handshakes, authentication, and backup processing. These connections cannot be /// reconnected and have no encryption. pub fn pair(buffer: usize) -> (Self, Self) { let id = rand::random::(); let (t1, t2) = FramedTransport::pair(buffer); let client = Connection::Client { id, reauth_otp: HeapSecretKey::generate(32).unwrap(), transport: t1, }; let server = Connection::Server { id, tx: oneshot::channel().0, transport: t2, }; (client, server) } } #[cfg(test)] impl Connection { /// Returns the id of the connection. pub fn id(&self) -> ConnectionId { match self { Self::Client { id, .. } => *id, Self::Server { id, .. } => *id, } } /// Returns the OTP associated with the connection, or none if connection is server-side. pub fn otp(&self) -> Option<&HeapSecretKey> { match self { Self::Client { reauth_otp, .. } => Some(reauth_otp), Self::Server { .. } => None, } } /// Returns a reference to the underlying transport. pub fn transport(&self) -> &FramedTransport { match self { Self::Client { transport, .. } => transport, Self::Server { transport, .. } => transport, } } /// Returns a mutable reference to the underlying transport. pub fn mut_transport(&mut self) -> &mut FramedTransport { match self { Self::Client { transport, .. } => transport, Self::Server { transport, .. } => transport, } } } #[cfg(test)] impl Connection { pub fn test_client(transport: T) -> Self { Self::Client { id: rand::random(), reauth_otp: HeapSecretKey::generate(32).unwrap(), transport: FramedTransport::plain(transport), } } } #[cfg(test)] mod tests { use super::*; use crate::common::{ authentication::{msg::Challenge, Authenticator, DummyAuthHandler}, Frame, }; use std::sync::Arc; use test_log::test; #[test(tokio::test)] async fn client_should_fail_if_codec_handshake_fails() { let (mut t1, t2) = FramedTransport::pair(100); // Spawn a task to perform the client connection so we don't deadlock while simulating the // server actions on the other side let task = tokio::spawn(async move { Connection::client(t2.into_inner(), DummyAuthHandler) .await .unwrap() }); // Send garbage to fail the handshake t1.write_frame(Frame::new(b"invalid")).await.unwrap(); // Client should fail task.await.unwrap_err(); } #[test(tokio::test)] async fn client_should_fail_if_unable_to_receive_connection_id_from_server() { let (mut t1, t2) = FramedTransport::pair(100); // Spawn a task to perform the client connection so we don't deadlock while simulating the // server actions on the other side let task = tokio::spawn(async move { Connection::client(t2.into_inner(), DummyAuthHandler) .await .unwrap() }); // Perform first step of connection by establishing the codec t1.server_handshake().await.unwrap(); // Receive a type that indicates a new connection let ct = t1.read_frame_as::().await.unwrap().unwrap(); assert!( matches!(ct, ConnectType::Connect), "Unexpected connect type: {ct:?}" ); // Drop to cause id retrieval on client to fail drop(t1); // Client should fail task.await.unwrap_err(); } #[test(tokio::test)] async fn client_should_fail_if_authentication_fails() { let (mut t1, t2) = FramedTransport::pair(100); // Spawn a task to perform the client connection so we don't deadlock while simulating the // server actions on the other side let task = tokio::spawn(async move { Connection::client(t2.into_inner(), DummyAuthHandler) .await .unwrap() }); // Perform first step of connection by establishing the codec t1.server_handshake().await.unwrap(); // Receive a type that indicates a new connection let ct = t1.read_frame_as::().await.unwrap().unwrap(); assert!( matches!(ct, ConnectType::Connect), "Unexpected connect type: {ct:?}" ); // Send a connection id as second step of connection t1.write_frame_for(&rand::random::()) .await .unwrap(); // Perform an authentication request that will fail on the client side, which will // cause the client to drop and therefore this transport to fail in getting a response t1.challenge(Challenge { questions: Vec::new(), options: Default::default(), }) .await .unwrap_err(); // Client should fail task.await.unwrap_err(); } #[test(tokio::test)] async fn client_should_fail_if_unable_to_exchange_otp_for_reauthentication() { let (mut t1, t2) = FramedTransport::pair(100); // Spawn a task to perform the client connection so we don't deadlock while simulating the // server actions on the other side let task = tokio::spawn(async move { Connection::client(t2.into_inner(), DummyAuthHandler) .await .unwrap() }); // Perform first step of connection by establishing the codec t1.server_handshake().await.unwrap(); // Receive a type that indicates a new connection let ct = t1.read_frame_as::().await.unwrap().unwrap(); assert!( matches!(ct, ConnectType::Connect), "Unexpected connect type: {ct:?}" ); // Send a connection id as second step of connection t1.write_frame_for(&rand::random::()) .await .unwrap(); // Perform verification as third step using none method, which should always succeed // without challenging Verifier::none().verify(&mut t1).await.unwrap(); // Send garbage to fail the key exchange t1.write_frame(Frame::new(b"invalid")).await.unwrap(); // Client should fail task.await.unwrap_err(); } #[test(tokio::test)] async fn client_should_succeed_if_establishes_connection_with_server() { let (mut t1, t2) = FramedTransport::pair(100); // Spawn a task to perform the client connection so we don't deadlock while simulating the // server actions on the other side let task = tokio::spawn(async move { Connection::client(t2.into_inner(), DummyAuthHandler) .await .unwrap() }); // Perform first step of connection by establishing the codec t1.server_handshake().await.unwrap(); // Receive a type that indicates a new connection let ct = t1.read_frame_as::().await.unwrap().unwrap(); assert!( matches!(ct, ConnectType::Connect), "Unexpected connect type: {ct:?}" ); // Send a connection id as second step of connection t1.write_frame_for(&rand::random::()) .await .unwrap(); // Perform verification as third step using none method, which should always succeed // without challenging Verifier::none().verify(&mut t1).await.unwrap(); // Perform fourth step of key exchange for OTP let otp = t1.exchange_keys().await.unwrap().into_heap_secret_key(); // Client should succeed and have an OTP that matches the server-side version let client = task.await.unwrap(); assert_eq!(client.otp(), Some(&otp)); } #[test(tokio::test)] async fn server_should_fail_if_codec_handshake_fails() { let (mut t1, t2) = FramedTransport::pair(100); let verifier = Verifier::none(); let keychain = Keychain::new(); // Spawn a task to perform the server connection so we don't deadlock while simulating the // client actions on the other side let task = tokio::spawn(async move { Connection::server(t2.into_inner(), &verifier, keychain) .await .unwrap() }); // Send garbage to fail the handshake t1.write_frame(Frame::new(b"invalid")).await.unwrap(); // Server should fail task.await.unwrap_err(); } #[test(tokio::test)] async fn server_should_fail_if_unable_to_receive_connect_type() { let (mut t1, t2) = FramedTransport::pair(100); let verifier = Verifier::none(); let keychain = Keychain::new(); // Spawn a task to perform the server connection so we don't deadlock while simulating the // client actions on the other side let task = tokio::spawn(async move { Connection::server(t2.into_inner(), &verifier, keychain) .await .unwrap() }); // Perform first step of completing client-side of handshake t1.client_handshake().await.unwrap(); // Send some garbage that is not the connection type t1.write_frame(Frame::new(b"hello")).await.unwrap(); // Server should fail task.await.unwrap_err(); } #[test(tokio::test)] async fn server_should_fail_if_unable_to_verify_new_client() { let (mut t1, t2) = FramedTransport::pair(100); let verifier = Verifier::static_key(HeapSecretKey::generate(32).unwrap()); let keychain = Keychain::new(); // Spawn a task to perform the server connection so we don't deadlock while simulating the // client actions on the other side let task = tokio::spawn(async move { Connection::server(t2.into_inner(), &verifier, keychain) .await .unwrap() }); // Perform first step of completing client-side of handshake t1.client_handshake().await.unwrap(); // Send type to indicate a new connection t1.write_frame_for(&ConnectType::Connect).await.unwrap(); // Receive the connection id let _id = t1.read_frame_as::().await.unwrap().unwrap(); // Fail verification using the dummy handler that will fail when asked for a static key t1.authenticate(DummyAuthHandler).await.unwrap_err(); // Drop the transport so we kill the server-side connection // NOTE: If we don't drop here, the above authentication failure won't kill the server drop(t1); // Server should fail task.await.unwrap_err(); } #[test(tokio::test)] async fn server_should_fail_if_unable_to_exchange_otp_for_reauthentication_with_new_client() { let (mut t1, t2) = FramedTransport::pair(100); let verifier = Verifier::none(); let keychain = Keychain::new(); // Spawn a task to perform the server connection so we don't deadlock while simulating the // client actions on the other side let task = tokio::spawn(async move { Connection::server(t2.into_inner(), &verifier, keychain) .await .unwrap() }); // Perform first step of completing client-side of handshake t1.client_handshake().await.unwrap(); // Send type to indicate a new connection t1.write_frame_for(&ConnectType::Connect).await.unwrap(); // Receive the connection id let _id = t1.read_frame_as::().await.unwrap().unwrap(); // Pass verification using the dummy handler since our verifier supports no authentication t1.authenticate(DummyAuthHandler).await.unwrap(); // Send some garbage to fail the exchange t1.write_frame(Frame::new(b"hello")).await.unwrap(); // Server should fail task.await.unwrap_err(); } #[test(tokio::test)] async fn server_should_fail_if_existing_client_id_is_invalid() { let (mut t1, t2) = FramedTransport::pair(100); let verifier = Verifier::none(); let keychain = Keychain::new(); // Spawn a task to perform the server connection so we don't deadlock while simulating the // client actions on the other side let task = tokio::spawn(async move { Connection::server(t2.into_inner(), &verifier, keychain) .await .unwrap() }); // Perform first step of completing client-side of handshake t1.client_handshake().await.unwrap(); // Send type to indicate an existing connection, which should cause the server-side to fail // because there is no matching id t1.write_frame_for(&ConnectType::Reconnect { id: 1234, otp: HeapSecretKey::generate(32) .unwrap() .unprotected_into_bytes(), }) .await .unwrap(); // Server should fail task.await.unwrap_err(); } #[test(tokio::test)] async fn server_should_fail_if_existing_client_otp_is_invalid() { let (mut t1, t2) = FramedTransport::pair(100); let verifier = Verifier::none(); let keychain = Keychain::new(); keychain .insert( 1234.to_string(), HeapSecretKey::generate(32).unwrap(), oneshot::channel().1, ) .await; // Spawn a task to perform the server connection so we don't deadlock while simulating the // client actions on the other side let task = tokio::spawn(async move { Connection::server(t2.into_inner(), &verifier, keychain) .await .unwrap() }); // Perform first step of completing client-side of handshake t1.client_handshake().await.unwrap(); // Send type to indicate an existing connection, which should cause the server-side to fail // because the OTP is wrong for the given id t1.write_frame_for(&ConnectType::Reconnect { id: 1234, otp: HeapSecretKey::generate(32) .unwrap() .unprotected_into_bytes(), }) .await .unwrap(); // Server should fail task.await.unwrap_err(); } #[test(tokio::test)] async fn server_should_fail_if_unable_to_exchange_otp_for_reauthentication_with_existing_client( ) { let (mut t1, t2) = FramedTransport::pair(100); let verifier = Verifier::none(); let keychain = Keychain::new(); let key = HeapSecretKey::generate(32).unwrap(); keychain .insert(1234.to_string(), key.clone(), oneshot::channel().1) .await; // Spawn a task to perform the server connection so we don't deadlock while simulating the // client actions on the other side let task = tokio::spawn(async move { Connection::server(t2.into_inner(), &verifier, keychain) .await .unwrap() }); // Perform first step of completing client-side of handshake t1.client_handshake().await.unwrap(); // Send type to indicate an existing connection, which should cause the server-side to fail // because the OTP is wrong for the given id t1.write_frame_for(&ConnectType::Reconnect { id: 1234, otp: key.unprotected_into_bytes(), }) .await .unwrap(); // Receive a new client id let _id = t1.read_frame_as::().await.unwrap().unwrap(); // Send garbage to fail the otp exchange t1.write_frame(Frame::new(b"hello")).await.unwrap(); // Server should fail task.await.unwrap_err(); } #[test(tokio::test)] async fn server_should_fail_if_unable_to_synchronize_with_existing_client() { let (mut t1, t2) = FramedTransport::pair(100); let verifier = Verifier::none(); let keychain = Keychain::new(); let key = HeapSecretKey::generate(32).unwrap(); keychain .insert(1234.to_string(), key.clone(), oneshot::channel().1) .await; // Spawn a task to perform the server connection so we don't deadlock while simulating the // client actions on the other side let task = tokio::spawn(async move { Connection::server(t2.into_inner(), &verifier, keychain) .await .unwrap() }); // Perform first step of completing client-side of handshake t1.client_handshake().await.unwrap(); // Send type to indicate an existing connection, which should cause the server-side to fail // because the OTP is wrong for the given id t1.write_frame_for(&ConnectType::Reconnect { id: 1234, otp: key.unprotected_into_bytes(), }) .await .unwrap(); // Receive a new client id let _id = t1.read_frame_as::().await.unwrap().unwrap(); // Perform otp exchange let _otp = t1.exchange_keys().await.unwrap(); // Send garbage to fail synchronization t1.write_frame(b"hello").await.unwrap(); // Server should fail task.await.unwrap_err(); } #[test(tokio::test)] async fn server_should_succeed_if_establishes_connection_with_new_client() { let (mut t1, t2) = FramedTransport::pair(100); let verifier = Verifier::none(); let keychain = Keychain::new(); // Spawn a task to perform the server connection so we don't deadlock while simulating the // client actions on the other side let task = tokio::spawn({ let keychain = keychain.clone(); async move { Connection::server(t2.into_inner(), &verifier, keychain) .await .unwrap() } }); // Perform first step of completing client-side of handshake t1.client_handshake().await.unwrap(); // Send type to indicate a new connection t1.write_frame_for(&ConnectType::Connect).await.unwrap(); // Receive the connection id let id = t1.read_frame_as::().await.unwrap().unwrap(); // Pass verification using the dummy handler since our verifier supports no authentication t1.authenticate(DummyAuthHandler).await.unwrap(); // Perform otp exchange let otp = t1.exchange_keys().await.unwrap(); // Server connection should be established, and have received some replayed frames let server = task.await.unwrap(); // Validate the connection ids match assert_eq!(server.id(), id); // Validate the OTP was stored in our keychain assert!( keychain .has_key(id.to_string(), otp.into_heap_secret_key()) .await, "Missing OTP" ); } #[test(tokio::test)] async fn server_should_succeed_if_establishes_connection_with_existing_client() { let (mut t1, t2) = FramedTransport::pair(100); let verifier = Verifier::none(); let keychain = Keychain::new(); let key = HeapSecretKey::generate(32).unwrap(); keychain .insert(1234.to_string(), key.clone(), { // Create a custom backup we'll use to replay frames from the server-side let mut backup = Backup::new(); backup.push_frame(Frame::new(b"hello")); backup.push_frame(Frame::new(b"world")); backup.increment_sent_cnt(); backup.increment_sent_cnt(); let (tx, rx) = oneshot::channel(); tx.send(backup).unwrap(); rx }) .await; // Spawn a task to perform the server connection so we don't deadlock while simulating the // client actions on the other side let task = tokio::spawn({ let keychain = keychain.clone(); async move { Connection::server(t2.into_inner(), &verifier, keychain) .await .unwrap() } }); // Perform first step of completing client-side of handshake t1.client_handshake().await.unwrap(); // Send type to indicate an existing connection, which should cause the server-side to fail // because the OTP is wrong for the given id t1.write_frame_for(&ConnectType::Reconnect { id: 1234, otp: key.unprotected_into_bytes(), }) .await .unwrap(); // Receive a new client id let id = t1.read_frame_as::().await.unwrap().unwrap(); // Perform otp exchange let otp = t1.exchange_keys().await.unwrap(); // Queue up some frames to send to the server t1.backup.clear(); t1.backup.push_frame(Frame::new(b"foo")); t1.backup.push_frame(Frame::new(b"bar")); t1.backup.increment_sent_cnt(); t1.backup.increment_sent_cnt(); // Perform synchronization t1.synchronize().await.unwrap(); // Verify that we received frames from the server assert_eq!(t1.read_frame().await.unwrap().unwrap(), b"hello"); assert_eq!(t1.read_frame().await.unwrap().unwrap(), b"world"); // Server connection should be established, and have received some replayed frames let mut server = task.await.unwrap(); assert_eq!(server.read_frame().await.unwrap().unwrap(), b"foo"); assert_eq!(server.read_frame().await.unwrap().unwrap(), b"bar"); // Validate the connection ids match assert_eq!(server.id(), id); // Check that our old connection id is no longer contained in the keychain assert!(!keychain.has_id("1234").await, "Old OTP still exists"); // Validate the OTP was stored in our keychain assert!( keychain .has_key(id.to_string(), otp.into_heap_secret_key()) .await, "Missing OTP" ); } #[test(tokio::test)] async fn client_server_new_connection_e2e_should_establish_connection() { let (t1, t2) = InmemoryTransport::pair(100); let verifier = Verifier::none(); let keychain = Keychain::new(); // Spawn a task to perform the server connection so we don't deadlock let task = tokio::spawn(async move { Connection::server(t2, &verifier, keychain) .await .expect("Failed to connect from server") }); // Perform the client-side of the connection let mut client = Connection::client(t1, DummyAuthHandler) .await .expect("Failed to connect from client"); let mut server = task.await.unwrap(); // Test out the connection client.write_frame(Frame::new(b"hello")).await.unwrap(); assert_eq!(server.read_frame().await.unwrap().unwrap(), b"hello"); server.write_frame(Frame::new(b"goodbye")).await.unwrap(); assert_eq!(client.read_frame().await.unwrap().unwrap(), b"goodbye"); } /// Helper utility to set up for a client reconnection async fn setup_reconnect_scenario() -> ( Connection, InmemoryTransport, Arc, Keychain>, ) { let (t1, t2) = InmemoryTransport::pair(100); let verifier = Arc::new(Verifier::none()); let keychain = Keychain::new(); // Spawn a task to perform the server connection so we don't deadlock let task = { let verifier = Arc::clone(&verifier); let keychain = keychain.clone(); tokio::spawn(async move { Connection::server(t2, &verifier, keychain) .await .expect("Failed to connect from server") }) }; // Perform the client-side of the connection let mut client = Connection::client(t1, DummyAuthHandler) .await .expect("Failed to connect from client"); // Ensure the server is established and then drop it let server = task.await.unwrap(); drop(server); // Create a new inmemory transport and link it to the client let mut t2 = InmemoryTransport::pair(100).0; t2.link(client.mut_transport().as_mut_inner(), 100); (client, t2, verifier, keychain) } #[test(tokio::test)] async fn reconnect_should_fail_if_client_side_connection_handshake_fails() { let (mut client, transport, _verifier, _keychain) = setup_reconnect_scenario().await; let mut transport = FramedTransport::plain(transport); // Spawn a task to perform the client reconnection so we don't deadlock let task = tokio::spawn(async move { client.reconnect().await.unwrap() }); // Send garbage to fail handshake from server-side transport.write_frame(b"hello").await.unwrap(); // Client should fail task.await.unwrap_err(); } #[test(tokio::test)] async fn reconnect_should_fail_if_client_side_connection_unable_to_receive_new_connection_id() { let (mut client, transport, _verifier, _keychain) = setup_reconnect_scenario().await; let mut transport = FramedTransport::plain(transport); // Spawn a task to perform the client reconnection so we don't deadlock let task = tokio::spawn(async move { client.reconnect().await.unwrap() }); // Perform first step of completing server-side of handshake transport.server_handshake().await.unwrap(); // Drop transport to cause client to fail in not receiving connection id drop(transport); // Client should fail task.await.unwrap_err(); } #[test(tokio::test)] async fn reconnect_should_fail_if_client_side_connection_unable_to_exchange_otp_with_server() { let (mut client, transport, _verifier, keychain) = setup_reconnect_scenario().await; let mut transport = FramedTransport::plain(transport); // Spawn a task to perform the client reconnection so we don't deadlock let task = tokio::spawn(async move { client.reconnect().await.unwrap() }); // Perform first step of completing server-side of handshake transport.server_handshake().await.unwrap(); // Receive reconnect data from client-side let (id, otp) = match transport.read_frame_as::().await { Ok(Some(ConnectType::Reconnect { id, otp })) => (id, HeapSecretKey::from(otp)), x => panic!("Unexpected result: {x:?}"), }; // Verify the id and OTP matches the one stored into our keychain from the setup assert!( keychain.has_key(id.to_string(), otp).await, "Wrong id or OTP" ); // Send a new id back to the client connection transport .write_frame_for(&rand::random::()) .await .unwrap(); // Send garbage to fail the key exchange for new OTP transport.write_frame(Frame::new(b"hello")).await.unwrap(); // Client should fail task.await.unwrap_err(); } #[test(tokio::test)] async fn reconnect_should_fail_if_client_side_connection_unable_to_synchronize_with_server() { let (mut client, transport, _verifier, keychain) = setup_reconnect_scenario().await; let mut transport = FramedTransport::plain(transport); // Spawn a task to perform the client reconnection so we don't deadlock let task = tokio::spawn(async move { client.reconnect().await.unwrap() }); // Perform first step of completing server-side of handshake transport.server_handshake().await.unwrap(); // Receive reconnect data from client-side let (id, otp) = match transport.read_frame_as::().await { Ok(Some(ConnectType::Reconnect { id, otp })) => (id, HeapSecretKey::from(otp)), x => panic!("Unexpected result: {x:?}"), }; // Verify the id and OTP matches the one stored into our keychain from the setup assert!( keychain.has_key(id.to_string(), otp).await, "Wrong id or OTP" ); // Send a new id back to the client connection transport .write_frame_for(&rand::random::()) .await .unwrap(); // Send garbage to fail the key exchange for new OTP transport.write_frame(Frame::new(b"hello")).await.unwrap(); // Client should fail task.await.unwrap_err(); } #[test(tokio::test)] async fn reconnect_should_succeed_if_client_side_connection_fully_connects_and_synchronizes_with_server( ) { let (mut client, transport, _verifier, keychain) = setup_reconnect_scenario().await; let mut transport = FramedTransport::plain(transport); // Copy client backup for verification later let client_backup = client.transport().backup.clone(); // Spawn a task to perform the client reconnection so we don't deadlock let task = tokio::spawn(async move { client.reconnect().await.unwrap(); client }); // Perform first step of completing server-side of handshake transport.server_handshake().await.unwrap(); // Receive reconnect data from client-side let (id, otp) = match transport.read_frame_as::().await { Ok(Some(ConnectType::Reconnect { id, otp })) => (id, HeapSecretKey::from(otp)), x => panic!("Unexpected result: {x:?}"), }; // Retrieve server backup let backup = keychain .remove_if_has_key(id.to_string(), otp) .await .into_ok() .expect("Invalid id or OTP") .await .expect("Failed to retrieve backup"); // Send a new id back to the client connection transport .write_frame_for(&rand::random::()) .await .unwrap(); // Perform key exchange let otp = transport.exchange_keys().await.unwrap(); // Perform synchronization after restoring backup transport.backup = backup; transport.synchronize().await.unwrap(); // Client should succeed let mut client = task.await.unwrap(); assert_eq!(client.otp(), Some(&otp.into_heap_secret_key())); // Verify client backup sent/received count was not modified (stored frames may be // truncated, though) assert_eq!( client.transport().backup.sent_cnt(), client_backup.sent_cnt(), "Client backup sent cnt altered" ); assert_eq!( client.transport().backup.received_cnt(), client_backup.received_cnt(), "Client backup received cnt altered" ); // Verify that client can send a frame and receive a frame, and that there is // nothing unexpected in the buffers on either side client.write_frame(Frame::new(b"hello")).await.unwrap(); assert_eq!(transport.read_frame().await.unwrap().unwrap(), b"hello"); transport.write_frame(Frame::new(b"goodbye")).await.unwrap(); assert_eq!(client.read_frame().await.unwrap().unwrap(), b"goodbye"); } #[test(tokio::test)] async fn reconnect_should_fail_if_connection_is_server_side() { let mut connection = Connection::Server { id: rand::random(), tx: oneshot::channel().0, transport: FramedTransport::pair(100).0, }; assert_eq!( connection.reconnect().await.unwrap_err().kind(), io::ErrorKind::Unsupported ); } #[test(tokio::test)] async fn client_server_returning_connection_e2e_should_reestablish_connection() { let (mut client, transport, verifier, keychain) = setup_reconnect_scenario().await; // Spawn a task to perform the server reconnection so we don't deadlock let task = tokio::spawn(async move { Connection::server(transport, &verifier, keychain) .await .expect("Failed to connect from server") }); // Reconnect and verify that the connection still works client .reconnect() .await .expect("Failed to reconnect from client"); // Ensure the server is established and then drop it let mut server = task.await.unwrap(); // Test out the connection client.write_frame(Frame::new(b"hello")).await.unwrap(); assert_eq!(server.read_frame().await.unwrap().unwrap(), b"hello"); server.write_frame(Frame::new(b"goodbye")).await.unwrap(); assert_eq!(client.read_frame().await.unwrap().unwrap(), b"goodbye"); } }