|
|
|
use std::io;
|
|
|
|
use std::sync::Arc;
|
|
|
|
use std::time::Duration;
|
|
|
|
|
|
|
|
use async_trait::async_trait;
|
|
|
|
use distant_auth::Verifier;
|
|
|
|
use log::*;
|
|
|
|
use serde::de::DeserializeOwned;
|
|
|
|
use serde::Serialize;
|
|
|
|
use tokio::sync::{broadcast, RwLock};
|
|
|
|
|
|
|
|
use crate::common::{ConnectionId, Listener, Response, Transport, Version};
|
|
|
|
|
|
|
|
mod builder;
|
|
|
|
pub use builder::*;
|
|
|
|
|
|
|
|
mod config;
|
|
|
|
pub use config::*;
|
|
|
|
|
|
|
|
mod connection;
|
|
|
|
use connection::*;
|
|
|
|
|
|
|
|
mod context;
|
|
|
|
pub use context::*;
|
|
|
|
|
|
|
|
mod r#ref;
|
|
|
|
pub use r#ref::*;
|
|
|
|
|
|
|
|
mod reply;
|
|
|
|
pub use reply::*;
|
|
|
|
|
|
|
|
mod state;
|
|
|
|
use state::*;
|
|
|
|
|
|
|
|
mod shutdown_timer;
|
|
|
|
use shutdown_timer::*;
|
|
|
|
|
|
|
|
/// Represents a server that can be used to receive requests & send responses to clients.
|
|
|
|
pub struct Server<T> {
|
|
|
|
/// Custom configuration details associated with the server
|
|
|
|
config: ServerConfig,
|
|
|
|
|
|
|
|
/// Handler used to process various server events
|
|
|
|
handler: T,
|
|
|
|
|
|
|
|
/// Performs authentication using various methods
|
|
|
|
verifier: Verifier,
|
|
|
|
|
|
|
|
/// Version associated with the server used by clients to verify compatibility
|
|
|
|
version: Version,
|
|
|
|
}
|
|
|
|
|
|
|
|
/// Interface for a handler that receives connections and requests
|
|
|
|
#[async_trait]
|
|
|
|
pub trait ServerHandler: Send {
|
|
|
|
/// Type of data received by the server
|
|
|
|
type Request;
|
|
|
|
|
|
|
|
/// Type of data sent back by the server
|
|
|
|
type Response;
|
|
|
|
|
|
|
|
/// Invoked upon a new connection becoming established.
|
|
|
|
#[allow(unused_variables)]
|
|
|
|
async fn on_connect(&self, id: ConnectionId) -> io::Result<()> {
|
|
|
|
Ok(())
|
|
|
|
}
|
|
|
|
|
|
|
|
/// Invoked upon an existing connection getting dropped.
|
|
|
|
#[allow(unused_variables)]
|
|
|
|
async fn on_disconnect(&self, id: ConnectionId) -> io::Result<()> {
|
|
|
|
Ok(())
|
|
|
|
}
|
|
|
|
|
|
|
|
/// Invoked upon receiving a request from a client. The server should process this
|
|
|
|
/// request, which can be found in `ctx`, and send one or more replies in response.
|
|
|
|
async fn on_request(&self, ctx: RequestCtx<Self::Request, Self::Response>);
|
|
|
|
}
|
|
|
|
|
|
|
|
impl Server<()> {
|
|
|
|
/// Creates a new [`Server`], starting with a default configuration, no authentication methods,
|
|
|
|
/// and no [`ServerHandler`].
|
|
|
|
pub fn new() -> Self {
|
|
|
|
Self {
|
|
|
|
config: Default::default(),
|
|
|
|
handler: (),
|
|
|
|
verifier: Verifier::empty(),
|
|
|
|
version: Default::default(),
|
|
|
|
}
|
|
|
|
}
|
|
|
|
|
|
|
|
/// Creates a new [`TcpServerBuilder`] that is used to construct a [`Server`].
|
|
|
|
pub fn tcp() -> TcpServerBuilder<()> {
|
|
|
|
TcpServerBuilder::default()
|
|
|
|
}
|
|
|
|
|
|
|
|
/// Creates a new [`UnixSocketServerBuilder`] that is used to construct a [`Server`].
|
|
|
|
#[cfg(unix)]
|
|
|
|
pub fn unix_socket() -> UnixSocketServerBuilder<()> {
|
|
|
|
UnixSocketServerBuilder::default()
|
|
|
|
}
|
|
|
|
|
|
|
|
/// Creates a new [`WindowsPipeServerBuilder`] that is used to construct a [`Server`].
|
|
|
|
#[cfg(windows)]
|
|
|
|
pub fn windows_pipe() -> WindowsPipeServerBuilder<()> {
|
|
|
|
WindowsPipeServerBuilder::default()
|
|
|
|
}
|
|
|
|
}
|
|
|
|
|
|
|
|
impl Default for Server<()> {
|
|
|
|
fn default() -> Self {
|
|
|
|
Self::new()
|
|
|
|
}
|
|
|
|
}
|
|
|
|
|
|
|
|
impl<T> Server<T> {
|
|
|
|
/// Consumes the current server, replacing its config with `config` and returning it.
|
|
|
|
pub fn config(self, config: ServerConfig) -> Self {
|
|
|
|
Self {
|
|
|
|
config,
|
|
|
|
handler: self.handler,
|
|
|
|
verifier: self.verifier,
|
|
|
|
version: self.version,
|
|
|
|
}
|
|
|
|
}
|
|
|
|
|
|
|
|
/// Consumes the current server, replacing its handler with `handler` and returning it.
|
|
|
|
pub fn handler<U>(self, handler: U) -> Server<U> {
|
|
|
|
Server {
|
|
|
|
config: self.config,
|
|
|
|
handler,
|
|
|
|
verifier: self.verifier,
|
|
|
|
version: self.version,
|
|
|
|
}
|
|
|
|
}
|
|
|
|
|
|
|
|
/// Consumes the current server, replacing its verifier with `verifier` and returning it.
|
|
|
|
pub fn verifier(self, verifier: Verifier) -> Self {
|
|
|
|
Self {
|
|
|
|
config: self.config,
|
|
|
|
handler: self.handler,
|
|
|
|
verifier,
|
|
|
|
version: self.version,
|
|
|
|
}
|
|
|
|
}
|
|
|
|
|
|
|
|
/// Consumes the current server, replacing its version with `version` and returning it.
|
|
|
|
pub fn version(self, version: Version) -> Self {
|
|
|
|
Self {
|
|
|
|
config: self.config,
|
|
|
|
handler: self.handler,
|
|
|
|
verifier: self.verifier,
|
|
|
|
version,
|
|
|
|
}
|
|
|
|
}
|
|
|
|
}
|
|
|
|
|
|
|
|
impl<T> Server<T>
|
|
|
|
where
|
|
|
|
T: ServerHandler + Sync + 'static,
|
|
|
|
T::Request: DeserializeOwned + Send + Sync + 'static,
|
|
|
|
T::Response: Serialize + Send + 'static,
|
|
|
|
{
|
|
|
|
/// Consumes the server, starting a task to process connections from the `listener` and
|
|
|
|
/// returning a [`ServerRef`] that can be used to control the active server instance.
|
|
|
|
pub fn start<L>(self, listener: L) -> io::Result<ServerRef>
|
|
|
|
where
|
|
|
|
L: Listener + 'static,
|
|
|
|
L::Output: Transport + 'static,
|
|
|
|
{
|
|
|
|
let state = Arc::new(ServerState::new());
|
|
|
|
let (tx, rx) = broadcast::channel(1);
|
|
|
|
let task = tokio::spawn(self.task(Arc::clone(&state), listener, tx.clone(), rx));
|
|
|
|
|
|
|
|
Ok(ServerRef { shutdown: tx, task })
|
|
|
|
}
|
|
|
|
|
|
|
|
/// Internal task that is run to receive connections and spawn connection tasks
|
|
|
|
async fn task<L>(
|
|
|
|
self,
|
|
|
|
state: Arc<ServerState<Response<T::Response>>>,
|
|
|
|
mut listener: L,
|
|
|
|
shutdown_tx: broadcast::Sender<()>,
|
|
|
|
shutdown_rx: broadcast::Receiver<()>,
|
|
|
|
) where
|
|
|
|
L: Listener + 'static,
|
|
|
|
L::Output: Transport + 'static,
|
|
|
|
{
|
|
|
|
let Server {
|
|
|
|
config,
|
|
|
|
handler,
|
|
|
|
verifier,
|
|
|
|
version,
|
|
|
|
} = self;
|
|
|
|
|
|
|
|
let handler = Arc::new(handler);
|
|
|
|
let timer = ShutdownTimer::start(config.shutdown);
|
|
|
|
let mut notification = timer.clone_notification();
|
|
|
|
let timer = Arc::new(RwLock::new(timer));
|
|
|
|
let verifier = Arc::new(verifier);
|
|
|
|
|
|
|
|
let mut connection_tasks = Vec::new();
|
|
|
|
loop {
|
|
|
|
// Receive a new connection, exiting if no longer accepting connections or if the shutdown
|
|
|
|
// signal has been received
|
|
|
|
let transport = tokio::select! {
|
|
|
|
result = listener.accept() => {
|
|
|
|
match result {
|
|
|
|
Ok(x) => x,
|
|
|
|
Err(x) => {
|
|
|
|
error!("Server no longer accepting connections: {x}");
|
|
|
|
timer.read().await.abort();
|
|
|
|
break;
|
|
|
|
}
|
|
|
|
}
|
|
|
|
}
|
|
|
|
_ = notification.wait() => {
|
|
|
|
info!(
|
|
|
|
"Server shutdown triggered after {}s",
|
|
|
|
config.shutdown.duration().unwrap_or_default().as_secs_f32(),
|
|
|
|
);
|
|
|
|
|
|
|
|
let _ = shutdown_tx.send(());
|
|
|
|
|
|
|
|
break;
|
|
|
|
}
|
|
|
|
};
|
|
|
|
|
|
|
|
// Ensure that the shutdown timer is cancelled now that we have a connection
|
|
|
|
timer.read().await.stop();
|
|
|
|
|
|
|
|
connection_tasks.push(
|
|
|
|
ConnectionTask::build()
|
|
|
|
.handler(Arc::downgrade(&handler))
|
|
|
|
.state(Arc::downgrade(&state))
|
|
|
|
.keychain(state.keychain.clone())
|
|
|
|
.transport(transport)
|
|
|
|
.shutdown(shutdown_rx.resubscribe())
|
|
|
|
.shutdown_timer(Arc::downgrade(&timer))
|
|
|
|
.sleep_duration(config.connection_sleep)
|
|
|
|
.heartbeat_duration(config.connection_heartbeat)
|
|
|
|
.verifier(Arc::downgrade(&verifier))
|
|
|
|
.version(version.clone())
|
|
|
|
.spawn(),
|
|
|
|
);
|
|
|
|
|
|
|
|
// Clean up current tasks being tracked
|
|
|
|
connection_tasks.retain(|task| !task.is_finished());
|
|
|
|
}
|
|
|
|
|
|
|
|
// Once we stop listening, we still want to wait until all connections have terminated
|
|
|
|
info!("Server waiting for active connections to terminate");
|
|
|
|
loop {
|
|
|
|
connection_tasks.retain(|task| !task.is_finished());
|
|
|
|
if connection_tasks.is_empty() {
|
|
|
|
break;
|
|
|
|
}
|
|
|
|
tokio::time::sleep(Duration::from_millis(50)).await;
|
|
|
|
}
|
|
|
|
info!("Server task terminated");
|
|
|
|
}
|
|
|
|
}
|
|
|
|
|
|
|
|
#[cfg(test)]
|
|
|
|
mod tests {
|
|
|
|
use std::time::Duration;
|
|
|
|
|
|
|
|
use async_trait::async_trait;
|
|
|
|
use distant_auth::{AuthenticationMethod, DummyAuthHandler, NoneAuthenticationMethod};
|
|
|
|
use test_log::test;
|
|
|
|
use tokio::sync::mpsc;
|
|
|
|
|
|
|
|
use super::*;
|
|
|
|
use crate::common::{Connection, InmemoryTransport, MpscListener, Request, Response};
|
|
|
|
|
|
|
|
macro_rules! server_version {
|
|
|
|
() => {
|
|
|
|
Version::new(1, 2, 3)
|
|
|
|
};
|
|
|
|
}
|
|
|
|
|
|
|
|
pub struct TestServerHandler;
|
|
|
|
|
|
|
|
#[async_trait]
|
|
|
|
impl ServerHandler for TestServerHandler {
|
|
|
|
type Request = u16;
|
|
|
|
type Response = String;
|
|
|
|
|
|
|
|
async fn on_request(&self, ctx: RequestCtx<Self::Request, Self::Response>) {
|
|
|
|
// Always send back "hello"
|
|
|
|
ctx.reply.send("hello".to_string()).unwrap();
|
|
|
|
}
|
|
|
|
}
|
|
|
|
|
|
|
|
#[inline]
|
|
|
|
fn make_test_server(config: ServerConfig) -> Server<TestServerHandler> {
|
|
|
|
let methods: Vec<Box<dyn AuthenticationMethod>> =
|
|
|
|
vec![Box::new(NoneAuthenticationMethod::new())];
|
|
|
|
|
|
|
|
Server {
|
|
|
|
config,
|
|
|
|
handler: TestServerHandler,
|
|
|
|
verifier: Verifier::new(methods),
|
|
|
|
version: server_version!(),
|
|
|
|
}
|
|
|
|
}
|
|
|
|
|
|
|
|
#[allow(clippy::type_complexity)]
|
|
|
|
fn make_listener(
|
|
|
|
buffer: usize,
|
|
|
|
) -> (
|
|
|
|
mpsc::Sender<InmemoryTransport>,
|
|
|
|
MpscListener<InmemoryTransport>,
|
|
|
|
) {
|
|
|
|
MpscListener::channel(buffer)
|
|
|
|
}
|
|
|
|
|
|
|
|
#[test(tokio::test)]
|
|
|
|
async fn should_invoke_handler_upon_receiving_a_request() {
|
|
|
|
// Create a test listener where we will forward a connection
|
|
|
|
let (tx, listener) = make_listener(100);
|
|
|
|
|
|
|
|
// Make bounded transport pair and send off one of them to act as our connection
|
|
|
|
let (transport, connection) = InmemoryTransport::pair(100);
|
|
|
|
tx.send(connection)
|
|
|
|
.await
|
|
|
|
.expect("Failed to feed listener a connection");
|
|
|
|
|
|
|
|
let _server = make_test_server(ServerConfig::default())
|
|
|
|
.start(listener)
|
|
|
|
.expect("Failed to start server");
|
|
|
|
|
|
|
|
// Perform handshake and authentication with the server before beginning to send data
|
|
|
|
let mut connection = Connection::client(transport, DummyAuthHandler, server_version!())
|
|
|
|
.await
|
|
|
|
.expect("Failed to connect to server");
|
|
|
|
|
|
|
|
connection
|
|
|
|
.write_frame(Request::new(123).to_vec().unwrap())
|
|
|
|
.await
|
|
|
|
.expect("Failed to send request");
|
|
|
|
|
|
|
|
// Wait for a response
|
|
|
|
let frame = connection.read_frame().await.unwrap().unwrap();
|
|
|
|
let response: Response<String> = Response::from_slice(frame.as_item()).unwrap();
|
|
|
|
assert_eq!(response.payload, "hello");
|
|
|
|
}
|
|
|
|
|
|
|
|
#[test(tokio::test)]
|
|
|
|
async fn should_lonely_shutdown_if_no_connections_received_after_n_secs_when_config_set() {
|
|
|
|
let (_tx, listener) = make_listener(100);
|
|
|
|
|
|
|
|
let server = make_test_server(ServerConfig {
|
|
|
|
shutdown: Shutdown::Lonely(Duration::from_millis(100)),
|
|
|
|
..Default::default()
|
|
|
|
})
|
|
|
|
.start(listener)
|
|
|
|
.expect("Failed to start server");
|
|
|
|
|
|
|
|
// Wait for some time
|
|
|
|
tokio::time::sleep(Duration::from_millis(300)).await;
|
|
|
|
|
|
|
|
assert!(server.is_finished(), "Server shutdown not triggered!");
|
|
|
|
}
|
|
|
|
|
|
|
|
#[test(tokio::test)]
|
|
|
|
async fn should_lonely_shutdown_if_last_connection_terminated_and_then_no_connections_after_n_secs(
|
|
|
|
) {
|
|
|
|
// Create a test listener where we will forward a connection
|
|
|
|
let (tx, listener) = make_listener(100);
|
|
|
|
|
|
|
|
// Make bounded transport pair and send off one of them to act as our connection
|
|
|
|
let (transport, connection) = InmemoryTransport::pair(100);
|
|
|
|
tx.send(connection)
|
|
|
|
.await
|
|
|
|
.expect("Failed to feed listener a connection");
|
|
|
|
|
|
|
|
let server = make_test_server(ServerConfig {
|
|
|
|
shutdown: Shutdown::Lonely(Duration::from_millis(100)),
|
|
|
|
..Default::default()
|
|
|
|
})
|
|
|
|
.start(listener)
|
|
|
|
.expect("Failed to start server");
|
|
|
|
|
|
|
|
// Drop the connection by dropping the transport
|
|
|
|
drop(transport);
|
|
|
|
|
|
|
|
// Wait for some time
|
|
|
|
tokio::time::sleep(Duration::from_millis(300)).await;
|
|
|
|
|
|
|
|
assert!(server.is_finished(), "Server shutdown not triggered!");
|
|
|
|
}
|
|
|
|
|
|
|
|
#[test(tokio::test)]
|
|
|
|
async fn should_not_lonely_shutdown_as_long_as_a_connection_exists() {
|
|
|
|
// Create a test listener where we will forward a connection
|
|
|
|
let (tx, listener) = make_listener(100);
|
|
|
|
|
|
|
|
// Make bounded transport pair and send off one of them to act as our connection
|
|
|
|
let (_transport, connection) = InmemoryTransport::pair(100);
|
|
|
|
tx.send(connection)
|
|
|
|
.await
|
|
|
|
.expect("Failed to feed listener a connection");
|
|
|
|
|
|
|
|
let server = make_test_server(ServerConfig {
|
|
|
|
shutdown: Shutdown::Lonely(Duration::from_millis(100)),
|
|
|
|
..Default::default()
|
|
|
|
})
|
|
|
|
.start(listener)
|
|
|
|
.expect("Failed to start server");
|
|
|
|
|
|
|
|
// Wait for some time
|
|
|
|
tokio::time::sleep(Duration::from_millis(300)).await;
|
|
|
|
|
|
|
|
assert!(!server.is_finished(), "Server shutdown when it should not!");
|
|
|
|
}
|
|
|
|
|
|
|
|
#[test(tokio::test)]
|
|
|
|
async fn should_shutdown_after_n_seconds_even_with_connections_if_config_set_to_after() {
|
|
|
|
let (tx, listener) = make_listener(100);
|
|
|
|
|
|
|
|
// Make bounded transport pair and send off one of them to act as our connection
|
|
|
|
let (_transport, connection) = InmemoryTransport::pair(100);
|
|
|
|
tx.send(connection)
|
|
|
|
.await
|
|
|
|
.expect("Failed to feed listener a connection");
|
|
|
|
|
|
|
|
let server = make_test_server(ServerConfig {
|
|
|
|
shutdown: Shutdown::After(Duration::from_millis(100)),
|
|
|
|
..Default::default()
|
|
|
|
})
|
|
|
|
.start(listener)
|
|
|
|
.expect("Failed to start server");
|
|
|
|
|
|
|
|
// Wait for some time
|
|
|
|
tokio::time::sleep(Duration::from_millis(300)).await;
|
|
|
|
|
|
|
|
assert!(server.is_finished(), "Server shutdown not triggered!");
|
|
|
|
}
|
|
|
|
|
|
|
|
#[test(tokio::test)]
|
|
|
|
async fn should_shutdown_after_n_seconds_if_config_set_to_after() {
|
|
|
|
let (_tx, listener) = make_listener(100);
|
|
|
|
|
|
|
|
let server = make_test_server(ServerConfig {
|
|
|
|
shutdown: Shutdown::After(Duration::from_millis(100)),
|
|
|
|
..Default::default()
|
|
|
|
})
|
|
|
|
.start(listener)
|
|
|
|
.expect("Failed to start server");
|
|
|
|
|
|
|
|
// Wait for some time
|
|
|
|
tokio::time::sleep(Duration::from_millis(300)).await;
|
|
|
|
|
|
|
|
assert!(server.is_finished(), "Server shutdown not triggered!");
|
|
|
|
}
|
|
|
|
|
|
|
|
#[test(tokio::test)]
|
|
|
|
async fn should_never_shutdown_if_config_set_to_never() {
|
|
|
|
let (_tx, listener) = make_listener(100);
|
|
|
|
|
|
|
|
let server = make_test_server(ServerConfig {
|
|
|
|
shutdown: Shutdown::Never,
|
|
|
|
..Default::default()
|
|
|
|
})
|
|
|
|
.start(listener)
|
|
|
|
.expect("Failed to start server");
|
|
|
|
|
|
|
|
// Wait for some time
|
|
|
|
tokio::time::sleep(Duration::from_millis(300)).await;
|
|
|
|
|
|
|
|
assert!(!server.is_finished(), "Server shutdown when it should not!");
|
|
|
|
}
|
|
|
|
}
|