Splitting out to broken individual crates

feat/RusshSupport
Chip Senkbeil 7 months ago
parent fc67e9e693
commit d67002421d
No known key found for this signature in database
GPG Key ID: 35EF1F8EC72A4131

3
Cargo.lock generated

@ -866,6 +866,7 @@ version = "0.21.0"
dependencies = [
"async-trait",
"derive_more",
"distant-core-net",
"env_logger",
"log",
"serde",
@ -879,6 +880,7 @@ version = "0.21.0"
dependencies = [
"async-trait",
"derive_more",
"distant-core-net",
"env_logger",
"log",
"serde",
@ -941,6 +943,7 @@ version = "0.21.0"
dependencies = [
"async-trait",
"derive_more",
"distant-core-net",
"env_logger",
"log",
"serde",

@ -14,6 +14,7 @@ license = "MIT OR Apache-2.0"
[dependencies]
async-trait = "0.1.68"
derive_more = { version = "0.99.17", default-features = false, features = ["display", "from", "error"] }
distant-core-net = { version = "=0.21.0", path = "../distant-core-net" }
log = "0.4.18"
serde = { version = "1.0.163", features = ["derive"] }

File diff suppressed because it is too large Load Diff

@ -14,6 +14,7 @@ license = "MIT OR Apache-2.0"
[dependencies]
async-trait = "0.1.68"
derive_more = { version = "0.99.17", default-features = false, features = ["display", "from", "error"] }
distant-core-net = { version = "=0.21.0", path = "../distant-core-net" }
log = "0.4.18"
serde = { version = "1.0.163", features = ["derive"] }

@ -0,0 +1,16 @@
mod client;
mod data;
mod server;
pub use client::*;
pub use data::*;
pub use server::*;
use crate::common::Version;
/// Represents the version associated with the manager's protocol.
pub const PROTOCOL_VERSION: Version = Version::new(
const_str::parse!(env!("CARGO_PKG_VERSION_MAJOR"), u64),
const_str::parse!(env!("CARGO_PKG_VERSION_MINOR"), u64),
const_str::parse!(env!("CARGO_PKG_VERSION_PATCH"), u64),
);

@ -1,8 +1,8 @@
[package]
name = "distant-core-net"
description = "Core network library for distant, providing implementations to support client/server architecture"
description = "Core network library for distant, providing primitives for use in network communication"
categories = ["network-programming"]
keywords = ["api", "async"]
keywords = ["api", "async", "network", "primitives"]
version = "0.21.0"
authors = ["Chip Senkbeil <chip@senkbeil.org>"]
edition = "2021"

@ -1,19 +1,19 @@
use std::any::Any;
/// Trait used for casting support into the [`Any`] trait object
/// Trait used for casting support into the [`Any`] trait object.
pub trait AsAny: Any {
/// Converts reference to [`Any`]
/// Converts reference to [`Any`].
fn as_any(&self) -> &dyn Any;
/// Converts mutable reference to [`Any`]
/// Converts mutable reference to [`Any`].
fn as_mut_any(&mut self) -> &mut dyn Any;
/// Consumes and produces `Box<dyn Any>`
/// Consumes and produces `Box<dyn Any>`.
fn into_any(self: Box<Self>) -> Box<dyn Any>;
}
/// Blanket implementation that enables any `'static` reference to convert
/// to the [`Any`] type
/// to the [`Any`] type.
impl<T: 'static> AsAny for T {
fn as_any(&self) -> &dyn Any {
self

@ -5,7 +5,7 @@ use distant_core_auth::msg::*;
use distant_core_auth::{AuthHandler, Authenticate, Authenticator};
use log::*;
use crate::common::{utils, FramedTransport, Transport};
use crate::{utils, FramedTransport, Transport};
macro_rules! write_frame {
($transport:expr, $data:expr) => {{

File diff suppressed because it is too large Load Diff

@ -1,21 +0,0 @@
mod any;
mod connection;
mod key;
mod keychain;
mod listener;
mod packet;
mod port;
mod transport;
pub(crate) mod utils;
mod version;
pub use any::*;
pub(crate) use connection::Connection;
pub use connection::ConnectionId;
pub use key::*;
pub use keychain::*;
pub use listener::*;
pub use packet::*;
pub use port::*;
pub use transport::*;
pub use version::*;

@ -8,16 +8,16 @@ use serde::{Deserialize, Serialize};
use tokio::sync::oneshot;
#[cfg(test)]
use crate::common::InmemoryTransport;
use crate::common::{
use crate::InmemoryTransport;
use crate::{
Backup, FramedTransport, HeapSecretKey, Keychain, KeychainResult, Reconnectable, Transport,
TransportExt, Version,
};
/// Id of the connection
/// Id of the connection.
pub type ConnectionId = u32;
/// Represents a connection from either the client or server side
/// Represents a connection from either the client or server side.
#[derive(Debug)]
pub enum Connection<T> {
/// Connection from the client side
@ -179,7 +179,7 @@ where
}
}
/// Type of connection to perform
/// Type of connection to perform.
#[derive(Debug, Serialize, Deserialize)]
enum ConnectType {
/// Indicates that the connection from client to server is no and not a reconnection

@ -3,7 +3,7 @@ use std::sync::Arc;
use tokio::sync::RwLock;
use crate::common::HeapSecretKey;
use crate::HeapSecretKey;
/// Represents the result of a request to the database.
#[derive(Copy, Clone, Debug, PartialEq, Eq)]

@ -4,14 +4,28 @@
#[cfg(doctest)]
pub struct ReadmeDoctests;
mod any;
mod authentication;
pub mod client;
pub mod common;
pub mod manager;
pub mod server;
mod connection;
mod key;
mod keychain;
mod listener;
mod packet;
mod port;
mod transport;
pub(crate) mod utils;
mod version;
pub use any::*;
pub use connection::*;
pub use key::*;
pub use keychain::*;
pub use listener::*;
pub use packet::*;
pub use port::*;
pub use transport::*;
pub use version::*;
pub use client::{Client, ReconnectStrategy};
/// Authentication functionality tied to network operations.
pub use distant_core_auth as auth;
pub use server::Server;
pub use {log, paste};

@ -5,7 +5,7 @@ use async_trait::async_trait;
use tokio::net::TcpListener as TokioTcpListener;
use super::Listener;
use crate::common::{PortRange, TcpTransport};
use crate::{PortRange, TcpTransport};
/// Represents a [`Listener`] for incoming connections over TCP
pub struct TcpListener {

@ -6,7 +6,7 @@ use async_trait::async_trait;
use tokio::net::{UnixListener, UnixStream};
use super::Listener;
use crate::common::UnixSocketTransport;
use crate::UnixSocketTransport;
/// Represents a [`Listener`] for incoming connections over a Unix socket
pub struct UnixSocketListener {

@ -5,7 +5,7 @@ use async_trait::async_trait;
use tokio::net::windows::named_pipe::{NamedPipeServer, ServerOptions};
use super::Listener;
use crate::common::{NamedPipe, WindowsPipeTransport};
use crate::{NamedPipe, WindowsPipeTransport};
/// Represents a [`Listener`] for incoming connections over a named windows pipe
pub struct WindowsPipeListener {

@ -1,16 +0,0 @@
mod client;
mod data;
mod server;
pub use client::*;
pub use data::*;
pub use server::*;
use crate::common::Version;
/// Represents the version associated with the manager's protocol.
pub const PROTOCOL_VERSION: Version = Version::new(
const_str::parse!(env!("CARGO_PKG_VERSION_MAJOR"), u64),
const_str::parse!(env!("CARGO_PKG_VERSION_MINOR"), u64),
const_str::parse!(env!("CARGO_PKG_VERSION_PATCH"), u64),
);

@ -6,7 +6,7 @@ use derive_more::IntoIterator;
use serde::de::DeserializeOwned;
use serde::{Deserialize, Serialize};
use crate::common::{utils, Value};
use crate::{utils, Value};
/// Generates a new [`Header`] of key/value pairs based on literals.
///
@ -18,7 +18,7 @@ use crate::common::{utils, Value};
#[macro_export]
macro_rules! header {
($($key:literal -> $value:expr),* $(,)?) => {{
let mut _header = $crate::common::Header::default();
let mut _header = $crate::Header::default();
$(
_header.insert($key, $value);

@ -6,8 +6,7 @@ use serde::de::DeserializeOwned;
use serde::{Deserialize, Serialize};
use super::{read_header_bytes, read_key_eq, read_str_bytes, Header, Id};
use crate::common::utils;
use crate::header;
use crate::{header, utils};
/// Represents a request to send
#[derive(Clone, Debug, PartialEq, Eq, Serialize, Deserialize)]

@ -6,8 +6,7 @@ use serde::de::DeserializeOwned;
use serde::{Deserialize, Serialize};
use super::{read_header_bytes, read_key_eq, read_str_bytes, Header, Id};
use crate::common::utils;
use crate::header;
use crate::{header, utils};
/// Represents a response received related to some response
#[derive(Clone, Debug, PartialEq, Eq, Serialize, Deserialize)]

@ -5,7 +5,7 @@ use std::ops::{Deref, DerefMut};
use serde::de::DeserializeOwned;
use serde::{Deserialize, Serialize};
use crate::common::utils;
use crate::utils;
/// Generic value type for data passed through header.
#[derive(Clone, Debug, PartialEq, Eq, Serialize, Deserialize)]

@ -1,473 +0,0 @@
use std::io;
use std::sync::Arc;
use std::time::Duration;
use async_trait::async_trait;
use distant_core_auth::Verifier;
use log::*;
use serde::de::DeserializeOwned;
use serde::Serialize;
use tokio::sync::{broadcast, RwLock};
use crate::common::{ConnectionId, Listener, Response, Transport, Version};
mod builder;
pub use builder::*;
mod config;
pub use config::*;
mod connection;
use connection::*;
mod context;
pub use context::*;
mod r#ref;
pub use r#ref::*;
mod reply;
pub use reply::*;
mod state;
use state::*;
mod shutdown_timer;
use shutdown_timer::*;
/// Represents a server that can be used to receive requests & send responses to clients.
pub struct Server<T> {
/// Custom configuration details associated with the server
config: ServerConfig,
/// Handler used to process various server events
handler: T,
/// Performs authentication using various methods
verifier: Verifier,
/// Version associated with the server used by clients to verify compatibility
version: Version,
}
/// Interface for a handler that receives connections and requests
#[async_trait]
pub trait ServerHandler: Send {
/// Type of data received by the server
type Request;
/// Type of data sent back by the server
type Response;
/// Invoked upon a new connection becoming established.
#[allow(unused_variables)]
async fn on_connect(&self, id: ConnectionId) -> io::Result<()> {
Ok(())
}
/// Invoked upon an existing connection getting dropped.
#[allow(unused_variables)]
async fn on_disconnect(&self, id: ConnectionId) -> io::Result<()> {
Ok(())
}
/// Invoked upon receiving a request from a client. The server should process this
/// request, which can be found in `ctx`, and send one or more replies in response.
async fn on_request(&self, ctx: RequestCtx<Self::Request, Self::Response>);
}
impl Server<()> {
/// Creates a new [`Server`], starting with a default configuration, no authentication methods,
/// and no [`ServerHandler`].
pub fn new() -> Self {
Self {
config: Default::default(),
handler: (),
verifier: Verifier::empty(),
version: Default::default(),
}
}
/// Creates a new [`TcpServerBuilder`] that is used to construct a [`Server`].
pub fn tcp() -> TcpServerBuilder<()> {
TcpServerBuilder::default()
}
/// Creates a new [`UnixSocketServerBuilder`] that is used to construct a [`Server`].
#[cfg(unix)]
pub fn unix_socket() -> UnixSocketServerBuilder<()> {
UnixSocketServerBuilder::default()
}
/// Creates a new [`WindowsPipeServerBuilder`] that is used to construct a [`Server`].
#[cfg(windows)]
pub fn windows_pipe() -> WindowsPipeServerBuilder<()> {
WindowsPipeServerBuilder::default()
}
}
impl Default for Server<()> {
fn default() -> Self {
Self::new()
}
}
impl<T> Server<T> {
/// Consumes the current server, replacing its config with `config` and returning it.
pub fn config(self, config: ServerConfig) -> Self {
Self {
config,
handler: self.handler,
verifier: self.verifier,
version: self.version,
}
}
/// Consumes the current server, replacing its handler with `handler` and returning it.
pub fn handler<U>(self, handler: U) -> Server<U> {
Server {
config: self.config,
handler,
verifier: self.verifier,
version: self.version,
}
}
/// Consumes the current server, replacing its verifier with `verifier` and returning it.
pub fn verifier(self, verifier: Verifier) -> Self {
Self {
config: self.config,
handler: self.handler,
verifier,
version: self.version,
}
}
/// Consumes the current server, replacing its version with `version` and returning it.
pub fn version(self, version: Version) -> Self {
Self {
config: self.config,
handler: self.handler,
verifier: self.verifier,
version,
}
}
}
impl<T> Server<T>
where
T: ServerHandler + Sync + 'static,
T::Request: DeserializeOwned + Send + Sync + 'static,
T::Response: Serialize + Send + 'static,
{
/// Consumes the server, starting a task to process connections from the `listener` and
/// returning a [`ServerRef`] that can be used to control the active server instance.
pub fn start<L>(self, listener: L) -> io::Result<ServerRef>
where
L: Listener + 'static,
L::Output: Transport + 'static,
{
let state = Arc::new(ServerState::new());
let (tx, rx) = broadcast::channel(1);
let task = tokio::spawn(self.task(Arc::clone(&state), listener, tx.clone(), rx));
Ok(ServerRef { shutdown: tx, task })
}
/// Internal task that is run to receive connections and spawn connection tasks
async fn task<L>(
self,
state: Arc<ServerState<Response<T::Response>>>,
mut listener: L,
shutdown_tx: broadcast::Sender<()>,
shutdown_rx: broadcast::Receiver<()>,
) where
L: Listener + 'static,
L::Output: Transport + 'static,
{
let Server {
config,
handler,
verifier,
version,
} = self;
let handler = Arc::new(handler);
let timer = ShutdownTimer::start(config.shutdown);
let mut notification = timer.clone_notification();
let timer = Arc::new(RwLock::new(timer));
let verifier = Arc::new(verifier);
let mut connection_tasks = Vec::new();
loop {
// Receive a new connection, exiting if no longer accepting connections or if the shutdown
// signal has been received
let transport = tokio::select! {
result = listener.accept() => {
match result {
Ok(x) => x,
Err(x) => {
error!("Server no longer accepting connections: {x}");
timer.read().await.abort();
break;
}
}
}
_ = notification.wait() => {
info!(
"Server shutdown triggered after {}s",
config.shutdown.duration().unwrap_or_default().as_secs_f32(),
);
let _ = shutdown_tx.send(());
break;
}
};
// Ensure that the shutdown timer is cancelled now that we have a connection
timer.read().await.stop();
connection_tasks.push(
ConnectionTask::build()
.handler(Arc::downgrade(&handler))
.state(Arc::downgrade(&state))
.keychain(state.keychain.clone())
.transport(transport)
.shutdown(shutdown_rx.resubscribe())
.shutdown_timer(Arc::downgrade(&timer))
.sleep_duration(config.connection_sleep)
.heartbeat_duration(config.connection_heartbeat)
.verifier(Arc::downgrade(&verifier))
.version(version.clone())
.spawn(),
);
// Clean up current tasks being tracked
connection_tasks.retain(|task| !task.is_finished());
}
// Once we stop listening, we still want to wait until all connections have terminated
info!("Server waiting for active connections to terminate");
loop {
connection_tasks.retain(|task| !task.is_finished());
if connection_tasks.is_empty() {
break;
}
tokio::time::sleep(Duration::from_millis(50)).await;
}
info!("Server task terminated");
}
}
#[cfg(test)]
mod tests {
use std::time::Duration;
use async_trait::async_trait;
use distant_core_auth::{AuthenticationMethod, DummyAuthHandler, NoneAuthenticationMethod};
use test_log::test;
use tokio::sync::mpsc;
use super::*;
use crate::common::{Connection, InmemoryTransport, MpscListener, Request, Response};
macro_rules! server_version {
() => {
Version::new(1, 2, 3)
};
}
pub struct TestServerHandler;
#[async_trait]
impl ServerHandler for TestServerHandler {
type Request = u16;
type Response = String;
async fn on_request(&self, ctx: RequestCtx<Self::Request, Self::Response>) {
// Always send back "hello"
ctx.reply.send("hello".to_string()).unwrap();
}
}
#[inline]
fn make_test_server(config: ServerConfig) -> Server<TestServerHandler> {
let methods: Vec<Box<dyn AuthenticationMethod>> =
vec![Box::new(NoneAuthenticationMethod::new())];
Server {
config,
handler: TestServerHandler,
verifier: Verifier::new(methods),
version: server_version!(),
}
}
#[allow(clippy::type_complexity)]
fn make_listener(
buffer: usize,
) -> (
mpsc::Sender<InmemoryTransport>,
MpscListener<InmemoryTransport>,
) {
MpscListener::channel(buffer)
}
#[test(tokio::test)]
async fn should_invoke_handler_upon_receiving_a_request() {
// Create a test listener where we will forward a connection
let (tx, listener) = make_listener(100);
// Make bounded transport pair and send off one of them to act as our connection
let (transport, connection) = InmemoryTransport::pair(100);
tx.send(connection)
.await
.expect("Failed to feed listener a connection");
let _server = make_test_server(ServerConfig::default())
.start(listener)
.expect("Failed to start server");
// Perform handshake and authentication with the server before beginning to send data
let mut connection = Connection::client(transport, DummyAuthHandler, server_version!())
.await
.expect("Failed to connect to server");
connection
.write_frame(Request::new(123).to_vec().unwrap())
.await
.expect("Failed to send request");
// Wait for a response
let frame = connection.read_frame().await.unwrap().unwrap();
let response: Response<String> = Response::from_slice(frame.as_item()).unwrap();
assert_eq!(response.payload, "hello");
}
#[test(tokio::test)]
async fn should_lonely_shutdown_if_no_connections_received_after_n_secs_when_config_set() {
let (_tx, listener) = make_listener(100);
let server = make_test_server(ServerConfig {
shutdown: Shutdown::Lonely(Duration::from_millis(100)),
..Default::default()
})
.start(listener)
.expect("Failed to start server");
// Wait for some time
tokio::time::sleep(Duration::from_millis(300)).await;
assert!(server.is_finished(), "Server shutdown not triggered!");
}
#[test(tokio::test)]
async fn should_lonely_shutdown_if_last_connection_terminated_and_then_no_connections_after_n_secs(
) {
// Create a test listener where we will forward a connection
let (tx, listener) = make_listener(100);
// Make bounded transport pair and send off one of them to act as our connection
let (transport, connection) = InmemoryTransport::pair(100);
tx.send(connection)
.await
.expect("Failed to feed listener a connection");
let server = make_test_server(ServerConfig {
shutdown: Shutdown::Lonely(Duration::from_millis(100)),
..Default::default()
})
.start(listener)
.expect("Failed to start server");
// Drop the connection by dropping the transport
drop(transport);
// Wait for some time
tokio::time::sleep(Duration::from_millis(300)).await;
assert!(server.is_finished(), "Server shutdown not triggered!");
}
#[test(tokio::test)]
async fn should_not_lonely_shutdown_as_long_as_a_connection_exists() {
// Create a test listener where we will forward a connection
let (tx, listener) = make_listener(100);
// Make bounded transport pair and send off one of them to act as our connection
let (_transport, connection) = InmemoryTransport::pair(100);
tx.send(connection)
.await
.expect("Failed to feed listener a connection");
let server = make_test_server(ServerConfig {
shutdown: Shutdown::Lonely(Duration::from_millis(100)),
..Default::default()
})
.start(listener)
.expect("Failed to start server");
// Wait for some time
tokio::time::sleep(Duration::from_millis(300)).await;
assert!(!server.is_finished(), "Server shutdown when it should not!");
}
#[test(tokio::test)]
async fn should_shutdown_after_n_seconds_even_with_connections_if_config_set_to_after() {
let (tx, listener) = make_listener(100);
// Make bounded transport pair and send off one of them to act as our connection
let (_transport, connection) = InmemoryTransport::pair(100);
tx.send(connection)
.await
.expect("Failed to feed listener a connection");
let server = make_test_server(ServerConfig {
shutdown: Shutdown::After(Duration::from_millis(100)),
..Default::default()
})
.start(listener)
.expect("Failed to start server");
// Wait for some time
tokio::time::sleep(Duration::from_millis(300)).await;
assert!(server.is_finished(), "Server shutdown not triggered!");
}
#[test(tokio::test)]
async fn should_shutdown_after_n_seconds_if_config_set_to_after() {
let (_tx, listener) = make_listener(100);
let server = make_test_server(ServerConfig {
shutdown: Shutdown::After(Duration::from_millis(100)),
..Default::default()
})
.start(listener)
.expect("Failed to start server");
// Wait for some time
tokio::time::sleep(Duration::from_millis(300)).await;
assert!(server.is_finished(), "Server shutdown not triggered!");
}
#[test(tokio::test)]
async fn should_never_shutdown_if_config_set_to_never() {
let (_tx, listener) = make_listener(100);
let server = make_test_server(ServerConfig {
shutdown: Shutdown::Never,
..Default::default()
})
.start(listener)
.expect("Failed to start server");
// Wait for some time
tokio::time::sleep(Duration::from_millis(300)).await;
assert!(!server.is_finished(), "Server shutdown when it should not!");
}
}

@ -9,7 +9,7 @@ use serde::de::DeserializeOwned;
use serde::{Deserialize, Serialize};
use super::{InmemoryTransport, Interest, Ready, Reconnectable, Transport};
use crate::common::{utils, SecretKey32};
use crate::{utils, SecretKey32};
mod backup;
mod codec;

@ -3,7 +3,7 @@ use std::{fmt, io};
use derive_more::Display;
use super::{Codec, Frame};
use crate::common::{SecretKey, SecretKey32};
use crate::{SecretKey, SecretKey32};
/// Represents the type of encryption for a [`EncryptionCodec`]
#[derive(

@ -6,7 +6,7 @@ use p256::PublicKey;
use rand::rngs::OsRng;
use sha2::Sha256;
use crate::common::SecretKey32;
use crate::SecretKey32;
mod pkb;
pub use pkb::PublicKeyBytes;

@ -203,7 +203,7 @@ mod tests {
use test_log::test;
use super::*;
use crate::common::TransportExt;
use crate::TransportExt;
#[test]
fn is_rx_closed_should_properly_reflect_if_internal_rx_channel_is_closed() {

@ -79,7 +79,7 @@ mod tests {
use tokio::task::JoinHandle;
use super::*;
use crate::common::TransportExt;
use crate::TransportExt;
async fn find_ephemeral_addr() -> SocketAddr {
// Start a listener on a distinct port, get its port, and kill it

@ -69,7 +69,7 @@ mod tests {
use tokio::task::JoinHandle;
use super::*;
use crate::common::TransportExt;
use crate::TransportExt;
async fn start_and_run_server(tx: oneshot::Sender<PathBuf>) -> io::Result<()> {
// Generate a socket path and delete the file after so there is nothing there

@ -93,7 +93,7 @@ mod tests {
use tokio::task::JoinHandle;
use super::*;
use crate::common::TransportExt;
use crate::TransportExt;
async fn start_and_run_server(tx: oneshot::Sender<String>) -> io::Result<()> {
let pipe = start_server(tx).await?;

@ -14,6 +14,7 @@ license = "MIT OR Apache-2.0"
[dependencies]
async-trait = "0.1.68"
derive_more = { version = "0.99.17", default-features = false, features = ["display", "from", "error"] }
distant-core-net = { version = "=0.21.0", path = "../distant-core-net" }
log = "0.4.18"
serde = { version = "1.0.163", features = ["derive"] }

@ -0,0 +1,473 @@
use std::io;
use std::sync::Arc;
use std::time::Duration;
use async_trait::async_trait;
use distant_core_auth::Verifier;
use log::*;
use serde::de::DeserializeOwned;
use serde::Serialize;
use tokio::sync::{broadcast, RwLock};
use crate::common::{ConnectionId, Listener, Response, Transport, Version};
mod builder;
pub use builder::*;
mod config;
pub use config::*;
mod connection;
use connection::*;
mod context;
pub use context::*;
mod r#ref;
pub use r#ref::*;
mod reply;
pub use reply::*;
mod state;
use state::*;
mod shutdown_timer;
use shutdown_timer::*;
/// Represents a server that can be used to receive requests & send responses to clients.
pub struct Server<T> {
/// Custom configuration details associated with the server
config: ServerConfig,
/// Handler used to process various server events
handler: T,
/// Performs authentication using various methods
verifier: Verifier,
/// Version associated with the server used by clients to verify compatibility
version: Version,
}
/// Interface for a handler that receives connections and requests
#[async_trait]
pub trait ServerHandler: Send {
/// Type of data received by the server
type Request;
/// Type of data sent back by the server
type Response;
/// Invoked upon a new connection becoming established.
#[allow(unused_variables)]
async fn on_connect(&self, id: ConnectionId) -> io::Result<()> {
Ok(())
}
/// Invoked upon an existing connection getting dropped.
#[allow(unused_variables)]
async fn on_disconnect(&self, id: ConnectionId) -> io::Result<()> {
Ok(())
}
/// Invoked upon receiving a request from a client. The server should process this
/// request, which can be found in `ctx`, and send one or more replies in response.
async fn on_request(&self, ctx: RequestCtx<Self::Request, Self::Response>);
}
impl Server<()> {
/// Creates a new [`Server`], starting with a default configuration, no authentication methods,
/// and no [`ServerHandler`].
pub fn new() -> Self {
Self {
config: Default::default(),
handler: (),
verifier: Verifier::empty(),
version: Default::default(),
}
}
/// Creates a new [`TcpServerBuilder`] that is used to construct a [`Server`].
pub fn tcp() -> TcpServerBuilder<()> {
TcpServerBuilder::default()
}
/// Creates a new [`UnixSocketServerBuilder`] that is used to construct a [`Server`].
#[cfg(unix)]
pub fn unix_socket() -> UnixSocketServerBuilder<()> {
UnixSocketServerBuilder::default()
}
/// Creates a new [`WindowsPipeServerBuilder`] that is used to construct a [`Server`].
#[cfg(windows)]
pub fn windows_pipe() -> WindowsPipeServerBuilder<()> {
WindowsPipeServerBuilder::default()
}
}
impl Default for Server<()> {
fn default() -> Self {
Self::new()
}
}
impl<T> Server<T> {
/// Consumes the current server, replacing its config with `config` and returning it.
pub fn config(self, config: ServerConfig) -> Self {
Self {
config,
handler: self.handler,
verifier: self.verifier,
version: self.version,
}
}
/// Consumes the current server, replacing its handler with `handler` and returning it.
pub fn handler<U>(self, handler: U) -> Server<U> {
Server {
config: self.config,
handler,
verifier: self.verifier,
version: self.version,
}
}
/// Consumes the current server, replacing its verifier with `verifier` and returning it.
pub fn verifier(self, verifier: Verifier) -> Self {
Self {
config: self.config,
handler: self.handler,
verifier,
version: self.version,
}
}
/// Consumes the current server, replacing its version with `version` and returning it.
pub fn version(self, version: Version) -> Self {
Self {
config: self.config,
handler: self.handler,
verifier: self.verifier,
version,
}
}
}
impl<T> Server<T>
where
T: ServerHandler + Sync + 'static,
T::Request: DeserializeOwned + Send + Sync + 'static,
T::Response: Serialize + Send + 'static,
{
/// Consumes the server, starting a task to process connections from the `listener` and
/// returning a [`ServerRef`] that can be used to control the active server instance.
pub fn start<L>(self, listener: L) -> io::Result<ServerRef>
where
L: Listener + 'static,
L::Output: Transport + 'static,
{
let state = Arc::new(ServerState::new());
let (tx, rx) = broadcast::channel(1);
let task = tokio::spawn(self.task(Arc::clone(&state), listener, tx.clone(), rx));
Ok(ServerRef { shutdown: tx, task })
}
/// Internal task that is run to receive connections and spawn connection tasks
async fn task<L>(
self,
state: Arc<ServerState<Response<T::Response>>>,
mut listener: L,
shutdown_tx: broadcast::Sender<()>,
shutdown_rx: broadcast::Receiver<()>,
) where
L: Listener + 'static,
L::Output: Transport + 'static,
{
let Server {
config,
handler,
verifier,
version,
} = self;
let handler = Arc::new(handler);
let timer = ShutdownTimer::start(config.shutdown);
let mut notification = timer.clone_notification();
let timer = Arc::new(RwLock::new(timer));
let verifier = Arc::new(verifier);
let mut connection_tasks = Vec::new();
loop {
// Receive a new connection, exiting if no longer accepting connections or if the shutdown
// signal has been received
let transport = tokio::select! {
result = listener.accept() => {
match result {
Ok(x) => x,
Err(x) => {
error!("Server no longer accepting connections: {x}");
timer.read().await.abort();
break;
}
}
}
_ = notification.wait() => {
info!(
"Server shutdown triggered after {}s",
config.shutdown.duration().unwrap_or_default().as_secs_f32(),
);
let _ = shutdown_tx.send(());
break;
}
};
// Ensure that the shutdown timer is cancelled now that we have a connection
timer.read().await.stop();
connection_tasks.push(
ConnectionTask::build()
.handler(Arc::downgrade(&handler))
.state(Arc::downgrade(&state))
.keychain(state.keychain.clone())
.transport(transport)
.shutdown(shutdown_rx.resubscribe())
.shutdown_timer(Arc::downgrade(&timer))
.sleep_duration(config.connection_sleep)
.heartbeat_duration(config.connection_heartbeat)
.verifier(Arc::downgrade(&verifier))
.version(version.clone())
.spawn(),
);
// Clean up current tasks being tracked
connection_tasks.retain(|task| !task.is_finished());
}
// Once we stop listening, we still want to wait until all connections have terminated
info!("Server waiting for active connections to terminate");
loop {
connection_tasks.retain(|task| !task.is_finished());
if connection_tasks.is_empty() {
break;
}
tokio::time::sleep(Duration::from_millis(50)).await;
}
info!("Server task terminated");
}
}
#[cfg(test)]
mod tests {
use std::time::Duration;
use async_trait::async_trait;
use distant_core_auth::{AuthenticationMethod, DummyAuthHandler, NoneAuthenticationMethod};
use test_log::test;
use tokio::sync::mpsc;
use super::*;
use crate::common::{Connection, InmemoryTransport, MpscListener, Request, Response};
macro_rules! server_version {
() => {
Version::new(1, 2, 3)
};
}
pub struct TestServerHandler;
#[async_trait]
impl ServerHandler for TestServerHandler {
type Request = u16;
type Response = String;
async fn on_request(&self, ctx: RequestCtx<Self::Request, Self::Response>) {
// Always send back "hello"
ctx.reply.send("hello".to_string()).unwrap();
}
}
#[inline]
fn make_test_server(config: ServerConfig) -> Server<TestServerHandler> {
let methods: Vec<Box<dyn AuthenticationMethod>> =
vec![Box::new(NoneAuthenticationMethod::new())];
Server {
config,
handler: TestServerHandler,
verifier: Verifier::new(methods),
version: server_version!(),
}
}
#[allow(clippy::type_complexity)]
fn make_listener(
buffer: usize,
) -> (
mpsc::Sender<InmemoryTransport>,
MpscListener<InmemoryTransport>,
) {
MpscListener::channel(buffer)
}
#[test(tokio::test)]
async fn should_invoke_handler_upon_receiving_a_request() {
// Create a test listener where we will forward a connection
let (tx, listener) = make_listener(100);
// Make bounded transport pair and send off one of them to act as our connection
let (transport, connection) = InmemoryTransport::pair(100);
tx.send(connection)
.await
.expect("Failed to feed listener a connection");
let _server = make_test_server(ServerConfig::default())
.start(listener)
.expect("Failed to start server");
// Perform handshake and authentication with the server before beginning to send data
let mut connection = Connection::client(transport, DummyAuthHandler, server_version!())
.await
.expect("Failed to connect to server");
connection
.write_frame(Request::new(123).to_vec().unwrap())
.await
.expect("Failed to send request");
// Wait for a response
let frame = connection.read_frame().await.unwrap().unwrap();
let response: Response<String> = Response::from_slice(frame.as_item()).unwrap();
assert_eq!(response.payload, "hello");
}
#[test(tokio::test)]
async fn should_lonely_shutdown_if_no_connections_received_after_n_secs_when_config_set() {
let (_tx, listener) = make_listener(100);
let server = make_test_server(ServerConfig {
shutdown: Shutdown::Lonely(Duration::from_millis(100)),
..Default::default()
})
.start(listener)
.expect("Failed to start server");
// Wait for some time
tokio::time::sleep(Duration::from_millis(300)).await;
assert!(server.is_finished(), "Server shutdown not triggered!");
}
#[test(tokio::test)]
async fn should_lonely_shutdown_if_last_connection_terminated_and_then_no_connections_after_n_secs(
) {
// Create a test listener where we will forward a connection
let (tx, listener) = make_listener(100);
// Make bounded transport pair and send off one of them to act as our connection
let (transport, connection) = InmemoryTransport::pair(100);
tx.send(connection)
.await
.expect("Failed to feed listener a connection");
let server = make_test_server(ServerConfig {
shutdown: Shutdown::Lonely(Duration::from_millis(100)),
..Default::default()
})
.start(listener)
.expect("Failed to start server");
// Drop the connection by dropping the transport
drop(transport);
// Wait for some time
tokio::time::sleep(Duration::from_millis(300)).await;
assert!(server.is_finished(), "Server shutdown not triggered!");
}
#[test(tokio::test)]
async fn should_not_lonely_shutdown_as_long_as_a_connection_exists() {
// Create a test listener where we will forward a connection
let (tx, listener) = make_listener(100);
// Make bounded transport pair and send off one of them to act as our connection
let (_transport, connection) = InmemoryTransport::pair(100);
tx.send(connection)
.await
.expect("Failed to feed listener a connection");
let server = make_test_server(ServerConfig {
shutdown: Shutdown::Lonely(Duration::from_millis(100)),
..Default::default()
})
.start(listener)
.expect("Failed to start server");
// Wait for some time
tokio::time::sleep(Duration::from_millis(300)).await;
assert!(!server.is_finished(), "Server shutdown when it should not!");
}
#[test(tokio::test)]
async fn should_shutdown_after_n_seconds_even_with_connections_if_config_set_to_after() {
let (tx, listener) = make_listener(100);
// Make bounded transport pair and send off one of them to act as our connection
let (_transport, connection) = InmemoryTransport::pair(100);
tx.send(connection)
.await
.expect("Failed to feed listener a connection");
let server = make_test_server(ServerConfig {
shutdown: Shutdown::After(Duration::from_millis(100)),
..Default::default()
})
.start(listener)
.expect("Failed to start server");
// Wait for some time
tokio::time::sleep(Duration::from_millis(300)).await;
assert!(server.is_finished(), "Server shutdown not triggered!");
}
#[test(tokio::test)]
async fn should_shutdown_after_n_seconds_if_config_set_to_after() {
let (_tx, listener) = make_listener(100);
let server = make_test_server(ServerConfig {
shutdown: Shutdown::After(Duration::from_millis(100)),
..Default::default()
})
.start(listener)
.expect("Failed to start server");
// Wait for some time
tokio::time::sleep(Duration::from_millis(300)).await;
assert!(server.is_finished(), "Server shutdown not triggered!");
}
#[test(tokio::test)]
async fn should_never_shutdown_if_config_set_to_never() {
let (_tx, listener) = make_listener(100);
let server = make_test_server(ServerConfig {
shutdown: Shutdown::Never,
..Default::default()
})
.start(listener)
.expect("Failed to start server");
// Wait for some time
tokio::time::sleep(Duration::from_millis(300)).await;
assert!(!server.is_finished(), "Server shutdown when it should not!");
}
}

@ -1,5 +1,22 @@
use std::io;
use std::sync::mpsc;
use async_trait::async_trait;
use distant_core_protocol::{Request, Response};
///
/// Full API for a distant-compatible client.
#[async_trait]
pub trait Client {}
pub trait Client {
/// Sends a request without waiting for a response; this method is able to be used even
/// if the session's receiving line to the remote server has been severed.
async fn fire(&mut self, request: Request) -> io::Result<()>;
/// Sends a request and returns a mailbox that can receive one or more responses, failing if
/// unable to send a request or if the session's receiving line to the remote server has
/// already been severed.
async fn mail(&mut self, request: Request) -> io::Result<mpsc::Receiver<Response>>;
/// Sends a request and waits for a response, failing if unable to send a request or if
/// the session's receiving line to the remote server has already been severed
async fn send(&mut self, request: Request) -> io::Result<Response>;
}

@ -4,6 +4,7 @@ use std::io;
use async_trait::async_trait;
use distant_core_auth::Authenticator;
use crate::client::Client;
use crate::common::{Destination, Map};
/// Boxed [`LaunchHandler`].
@ -96,21 +97,21 @@ pub trait ConnectHandler: Send + Sync {
destination: &Destination,
options: &Map,
authenticator: &mut dyn Authenticator,
) -> io::Result<UntypedClient>;
) -> io::Result<Box<dyn Client>>;
}
#[async_trait]
impl<F, R> ConnectHandler for F
where
F: Fn(&Destination, &Map, &mut dyn Authenticator) -> R + Send + Sync + 'static,
R: Future<Output = io::Result<UntypedClient>> + Send + 'static,
R: Future<Output = io::Result<Box<dyn Client>>> + Send + 'static,
{
async fn connect(
&self,
destination: &Destination,
options: &Map,
authenticator: &mut dyn Authenticator,
) -> io::Result<UntypedClient> {
) -> io::Result<Box<dyn Client>> {
self(destination, options, authenticator).await
}
}
@ -156,10 +157,10 @@ macro_rules! boxed_connect_handler {
#[cfg(test)]
mod tests {
use distant_core_auth::*;
use test_log::test;
use super::*;
use crate::common::FramedTransport;
#[inline]
fn test_destination() -> Destination {
@ -171,9 +172,48 @@ mod tests {
Map::default()
}
/// Creates an authenticator that does nothing.
#[inline]
fn test_authenticator() -> impl Authenticator {
FramedTransport::pair(1).0
struct __TestAuthenticator;
impl Authenticator for __TestAuthenticator {
async fn initialize(
&mut self,
initialization: Initialization,
) -> io::Result<InitializationResponse> {
unimplemented!()
}
async fn challenge(&mut self, challenge: Challenge) -> io::Result<ChallengeResponse> {
unimplemented!()
}
async fn verify(
&mut self,
verification: Verification,
) -> io::Result<VerificationResponse> {
unimplemented!()
}
async fn info(&mut self, info: Info) -> io::Result<()> {
unimplemented!()
}
async fn error(&mut self, error: Error) -> io::Result<()> {
unimplemented!()
}
async fn start_method(&mut self, start_method: StartMethod) -> io::Result<()> {
unimplemented!()
}
async fn finished(&mut self) -> io::Result<()> {
unimplemented!()
}
}
__TestAuthenticator
}
#[test(tokio::test)]

Loading…
Cancel
Save