Fix shutdown-after such that it now functions

pull/137/head
Chip Senkbeil 2 years ago
parent a0c7c492bd
commit 1ff3ef2db1
No known key found for this signature in database
GPG Key ID: 35EF1F8EC72A4131

@ -6,6 +6,11 @@ The format is based on [Keep a Changelog](https://keepachangelog.com/en/1.0.0/),
and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0.html). and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0.html).
## [Unreleased] ## [Unreleased]
### Fixed
- `shutdown-after` cli parameter and config option now properly shuts down
server after N seconds with no connections
## [0.17.5] - 2022-08-18 ## [0.17.5] - 2022-08-18
### Fixed ### Fixed

@ -3,7 +3,7 @@ use crate::{
ConnectionId, DistantMsg, DistantRequestData, DistantResponseData, ConnectionId, DistantMsg, DistantRequestData, DistantResponseData,
}; };
use async_trait::async_trait; use async_trait::async_trait;
use distant_net::{Reply, Server, ServerCtx}; use distant_net::{Reply, Server, ServerConfig, ServerCtx};
use log::*; use log::*;
use std::{io, path::PathBuf, sync::Arc}; use std::{io, path::PathBuf, sync::Arc};
@ -39,9 +39,9 @@ where
impl DistantApiServer<LocalDistantApi, <LocalDistantApi as DistantApi>::LocalData> { impl DistantApiServer<LocalDistantApi, <LocalDistantApi as DistantApi>::LocalData> {
/// Creates a new server using the [`LocalDistantApi`] implementation /// Creates a new server using the [`LocalDistantApi`] implementation
pub fn local() -> io::Result<Self> { pub fn local(config: ServerConfig) -> io::Result<Self> {
Ok(Self { Ok(Self {
api: LocalDistantApi::initialize()?, api: LocalDistantApi::initialize(config)?,
}) })
} }
} }
@ -60,6 +60,11 @@ fn unsupported<T>(label: &str) -> io::Result<T> {
pub trait DistantApi { pub trait DistantApi {
type LocalData: Send + Sync; type LocalData: Send + Sync;
/// Returns config associated with API server
fn config(&self) -> ServerConfig {
ServerConfig::default()
}
/// Invoked whenever a new connection is established, providing a mutable reference to the /// Invoked whenever a new connection is established, providing a mutable reference to the
/// newly-created local data. This is a way to support modifying local data before it is used. /// newly-created local data. This is a way to support modifying local data before it is used.
#[allow(unused_variables)] #[allow(unused_variables)]
@ -385,6 +390,11 @@ where
type Response = DistantMsg<DistantResponseData>; type Response = DistantMsg<DistantResponseData>;
type LocalData = D; type LocalData = D;
/// Overridden to leverage [`DistantApi`] implementation of `config`
fn config(&self) -> ServerConfig {
T::config(&self.api)
}
/// Overridden to leverage [`DistantApi`] implementation of `on_accept` /// Overridden to leverage [`DistantApi`] implementation of `on_accept`
async fn on_accept(&self, local_data: &mut Self::LocalData) { async fn on_accept(&self, local_data: &mut Self::LocalData) {
T::on_accept(&self.api, local_data).await T::on_accept(&self.api, local_data).await

@ -6,6 +6,7 @@ use crate::{
DistantApi, DistantCtx, DistantApi, DistantCtx,
}; };
use async_trait::async_trait; use async_trait::async_trait;
use distant_net::ServerConfig;
use log::*; use log::*;
use std::{ use std::{
io, io,
@ -25,13 +26,15 @@ use state::*;
/// impementation of the API instead of a proxy to another machine as seen with /// impementation of the API instead of a proxy to another machine as seen with
/// implementations on top of SSH and other protocol /// implementations on top of SSH and other protocol
pub struct LocalDistantApi { pub struct LocalDistantApi {
config: ServerConfig,
state: GlobalState, state: GlobalState,
} }
impl LocalDistantApi { impl LocalDistantApi {
/// Initialize the api instance /// Initialize the api instance
pub fn initialize() -> io::Result<Self> { pub fn initialize(config: ServerConfig) -> io::Result<Self> {
Ok(Self { Ok(Self {
config,
state: GlobalState::initialize()?, state: GlobalState::initialize()?,
}) })
} }
@ -41,6 +44,10 @@ impl LocalDistantApi {
impl DistantApi for LocalDistantApi { impl DistantApi for LocalDistantApi {
type LocalData = ConnectionState; type LocalData = ConnectionState;
fn config(&self) -> ServerConfig {
self.config.clone()
}
/// Injects the global channels into the local connection /// Injects the global channels into the local connection
async fn on_accept(&self, local_data: &mut Self::LocalData) { async fn on_accept(&self, local_data: &mut Self::LocalData) {
local_data.process_channel = self.state.process.clone_channel(); local_data.process_channel = self.state.process.clone_channel();
@ -547,7 +554,7 @@ mod tests {
DistantCtx<ConnectionState>, DistantCtx<ConnectionState>,
mpsc::Receiver<DistantResponseData>, mpsc::Receiver<DistantResponseData>,
) { ) {
let api = LocalDistantApi::initialize().unwrap(); let api = LocalDistantApi::initialize(Default::default()).unwrap();
let (reply, rx) = make_reply(buffer); let (reply, rx) = make_reply(buffer);
let mut local_data = ConnectionState::default(); let mut local_data = ConnectionState::default();
DistantApi::on_accept(&api, &mut local_data).await; DistantApi::on_accept(&api, &mut local_data).await;

@ -33,7 +33,7 @@ async fn should_be_able_to_establish_a_single_connection_and_communicate() {
let (t1, t2) = FramedTransport::pair(100); let (t1, t2) = FramedTransport::pair(100);
// Spawn a server on one end // Spawn a server on one end
let _ = DistantApiServer::local() let _ = DistantApiServer::local(Default::default())
.unwrap() .unwrap()
.start(OneshotListener::from_value(t2.into_split()))?; .start(OneshotListener::from_value(t2.into_split()))?;

@ -25,7 +25,7 @@ impl DistantClientCtx {
let key = SecretKey::default(); let key = SecretKey::default();
let codec = XChaCha20Poly1305Codec::from(key.clone()); let codec = XChaCha20Poly1305Codec::from(key.clone());
if let Ok(api) = LocalDistantApi::initialize() { if let Ok(api) = LocalDistantApi::initialize(Default::default()) {
let port: PortRange = "0".parse().unwrap(); let port: PortRange = "0".parse().unwrap();
let port = { let port = {
let server_ref = DistantApiServer::new(api) let server_ref = DistantApiServer::new(api)

@ -1,6 +1,9 @@
use async_trait::async_trait; use async_trait::async_trait;
use serde::{de::DeserializeOwned, Serialize}; use serde::{de::DeserializeOwned, Serialize};
mod config;
pub use config::*;
mod connection; mod connection;
pub use connection::*; pub use connection::*;
@ -31,6 +34,11 @@ pub trait Server: Send {
/// Type of data to store locally tied to the specific connection /// Type of data to store locally tied to the specific connection
type LocalData: Send + Sync; type LocalData: Send + Sync;
/// Returns configuration tied to server instance
fn config(&self) -> ServerConfig {
ServerConfig::default()
}
/// Invoked immediately on server start, being provided the raw listener to use (untyped /// Invoked immediately on server start, being provided the raw listener to use (untyped
/// transport), and returning the listener when ready to start (enabling servers that need to /// transport), and returning the listener when ready to start (enabling servers that need to
/// tweak a listener to do so) /// tweak a listener to do so)

@ -0,0 +1,10 @@
use serde::{Deserialize, Serialize};
use std::time::Duration;
/// Represents a general-purpose set of properties tied with a server instance
#[derive(Clone, Debug, Default, PartialEq, Eq, Serialize, Deserialize)]
pub struct ServerConfig {
/// If provided, will cause server to shut down if duration is exceeded with no active
/// connections
pub shutdown_after: Option<Duration>,
}

@ -1,11 +1,14 @@
use crate::{ use crate::{
GenericServerRef, Listener, Request, Response, Server, ServerConnection, ServerCtx, ServerRef, utils::Timer, GenericServerRef, Listener, Request, Response, Server, ServerConnection,
ServerReply, ServerState, TypedAsyncRead, TypedAsyncWrite, ServerCtx, ServerRef, ServerReply, ServerState, TypedAsyncRead, TypedAsyncWrite,
}; };
use log::*; use log::*;
use serde::{de::DeserializeOwned, Serialize}; use serde::{de::DeserializeOwned, Serialize};
use std::{io, sync::Arc}; use std::{
use tokio::sync::mpsc; io,
sync::{Arc, Weak},
};
use tokio::sync::{mpsc, Mutex};
mod tcp; mod tcp;
pub use tcp::*; pub use tcp::*;
@ -72,91 +75,162 @@ where
R: TypedAsyncRead<Request<Req>> + Send + 'static, R: TypedAsyncRead<Request<Req>> + Send + 'static,
W: TypedAsyncWrite<Response<Res>> + Send + 'static, W: TypedAsyncWrite<Response<Res>> + Send + 'static,
{ {
// Grab a copy of our server's configuration so we can leverage it below
let config = server.config();
// Create the timer that will be used shutdown the server after duration elapsed
let (shutdown_tx, mut shutdown_rx) = mpsc::channel(1);
// NOTE: We do a manual map such that the shutdown sender is not captured and dropped when
// there is no shutdown after configured. This is because we need the future for the
// shutdown receiver to last forever in the event that there is no shutdown configured,
// not return immediately, which is what would happen if the sender was dropped.
#[allow(clippy::manual_map)]
let mut shutdown_timer = match config.shutdown_after {
Some(duration) => Some(Timer::new(duration, async move {
let _ = shutdown_tx.send(()).await;
})),
None => None,
};
if let Some(timer) = shutdown_timer.as_mut() {
info!(
"Server shutdown timer configured: {}s",
timer.duration().as_secs_f32()
);
timer.start();
}
let mut shutdown_timer = shutdown_timer.map(|timer| Arc::new(Mutex::new(timer)));
loop { loop {
let server = Arc::clone(&server); let server = Arc::clone(&server);
match listener.accept().await {
Ok((mut writer, mut reader)) => { // Receive a new connection, exiting if no longer accepting connections or if the shutdown
let mut connection = ServerConnection::new(); // signal has been received
let connection_id = connection.id; let (mut writer, mut reader) = tokio::select! {
let state = Arc::clone(&state); result = listener.accept() => {
match result {
// Create some default data for the new connection and pass it Ok(x) => x,
// to the callback prior to processing new requests Err(x) => {
let local_data = { error!("Server no longer accepting connections: {}", x);
let mut data = Data::default(); if let Some(timer) = shutdown_timer.take() {
server.on_accept(&mut data).await; timer.lock().await.abort();
Arc::new(data)
};
// Start a writer task that reads from a channel and forwards all
// data through the writer
let (tx, mut rx) = mpsc::channel::<Response<Res>>(1);
connection.writer_task = Some(tokio::spawn(async move {
while let Some(data) = rx.recv().await {
// trace!("[Conn {}] Sending {:?}", connection_id, data.payload);
if let Err(x) = writer.write(data).await {
error!("[Conn {}] Failed to send {:?}", connection_id, x);
break;
}
}
}));
// Start a reader task that reads requests and processes them
// using the provided handler
connection.reader_task = Some(tokio::spawn(async move {
loop {
match reader.read().await {
Ok(Some(request)) => {
let reply = ServerReply {
origin_id: request.id.clone(),
tx: tx.clone(),
};
let ctx = ServerCtx {
connection_id,
request,
reply: reply.clone(),
local_data: Arc::clone(&local_data),
};
server.on_request(ctx).await;
}
Ok(None) => {
debug!("[Conn {}] Connection closed", connection_id);
break;
}
Err(x) => {
// NOTE: We do NOT break out of the loop, as this could happen
// if someone sends bad data at any point, but does not
// mean that the reader itself has failed. This can
// happen from getting non-compliant typed data
error!("[Conn {}] {}", connection_id, x);
}
} }
break;
} }
})); }
state
.connections
.write()
.await
.insert(connection_id, connection);
} }
Err(x) => { _ = shutdown_rx.recv() => {
error!("Server no longer accepting connections: {}", x); info!(
"Server shutdown triggered after {}s",
config.shutdown_after.unwrap_or_default().as_secs_f32(),
);
break; break;
} }
};
let mut connection = ServerConnection::new();
let connection_id = connection.id;
let state = Arc::clone(&state);
// Ensure that the shutdown timer is cancelled now that we have a connection
if let Some(timer) = shutdown_timer.as_ref() {
timer.lock().await.stop();
} }
// Create some default data for the new connection and pass it
// to the callback prior to processing new requests
let local_data = {
let mut data = Data::default();
server.on_accept(&mut data).await;
Arc::new(data)
};
// Start a writer task that reads from a channel and forwards all
// data through the writer
let (tx, mut rx) = mpsc::channel::<Response<Res>>(1);
connection.writer_task = Some(tokio::spawn(async move {
while let Some(data) = rx.recv().await {
// trace!("[Conn {}] Sending {:?}", connection_id, data.payload);
if let Err(x) = writer.write(data).await {
error!("[Conn {}] Failed to send {:?}", connection_id, x);
break;
}
}
}));
// Start a reader task that reads requests and processes them
// using the provided handler
let weak_state = Arc::downgrade(&state);
let weak_shutdown_timer = shutdown_timer
.as_ref()
.map(Arc::downgrade)
.unwrap_or_default();
connection.reader_task = Some(tokio::spawn(async move {
loop {
match reader.read().await {
Ok(Some(request)) => {
let reply = ServerReply {
origin_id: request.id.clone(),
tx: tx.clone(),
};
let ctx = ServerCtx {
connection_id,
request,
reply: reply.clone(),
local_data: Arc::clone(&local_data),
};
server.on_request(ctx).await;
}
Ok(None) => {
debug!("[Conn {}] Connection closed", connection_id);
// Remove the connection from our state if it has closed
if let Some(state) = Weak::upgrade(&weak_state) {
state.connections.write().await.remove(&connection_id);
// If we have no more connections, start the timer
if let Some(timer) = Weak::upgrade(&weak_shutdown_timer) {
if state.connections.read().await.is_empty() {
timer.lock().await.start();
}
}
}
break;
}
Err(x) => {
// NOTE: We do NOT break out of the loop, as this could happen
// if someone sends bad data at any point, but does not
// mean that the reader itself has failed. This can
// happen from getting non-compliant typed data
error!("[Conn {}] {}", connection_id, x);
}
}
}
}));
state
.connections
.write()
.await
.insert(connection_id, connection);
} }
} }
#[cfg(test)] #[cfg(test)]
mod tests { mod tests {
use super::*; use super::*;
use crate::{IntoSplit, MpscListener, MpscTransport}; use crate::{
IntoSplit, MpscListener, MpscTransport, MpscTransportReadHalf, MpscTransportWriteHalf,
ServerConfig,
};
use async_trait::async_trait; use async_trait::async_trait;
use std::time::Duration;
pub struct TestServer; pub struct TestServer(ServerConfig);
#[async_trait] #[async_trait]
impl Server for TestServer { impl Server for TestServer {
@ -164,16 +238,36 @@ mod tests {
type Response = String; type Response = String;
type LocalData = (); type LocalData = ();
fn config(&self) -> ServerConfig {
self.0.clone()
}
async fn on_request(&self, ctx: ServerCtx<Self::Request, Self::Response, Self::LocalData>) { async fn on_request(&self, ctx: ServerCtx<Self::Request, Self::Response, Self::LocalData>) {
// Always send back "hello" // Always send back "hello"
ctx.reply.send("hello".to_string()).await.unwrap(); ctx.reply.send("hello".to_string()).await.unwrap();
} }
} }
#[allow(clippy::type_complexity)]
fn make_listener(
buffer: usize,
) -> (
mpsc::Sender<(
MpscTransportWriteHalf<Response<String>>,
MpscTransportReadHalf<Request<u16>>,
)>,
MpscListener<(
MpscTransportWriteHalf<Response<String>>,
MpscTransportReadHalf<Request<u16>>,
)>,
) {
MpscListener::channel(buffer)
}
#[tokio::test] #[tokio::test]
async fn should_invoke_handler_upon_receiving_a_request() { async fn should_invoke_handler_upon_receiving_a_request() {
// Create a test listener where we will forward a connection // Create a test listener where we will forward a connection
let (tx, listener) = MpscListener::channel(100); let (tx, listener) = make_listener(100);
// Make bounded transport pair and send off one of them to act as our connection // Make bounded transport pair and send off one of them to act as our connection
let (mut transport, connection) = let (mut transport, connection) =
@ -182,7 +276,8 @@ mod tests {
.await .await
.expect("Failed to feed listener a connection"); .expect("Failed to feed listener a connection");
let _server = ServerExt::start(TestServer, listener).expect("Failed to start server"); let _server = ServerExt::start(TestServer(ServerConfig::default()), listener)
.expect("Failed to start server");
transport transport
.write(Request::new(123)) .write(Request::new(123))
@ -192,4 +287,93 @@ mod tests {
let response: Response<String> = transport.read().await.unwrap().unwrap(); let response: Response<String> = transport.read().await.unwrap().unwrap();
assert_eq!(response.payload, "hello"); assert_eq!(response.payload, "hello");
} }
#[tokio::test]
async fn should_shutdown_if_no_connections_received_after_n_secs_when_config_set() {
let (_tx, listener) = make_listener(100);
let server = ServerExt::start(
TestServer(ServerConfig {
shutdown_after: Some(Duration::from_millis(100)),
}),
listener,
)
.expect("Failed to start server");
// Wait for some time
tokio::time::sleep(Duration::from_millis(300)).await;
assert!(server.is_finished(), "Server shutdown not triggered!");
}
#[tokio::test]
async fn should_shutdown_if_last_connection_terminated_and_then_no_connections_after_n_secs() {
// Create a test listener where we will forward a connection
let (tx, listener) = make_listener(100);
// Make bounded transport pair and send off one of them to act as our connection
let (transport, connection) = MpscTransport::<Request<u16>, Response<String>>::pair(100);
tx.send(connection.into_split())
.await
.expect("Failed to feed listener a connection");
let server = ServerExt::start(
TestServer(ServerConfig {
shutdown_after: Some(Duration::from_millis(100)),
}),
listener,
)
.expect("Failed to start server");
// Drop the connection by dropping the transport
drop(transport);
// Wait for some time
tokio::time::sleep(Duration::from_millis(300)).await;
assert!(server.is_finished(), "Server shutdown not triggered!");
}
#[tokio::test]
async fn should_not_shutdown_as_long_as_a_connection_exists() {
// Create a test listener where we will forward a connection
let (tx, listener) = make_listener(100);
// Make bounded transport pair and send off one of them to act as our connection
let (_transport, connection) = MpscTransport::<Request<u16>, Response<String>>::pair(100);
tx.send(connection.into_split())
.await
.expect("Failed to feed listener a connection");
let server = ServerExt::start(
TestServer(ServerConfig {
shutdown_after: Some(Duration::from_millis(100)),
}),
listener,
)
.expect("Failed to start server");
// Wait for some time
tokio::time::sleep(Duration::from_millis(300)).await;
assert!(!server.is_finished(), "Server shutdown when it should not!");
}
#[tokio::test]
async fn should_never_shutdown_if_config_not_set() {
let (_tx, listener) = make_listener(100);
let server = ServerExt::start(
TestServer(ServerConfig {
shutdown_after: None,
}),
listener,
)
.expect("Failed to start server");
// Wait for some time
tokio::time::sleep(Duration::from_millis(300)).await;
assert!(!server.is_finished(), "Server shutdown when it should not!");
}
} }

@ -1,5 +1,6 @@
use serde::{de::DeserializeOwned, Serialize}; use serde::{de::DeserializeOwned, Serialize};
use std::io; use std::{future::Future, io, time::Duration};
use tokio::{sync::mpsc, task::JoinHandle};
pub fn serialize_to_vec<T: Serialize>(value: &T) -> io::Result<Vec<u8>> { pub fn serialize_to_vec<T: Serialize>(value: &T) -> io::Result<Vec<u8>> {
rmp_serde::encode::to_vec_named(value).map_err(|x| { rmp_serde::encode::to_vec_named(value).map_err(|x| {
@ -18,3 +19,147 @@ pub fn deserialize_from_slice<T: DeserializeOwned>(slice: &[u8]) -> io::Result<T
) )
}) })
} }
pub(crate) struct Timer<T>
where
T: Send + 'static,
{
active_timer: Option<JoinHandle<()>>,
callback: JoinHandle<T>,
duration: Duration,
trigger: mpsc::Sender<bool>,
}
impl<T> Timer<T>
where
T: Send + 'static,
{
/// Create a new callback to trigger `future` that will be executed after `duration` is
/// exceeded. The timer is not started yet until `start` is invoked
pub fn new<F>(duration: Duration, future: F) -> Self
where
F: Future<Output = T> + Send + 'static,
{
let (trigger, mut trigger_rx) = mpsc::channel(1);
let callback = tokio::spawn(async move {
trigger_rx.recv().await;
future.await
});
Self {
active_timer: None,
callback,
duration,
trigger,
}
}
/// Returns duration of the timer
pub fn duration(&self) -> Duration {
self.duration
}
/// Starts the timer, re-starting the countdown if already running. If the callback has already
/// been completed, this timer will not invoke it again; however, this will start the timer
/// itself, which will wait the duration and then fail to trigger the callback
pub fn start(&mut self) {
// Cancel the active timer task
self.stop();
// Exit early if callback completed as starting will do nothing
if self.callback.is_finished() {
return;
}
// Create a new active timer task
let duration = self.duration;
let trigger = self.trigger.clone();
self.active_timer = Some(tokio::spawn(async move {
tokio::time::sleep(duration).await;
let _ = trigger.send(true).await;
}));
}
/// Stops the timer, cancelling the internal task, but leaving the callback in place in case
/// the timer is re-started later
pub fn stop(&mut self) {
// Delete the active timer task
if let Some(task) = self.active_timer.take() {
task.abort();
}
}
/// Aborts the timer's callback task and internal task to trigger the callback, which means
/// that the timer will never complete the callback and starting will have no effect
pub fn abort(&self) {
if let Some(task) = self.active_timer.as_ref() {
task.abort();
}
self.callback.abort();
}
}
#[cfg(test)]
mod tests {
use super::*;
mod timer {
use super::*;
#[tokio::test]
async fn should_not_invoke_callback_regardless_of_time_if_not_started() {
let timer = Timer::new(Duration::default(), async {});
tokio::time::sleep(Duration::from_millis(300)).await;
assert!(
!timer.callback.is_finished(),
"Callback completed unexpectedly"
);
}
#[tokio::test]
async fn should_not_invoke_callback_if_only_stop_called() {
let mut timer = Timer::new(Duration::default(), async {});
timer.stop();
tokio::time::sleep(Duration::from_millis(300)).await;
assert!(
!timer.callback.is_finished(),
"Callback completed unexpectedly"
);
}
#[tokio::test]
async fn should_finish_callback_but_not_trigger_it_if_abort_called() {
let (tx, mut rx) = mpsc::channel(1);
let timer = Timer::new(Duration::default(), async move {
let _ = tx.send(()).await;
});
timer.abort();
tokio::time::sleep(Duration::from_millis(300)).await;
assert!(timer.callback.is_finished(), "Callback not finished");
assert!(rx.try_recv().is_err(), "Callback triggered unexpectedly");
}
#[tokio::test]
async fn should_trigger_callback_after_time_elapses_once_started() {
let (tx, mut rx) = mpsc::channel(1);
let mut timer = Timer::new(Duration::default(), async move {
let _ = tx.send(()).await;
});
timer.start();
tokio::time::sleep(Duration::from_millis(300)).await;
assert!(timer.callback.is_finished(), "Callback not finished");
assert!(rx.try_recv().is_ok(), "Callback not triggered");
}
}
}

@ -529,6 +529,7 @@ pub async fn launched_client(
let client = ssh_client let client = ssh_client
.launch_and_connect(DistantLaunchOpts { .launch_and_connect(DistantLaunchOpts {
binary, binary,
args: "--shutdown-after 10".to_string(),
..Default::default() ..Default::default()
}) })
.await .await

@ -5,11 +5,17 @@ use crate::{
use anyhow::Context; use anyhow::Context;
use clap::Subcommand; use clap::Subcommand;
use distant_core::{ use distant_core::{
net::{SecretKey32, ServerRef, TcpServerExt, XChaCha20Poly1305Codec}, net::{
SecretKey32, ServerConfig as NetServerConfig, ServerRef, TcpServerExt,
XChaCha20Poly1305Codec,
},
DistantApiServer, DistantSingleKeyCredentials, Host, DistantApiServer, DistantSingleKeyCredentials, Host,
}; };
use log::*; use log::*;
use std::io::{self, Read, Write}; use std::{
io::{self, Read, Write},
time::Duration,
};
#[derive(Debug, Subcommand)] #[derive(Debug, Subcommand)]
pub enum ServerSubcommand { pub enum ServerSubcommand {
@ -36,18 +42,18 @@ pub enum ServerSubcommand {
} }
impl ServerSubcommand { impl ServerSubcommand {
pub fn run(self, _config: ServerConfig) -> CliResult { pub fn run(self, config: ServerConfig) -> CliResult {
match &self { match &self {
Self::Listen { daemon, .. } if *daemon => Self::run_daemon(self), Self::Listen { daemon, .. } if *daemon => Self::run_daemon(self, config),
Self::Listen { .. } => { Self::Listen { .. } => {
let rt = tokio::runtime::Runtime::new().context("Failed to start up runtime")?; let rt = tokio::runtime::Runtime::new().context("Failed to start up runtime")?;
rt.block_on(Self::async_run(self, false)) rt.block_on(Self::async_run(self, config, false))
} }
} }
} }
#[cfg(windows)] #[cfg(windows)]
fn run_daemon(self) -> CliResult { fn run_daemon(self, _config: ServerConfig) -> CliResult {
use crate::cli::Spawner; use crate::cli::Spawner;
use distant_core::net::{Listener, WindowsPipeListener}; use distant_core::net::{Listener, WindowsPipeListener};
use std::ffi::OsString; use std::ffi::OsString;
@ -96,7 +102,7 @@ impl ServerSubcommand {
} }
#[cfg(unix)] #[cfg(unix)]
fn run_daemon(self) -> CliResult { fn run_daemon(self, config: ServerConfig) -> CliResult {
use fork::{daemon, Fork}; use fork::{daemon, Fork};
// NOTE: We keep the stdin, stdout, stderr open so we can print out the pid with the parent // NOTE: We keep the stdin, stdout, stderr open so we can print out the pid with the parent
@ -104,7 +110,7 @@ impl ServerSubcommand {
match daemon(true, true) { match daemon(true, true) {
Ok(Fork::Child) => { Ok(Fork::Child) => {
let rt = tokio::runtime::Runtime::new().context("Failed to start up runtime")?; let rt = tokio::runtime::Runtime::new().context("Failed to start up runtime")?;
rt.block_on(async { Self::async_run(self, true).await })?; rt.block_on(async { Self::async_run(self, config, true).await })?;
Ok(()) Ok(())
} }
Ok(Fork::Parent(pid)) => { Ok(Fork::Parent(pid)) => {
@ -119,21 +125,30 @@ impl ServerSubcommand {
} }
} }
async fn async_run(self, _is_forked: bool) -> CliResult { async fn async_run(self, config: ServerConfig, _is_forked: bool) -> CliResult {
match self { match self {
Self::Listen { Self::Listen {
config, config: listen_config,
key_from_stdin, key_from_stdin,
#[cfg(windows)] #[cfg(windows)]
output_to_local_pipe, output_to_local_pipe,
.. ..
} => { } => {
let host = config.host.unwrap_or(BindAddress::Any); macro_rules! get {
(@flag $field:ident) => {{
config.listen.$field || listen_config.$field
}};
($field:ident) => {{
config.listen.$field.or(listen_config.$field)
}};
}
let host = get!(host).unwrap_or(BindAddress::Any);
trace!("Starting server using unresolved host '{}'", host); trace!("Starting server using unresolved host '{}'", host);
let addr = host.resolve(config.use_ipv6)?; let addr = host.resolve(get!(@flag use_ipv6))?;
// If specified, change the current working directory of this program // If specified, change the current working directory of this program
if let Some(path) = config.current_dir.as_ref() { if let Some(path) = get!(current_dir) {
debug!("Setting current directory to {:?}", path); debug!("Setting current directory to {:?}", path);
std::env::set_current_dir(path) std::env::set_current_dir(path)
.context("Failed to set new current directory")?; .context("Failed to set new current directory")?;
@ -156,25 +171,26 @@ impl ServerSubcommand {
debug!( debug!(
"Starting local API server, binding to {} {}", "Starting local API server, binding to {} {}",
addr, addr,
match config.port { match get!(port) {
Some(range) => format!("with port in range {}", range), Some(range) => format!("with port in range {}", range),
None => "using an ephemeral port".to_string(), None => "using an ephemeral port".to_string(),
} }
); );
let server = DistantApiServer::local() let server = DistantApiServer::local(NetServerConfig {
.context("Failed to create local distant api")? shutdown_after: get!(shutdown_after).map(Duration::from_secs_f32),
.start(addr, config.port.unwrap_or_else(|| 0.into()), codec) })
.await .context("Failed to create local distant api")?
.with_context(|| { .start(addr, get!(port).unwrap_or_else(|| 0.into()), codec)
format!( .await
"Failed to start server @ {} with {}", .with_context(|| {
addr, format!(
config "Failed to start server @ {} with {}",
.port addr,
.map(|p| format!("port in range {p}")) get!(port)
.unwrap_or_else(|| String::from("ephemeral port")) .map(|p| format!("port in range {p}"))
) .unwrap_or_else(|| String::from("ephemeral port"))
})?; )
})?;
let credentials = DistantSingleKeyCredentials { let credentials = DistantSingleKeyCredentials {
host: Host::from(addr), host: Host::from(addr),

Loading…
Cancel
Save