@ -11,6 +11,7 @@ use tokio::sync::oneshot;
use crate ::common ::InmemoryTransport ;
use crate ::common ::{
Backup , FramedTransport , HeapSecretKey , Keychain , KeychainResult , Reconnectable , Transport ,
TransportExt , Version ,
} ;
/// Id of the connection
@ -110,6 +111,19 @@ where
debug ! ( "[Conn {id}] Re-establishing connection" ) ;
Reconnectable ::reconnect ( transport ) . await ? ;
// Wait for exactly version bytes (24 where 8 bytes for major, minor, patch)
// but with a reconnect we don't actually validate it because we did that
// the first time we connected
//
// NOTE: We do this with the raw transport and not the framed version!
debug ! ( "[Conn {id}] Waiting for server version" ) ;
if transport . as_mut_inner ( ) . read_exact ( & mut [ 0 u8 ; 24 ] ) . await ? ! = 24 {
return Err ( io ::Error ::new (
io ::ErrorKind ::InvalidData ,
"Wrong version byte len received" ,
) ) ;
}
// Perform a handshake to ensure that the connection is properly established and encrypted
debug ! ( "[Conn {id}] Performing handshake" ) ;
transport . client_handshake ( ) . await ? ;
@ -190,13 +204,42 @@ where
/// Transforms a raw [`Transport`] into an established [`Connection`] from the client-side by
/// performing the following:
///
/// 1. Handshakes to derive the appropriate [`Codec`](crate::Codec) to use
/// 2. Authenticates the established connection to ensure it is valid
/// 3. Restores pre-existing state using the provided backup, replaying any missing frames and
/// 1. Performs a version check with the server
/// 2. Handshakes to derive the appropriate [`Codec`](crate::Codec) to use
/// 3. Authenticates the established connection to ensure it is valid
/// 4. Restores pre-existing state using the provided backup, replaying any missing frames and
/// receiving any frames from the other side
pub async fn client < H : AuthHandler + Send > ( transport : T , handler : H ) -> io ::Result < Self > {
pub async fn client < H : AuthHandler + Send > (
transport : T ,
handler : H ,
version : Version ,
) -> io ::Result < Self > {
let id : ConnectionId = rand ::random ( ) ;
// Wait for exactly version bytes (24 where 8 bytes for major, minor, patch)
debug ! ( "[Conn {id}] Waiting for server version" ) ;
let mut version_bytes = [ 0 u8 ; 24 ] ;
if transport . read_exact ( & mut version_bytes ) . await ? ! = 24 {
return Err ( io ::Error ::new (
io ::ErrorKind ::InvalidData ,
"Wrong version byte len received" ,
) ) ;
}
// Compare versions for compatibility and drop the connection if incompatible
let server_version = Version ::from_be_bytes ( version_bytes ) ;
debug ! (
"[Conn {id}] Checking compatibility between client {version} & server {server_version}"
) ;
if ! version . is_compatible_with ( & server_version ) {
return Err ( io ::Error ::new (
io ::ErrorKind ::Other ,
format! (
"Client version {version} is incompatible with server version {server_version}"
) ,
) ) ;
}
// Perform a handshake to ensure that the connection is properly established and encrypted
debug ! ( "[Conn {id}] Performing handshake" ) ;
let mut transport : FramedTransport < T > =
@ -238,19 +281,25 @@ where
/// Transforms a raw [`Transport`] into an established [`Connection`] from the server-side by
/// performing the following:
///
/// 1. Handshakes to derive the appropriate [`Codec`](crate::Codec) to use
/// 2. Authenticates the established connection to ensure it is valid by either using the
/// 1. Performs a version check with the client
/// 2. Handshakes to derive the appropriate [`Codec`](crate::Codec) to use
/// 3. Authenticates the established connection to ensure it is valid by either using the
/// given `verifier` or, if working with an existing client connection, will validate an OTP
/// from our database
/// 3 . Restores pre-existing state using the provided backup, replaying any missing frames and
/// 4 . Restores pre-existing state using the provided backup, replaying any missing frames and
/// receiving any frames from the other side
pub async fn server (
transport : T ,
verifier : & Verifier ,
keychain : Keychain < oneshot ::Receiver < Backup > > ,
version : Version ,
) -> io ::Result < Self > {
let id : ConnectionId = rand ::random ( ) ;
// Write the version as bytes
debug ! ( "[Conn {id}] Sending version {version}" ) ;
transport . write_all ( & version . to_be_bytes ( ) ) . await ? ;
// Perform a handshake to ensure that the connection is properly established and encrypted
debug ! ( "[Conn {id}] Performing handshake" ) ;
let mut transport : FramedTransport < T > =
@ -464,6 +513,60 @@ mod tests {
use super ::* ;
use crate ::common ::Frame ;
macro_rules! server_version {
( ) = > {
Version ::new ( 1 , 2 , 3 )
} ;
}
macro_rules! send_server_version {
( $transport :expr , $version :expr ) = > { {
( $transport )
. as_mut_inner ( )
. write_all ( & $version . to_be_bytes ( ) )
. await
. unwrap ( ) ;
} } ;
( $transport :expr ) = > {
send_server_version ! ( $transport , server_version ! ( ) ) ;
} ;
}
macro_rules! receive_version {
( $transport :expr ) = > { {
let mut bytes = [ 0 u8 ; 24 ] ;
assert_eq! (
( $transport )
. as_mut_inner ( )
. read_exact ( & mut bytes )
. await
. unwrap ( ) ,
24 ,
"Wrong version len received"
) ;
Version ::from_be_bytes ( bytes )
} } ;
}
#[ test(tokio::test) ]
async fn client_should_fail_when_server_sends_incompatible_version ( ) {
let ( mut t1 , t2 ) = FramedTransport ::pair ( 100 ) ;
// Spawn a task to perform the client connection so we don't deadlock while simulating the
// server actions on the other side
let task = tokio ::spawn ( async move {
Connection ::client ( t2 . into_inner ( ) , DummyAuthHandler , Version ::new ( 1 , 2 , 3 ) )
. await
. unwrap ( )
} ) ;
// Send invalid version to fail the handshake
send_server_version ! ( t1 , Version ::new ( 2 , 0 , 0 ) ) ;
// Client should fail
task . await . unwrap_err ( ) ;
}
#[ test(tokio::test) ]
async fn client_should_fail_if_codec_handshake_fails ( ) {
let ( mut t1 , t2 ) = FramedTransport ::pair ( 100 ) ;
@ -471,11 +574,14 @@ mod tests {
// Spawn a task to perform the client connection so we don't deadlock while simulating the
// server actions on the other side
let task = tokio ::spawn ( async move {
Connection ::client ( t2 . into_inner ( ) , DummyAuthHandler )
Connection ::client ( t2 . into_inner ( ) , DummyAuthHandler , server_version ! ( ) )
. await
. unwrap ( )
} ) ;
// Send server version for client to confirm
send_server_version ! ( t1 ) ;
// Send garbage to fail the handshake
t1 . write_frame ( Frame ::new ( b" invalid " ) ) . await . unwrap ( ) ;
@ -490,11 +596,14 @@ mod tests {
// Spawn a task to perform the client connection so we don't deadlock while simulating the
// server actions on the other side
let task = tokio ::spawn ( async move {
Connection ::client ( t2 . into_inner ( ) , DummyAuthHandler )
Connection ::client ( t2 . into_inner ( ) , DummyAuthHandler , server_version ! ( ) )
. await
. unwrap ( )
} ) ;
// Send server version for client to confirm
send_server_version ! ( t1 ) ;
// Perform first step of connection by establishing the codec
t1 . server_handshake ( ) . await . unwrap ( ) ;
@ -519,11 +628,14 @@ mod tests {
// Spawn a task to perform the client connection so we don't deadlock while simulating the
// server actions on the other side
let task = tokio ::spawn ( async move {
Connection ::client ( t2 . into_inner ( ) , DummyAuthHandler )
Connection ::client ( t2 . into_inner ( ) , DummyAuthHandler , server_version ! ( ) )
. await
. unwrap ( )
} ) ;
// Send server version for client to confirm
send_server_version ! ( t1 ) ;
// Perform first step of connection by establishing the codec
t1 . server_handshake ( ) . await . unwrap ( ) ;
@ -559,11 +671,14 @@ mod tests {
// Spawn a task to perform the client connection so we don't deadlock while simulating the
// server actions on the other side
let task = tokio ::spawn ( async move {
Connection ::client ( t2 . into_inner ( ) , DummyAuthHandler )
Connection ::client ( t2 . into_inner ( ) , DummyAuthHandler , server_version ! ( ) )
. await
. unwrap ( )
} ) ;
// Send server version for client to confirm
send_server_version ! ( t1 ) ;
// Perform first step of connection by establishing the codec
t1 . server_handshake ( ) . await . unwrap ( ) ;
@ -597,11 +712,14 @@ mod tests {
// Spawn a task to perform the client connection so we don't deadlock while simulating the
// server actions on the other side
let task = tokio ::spawn ( async move {
Connection ::client ( t2 . into_inner ( ) , DummyAuthHandler )
Connection ::client ( t2 . into_inner ( ) , DummyAuthHandler , server_version ! ( ) )
. await
. unwrap ( )
} ) ;
// Send server version for client to confirm
send_server_version ! ( t1 ) ;
// Perform first step of connection by establishing the codec
t1 . server_handshake ( ) . await . unwrap ( ) ;
@ -629,6 +747,30 @@ mod tests {
assert_eq! ( client . otp ( ) , Some ( & otp ) ) ;
}
#[ test(tokio::test) ]
async fn server_should_fail_if_client_drops_due_to_version ( ) {
let ( mut t1 , t2 ) = FramedTransport ::pair ( 100 ) ;
let verifier = Verifier ::none ( ) ;
let keychain = Keychain ::new ( ) ;
// Spawn a task to perform the server connection so we don't deadlock while simulating the
// client actions on the other side
let task = tokio ::spawn ( async move {
Connection ::server ( t2 . into_inner ( ) , & verifier , keychain , server_version ! ( ) )
. await
. unwrap ( )
} ) ;
// Receive the version from the server
let _ = receive_version ! ( t1 ) ;
// Drop client connection as a result of an "incompatible version"
drop ( t1 ) ;
// Server should fail
task . await . unwrap_err ( ) ;
}
#[ test(tokio::test) ]
async fn server_should_fail_if_codec_handshake_fails ( ) {
let ( mut t1 , t2 ) = FramedTransport ::pair ( 100 ) ;
@ -638,11 +780,14 @@ mod tests {
// Spawn a task to perform the server connection so we don't deadlock while simulating the
// client actions on the other side
let task = tokio ::spawn ( async move {
Connection ::server ( t2 . into_inner ( ) , & verifier , keychain )
Connection ::server ( t2 . into_inner ( ) , & verifier , keychain , server_version ! ( ) )
. await
. unwrap ( )
} ) ;
// Receive the version from the server
let _ = receive_version ! ( t1 ) ;
// Send garbage to fail the handshake
t1 . write_frame ( Frame ::new ( b" invalid " ) ) . await . unwrap ( ) ;
@ -659,11 +804,14 @@ mod tests {
// Spawn a task to perform the server connection so we don't deadlock while simulating the
// client actions on the other side
let task = tokio ::spawn ( async move {
Connection ::server ( t2 . into_inner ( ) , & verifier , keychain )
Connection ::server ( t2 . into_inner ( ) , & verifier , keychain , server_version ! ( ) )
. await
. unwrap ( )
} ) ;
// Receive the version from the server
let _ = receive_version ! ( t1 ) ;
// Perform first step of completing client-side of handshake
t1 . client_handshake ( ) . await . unwrap ( ) ;
@ -683,11 +831,14 @@ mod tests {
// Spawn a task to perform the server connection so we don't deadlock while simulating the
// client actions on the other side
let task = tokio ::spawn ( async move {
Connection ::server ( t2 . into_inner ( ) , & verifier , keychain )
Connection ::server ( t2 . into_inner ( ) , & verifier , keychain , server_version ! ( ) )
. await
. unwrap ( )
} ) ;
// Receive the version from the server
let _ = receive_version ! ( t1 ) ;
// Perform first step of completing client-side of handshake
t1 . client_handshake ( ) . await . unwrap ( ) ;
@ -717,11 +868,14 @@ mod tests {
// Spawn a task to perform the server connection so we don't deadlock while simulating the
// client actions on the other side
let task = tokio ::spawn ( async move {
Connection ::server ( t2 . into_inner ( ) , & verifier , keychain )
Connection ::server ( t2 . into_inner ( ) , & verifier , keychain , server_version ! ( ) )
. await
. unwrap ( )
} ) ;
// Receive the version from the server
let _ = receive_version ! ( t1 ) ;
// Perform first step of completing client-side of handshake
t1 . client_handshake ( ) . await . unwrap ( ) ;
@ -750,11 +904,14 @@ mod tests {
// Spawn a task to perform the server connection so we don't deadlock while simulating the
// client actions on the other side
let task = tokio ::spawn ( async move {
Connection ::server ( t2 . into_inner ( ) , & verifier , keychain )
Connection ::server ( t2 . into_inner ( ) , & verifier , keychain , server_version ! ( ) )
. await
. unwrap ( )
} ) ;
// Receive the version from the server
let _ = receive_version ! ( t1 ) ;
// Perform first step of completing client-side of handshake
t1 . client_handshake ( ) . await . unwrap ( ) ;
@ -790,11 +947,14 @@ mod tests {
// Spawn a task to perform the server connection so we don't deadlock while simulating the
// client actions on the other side
let task = tokio ::spawn ( async move {
Connection ::server ( t2 . into_inner ( ) , & verifier , keychain )
Connection ::server ( t2 . into_inner ( ) , & verifier , keychain , server_version ! ( ) )
. await
. unwrap ( )
} ) ;
// Receive the version from the server
let _ = receive_version ! ( t1 ) ;
// Perform first step of completing client-side of handshake
t1 . client_handshake ( ) . await . unwrap ( ) ;
@ -828,11 +988,14 @@ mod tests {
// Spawn a task to perform the server connection so we don't deadlock while simulating the
// client actions on the other side
let task = tokio ::spawn ( async move {
Connection ::server ( t2 . into_inner ( ) , & verifier , keychain )
Connection ::server ( t2 . into_inner ( ) , & verifier , keychain , server_version ! ( ) )
. await
. unwrap ( )
} ) ;
// Receive the version from the server
let _ = receive_version ! ( t1 ) ;
// Perform first step of completing client-side of handshake
t1 . client_handshake ( ) . await . unwrap ( ) ;
@ -866,11 +1029,14 @@ mod tests {
// Spawn a task to perform the server connection so we don't deadlock while simulating the
// client actions on the other side
let task = tokio ::spawn ( async move {
Connection ::server ( t2 . into_inner ( ) , & verifier , keychain )
Connection ::server ( t2 . into_inner ( ) , & verifier , keychain , server_version ! ( ) )
. await
. unwrap ( )
} ) ;
// Receive the version from the server
let _ = receive_version ! ( t1 ) ;
// Perform first step of completing client-side of handshake
t1 . client_handshake ( ) . await . unwrap ( ) ;
@ -904,12 +1070,15 @@ mod tests {
let task = tokio ::spawn ( {
let keychain = keychain . clone ( ) ;
async move {
Connection ::server ( t2 . into_inner ( ) , & verifier , keychain )
Connection ::server ( t2 . into_inner ( ) , & verifier , keychain , server_version ! ( ) )
. await
. unwrap ( )
}
} ) ;
// Receive the version from the server
let _ = receive_version ! ( t1 ) ;
// Perform first step of completing client-side of handshake
t1 . client_handshake ( ) . await . unwrap ( ) ;
@ -969,12 +1138,15 @@ mod tests {
let task = tokio ::spawn ( {
let keychain = keychain . clone ( ) ;
async move {
Connection ::server ( t2 . into_inner ( ) , & verifier , keychain )
Connection ::server ( t2 . into_inner ( ) , & verifier , keychain , server_version ! ( ) )
. await
. unwrap ( )
}
} ) ;
// Receive the version from the server
let _ = receive_version ! ( t1 ) ;
// Perform first step of completing client-side of handshake
t1 . client_handshake ( ) . await . unwrap ( ) ;
@ -1029,13 +1201,13 @@ mod tests {
// Spawn a task to perform the server connection so we don't deadlock
let task = tokio ::spawn ( async move {
Connection ::server ( t2 , & verifier , keychain )
Connection ::server ( t2 , & verifier , keychain , server_version ! ( ) )
. await
. expect ( "Failed to connect from server" )
} ) ;
// Perform the client-side of the connection
let mut client = Connection ::client ( t1 , DummyAuthHandler )
let mut client = Connection ::client ( t1 , DummyAuthHandler , server_version ! ( ) )
. await
. expect ( "Failed to connect from client" ) ;
let mut server = task . await . unwrap ( ) ;
@ -1063,14 +1235,14 @@ mod tests {
let verifier = Arc ::clone ( & verifier ) ;
let keychain = keychain . clone ( ) ;
tokio ::spawn ( async move {
Connection ::server ( t2 , & verifier , keychain )
Connection ::server ( t2 , & verifier , keychain , server_version ! ( ) )
. await
. expect ( "Failed to connect from server" )
} )
} ;
// Perform the client-side of the connection
let mut client = Connection ::client ( t1 , DummyAuthHandler )
let mut client = Connection ::client ( t1 , DummyAuthHandler , server_version ! ( ) )
. await
. expect ( "Failed to connect from client" ) ;
@ -1093,6 +1265,9 @@ mod tests {
// Spawn a task to perform the client reconnection so we don't deadlock
let task = tokio ::spawn ( async move { client . reconnect ( ) . await . unwrap ( ) } ) ;
// Send a version, although it'll be ignored by a reconnecting client
send_server_version ! ( transport ) ;
// Send garbage to fail handshake from server-side
transport . write_frame ( b" hello " ) . await . unwrap ( ) ;
@ -1108,6 +1283,9 @@ mod tests {
// Spawn a task to perform the client reconnection so we don't deadlock
let task = tokio ::spawn ( async move { client . reconnect ( ) . await . unwrap ( ) } ) ;
// Send a version, although it'll be ignored by a reconnecting client
send_server_version ! ( transport ) ;
// Perform first step of completing server-side of handshake
transport . server_handshake ( ) . await . unwrap ( ) ;
@ -1126,6 +1304,9 @@ mod tests {
// Spawn a task to perform the client reconnection so we don't deadlock
let task = tokio ::spawn ( async move { client . reconnect ( ) . await . unwrap ( ) } ) ;
// Send a version, although it'll be ignored by a reconnecting client
send_server_version ! ( transport ) ;
// Perform first step of completing server-side of handshake
transport . server_handshake ( ) . await . unwrap ( ) ;
@ -1162,6 +1343,9 @@ mod tests {
// Spawn a task to perform the client reconnection so we don't deadlock
let task = tokio ::spawn ( async move { client . reconnect ( ) . await . unwrap ( ) } ) ;
// Send a version, although it'll be ignored by a reconnecting client
send_server_version ! ( transport ) ;
// Perform first step of completing server-side of handshake
transport . server_handshake ( ) . await . unwrap ( ) ;
@ -1205,6 +1389,9 @@ mod tests {
client
} ) ;
// Send a version, although it'll be ignored by a reconnecting client
send_server_version ! ( transport ) ;
// Perform first step of completing server-side of handshake
transport . server_handshake ( ) . await . unwrap ( ) ;
@ -1275,7 +1462,7 @@ mod tests {
// Spawn a task to perform the server reconnection so we don't deadlock
let task = tokio ::spawn ( async move {
Connection ::server ( transport , & verifier , keychain )
Connection ::server ( transport , & verifier , keychain , server_version ! ( ) )
. await
. expect ( "Failed to connect from server" )
} ) ;