@ -1,7 +1,10 @@
use super ::{ ConnectionCtx , ServerCtx , ServerHandler , ServerReply , ServerState , ShutdownTimer } ;
use super ::{
ConnectionCtx , ConnectionState , ServerCtx , ServerHandler , ServerReply , ServerState ,
ShutdownTimer ,
} ;
use crate ::common ::{
authentication ::{ Keychain , Verifier } ,
Backup , Connection , ConnectionId , Interest , Response , Transport , UntypedRequest ,
Backup , Connection , Frame , Interest , Response , Transport , UntypedRequest ,
} ;
use log ::* ;
use serde ::{ de ::DeserializeOwned , Serialize } ;
@ -11,56 +14,33 @@ use std::{
pin ::Pin ,
sync ::{ Arc , Weak } ,
task ::{ Context , Poll } ,
time ::Duration ,
time ::{ Duration , Instant } ,
} ;
use tokio ::{
sync ::{ mpsc, oneshot , RwLock } ,
sync ::{ broadcast, mpsc, oneshot , RwLock } ,
task ::JoinHandle ,
} ;
pub type ServerKeychain = Keychain < oneshot ::Receiver < Backup > > ;
/// Time to wait inbetween connection read/write when nothing was read or written on last pass
/// Time to wait inbetween connection read/write when nothing was read or written on last pass .
const SLEEP_DURATION : Duration = Duration ::from_millis ( 1 ) ;
/// Represents an individual connection on the server
pub struct ConnectionTask {
/// Unique identifier tied to the connection
id : ConnectionId ,
/// Minimum time between heartbeats to communicate to the client connection.
const MINIMUM_HEARTBEAT_DURATION : Duration = Duration ::from_secs ( 5 ) ;
/// Task that is processing requests and responses
task : JoinHandle < io ::Result < ( ) > > ,
}
/// Represents an individual connection on the server.
pub ( super ) struct ConnectionTask ( JoinHandle < io ::Result < ( ) > > ) ;
impl ConnectionTask {
/// Starts building a new connection
pub fn build ( ) -> ConnectionTaskBuilder < ( ) , ( ) > {
let id : ConnectionId = rand ::random ( ) ;
ConnectionTaskBuilder {
id ,
handler : Weak ::new ( ) ,
state : Weak ::new ( ) ,
keychain : Keychain ::new ( ) ,
transport : ( ) ,
shutdown_timer : Weak ::new ( ) ,
sleep_duration : SLEEP_DURATION ,
verifier : Weak ::new ( ) ,
}
}
/// Returns the id associated with the connection
pub fn id ( & self ) -> ConnectionId {
self . id
pub fn build ( ) -> ConnectionTaskBuilder < ( ) , ( ) , ( ) > {
ConnectionTaskBuilder ::new ( )
}
/// Returns true if the task has finished
pub fn is_finished ( & self ) -> bool {
self . task . is_finished ( )
}
/// Aborts the connection
pub fn abort ( & self ) {
self . task . abort ( ) ;
self . 0. is_finished ( )
}
}
@ -68,7 +48,7 @@ impl Future for ConnectionTask {
type Output = io ::Result < ( ) > ;
fn poll ( mut self : Pin < & mut Self > , cx : & mut Context < ' _ > ) -> Poll < Self ::Output > {
match Future ::poll ( Pin ::new ( & mut self . task ) , cx ) {
match Future ::poll ( Pin ::new ( & mut self . 0 ) , cx ) {
Poll ::Pending = > Poll ::Pending ,
Poll ::Ready ( x ) = > match x {
Ok ( x ) = > Poll ::Ready ( x ) ,
@ -78,114 +58,171 @@ impl Future for ConnectionTask {
}
}
pub struct ConnectionTaskBuilder < H , T > {
id : ConnectionId ,
/// Represents a builder for a new connection task.
pub ( super ) struct ConnectionTaskBuilder < H , S , T > {
handler : Weak < H > ,
state : Weak < ServerState >,
state : Weak < ServerState <S > >,
keychain : Keychain < oneshot ::Receiver < Backup > > ,
transport : T ,
shutdown : broadcast ::Receiver < ( ) > ,
shutdown_timer : Weak < RwLock < ShutdownTimer > > ,
sleep_duration : Duration ,
heartbeat_duration : Duration ,
verifier : Weak < Verifier > ,
}
impl < H , T > ConnectionTaskBuilder < H , T > {
pub fn handler < U > ( self , handler : Weak < U > ) -> ConnectionTaskBuilder < U , T > {
impl ConnectionTaskBuilder < ( ) , ( ) , ( ) > {
/// Starts building a new connection.
pub fn new ( ) -> Self {
Self {
handler : Weak ::new ( ) ,
state : Weak ::new ( ) ,
keychain : Keychain ::new ( ) ,
transport : ( ) ,
shutdown : broadcast ::channel ( 1 ) . 1 ,
shutdown_timer : Weak ::new ( ) ,
sleep_duration : SLEEP_DURATION ,
heartbeat_duration : MINIMUM_HEARTBEAT_DURATION ,
verifier : Weak ::new ( ) ,
}
}
}
impl < H , S , T > ConnectionTaskBuilder < H , S , T > {
pub fn handler < U > ( self , handler : Weak < U > ) -> ConnectionTaskBuilder < U , S , T > {
ConnectionTaskBuilder {
id : self . id ,
handler ,
state : self . state ,
keychain : self . keychain ,
transport : self . transport ,
shutdown : self . shutdown ,
shutdown_timer : self . shutdown_timer ,
sleep_duration : self . sleep_duration ,
heartbeat_duration : self . heartbeat_duration ,
verifier : self . verifier ,
}
}
pub fn state ( self , state : Weak < ServerState > ) -> ConnectionTaskBuilder < H , T > {
pub fn state < U > ( self , state : Weak < ServerState <U > >) -> ConnectionTaskBuilder < H , U , T > {
ConnectionTaskBuilder {
id : self . id ,
handler : self . handler ,
state ,
keychain : self . keychain ,
transport : self . transport ,
shutdown : self . shutdown ,
shutdown_timer : self . shutdown_timer ,
sleep_duration : self . sleep_duration ,
heartbeat_duration : self . heartbeat_duration ,
verifier : self . verifier ,
}
}
pub fn keychain ( self , keychain : ServerKeychain ) -> ConnectionTaskBuilder < H , T> {
pub fn keychain ( self , keychain : ServerKeychain ) -> ConnectionTaskBuilder < H , S, T> {
ConnectionTaskBuilder {
id : self . id ,
handler : self . handler ,
state : self . state ,
keychain ,
transport : self . transport ,
shutdown : self . shutdown ,
shutdown_timer : self . shutdown_timer ,
sleep_duration : self . sleep_duration ,
heartbeat_duration : self . heartbeat_duration ,
verifier : self . verifier ,
}
}
pub fn transport < U > ( self , transport : U ) -> ConnectionTaskBuilder < H , U> {
pub fn transport < U > ( self , transport : U ) -> ConnectionTaskBuilder < H , S, U> {
ConnectionTaskBuilder {
id : self . id ,
handler : self . handler ,
keychain : self . keychain ,
state : self . state ,
transport ,
shutdown : self . shutdown ,
shutdown_timer : self . shutdown_timer ,
sleep_duration : self . sleep_duration ,
heartbeat_duration : self . heartbeat_duration ,
verifier : self . verifier ,
}
}
pub ( crate ) fn shutdown_timer (
pub fn shutdown ( self , shutdown : broadcast ::Receiver < ( ) > ) -> ConnectionTaskBuilder < H , S , T > {
ConnectionTaskBuilder {
handler : self . handler ,
state : self . state ,
keychain : self . keychain ,
transport : self . transport ,
shutdown ,
shutdown_timer : self . shutdown_timer ,
sleep_duration : self . sleep_duration ,
heartbeat_duration : self . heartbeat_duration ,
verifier : self . verifier ,
}
}
pub fn shutdown_timer (
self ,
shutdown_timer : Weak < RwLock < ShutdownTimer > > ,
) -> ConnectionTaskBuilder < H , T > {
) -> ConnectionTaskBuilder < H , S, T> {
ConnectionTaskBuilder {
id : self . id ,
handler : self . handler ,
state : self . state ,
keychain : self . keychain ,
transport : self . transport ,
shutdown : self . shutdown ,
shutdown_timer ,
sleep_duration : self . sleep_duration ,
heartbeat_duration : self . heartbeat_duration ,
verifier : self . verifier ,
}
}
pub fn sleep_duration ( self , sleep_duration : Duration ) -> ConnectionTaskBuilder < H , T> {
pub fn sleep_duration ( self , sleep_duration : Duration ) -> ConnectionTaskBuilder < H , S, T> {
ConnectionTaskBuilder {
id : self . id ,
handler : self . handler ,
state : self . state ,
keychain : self . keychain ,
transport : self . transport ,
shutdown : self . shutdown ,
shutdown_timer : self . shutdown_timer ,
sleep_duration ,
heartbeat_duration : self . heartbeat_duration ,
verifier : self . verifier ,
}
}
pub fn verifier ( self , verifier : Weak < Verifier > ) -> ConnectionTaskBuilder < H , T > {
pub fn heartbeat_duration (
self ,
heartbeat_duration : Duration ,
) -> ConnectionTaskBuilder < H , S , T > {
ConnectionTaskBuilder {
id : self . id ,
handler : self . handler ,
state : self . state ,
keychain : self . keychain ,
transport : self . transport ,
shutdown : self . shutdown ,
shutdown_timer : self . shutdown_timer ,
sleep_duration : self . sleep_duration ,
heartbeat_duration ,
verifier : self . verifier ,
}
}
pub fn verifier ( self , verifier : Weak < Verifier > ) -> ConnectionTaskBuilder < H , S , T > {
ConnectionTaskBuilder {
handler : self . handler ,
state : self . state ,
keychain : self . keychain ,
transport : self . transport ,
shutdown : self . shutdown ,
shutdown_timer : self . shutdown_timer ,
sleep_duration : self . sleep_duration ,
heartbeat_duration : self . heartbeat_duration ,
verifier ,
}
}
}
impl < H , T > ConnectionTaskBuilder < H , T >
impl < H , T > ConnectionTaskBuilder < H , Response< H ::Response > , T>
where
H : ServerHandler + Sync + ' static ,
H ::Request : DeserializeOwned + Send + Sync + ' static ,
@ -194,52 +231,86 @@ where
T : Transport + ' static ,
{
pub fn spawn ( self ) -> ConnectionTask {
let id = self . id ;
ConnectionTask {
id ,
task : tokio ::spawn ( self . run ( ) ) ,
}
ConnectionTask ( tokio ::spawn ( self . run ( ) ) )
}
async fn run ( self ) -> io ::Result < ( ) > {
let ConnectionTaskBuilder {
id ,
handler ,
state ,
keychain ,
transport ,
mut shutdown ,
shutdown_timer ,
sleep_duration ,
heartbeat_duration ,
verifier ,
} = self ;
// NOTE: This exists purely to make the compiler happy for macro_rules declaration order.
let ( mut local_shutdown , channel_tx , connection_state ) = ConnectionState ::channel ( ) ;
// Will check if no more connections and restart timer if that's the case
macro_rules! terminate_connection {
// Prints an error message before terminating the connection by panicking
( @ error $( $msg :tt ) + ) = > {
// Prints an error message and does not store state
( @ fatal $( $msg :tt ) + ) = > {
error ! ( $( $msg ) + ) ;
terminate_connection ! ( ) ;
return Err ( io ::Error ::new ( io ::ErrorKind ::Other , format! ( $( $msg ) + ) ) ) ;
} ;
// Prints a debug message before terminating the connection by cleanly returning
( @ debug $( $msg :tt ) + ) = > {
// Prints an error message and stores state before terminating
( @ error ( $tx :ident , $rx :ident ) $( $msg :tt ) + ) = > {
error ! ( $( $msg ) + ) ;
terminate_connection ! ( $tx , $rx ) ;
return Err ( io ::Error ::new ( io ::ErrorKind ::Other , format! ( $( $msg ) + ) ) ) ;
} ;
// Prints a debug message and stores state before terminating
( @ debug ( $tx :ident , $rx :ident ) $( $msg :tt ) + ) = > {
debug ! ( $( $msg ) + ) ;
terminate_connection ! ( $tx , $rx ) ;
return Ok ( ( ) ) ;
} ;
// Prints a shutdown message with no connection id and exit without sending state
( @ shutdown ) = > {
debug ! ( "Shutdown triggered before a connection could be fully established" ) ;
terminate_connection ! ( ) ;
return Ok ( ( ) ) ;
} ;
// Prints a shutdown message with no connection id and stores state before terminating
( @ shutdown ) = > {
debug ! ( "Shutdown triggered before a connection could be fully established" ) ;
terminate_connection ! ( ) ;
return Ok ( ( ) ) ;
} ;
// Prints a shutdown message and stores state before terminating
( @ shutdown ( $id :ident , $tx :ident , $rx :ident ) ) = > { {
debug ! ( "[Conn {}] Shutdown triggered" , $id ) ;
terminate_connection ! ( $tx , $rx ) ;
return Ok ( ( ) ) ;
} } ;
// Performs the connection termination by removing it from server state and
// restarting the shutdown timer if it was the last connection
( $tx :ident , $rx :ident ) = > {
// Send the channels back
let _ = channel_tx . send ( ( $tx , $rx ) ) ;
terminate_connection ! ( ) ;
} ;
// Performs the connection termination by removing it from server state and
// restarting the shutdown timer if it was the last connection
( ) = > {
// Remove the connection from our state if it has closed
// Re start our shutdown timer if this is the last connection
if let Some ( state ) = Weak ::upgrade ( & state ) {
state . connections . write ( ) . await . remove ( & self . id ) ;
// If we have no more connections, start the timer
if let Some ( timer ) = Weak ::upgrade ( & shutdown_timer ) {
if state . connections . read ( ) . await . is_empty ( ) {
if state . connections . read ( ) . await . values ( ) . filter ( | conn | ! conn . is_finished ( ) ) . count ( ) < = 1 {
debug ! ( "Last connection terminating, so restarting shutdown timer" ) ;
timer . write ( ) . await . restart ( ) ;
}
}
@ -247,58 +318,160 @@ where
} ;
}
// Properly establish the connection's transport
debug ! ( "[Conn {id}] Establishing full connection" ) ;
let mut connection = match Weak ::upgrade ( & verifier ) {
Some ( verifier ) = > {
match Connection ::server ( transport , verifier . as_ref ( ) , keychain ) . await {
Ok ( connection ) = > connection ,
Err ( x ) = > {
terminate_connection ! ( @ error "[Conn {id}] Failed to setup connection: {x}" ) ;
/// Awaits a future to complete, or detects if a signal was received by either the global
/// or local shutdown channel. Shutdown only occurs if a signal was received, and any
/// errors received by either shutdown channel are ignored.
macro_rules! await_or_shutdown {
( $( @ save ( $id :ident , $tx :ident , $rx :ident ) ) ? $future :expr ) = > { {
let mut f = $future ;
loop {
let use_shutdown = match shutdown . try_recv ( ) {
Ok ( _ ) = > {
terminate_connection ! ( @ shutdown $( ( $id , $tx , $rx ) ) ? ) ;
}
Err ( broadcast ::error ::TryRecvError ::Empty ) = > true ,
Err ( broadcast ::error ::TryRecvError ::Lagged ( _ ) ) = > true ,
Err ( broadcast ::error ::TryRecvError ::Closed ) = > false ,
} ;
let use_local_shutdown = match local_shutdown . try_recv ( ) {
Ok ( _ ) = > {
terminate_connection ! ( @ shutdown $( ( $id , $tx , $rx ) ) ? ) ;
}
Err ( oneshot ::error ::TryRecvError ::Empty ) = > true ,
Err ( oneshot ::error ::TryRecvError ::Closed ) = > false ,
} ;
if use_shutdown & & use_local_shutdown {
tokio ::select ! {
x = shutdown . recv ( ) = > {
if x . is_err ( ) {
continue ;
}
None = > {
terminate_connection ! ( @ error "[Conn {id}] Verifier has been dropped" ) ;
terminate_connection ! ( @ shutdown $( ( $id , $tx , $rx ) ) ? ) ;
}
x = & mut local_shutdown = > {
if x . is_err ( ) {
continue ;
}
terminate_connection ! ( @ shutdown $( ( $id , $tx , $rx ) ) ? ) ;
}
x = & mut f = > { break x ; }
}
} else if use_shutdown {
tokio ::select ! {
x = shutdown . recv ( ) = > {
if x . is_err ( ) {
continue ;
}
terminate_connection ! ( @ shutdown $( ( $id , $tx , $rx ) ) ? ) ;
}
x = & mut f = > { break x ; }
}
} else if use_local_shutdown {
tokio ::select ! {
x = & mut local_shutdown = > {
if x . is_err ( ) {
continue ;
}
terminate_connection ! ( @ shutdown $( ( $id , $tx , $rx ) ) ? ) ;
}
x = & mut f = > { break x ; }
}
} else {
break f . await ;
}
}
} } ;
}
} ;
// Attempt to upgrade our handler for use with the connection going forward
debug ! ( "[Conn {id}] Preparing connection handler" ) ;
let handler = match Weak ::upgrade ( & handler ) {
Some ( handler ) = > handler ,
None = > {
terminate_connection ! ( @ error "[Conn {id}] Handler has been dropped" ) ;
terminate_connection ! ( @ fatal "Failed to setup connection because handler dropped") ;
}
} ;
// Construct a queue of outgoing responses
let ( tx , mut rx ) = mpsc ::channel ::< Response < H ::Response > > ( 1 ) ;
// Attempt to upgrade our state for use with the connection going forward
let state = match Weak ::upgrade ( & state ) {
Some ( state ) = > state ,
None = > {
terminate_connection ! ( @ fatal "Failed to setup connection because state dropped" ) ;
}
} ;
// Properly establish the connection's transport
debug ! ( "Establishing full connection using {transport:?}" ) ;
let mut connection = match Weak ::upgrade ( & verifier ) {
Some ( verifier ) = > {
match await_or_shutdown ! ( Box ::pin ( Connection ::server (
transport ,
verifier . as_ref ( ) ,
keychain
) ) ) {
Ok ( connection ) = > connection ,
Err ( x ) = > {
terminate_connection ! ( @ fatal "Failed to setup connection: {x}" ) ;
}
}
}
None = > {
terminate_connection ! ( @ fatal "Verifier has been dropped" ) ;
}
} ;
// Update our id to be the connection id
let id = connection . id ( ) ;
// Create local data for the connection and then process it
debug ! ( "[Conn {id}] Officially accepting connection" ) ;
let mut local_data = H ::LocalData ::default ( ) ;
if let Err ( x ) = handler
. on_accept ( ConnectionCtx {
if let Err ( x ) = await_or_shutdown ! ( handler . on_accept ( ConnectionCtx {
connection_id : id ,
local_data : & mut local_data ,
} )
. await
{
terminate_connection ! ( @ error "[Conn {id}] Accepting connection failed: {x}" ) ;
local_data : & mut local_data
} ) ) {
terminate_connection ! ( @ fatal "[Conn {id}] Accepting connection failed: {x}" ) ;
}
let local_data = Arc ::new ( local_data ) ;
let mut last_heartbeat = Instant ::now ( ) ;
// Restore our connection's channels if we have them, otherwise make new ones
let ( tx , mut rx ) = match state . connections . write ( ) . await . remove ( & id ) {
Some ( conn ) = > match conn . shutdown_and_wait ( ) . await {
Some ( x ) = > {
debug ! ( "[Conn {id}] Marked as existing connection" ) ;
x
}
None = > {
warn ! ( "[Conn {id}] Existing connection with id, but channels not saved" ) ;
mpsc ::channel ::< Response < H ::Response > > ( 1 )
}
} ,
None = > {
debug ! ( "[Conn {id}] Marked as new connection" ) ;
mpsc ::channel ::< Response < H ::Response > > ( 1 )
}
} ;
// Store our connection details
state . connections . write ( ) . await . insert ( id , connection_state ) ;
debug ! ( "[Conn {id}] Beginning read/write loop" ) ;
loop {
let ready = match connection
. ready ( Interest ::READABLE | Interest ::WRITABLE )
. await
{
let ready = match await_or_shutdown! (
@ save ( id , tx , rx )
Box ::pin ( connection . ready ( Interest ::READABLE | Interest ::WRITABLE ) )
) {
Ok ( ready ) = > ready ,
Err ( x ) = > {
terminate_connection ! ( @ error "[Conn {id}] Failed to examine ready state: {x}" ) ;
terminate_connection ! ( @ error ( tx , rx ) "[Conn {id}] Failed to examine ready state: {x}" ) ;
}
} ;
@ -311,15 +484,14 @@ where
Ok ( Some ( frame ) ) = > match UntypedRequest ::from_slice ( frame . as_item ( ) ) {
Ok ( request ) = > match request . to_typed_request ( ) {
Ok ( request ) = > {
let reply = ServerReply {
origin_id : request . id . clone ( ) ,
tx : tx . clone ( ) ,
} ;
let origin_id = request . id . clone ( ) ;
let ctx = ServerCtx {
connection_id : id ,
request ,
reply : reply . clone ( ) ,
reply : ServerReply {
origin_id ,
tx : tx . clone ( ) ,
} ,
local_data : Arc ::clone ( & local_data ) ,
} ;
@ -344,11 +516,11 @@ where
}
} ,
Ok ( None ) = > {
terminate_connection ! ( @ debug "[Conn {id}] Connection closed" ) ;
terminate_connection ! ( @ debug ( tx , rx ) "[Conn {id}] Connection closed" ) ;
}
Err ( x ) if x . kind ( ) = = io ::ErrorKind ::WouldBlock = > read_blocked = true ,
Err ( x ) = > {
terminate_connection ! ( @ error "[Conn {id}] {x}" ) ;
terminate_connection ! ( @ error ( tx , rx ) "[Conn {id}] {x}" ) ;
}
}
}
@ -356,10 +528,20 @@ where
// If our socket is ready to be written to, we try to get the next item from
// the queue and process it
if ready . is_writable ( ) {
// Send a heartbeat if we have exceeded our last time
if last_heartbeat . elapsed ( ) > = heartbeat_duration {
trace ! ( "[Conn {id}] Sending heartbeat via empty frame" ) ;
match connection . try_write_frame ( Frame ::empty ( ) ) {
Ok ( ( ) ) = > ( ) ,
Err ( x ) if x . kind ( ) = = io ::ErrorKind ::WouldBlock = > write_blocked = true ,
Err ( x ) = > error ! ( "[Conn {id}] Send failed: {x}" ) ,
}
last_heartbeat = Instant ::now ( ) ;
}
// If we get more data to write, attempt to write it, which will result in writing
// any queued bytes as well. Othewise, we attempt to flush any pending outgoing
// bytes that weren't sent earlier.
if let Ok ( response ) = rx . try_recv ( ) {
else if let Ok ( response ) = rx . try_recv ( ) {
// Log our message as a string, which can be expensive
if log_enabled ! ( Level ::Trace ) {
trace ! (
@ -541,7 +723,7 @@ mod tests {
let err = task . await . unwrap_err ( ) ;
assert! (
err . to_string ( ) . contains ( " Handler has been dropped") ,
err . to_string ( ) . contains ( " handler dropped") ,
"Unexpected error: {err}"
) ;
}
@ -610,6 +792,7 @@ mod tests {
let shutdown_timer = Arc ::new ( RwLock ::new ( ShutdownTimer ::start ( Shutdown ::Never ) ) ) ;
let verifier = Arc ::new ( Verifier ::none ( ) ) ;
#[ derive(Debug) ]
struct FakeTransport {
inner : InmemoryTransport ,
fail_ready : Arc < AtomicBool > ,
@ -678,7 +861,7 @@ mod tests {
let err = task . await . unwrap_err ( ) ;
assert! (
err . to_string ( ) . contains ( " Failed to examine ready stat e") ,
err . to_string ( ) . contains ( " targeted ready failur e") ,
"Unexpected error: {err}"
) ;
}
@ -722,7 +905,7 @@ mod tests {
let shutdown_timer = Arc ::new ( RwLock ::new ( ShutdownTimer ::start ( Shutdown ::Never ) ) ) ;
let verifier = Arc ::new ( Verifier ::none ( ) ) ;
ConnectionTask ::build ( )
let _conn = ConnectionTask ::build ( )
. handler ( Arc ::downgrade ( & handler ) )
. state ( Arc ::downgrade ( & state ) )
. keychain ( keychain )
@ -748,4 +931,205 @@ mod tests {
let response = task . await . unwrap ( ) ;
assert_eq! ( response . payload , "hello" ) ;
}
#[ test(tokio::test) ]
async fn should_send_heartbeat_via_empty_frame_every_minimum_duration ( ) {
let handler = Arc ::new ( TestServerHandler ) ;
let state = Arc ::new ( ServerState ::default ( ) ) ;
let keychain = ServerKeychain ::new ( ) ;
let ( t1 , t2 ) = InmemoryTransport ::pair ( 100 ) ;
let shutdown_timer = Arc ::new ( RwLock ::new ( ShutdownTimer ::start ( Shutdown ::Never ) ) ) ;
let verifier = Arc ::new ( Verifier ::none ( ) ) ;
let _conn = ConnectionTask ::build ( )
. handler ( Arc ::downgrade ( & handler ) )
. state ( Arc ::downgrade ( & state ) )
. keychain ( keychain )
. transport ( t1 )
. shutdown_timer ( Arc ::downgrade ( & shutdown_timer ) )
. heartbeat_duration ( Duration ::from_millis ( 200 ) )
. verifier ( Arc ::downgrade ( & verifier ) )
. spawn ( ) ;
// Spawn a task to handle establishing connection from client-side
let task = tokio ::spawn ( async move {
let mut client = Connection ::client ( t2 , DummyAuthHandler )
. await
. expect ( "Fail to establish client-side connection" ) ;
// Verify we don't get a frame immediately
assert_eq! (
client . try_read_frame ( ) . unwrap_err ( ) . kind ( ) ,
io ::ErrorKind ::WouldBlock ,
"got a frame early"
) ;
// Sleep more than our minimum heartbeat duration to ensure we get one
tokio ::time ::sleep ( Duration ::from_millis ( 250 ) ) . await ;
assert_eq! (
client . read_frame ( ) . await . unwrap ( ) . unwrap ( ) ,
Frame ::empty ( ) ,
"non-empty frame"
) ;
// Verify we don't get a frame immediately
assert_eq! (
client . try_read_frame ( ) . unwrap_err ( ) . kind ( ) ,
io ::ErrorKind ::WouldBlock ,
"got a frame early"
) ;
// Sleep more than our minimum heartbeat duration to ensure we get one
tokio ::time ::sleep ( Duration ::from_millis ( 250 ) ) . await ;
assert_eq! (
client . read_frame ( ) . await . unwrap ( ) . unwrap ( ) ,
Frame ::empty ( ) ,
"non-empty frame"
) ;
} ) ;
task . await . unwrap ( ) ;
}
#[ test(tokio::test) ]
async fn should_be_able_to_shutdown_while_establishing_connection ( ) {
let handler = Arc ::new ( TestServerHandler ) ;
let state = Arc ::new ( ServerState ::default ( ) ) ;
let keychain = ServerKeychain ::new ( ) ;
let ( t1 , _t2 ) = InmemoryTransport ::pair ( 100 ) ;
let shutdown_timer = Arc ::new ( RwLock ::new ( ShutdownTimer ::start ( Shutdown ::Never ) ) ) ;
let verifier = Arc ::new ( Verifier ::none ( ) ) ;
let ( shutdown_tx , shutdown_rx ) = broadcast ::channel ( 1 ) ;
let conn = ConnectionTask ::build ( )
. handler ( Arc ::downgrade ( & handler ) )
. state ( Arc ::downgrade ( & state ) )
. keychain ( keychain )
. transport ( t1 )
. shutdown ( shutdown_rx )
. shutdown_timer ( Arc ::downgrade ( & shutdown_timer ) )
. heartbeat_duration ( Duration ::from_millis ( 200 ) )
. verifier ( Arc ::downgrade ( & verifier ) )
. spawn ( ) ;
// Shutdown server connection task while it is establishing a full connection with the
// client, verifying that we do not get an error in return
shutdown_tx
. send ( ( ) )
. expect ( "Failed to send shutdown signal" ) ;
conn . await . unwrap ( ) ;
}
#[ test(tokio::test) ]
async fn should_be_able_to_shutdown_while_accepting_connection ( ) {
struct HangingAcceptServerHandler ;
#[ async_trait ]
impl ServerHandler for HangingAcceptServerHandler {
type Request = ( ) ;
type Response = ( ) ;
type LocalData = ( ) ;
async fn on_accept ( & self , _ : ConnectionCtx < ' _ , Self ::LocalData > ) -> io ::Result < ( ) > {
// Wait "forever" so we can ensure that we fail at this step
tokio ::time ::sleep ( Duration ::MAX ) . await ;
Err ( io ::Error ::new ( io ::ErrorKind ::Other , "bad accept" ) )
}
async fn on_request (
& self ,
_ : ServerCtx < Self ::Request , Self ::Response , Self ::LocalData > ,
) {
unreachable! ( ) ;
}
}
let handler = Arc ::new ( HangingAcceptServerHandler ) ;
let state = Arc ::new ( ServerState ::default ( ) ) ;
let keychain = ServerKeychain ::new ( ) ;
let ( t1 , t2 ) = InmemoryTransport ::pair ( 100 ) ;
let shutdown_timer = Arc ::new ( RwLock ::new ( ShutdownTimer ::start ( Shutdown ::Never ) ) ) ;
let verifier = Arc ::new ( Verifier ::none ( ) ) ;
let ( shutdown_tx , shutdown_rx ) = broadcast ::channel ( 1 ) ;
let conn = ConnectionTask ::build ( )
. handler ( Arc ::downgrade ( & handler ) )
. state ( Arc ::downgrade ( & state ) )
. keychain ( keychain )
. transport ( t1 )
. shutdown ( shutdown_rx )
. shutdown_timer ( Arc ::downgrade ( & shutdown_timer ) )
. heartbeat_duration ( Duration ::from_millis ( 200 ) )
. verifier ( Arc ::downgrade ( & verifier ) )
. spawn ( ) ;
// Spawn a task to handle the client-side establishment of a full connection
let _client_task = tokio ::spawn ( Connection ::client ( t2 , DummyAuthHandler ) ) ;
// Shutdown server connection task while it is accepting the connection, verifying that we
// do not get an error in return
shutdown_tx
. send ( ( ) )
. expect ( "Failed to send shutdown signal" ) ;
conn . await . unwrap ( ) ;
}
#[ test(tokio::test) ]
async fn should_be_able_to_shutdown_while_waiting_for_connection_to_be_ready ( ) {
struct AcceptServerHandler {
tx : mpsc ::Sender < ( ) > ,
}
#[ async_trait ]
impl ServerHandler for AcceptServerHandler {
type Request = ( ) ;
type Response = ( ) ;
type LocalData = ( ) ;
async fn on_accept ( & self , _ : ConnectionCtx < ' _ , Self ::LocalData > ) -> io ::Result < ( ) > {
self . tx . send ( ( ) ) . await . unwrap ( ) ;
Ok ( ( ) )
}
async fn on_request (
& self ,
_ : ServerCtx < Self ::Request , Self ::Response , Self ::LocalData > ,
) {
unreachable! ( ) ;
}
}
let ( tx , mut rx ) = mpsc ::channel ( 100 ) ;
let handler = Arc ::new ( AcceptServerHandler { tx } ) ;
let state = Arc ::new ( ServerState ::default ( ) ) ;
let keychain = ServerKeychain ::new ( ) ;
let ( t1 , t2 ) = InmemoryTransport ::pair ( 100 ) ;
let shutdown_timer = Arc ::new ( RwLock ::new ( ShutdownTimer ::start ( Shutdown ::Never ) ) ) ;
let verifier = Arc ::new ( Verifier ::none ( ) ) ;
let ( shutdown_tx , shutdown_rx ) = broadcast ::channel ( 1 ) ;
let conn = ConnectionTask ::build ( )
. handler ( Arc ::downgrade ( & handler ) )
. state ( Arc ::downgrade ( & state ) )
. keychain ( keychain )
. transport ( t1 )
. shutdown ( shutdown_rx )
. shutdown_timer ( Arc ::downgrade ( & shutdown_timer ) )
. heartbeat_duration ( Duration ::from_millis ( 200 ) )
. verifier ( Arc ::downgrade ( & verifier ) )
. spawn ( ) ;
// Spawn a task to handle the client-side establishment of a full connection
let _client_task = tokio ::spawn ( Connection ::client ( t2 , DummyAuthHandler ) ) ;
// Wait to ensure we complete the accept call first
let _ = rx . recv ( ) . await ;
// Shutdown server connection task while it is accepting the connection, verifying that we
// do not get an error in return
shutdown_tx
. send ( ( ) )
. expect ( "Failed to send shutdown signal" ) ;
conn . await . unwrap ( ) ;
}
}