Unfinished timing

pull/38/head
Chip Senkbeil 3 years ago
parent df80f261bc
commit 956f7e0119
No known key found for this signature in database
GPG Key ID: 35EF1F8EC72A4131

@ -12,6 +12,7 @@ use std::{
net::{AddrParseError, IpAddr, Ipv4Addr, Ipv6Addr, SocketAddr},
path::PathBuf,
str::FromStr,
time::Duration,
};
use structopt::StructOpt;
use strum::{EnumString, EnumVariantNames, IntoStaticStr, VariantNames};
@ -53,10 +54,17 @@ pub struct CommonOpt {
#[structopt(long, global = true)]
pub log_file: Option<PathBuf>,
/// Represents the maximum time (in milliseconds) to wait for a network
/// Represents the maximum time (in seconds) to wait for a network
/// request before timing out; a timeout of 0 implies waiting indefinitely
#[structopt(short, long, global = true, default_value = &TIMEOUT_STR)]
pub timeout: usize,
pub timeout: f32,
}
impl CommonOpt {
/// Creates a new duration representing the timeout in seconds
pub fn to_timeout_duration(&self) -> Duration {
Duration::from_secs_f32(self.timeout)
}
}
/// Contains options related sessions
@ -331,6 +339,20 @@ pub struct LaunchSubcommand {
#[structopt(long)]
pub fail_if_socket_exists: bool,
/// The time in seconds before shutting down the server if there are no active
/// connections. The countdown begins once all connections have closed and
/// stops when a new connection is made. In not specified, the server will not
/// shutdown at any point when there are no active connections.
///
/// In the case of launch, this is only applicable when it is set to socket session
/// as this controls when the unix socket listener would shutdown, not when the
/// remote server it is connected to will shutdown.
///
/// To configure the remote server's shutdown time, provide it as an argument
/// via `--extra-server-args`
#[structopt(long)]
pub shutdown_after: Option<f32>,
/// Runs in background via daemon-mode (does nothing on windows); only applies
/// when session is socket
#[structopt(short, long)]
@ -391,6 +413,16 @@ pub struct LaunchSubcommand {
pub host: String,
}
impl LaunchSubcommand {
/// Creates a new duration representing the shutdown period in seconds
pub fn to_shutdown_after_duration(&self) -> Option<Duration> {
self.shutdown_after
.as_ref()
.copied()
.map(Duration::from_secs_f32)
}
}
/// Represents some range of ports
#[derive(Clone, Debug, Display, PartialEq, Eq)]
#[display(
@ -483,6 +515,13 @@ pub struct ListenSubcommand {
#[structopt(long, default_value = "1000")]
pub max_msg_capacity: u16,
/// The time in seconds before shutting down the server if there are no active
/// connections. The countdown begins once all connections have closed and
/// stops when a new connection is made. In not specified, the server will not
/// shutdown at any point when there are no active connections.
#[structopt(long)]
pub shutdown_after: Option<f32>,
/// Changes the current working directory (cwd) to the specified directory
#[structopt(long)]
pub current_dir: Option<PathBuf>,
@ -496,3 +535,13 @@ pub struct ListenSubcommand {
#[structopt(short, long, value_name = "PORT[:PORT2]", default_value = "8080:8099")]
pub port: PortRange,
}
impl ListenSubcommand {
/// Creates a new duration representing the shutdown period in seconds
pub fn to_shutdown_after_duration(&self) -> Option<Duration> {
self.shutdown_after
.as_ref()
.copied()
.map(Duration::from_secs_f32)
}
}

@ -29,7 +29,7 @@ pub fn run(cmd: ActionSubcommand, opt: CommonOpt) -> Result<(), Error> {
}
async fn run_async(cmd: ActionSubcommand, opt: CommonOpt) -> Result<(), Error> {
let timeout = Duration::from_millis(opt.timeout as u64);
let timeout = opt.to_timeout_duration();
match cmd.session {
SessionInput::Environment => {

@ -17,6 +17,7 @@ use std::{marker::Unpin, path::Path, string::FromUtf8Error, sync::Arc};
use tokio::{
io::{self, AsyncRead, AsyncWrite},
process::Command,
runtime::{Handle, Runtime},
sync::{broadcast, mpsc, oneshot, Mutex},
time::Duration,
};
@ -40,7 +41,7 @@ struct ConnState {
}
pub fn run(cmd: LaunchSubcommand, opt: CommonOpt) -> Result<(), Error> {
let rt = tokio::runtime::Runtime::new()?;
let rt = Runtime::new()?;
let session_output = cmd.session;
let mode = cmd.mode;
let is_daemon = cmd.daemon;
@ -48,7 +49,8 @@ pub fn run(cmd: LaunchSubcommand, opt: CommonOpt) -> Result<(), Error> {
let session_file = cmd.session_data.session_file.clone();
let session_socket = cmd.session_data.session_socket.clone();
let fail_if_socket_exists = cmd.fail_if_socket_exists;
let timeout = Duration::from_millis(opt.timeout as u64);
let timeout = opt.to_timeout_duration();
let shutdown_after = cmd.to_shutdown_after_duration();
let session = rt.block_on(async { spawn_remote_server(cmd, opt).await })?;
@ -81,9 +83,16 @@ pub fn run(cmd: LaunchSubcommand, opt: CommonOpt) -> Result<(), Error> {
// NOTE: We need to create a runtime within the forked process as
// tokio's runtime doesn't support being transferred from
// parent to child in a fork
let rt = tokio::runtime::Runtime::new()?;
let rt = Runtime::new()?;
rt.block_on(async {
socket_loop(session_socket, session, timeout, fail_if_socket_exists).await
socket_loop(
session_socket,
session,
timeout,
fail_if_socket_exists,
shutdown_after,
)
.await
})?
}
Ok(_) => {}
@ -97,7 +106,14 @@ pub fn run(cmd: LaunchSubcommand, opt: CommonOpt) -> Result<(), Error> {
session_socket
);
rt.block_on(async {
socket_loop(session_socket, session, timeout, fail_if_socket_exists).await
socket_loop(
session_socket,
session,
timeout,
fail_if_socket_exists,
shutdown_after,
)
.await
})?
}
#[cfg(not(unix))]
@ -133,6 +149,7 @@ async fn socket_loop(
session: Session,
duration: Duration,
fail_if_socket_exists: bool,
shutdown_after: Option<Duration>,
) -> io::Result<()> {
// We need to form a connection with the actual server to forward requests
// and responses between connections
@ -171,44 +188,63 @@ async fn socket_loop(
debug!("Binding to unix socket: {:?}", socket_path.as_ref());
let listener = tokio::net::UnixListener::bind(socket_path)?;
while let Ok((conn, _)) = listener.accept().await {
// Create a unique id to associate with the connection since its address
// is not guaranteed to have an identifiable string
let conn_id: usize = rand::random();
// Establish a proper connection via a handshake, discarding the connection otherwise
let transport = match Transport::from_handshake(conn, None).await {
Ok(transport) => transport,
Err(x) => {
error!("<Client @ {:?}> Failed handshake: {}", conn_id, x);
continue;
let (ct, notify) = utils::new_shutdown_task(Handle::current(), shutdown_after);
loop {
tokio::select! {
result = listener.accept() => {match result {
Ok((conn, _)) => {
// Create a unique id to associate with the connection since its address
// is not guaranteed to have an identifiable string
let conn_id: usize = rand::random();
// Establish a proper connection via a handshake, discarding the connection otherwise
let transport = match Transport::from_handshake(conn, None).await {
Ok(transport) => transport,
Err(x) => {
error!("<Client @ {:?}> Failed handshake: {}", conn_id, x);
continue;
}
};
let (t_read, t_write) = transport.into_split();
// Used to alert our response task of the connection's tenant name
// based on the first
let (tenant_tx, tenant_rx) = oneshot::channel();
// Create a state we use to keep track of connection-specific data
debug!("<Client @ {}> Initializing internal state", conn_id);
let state = Arc::new(Mutex::new(ConnState::default()));
// Spawn task to continually receive responses from the client that
// may or may not be relevant to the connection, which will filter
// by tenant and then along any response that matches
let res_rx = broadcaster.subscribe();
let state_2 = Arc::clone(&state);
tokio::spawn(async move {
handle_conn_outgoing(conn_id, state_2, t_write, tenant_rx, res_rx).await;
});
// Spawn task to continually read requests from connection and forward
// them along to be sent via the client
let req_tx = req_tx.clone();
let ct_2 = Arc::clone(&ct);
tokio::spawn(async move {
ct_2.lock().await.increment();
handle_conn_incoming(conn_id, state, t_read, tenant_tx, req_tx).await;
ct_2.lock().await.decrement();
});
}
Err(x) => {
error!("Listener failed: {}", x);
break;
}
}}
_ = notify.notified() => {
warn!("Reached shutdown timeout, so terminating");
break;
}
};
let (t_read, t_write) = transport.into_split();
// Used to alert our response task of the connection's tenant name
// based on the first
let (tenant_tx, tenant_rx) = oneshot::channel();
// Create a state we use to keep track of connection-specific data
debug!("<Client @ {}> Initializing internal state", conn_id);
let state = Arc::new(Mutex::new(ConnState::default()));
// Spawn task to continually receive responses from the client that
// may or may not be relevant to the connection, which will filter
// by tenant and then along any response that matches
let res_rx = broadcaster.subscribe();
let state_2 = Arc::clone(&state);
tokio::spawn(async move {
handle_conn_outgoing(conn_id, state_2, t_write, tenant_rx, res_rx).await;
});
// Spawn task to continually read requests from connection and forward
// them along to be sent via the client
let req_tx = req_tx.clone();
tokio::spawn(async move {
handle_conn_incoming(conn_id, state, t_read, tenant_tx, req_tx).await;
});
}
}
Ok(())

@ -5,6 +5,7 @@ use crate::{
net::{Transport, TransportReadHalf, TransportWriteHalf},
session::Session,
state::ServerState,
utils,
},
};
use derive_more::{Display, Error, From};
@ -15,6 +16,7 @@ use std::{net::SocketAddr, sync::Arc};
use tokio::{
io,
net::{tcp, TcpListener},
runtime::Handle,
sync::{mpsc, Mutex},
};
@ -54,6 +56,7 @@ pub fn run(cmd: ListenSubcommand, opt: CommonOpt) -> Result<(), Error> {
async fn run_async(cmd: ListenSubcommand, _opt: CommonOpt, is_forked: bool) -> Result<(), Error> {
let addr = cmd.host.to_ip_addr(cmd.use_ipv6)?;
let socket_addrs = cmd.port.make_socket_addrs(addr);
let shutdown_after = cmd.to_shutdown_after_duration();
// If specified, change the current working directory of this program
if let Some(path) = cmd.current_dir.as_ref() {
@ -89,37 +92,55 @@ async fn run_async(cmd: ListenSubcommand, _opt: CommonOpt, is_forked: bool) -> R
// Build our state for the server
let state: Arc<Mutex<ServerState<SocketAddr>>> = Arc::new(Mutex::new(ServerState::default()));
let (ct, notify) = utils::new_shutdown_task(Handle::current(), shutdown_after);
// Wait for a client connection, then spawn a new task to handle
// receiving data from the client
while let Ok((client, addr)) = listener.accept().await {
// Establish a proper connection via a handshake, discarding the connection otherwise
let transport = match Transport::from_handshake(client, Some(Arc::clone(&key))).await {
Ok(transport) => transport,
Err(x) => {
error!("<Client @ {}> Failed handshake: {}", addr, x);
continue;
}
};
// Split the transport into read and write halves so we can handle input
// and output concurrently
let (t_read, t_write) = transport.into_split();
let (tx, rx) = mpsc::channel(cmd.max_msg_capacity as usize);
// Spawn a new task that loops to handle requests from the client
tokio::spawn({
let f = request_loop(addr, Arc::clone(&state), t_read, tx);
let state = Arc::clone(&state);
async move {
f.await;
state.lock().await.cleanup_client(addr).await;
loop {
tokio::select! {
result = listener.accept() => {match result {
Ok((client, addr)) => {
// Establish a proper connection via a handshake, discarding the connection otherwise
let transport = match Transport::from_handshake(client, Some(Arc::clone(&key))).await {
Ok(transport) => transport,
Err(x) => {
error!("<Client @ {}> Failed handshake: {}", addr, x);
continue;
}
};
// Split the transport into read and write halves so we can handle input
// and output concurrently
let (t_read, t_write) = transport.into_split();
let (tx, rx) = mpsc::channel(cmd.max_msg_capacity as usize);
let ct_2 = Arc::clone(&ct);
// Spawn a new task that loops to handle requests from the client
tokio::spawn({
let f = request_loop(addr, Arc::clone(&state), t_read, tx);
let state = Arc::clone(&state);
async move {
ct_2.lock().await.increment();
f.await;
state.lock().await.cleanup_client(addr).await;
ct_2.lock().await.decrement();
}
});
// Spawn a new task that loops to handle responses to the client
tokio::spawn(async move { response_loop(addr, t_write, rx).await });
}
Err(x) => {
error!("Listener failed: {}", x);
break;
}
}}
_ = notify.notified() => {
warn!("Reached shutdown timeout, so terminating");
break;
}
});
// Spawn a new task that loops to handle responses to the client
tokio::spawn(async move { response_loop(addr, t_write, rx).await });
}
}
Ok(())

@ -1,9 +1,16 @@
use log::*;
use std::{
future::Future,
ops::{Deref, DerefMut},
sync::Arc,
time::Duration,
};
use tokio::{io, time};
use tokio::{
io,
runtime::Handle,
sync::{Mutex, Notify},
time::{self, Instant},
};
// Generates a new tenant name
pub fn new_tenant() -> String {
@ -21,6 +28,85 @@ where
.map_err(|x| io::Error::new(io::ErrorKind::TimedOut, x))
}
pub struct ConnTracker {
time: Instant,
cnt: usize,
}
impl ConnTracker {
pub fn new() -> Self {
Self {
time: Instant::now(),
cnt: 0,
}
}
pub fn time(&self) -> Instant {
self.time
}
pub fn increment(&mut self) {
self.time = Instant::now();
self.cnt += 1;
}
pub fn decrement(&mut self) {
if self.cnt > 0 {
self.time = Instant::now();
self.cnt -= 1;
}
}
pub fn has_exceeded_timeout(&self, duration: Duration) -> bool {
self.time.elapsed() > duration
}
}
/// Spawns a new task that continues to monitor the time since a
/// connection on the server existed, shutting down the runtime
/// if the time is exceeded
pub fn new_shutdown_task(
handle: Handle,
duration: Option<Duration>,
) -> (Arc<Mutex<ConnTracker>>, Arc<Notify>) {
let ct = Arc::new(Mutex::new(ConnTracker::new()));
let notify = Arc::new(Notify::new());
let ct_2 = Arc::clone(&ct);
let notify_2 = Arc::clone(&notify);
if let Some(duration) = duration {
handle.spawn(async move {
loop {
// Get the time we should wait based on when the last connection
// was dropped; this closes the gap in the case where we start
// sometime later than exactly duration since the last check
match ct_2.lock().await.time().checked_add(duration) {
Some(next_time) => {
// Wait until we've reached our desired duration since the
// last connection was dropped
time::sleep_until(next_time).await;
// If we do have a connection at this point, don't exit
if ct_2.lock().await.has_exceeded_timeout(duration) {
continue;
}
// Otherwise, we now should exit, which we do by reporting
notify_2.notify_one();
break;
}
None => {
error!("Shutdown check time forecast failed! Task is exiting!");
break;
}
}
}
});
}
(ct, notify)
}
/// Wraps a string to provide some friendly read and write methods
#[derive(Clone, Debug, PartialEq, Eq, Hash)]
pub struct StringBuf(String);

Loading…
Cancel
Save