From 956f7e0119770d68818819a79be900a2df81485e Mon Sep 17 00:00:00 2001 From: Chip Senkbeil Date: Thu, 5 Aug 2021 22:35:06 -0500 Subject: [PATCH] Unfinished timing --- src/cli/opt.rs | 53 +++++++++++++- src/cli/subcommand/action/mod.rs | 2 +- src/cli/subcommand/launch.rs | 120 ++++++++++++++++++++----------- src/cli/subcommand/listen/mod.rs | 75 ++++++++++++------- src/core/utils.rs | 88 ++++++++++++++++++++++- 5 files changed, 265 insertions(+), 73 deletions(-) diff --git a/src/cli/opt.rs b/src/cli/opt.rs index 194bb85..d746ae9 100644 --- a/src/cli/opt.rs +++ b/src/cli/opt.rs @@ -12,6 +12,7 @@ use std::{ net::{AddrParseError, IpAddr, Ipv4Addr, Ipv6Addr, SocketAddr}, path::PathBuf, str::FromStr, + time::Duration, }; use structopt::StructOpt; use strum::{EnumString, EnumVariantNames, IntoStaticStr, VariantNames}; @@ -53,10 +54,17 @@ pub struct CommonOpt { #[structopt(long, global = true)] pub log_file: Option, - /// Represents the maximum time (in milliseconds) to wait for a network + /// Represents the maximum time (in seconds) to wait for a network /// request before timing out; a timeout of 0 implies waiting indefinitely #[structopt(short, long, global = true, default_value = &TIMEOUT_STR)] - pub timeout: usize, + pub timeout: f32, +} + +impl CommonOpt { + /// Creates a new duration representing the timeout in seconds + pub fn to_timeout_duration(&self) -> Duration { + Duration::from_secs_f32(self.timeout) + } } /// Contains options related sessions @@ -331,6 +339,20 @@ pub struct LaunchSubcommand { #[structopt(long)] pub fail_if_socket_exists: bool, + /// The time in seconds before shutting down the server if there are no active + /// connections. The countdown begins once all connections have closed and + /// stops when a new connection is made. In not specified, the server will not + /// shutdown at any point when there are no active connections. + /// + /// In the case of launch, this is only applicable when it is set to socket session + /// as this controls when the unix socket listener would shutdown, not when the + /// remote server it is connected to will shutdown. + /// + /// To configure the remote server's shutdown time, provide it as an argument + /// via `--extra-server-args` + #[structopt(long)] + pub shutdown_after: Option, + /// Runs in background via daemon-mode (does nothing on windows); only applies /// when session is socket #[structopt(short, long)] @@ -391,6 +413,16 @@ pub struct LaunchSubcommand { pub host: String, } +impl LaunchSubcommand { + /// Creates a new duration representing the shutdown period in seconds + pub fn to_shutdown_after_duration(&self) -> Option { + self.shutdown_after + .as_ref() + .copied() + .map(Duration::from_secs_f32) + } +} + /// Represents some range of ports #[derive(Clone, Debug, Display, PartialEq, Eq)] #[display( @@ -483,6 +515,13 @@ pub struct ListenSubcommand { #[structopt(long, default_value = "1000")] pub max_msg_capacity: u16, + /// The time in seconds before shutting down the server if there are no active + /// connections. The countdown begins once all connections have closed and + /// stops when a new connection is made. In not specified, the server will not + /// shutdown at any point when there are no active connections. + #[structopt(long)] + pub shutdown_after: Option, + /// Changes the current working directory (cwd) to the specified directory #[structopt(long)] pub current_dir: Option, @@ -496,3 +535,13 @@ pub struct ListenSubcommand { #[structopt(short, long, value_name = "PORT[:PORT2]", default_value = "8080:8099")] pub port: PortRange, } + +impl ListenSubcommand { + /// Creates a new duration representing the shutdown period in seconds + pub fn to_shutdown_after_duration(&self) -> Option { + self.shutdown_after + .as_ref() + .copied() + .map(Duration::from_secs_f32) + } +} diff --git a/src/cli/subcommand/action/mod.rs b/src/cli/subcommand/action/mod.rs index 928fd4b..63bff18 100644 --- a/src/cli/subcommand/action/mod.rs +++ b/src/cli/subcommand/action/mod.rs @@ -29,7 +29,7 @@ pub fn run(cmd: ActionSubcommand, opt: CommonOpt) -> Result<(), Error> { } async fn run_async(cmd: ActionSubcommand, opt: CommonOpt) -> Result<(), Error> { - let timeout = Duration::from_millis(opt.timeout as u64); + let timeout = opt.to_timeout_duration(); match cmd.session { SessionInput::Environment => { diff --git a/src/cli/subcommand/launch.rs b/src/cli/subcommand/launch.rs index 1aef89d..d42001f 100644 --- a/src/cli/subcommand/launch.rs +++ b/src/cli/subcommand/launch.rs @@ -17,6 +17,7 @@ use std::{marker::Unpin, path::Path, string::FromUtf8Error, sync::Arc}; use tokio::{ io::{self, AsyncRead, AsyncWrite}, process::Command, + runtime::{Handle, Runtime}, sync::{broadcast, mpsc, oneshot, Mutex}, time::Duration, }; @@ -40,7 +41,7 @@ struct ConnState { } pub fn run(cmd: LaunchSubcommand, opt: CommonOpt) -> Result<(), Error> { - let rt = tokio::runtime::Runtime::new()?; + let rt = Runtime::new()?; let session_output = cmd.session; let mode = cmd.mode; let is_daemon = cmd.daemon; @@ -48,7 +49,8 @@ pub fn run(cmd: LaunchSubcommand, opt: CommonOpt) -> Result<(), Error> { let session_file = cmd.session_data.session_file.clone(); let session_socket = cmd.session_data.session_socket.clone(); let fail_if_socket_exists = cmd.fail_if_socket_exists; - let timeout = Duration::from_millis(opt.timeout as u64); + let timeout = opt.to_timeout_duration(); + let shutdown_after = cmd.to_shutdown_after_duration(); let session = rt.block_on(async { spawn_remote_server(cmd, opt).await })?; @@ -81,9 +83,16 @@ pub fn run(cmd: LaunchSubcommand, opt: CommonOpt) -> Result<(), Error> { // NOTE: We need to create a runtime within the forked process as // tokio's runtime doesn't support being transferred from // parent to child in a fork - let rt = tokio::runtime::Runtime::new()?; + let rt = Runtime::new()?; rt.block_on(async { - socket_loop(session_socket, session, timeout, fail_if_socket_exists).await + socket_loop( + session_socket, + session, + timeout, + fail_if_socket_exists, + shutdown_after, + ) + .await })? } Ok(_) => {} @@ -97,7 +106,14 @@ pub fn run(cmd: LaunchSubcommand, opt: CommonOpt) -> Result<(), Error> { session_socket ); rt.block_on(async { - socket_loop(session_socket, session, timeout, fail_if_socket_exists).await + socket_loop( + session_socket, + session, + timeout, + fail_if_socket_exists, + shutdown_after, + ) + .await })? } #[cfg(not(unix))] @@ -133,6 +149,7 @@ async fn socket_loop( session: Session, duration: Duration, fail_if_socket_exists: bool, + shutdown_after: Option, ) -> io::Result<()> { // We need to form a connection with the actual server to forward requests // and responses between connections @@ -171,44 +188,63 @@ async fn socket_loop( debug!("Binding to unix socket: {:?}", socket_path.as_ref()); let listener = tokio::net::UnixListener::bind(socket_path)?; - while let Ok((conn, _)) = listener.accept().await { - // Create a unique id to associate with the connection since its address - // is not guaranteed to have an identifiable string - let conn_id: usize = rand::random(); - - // Establish a proper connection via a handshake, discarding the connection otherwise - let transport = match Transport::from_handshake(conn, None).await { - Ok(transport) => transport, - Err(x) => { - error!(" Failed handshake: {}", conn_id, x); - continue; + let (ct, notify) = utils::new_shutdown_task(Handle::current(), shutdown_after); + + loop { + tokio::select! { + result = listener.accept() => {match result { + Ok((conn, _)) => { + // Create a unique id to associate with the connection since its address + // is not guaranteed to have an identifiable string + let conn_id: usize = rand::random(); + + // Establish a proper connection via a handshake, discarding the connection otherwise + let transport = match Transport::from_handshake(conn, None).await { + Ok(transport) => transport, + Err(x) => { + error!(" Failed handshake: {}", conn_id, x); + continue; + } + }; + let (t_read, t_write) = transport.into_split(); + + // Used to alert our response task of the connection's tenant name + // based on the first + let (tenant_tx, tenant_rx) = oneshot::channel(); + + // Create a state we use to keep track of connection-specific data + debug!(" Initializing internal state", conn_id); + let state = Arc::new(Mutex::new(ConnState::default())); + + // Spawn task to continually receive responses from the client that + // may or may not be relevant to the connection, which will filter + // by tenant and then along any response that matches + let res_rx = broadcaster.subscribe(); + let state_2 = Arc::clone(&state); + tokio::spawn(async move { + handle_conn_outgoing(conn_id, state_2, t_write, tenant_rx, res_rx).await; + }); + + // Spawn task to continually read requests from connection and forward + // them along to be sent via the client + let req_tx = req_tx.clone(); + let ct_2 = Arc::clone(&ct); + tokio::spawn(async move { + ct_2.lock().await.increment(); + handle_conn_incoming(conn_id, state, t_read, tenant_tx, req_tx).await; + ct_2.lock().await.decrement(); + }); + } + Err(x) => { + error!("Listener failed: {}", x); + break; + } + }} + _ = notify.notified() => { + warn!("Reached shutdown timeout, so terminating"); + break; } - }; - let (t_read, t_write) = transport.into_split(); - - // Used to alert our response task of the connection's tenant name - // based on the first - let (tenant_tx, tenant_rx) = oneshot::channel(); - - // Create a state we use to keep track of connection-specific data - debug!(" Initializing internal state", conn_id); - let state = Arc::new(Mutex::new(ConnState::default())); - - // Spawn task to continually receive responses from the client that - // may or may not be relevant to the connection, which will filter - // by tenant and then along any response that matches - let res_rx = broadcaster.subscribe(); - let state_2 = Arc::clone(&state); - tokio::spawn(async move { - handle_conn_outgoing(conn_id, state_2, t_write, tenant_rx, res_rx).await; - }); - - // Spawn task to continually read requests from connection and forward - // them along to be sent via the client - let req_tx = req_tx.clone(); - tokio::spawn(async move { - handle_conn_incoming(conn_id, state, t_read, tenant_tx, req_tx).await; - }); + } } Ok(()) diff --git a/src/cli/subcommand/listen/mod.rs b/src/cli/subcommand/listen/mod.rs index 2c80ec9..ce247d5 100644 --- a/src/cli/subcommand/listen/mod.rs +++ b/src/cli/subcommand/listen/mod.rs @@ -5,6 +5,7 @@ use crate::{ net::{Transport, TransportReadHalf, TransportWriteHalf}, session::Session, state::ServerState, + utils, }, }; use derive_more::{Display, Error, From}; @@ -15,6 +16,7 @@ use std::{net::SocketAddr, sync::Arc}; use tokio::{ io, net::{tcp, TcpListener}, + runtime::Handle, sync::{mpsc, Mutex}, }; @@ -54,6 +56,7 @@ pub fn run(cmd: ListenSubcommand, opt: CommonOpt) -> Result<(), Error> { async fn run_async(cmd: ListenSubcommand, _opt: CommonOpt, is_forked: bool) -> Result<(), Error> { let addr = cmd.host.to_ip_addr(cmd.use_ipv6)?; let socket_addrs = cmd.port.make_socket_addrs(addr); + let shutdown_after = cmd.to_shutdown_after_duration(); // If specified, change the current working directory of this program if let Some(path) = cmd.current_dir.as_ref() { @@ -89,37 +92,55 @@ async fn run_async(cmd: ListenSubcommand, _opt: CommonOpt, is_forked: bool) -> R // Build our state for the server let state: Arc>> = Arc::new(Mutex::new(ServerState::default())); + let (ct, notify) = utils::new_shutdown_task(Handle::current(), shutdown_after); // Wait for a client connection, then spawn a new task to handle // receiving data from the client - while let Ok((client, addr)) = listener.accept().await { - // Establish a proper connection via a handshake, discarding the connection otherwise - let transport = match Transport::from_handshake(client, Some(Arc::clone(&key))).await { - Ok(transport) => transport, - Err(x) => { - error!(" Failed handshake: {}", addr, x); - continue; - } - }; - - // Split the transport into read and write halves so we can handle input - // and output concurrently - let (t_read, t_write) = transport.into_split(); - let (tx, rx) = mpsc::channel(cmd.max_msg_capacity as usize); - - // Spawn a new task that loops to handle requests from the client - tokio::spawn({ - let f = request_loop(addr, Arc::clone(&state), t_read, tx); - - let state = Arc::clone(&state); - async move { - f.await; - state.lock().await.cleanup_client(addr).await; + loop { + tokio::select! { + result = listener.accept() => {match result { + Ok((client, addr)) => { + // Establish a proper connection via a handshake, discarding the connection otherwise + let transport = match Transport::from_handshake(client, Some(Arc::clone(&key))).await { + Ok(transport) => transport, + Err(x) => { + error!(" Failed handshake: {}", addr, x); + continue; + } + }; + + // Split the transport into read and write halves so we can handle input + // and output concurrently + let (t_read, t_write) = transport.into_split(); + let (tx, rx) = mpsc::channel(cmd.max_msg_capacity as usize); + let ct_2 = Arc::clone(&ct); + + // Spawn a new task that loops to handle requests from the client + tokio::spawn({ + let f = request_loop(addr, Arc::clone(&state), t_read, tx); + + let state = Arc::clone(&state); + async move { + ct_2.lock().await.increment(); + f.await; + state.lock().await.cleanup_client(addr).await; + ct_2.lock().await.decrement(); + } + }); + + // Spawn a new task that loops to handle responses to the client + tokio::spawn(async move { response_loop(addr, t_write, rx).await }); + } + Err(x) => { + error!("Listener failed: {}", x); + break; + } + }} + _ = notify.notified() => { + warn!("Reached shutdown timeout, so terminating"); + break; } - }); - - // Spawn a new task that loops to handle responses to the client - tokio::spawn(async move { response_loop(addr, t_write, rx).await }); + } } Ok(()) diff --git a/src/core/utils.rs b/src/core/utils.rs index f708041..82da1bd 100644 --- a/src/core/utils.rs +++ b/src/core/utils.rs @@ -1,9 +1,16 @@ +use log::*; use std::{ future::Future, ops::{Deref, DerefMut}, + sync::Arc, time::Duration, }; -use tokio::{io, time}; +use tokio::{ + io, + runtime::Handle, + sync::{Mutex, Notify}, + time::{self, Instant}, +}; // Generates a new tenant name pub fn new_tenant() -> String { @@ -21,6 +28,85 @@ where .map_err(|x| io::Error::new(io::ErrorKind::TimedOut, x)) } +pub struct ConnTracker { + time: Instant, + cnt: usize, +} + +impl ConnTracker { + pub fn new() -> Self { + Self { + time: Instant::now(), + cnt: 0, + } + } + + pub fn time(&self) -> Instant { + self.time + } + + pub fn increment(&mut self) { + self.time = Instant::now(); + self.cnt += 1; + } + + pub fn decrement(&mut self) { + if self.cnt > 0 { + self.time = Instant::now(); + self.cnt -= 1; + } + } + + pub fn has_exceeded_timeout(&self, duration: Duration) -> bool { + self.time.elapsed() > duration + } +} + +/// Spawns a new task that continues to monitor the time since a +/// connection on the server existed, shutting down the runtime +/// if the time is exceeded +pub fn new_shutdown_task( + handle: Handle, + duration: Option, +) -> (Arc>, Arc) { + let ct = Arc::new(Mutex::new(ConnTracker::new())); + let notify = Arc::new(Notify::new()); + + let ct_2 = Arc::clone(&ct); + let notify_2 = Arc::clone(¬ify); + if let Some(duration) = duration { + handle.spawn(async move { + loop { + // Get the time we should wait based on when the last connection + // was dropped; this closes the gap in the case where we start + // sometime later than exactly duration since the last check + match ct_2.lock().await.time().checked_add(duration) { + Some(next_time) => { + // Wait until we've reached our desired duration since the + // last connection was dropped + time::sleep_until(next_time).await; + + // If we do have a connection at this point, don't exit + if ct_2.lock().await.has_exceeded_timeout(duration) { + continue; + } + + // Otherwise, we now should exit, which we do by reporting + notify_2.notify_one(); + break; + } + None => { + error!("Shutdown check time forecast failed! Task is exiting!"); + break; + } + } + } + }); + } + + (ct, notify) +} + /// Wraps a string to provide some friendly read and write methods #[derive(Clone, Debug, PartialEq, Eq, Hash)] pub struct StringBuf(String);