use crate::common::{authentication::Verifier, Listener, Response, Transport}; use async_trait::async_trait; use log::*; use serde::{de::DeserializeOwned, Serialize}; use std::{io, sync::Arc, time::Duration}; use tokio::sync::{broadcast, RwLock}; mod builder; pub use builder::*; mod config; pub use config::*; mod connection; use connection::*; mod context; pub use context::*; mod r#ref; pub use r#ref::*; mod reply; pub use reply::*; mod state; use state::*; mod shutdown_timer; use shutdown_timer::*; /// Represents a server that can be used to receive requests & send responses to clients. pub struct Server { /// Custom configuration details associated with the server config: ServerConfig, /// Handler used to process various server events handler: T, /// Performs authentication using various methods verifier: Verifier, } /// Interface for a handler that receives connections and requests #[async_trait] pub trait ServerHandler: Send { /// Type of data received by the server type Request; /// Type of data sent back by the server type Response; /// Type of data to store locally tied to the specific connection type LocalData: Send; /// Invoked upon a new connection becoming established. /// /// ### Note /// /// This can be useful in performing some additional initialization on the connection's local /// data prior to it being used anywhere else. #[allow(unused_variables)] async fn on_accept(&self, ctx: ConnectionCtx<'_, Self::LocalData>) -> io::Result<()> { Ok(()) } /// Invoked upon receiving a request from a client. The server should process this /// request, which can be found in `ctx`, and send one or more replies in response. async fn on_request(&self, ctx: ServerCtx); } impl Server<()> { /// Creates a new [`Server`], starting with a default configuration, no authentication methods, /// and no [`ServerHandler`]. pub fn new() -> Self { Self { config: Default::default(), handler: (), verifier: Verifier::empty(), } } /// Creates a new [`TcpServerBuilder`] that is used to construct a [`Server`]. pub fn tcp() -> TcpServerBuilder<()> { TcpServerBuilder::default() } /// Creates a new [`UnixSocketServerBuilder`] that is used to construct a [`Server`]. #[cfg(unix)] pub fn unix_socket() -> UnixSocketServerBuilder<()> { UnixSocketServerBuilder::default() } /// Creates a new [`WindowsPipeServerBuilder`] that is used to construct a [`Server`]. #[cfg(windows)] pub fn windows_pipe() -> WindowsPipeServerBuilder<()> { WindowsPipeServerBuilder::default() } } impl Default for Server<()> { fn default() -> Self { Self::new() } } impl Server { /// Consumes the current server, replacing its config with `config` and returning it. pub fn config(self, config: ServerConfig) -> Self { Self { config, handler: self.handler, verifier: self.verifier, } } /// Consumes the current server, replacing its handler with `handler` and returning it. pub fn handler(self, handler: U) -> Server { Server { config: self.config, handler, verifier: self.verifier, } } /// Consumes the current server, replacing its verifier with `verifier` and returning it. pub fn verifier(self, verifier: Verifier) -> Self { Self { config: self.config, handler: self.handler, verifier, } } } impl Server where T: ServerHandler + Sync + 'static, T::Request: DeserializeOwned + Send + Sync + 'static, T::Response: Serialize + Send + 'static, T::LocalData: Default + Send + Sync + 'static, { /// Consumes the server, starting a task to process connections from the `listener` and /// returning a [`ServerRef`] that can be used to control the active server instance. pub fn start(self, listener: L) -> io::Result> where L: Listener + 'static, L::Output: Transport + 'static, { let state = Arc::new(ServerState::new()); let (tx, rx) = broadcast::channel(1); let task = tokio::spawn(self.task(Arc::clone(&state), listener, tx.clone(), rx)); Ok(Box::new(GenericServerRef { shutdown: tx, task })) } /// Internal task that is run to receive connections and spawn connection tasks async fn task( self, state: Arc>>, mut listener: L, shutdown_tx: broadcast::Sender<()>, shutdown_rx: broadcast::Receiver<()>, ) where L: Listener + 'static, L::Output: Transport + 'static, { let Server { config, handler, verifier, } = self; let handler = Arc::new(handler); let timer = ShutdownTimer::start(config.shutdown); let mut notification = timer.clone_notification(); let timer = Arc::new(RwLock::new(timer)); let verifier = Arc::new(verifier); let mut connection_tasks = Vec::new(); loop { // Receive a new connection, exiting if no longer accepting connections or if the shutdown // signal has been received let transport = tokio::select! { result = listener.accept() => { match result { Ok(x) => x, Err(x) => { error!("Server no longer accepting connections: {x}"); timer.read().await.abort(); break; } } } _ = notification.wait() => { info!( "Server shutdown triggered after {}s", config.shutdown.duration().unwrap_or_default().as_secs_f32(), ); let _ = shutdown_tx.send(()); break; } }; // Ensure that the shutdown timer is cancelled now that we have a connection timer.read().await.stop(); connection_tasks.push( ConnectionTask::build() .handler(Arc::downgrade(&handler)) .state(Arc::downgrade(&state)) .keychain(state.keychain.clone()) .transport(transport) .shutdown(shutdown_rx.resubscribe()) .shutdown_timer(Arc::downgrade(&timer)) .sleep_duration(config.connection_sleep) .heartbeat_duration(config.connection_heartbeat) .verifier(Arc::downgrade(&verifier)) .spawn(), ); } // Once we stop listening, we still want to wait until all connections have terminated info!("Server waiting for active connections to terminate"); loop { connection_tasks.retain(|task| !task.is_finished()); if connection_tasks.is_empty() { break; } tokio::time::sleep(Duration::from_millis(50)).await; } info!("Server task terminated"); } } #[cfg(test)] mod tests { use super::*; use crate::common::{ authentication::{AuthenticationMethod, DummyAuthHandler, NoneAuthenticationMethod}, Connection, InmemoryTransport, MpscListener, Request, Response, }; use async_trait::async_trait; use std::time::Duration; use test_log::test; use tokio::sync::mpsc; pub struct TestServerHandler; #[async_trait] impl ServerHandler for TestServerHandler { type Request = u16; type Response = String; type LocalData = (); async fn on_accept(&self, _: ConnectionCtx<'_, Self::LocalData>) -> io::Result<()> { Ok(()) } async fn on_request(&self, ctx: ServerCtx) { // Always send back "hello" ctx.reply.send("hello".to_string()).await.unwrap(); } } #[inline] fn make_test_server(config: ServerConfig) -> Server { let methods: Vec> = vec![Box::new(NoneAuthenticationMethod::new())]; Server { config, handler: TestServerHandler, verifier: Verifier::new(methods), } } #[allow(clippy::type_complexity)] fn make_listener( buffer: usize, ) -> ( mpsc::Sender, MpscListener, ) { MpscListener::channel(buffer) } #[test(tokio::test)] async fn should_invoke_handler_upon_receiving_a_request() { // Create a test listener where we will forward a connection let (tx, listener) = make_listener(100); // Make bounded transport pair and send off one of them to act as our connection let (transport, connection) = InmemoryTransport::pair(100); tx.send(connection) .await .expect("Failed to feed listener a connection"); let _server = make_test_server(ServerConfig::default()) .start(listener) .expect("Failed to start server"); // Perform handshake and authentication with the server before beginning to send data let mut connection = Connection::client(transport, DummyAuthHandler) .await .expect("Failed to connect to server"); connection .write_frame(Request::new(123).to_vec().unwrap()) .await .expect("Failed to send request"); // Wait for a response let frame = connection.read_frame().await.unwrap().unwrap(); let response: Response = Response::from_slice(frame.as_item()).unwrap(); assert_eq!(response.payload, "hello"); } #[test(tokio::test)] async fn should_lonely_shutdown_if_no_connections_received_after_n_secs_when_config_set() { let (_tx, listener) = make_listener(100); let server = make_test_server(ServerConfig { shutdown: Shutdown::Lonely(Duration::from_millis(100)), ..Default::default() }) .start(listener) .expect("Failed to start server"); // Wait for some time tokio::time::sleep(Duration::from_millis(300)).await; assert!(server.is_finished(), "Server shutdown not triggered!"); } #[test(tokio::test)] async fn should_lonely_shutdown_if_last_connection_terminated_and_then_no_connections_after_n_secs( ) { // Create a test listener where we will forward a connection let (tx, listener) = make_listener(100); // Make bounded transport pair and send off one of them to act as our connection let (transport, connection) = InmemoryTransport::pair(100); tx.send(connection) .await .expect("Failed to feed listener a connection"); let server = make_test_server(ServerConfig { shutdown: Shutdown::Lonely(Duration::from_millis(100)), ..Default::default() }) .start(listener) .expect("Failed to start server"); // Drop the connection by dropping the transport drop(transport); // Wait for some time tokio::time::sleep(Duration::from_millis(300)).await; assert!(server.is_finished(), "Server shutdown not triggered!"); } #[test(tokio::test)] async fn should_not_lonely_shutdown_as_long_as_a_connection_exists() { // Create a test listener where we will forward a connection let (tx, listener) = make_listener(100); // Make bounded transport pair and send off one of them to act as our connection let (_transport, connection) = InmemoryTransport::pair(100); tx.send(connection) .await .expect("Failed to feed listener a connection"); let server = make_test_server(ServerConfig { shutdown: Shutdown::Lonely(Duration::from_millis(100)), ..Default::default() }) .start(listener) .expect("Failed to start server"); // Wait for some time tokio::time::sleep(Duration::from_millis(300)).await; assert!(!server.is_finished(), "Server shutdown when it should not!"); } #[test(tokio::test)] async fn should_shutdown_after_n_seconds_even_with_connections_if_config_set_to_after() { let (tx, listener) = make_listener(100); // Make bounded transport pair and send off one of them to act as our connection let (_transport, connection) = InmemoryTransport::pair(100); tx.send(connection) .await .expect("Failed to feed listener a connection"); let server = make_test_server(ServerConfig { shutdown: Shutdown::After(Duration::from_millis(100)), ..Default::default() }) .start(listener) .expect("Failed to start server"); // Wait for some time tokio::time::sleep(Duration::from_millis(300)).await; assert!(server.is_finished(), "Server shutdown not triggered!"); } #[test(tokio::test)] async fn should_shutdown_after_n_seconds_if_config_set_to_after() { let (_tx, listener) = make_listener(100); let server = make_test_server(ServerConfig { shutdown: Shutdown::After(Duration::from_millis(100)), ..Default::default() }) .start(listener) .expect("Failed to start server"); // Wait for some time tokio::time::sleep(Duration::from_millis(300)).await; assert!(server.is_finished(), "Server shutdown not triggered!"); } #[test(tokio::test)] async fn should_never_shutdown_if_config_set_to_never() { let (_tx, listener) = make_listener(100); let server = make_test_server(ServerConfig { shutdown: Shutdown::Never, ..Default::default() }) .start(listener) .expect("Failed to start server"); // Wait for some time tokio::time::sleep(Duration::from_millis(300)).await; assert!(!server.is_finished(), "Server shutdown when it should not!"); } }