diff --git a/src/cli/exit.rs b/src/cli/exit.rs index 5e98daa..b65149c 100644 --- a/src/cli/exit.rs +++ b/src/cli/exit.rs @@ -88,7 +88,6 @@ impl ExitCodeError for RemoteProcessError { match self { Self::BadResponse => ExitCode::DataErr, Self::ChannelDead => ExitCode::Unavailable, - Self::Overloaded => ExitCode::Software, Self::TransportError(x) => x.to_exit_code(), Self::UnexpectedEof => ExitCode::IoError, Self::WaitFailed(_) => ExitCode::Software, diff --git a/src/cli/opt.rs b/src/cli/opt.rs index 4e82b05..3b653b8 100644 --- a/src/cli/opt.rs +++ b/src/cli/opt.rs @@ -95,6 +95,9 @@ pub enum Subcommand { /// Begins listening for incoming requests Listen(ListenSubcommand), + + /// Specialized treatment of running a remote LSP process + Lsp(LspSubcommand), } impl Subcommand { @@ -104,6 +107,7 @@ impl Subcommand { Self::Action(cmd) => subcommand::action::run(cmd, opt)?, Self::Launch(cmd) => subcommand::launch::run(cmd, opt)?, Self::Listen(cmd) => subcommand::listen::run(cmd, opt)?, + Self::Lsp(cmd) => subcommand::lsp::run(cmd, opt)?, } Ok(()) @@ -499,3 +503,42 @@ impl ListenSubcommand { .map(Duration::from_secs_f32) } } + +/// Represents subcommand to execute some LSP server on a remote machine +#[derive(Debug, StructOpt)] +#[structopt(verbatim_doc_comment)] +pub struct LspSubcommand { + /// Represents the format that results should be returned + /// + /// Currently, there are two possible formats: + /// + /// 1. "json": printing out JSON for external program usage + /// + /// 2. "shell": printing out human-readable results for interactive shell usage + #[structopt( + short, + long, + case_insensitive = true, + default_value = Format::Shell.into(), + possible_values = Format::VARIANTS + )] + pub format: Format, + + /// Represents the medium for retrieving a session to use when running a remote LSP server + #[structopt( + long, + default_value = SessionInput::default().into(), + possible_values = SessionInput::VARIANTS + )] + pub session: SessionInput, + + /// Contains additional information related to sessions + #[structopt(flatten)] + pub session_data: SessionOpt, + + /// Command to run on the remote machine that represents an LSP server + pub cmd: String, + + /// Additional arguments to supply to the remote machine + pub args: Vec, +} diff --git a/src/cli/session.rs b/src/cli/session.rs index b93d1ae..618fdde 100644 --- a/src/cli/session.rs +++ b/src/cli/session.rs @@ -11,7 +11,6 @@ use log::*; use std::{io, thread}; use structopt::StructOpt; use tokio::{sync::mpsc, task::JoinHandle}; -use tokio_stream::{wrappers::BroadcastStream, StreamExt}; /// Represents a wrapper around a session that provides CLI functionality such as reading from /// stdin and piping results back out to stdout @@ -29,9 +28,11 @@ impl CliSession { let (stdin_thread, stdin_rx) = stdin::spawn_channel(MAX_PIPE_CHUNK_SIZE); let (exit_tx, exit_rx) = mpsc::channel(1); - let stream = session.to_response_broadcast_stream(); + let broadcast = session.broadcast.take().unwrap(); let res_task = - tokio::spawn(async move { process_incoming_responses(stream, format, exit_rx).await }); + tokio::spawn( + async move { process_incoming_responses(broadcast, format, exit_rx).await }, + ); let map_line = move |line: &str| match format { Format::Json => serde_json::from_str(&line) @@ -76,16 +77,15 @@ impl CliSession { /// Helper function that loops, processing incoming responses not tied to a request to be sent out /// over stdout/stderr async fn process_incoming_responses( - mut stream: BroadcastStream, + mut broadcast: mpsc::Receiver, format: Format, mut exit: mpsc::Receiver<()>, ) -> io::Result<()> { loop { tokio::select! { - res = stream.next() => { + res = broadcast.recv() => { match res { - Some(Ok(res)) => ResponseOut::new(format, res)?.print(), - Some(Err(x)) => return Err(io::Error::new(io::ErrorKind::BrokenPipe, x)), + Some(res) => ResponseOut::new(format, res)?.print(), None => return Ok(()), } } diff --git a/src/cli/subcommand/launch.rs b/src/cli/subcommand/launch.rs index 7609485..2857706 100644 --- a/src/cli/subcommand/launch.rs +++ b/src/cli/subcommand/launch.rs @@ -157,7 +157,7 @@ async fn socket_loop( debug!("Binding to unix socket: {:?}", socket_path.as_ref()); let listener = tokio::net::UnixListener::bind(socket_path)?; - let server = RelayServer::initialize(session, listener, shutdown_after).await?; + let server = RelayServer::initialize(session, listener, shutdown_after)?; server .wait() .await diff --git a/src/cli/subcommand/lsp.rs b/src/cli/subcommand/lsp.rs new file mode 100644 index 0000000..67116cc --- /dev/null +++ b/src/cli/subcommand/lsp.rs @@ -0,0 +1,141 @@ +use crate::{ + cli::{ + link::RemoteProcessLink, + opt::{CommonOpt, LspSubcommand, SessionInput}, + ExitCode, ExitCodeError, + }, + core::{ + client::{ + self, LspData, RemoteLspProcess, RemoteProcessError, Session, SessionInfo, + SessionInfoFile, + }, + net::DataStream, + }, +}; +use derive_more::{Display, Error, From}; +use tokio::io; + +#[derive(Debug, Display, Error, From)] +pub enum Error { + #[display(fmt = "Process failed with exit code: {}", _0)] + BadProcessExit(#[error(not(source))] i32), + IoError(io::Error), + RemoteProcessError(RemoteProcessError), +} + +impl ExitCodeError for Error { + fn to_exit_code(&self) -> ExitCode { + match self { + Self::BadProcessExit(x) => ExitCode::Custom(*x), + Self::IoError(x) => x.to_exit_code(), + Self::RemoteProcessError(x) => x.to_exit_code(), + } + } +} + +pub fn run(cmd: LspSubcommand, opt: CommonOpt) -> Result<(), Error> { + let rt = tokio::runtime::Runtime::new()?; + + rt.block_on(async { run_async(cmd, opt).await }) +} + +async fn run_async(cmd: LspSubcommand, opt: CommonOpt) -> Result<(), Error> { + let timeout = opt.to_timeout_duration(); + + match cmd.session { + SessionInput::Environment => { + start( + cmd, + Session::tcp_connect_timeout(SessionInfo::from_environment()?, timeout).await?, + None, + ) + .await + } + SessionInput::File => { + let path = cmd.session_data.session_file.clone(); + start( + cmd, + Session::tcp_connect_timeout( + SessionInfoFile::load_from(path).await?.into(), + timeout, + ) + .await?, + None, + ) + .await + } + SessionInput::Pipe => { + start( + cmd, + Session::tcp_connect_timeout(SessionInfo::from_stdin()?, timeout).await?, + None, + ) + .await + } + SessionInput::Lsp => { + let mut data = + LspData::from_buf_reader(&mut std::io::stdin().lock()).map_err(io::Error::from)?; + let info = data.take_session_info().map_err(io::Error::from)?; + start( + cmd, + Session::tcp_connect_timeout(info, timeout).await?, + Some(data), + ) + .await + } + #[cfg(unix)] + SessionInput::Socket => { + let path = cmd.session_data.session_socket.clone(); + start( + cmd, + Session::unix_connect_timeout(path, None, timeout).await?, + None, + ) + .await + } + } +} + +async fn start( + cmd: LspSubcommand, + session: Session, + lsp_data: Option, +) -> Result<(), Error> +where + T: DataStream + 'static, +{ + let mut proc = + RemoteLspProcess::spawn(client::new_tenant(), session, cmd.cmd, cmd.args).await?; + + // If we also parsed an LSP's initialize request for its session, we want to forward + // it along in the case of a process call + if let Some(data) = lsp_data { + proc.stdin + .as_mut() + .unwrap() + .write(&data.to_string()) + .await?; + } + + // Now, map the remote LSP server's stdin/stdout/stderr to our own process + let link = RemoteProcessLink::from_remote_lsp_pipes( + proc.stdin.take().unwrap(), + proc.stdout.take().unwrap(), + proc.stderr.take().unwrap(), + ); + + let (success, exit_code) = proc.wait().await?; + + // Shut down our link + link.shutdown().await; + + if !success { + if let Some(code) = exit_code { + return Err(Error::BadProcessExit(code)); + } else { + return Err(Error::BadProcessExit(1)); + } + } + + Ok(()) +} diff --git a/src/cli/subcommand/mod.rs b/src/cli/subcommand/mod.rs index 90771a3..0ae0742 100644 --- a/src/cli/subcommand/mod.rs +++ b/src/cli/subcommand/mod.rs @@ -1,3 +1,4 @@ pub mod action; pub mod launch; pub mod listen; +pub mod lsp; diff --git a/src/core/client/lsp/mod.rs b/src/core/client/lsp/mod.rs index ac82fc9..40fe0ae 100644 --- a/src/core/client/lsp/mod.rs +++ b/src/core/client/lsp/mod.rs @@ -41,6 +41,11 @@ impl RemoteLspProcess { stderr, }) } + + /// Waits for the process to terminate, returning the success status and an optional exit code + pub async fn wait(self) -> Result<(bool, Option), RemoteProcessError> { + self.inner.wait().await + } } impl Deref for RemoteLspProcess { diff --git a/src/core/client/process.rs b/src/core/client/process.rs index babc746..5491a5f 100644 --- a/src/core/client/process.rs +++ b/src/core/client/process.rs @@ -10,7 +10,6 @@ use tokio::{ sync::mpsc, task::{JoinError, JoinHandle}, }; -use tokio_stream::{wrappers::BroadcastStream, StreamExt}; #[derive(Debug, Display, Error, From)] pub enum RemoteProcessError { @@ -20,10 +19,6 @@ pub enum RemoteProcessError { /// When attempting to relay stdout/stderr over channels, but the channels fail ChannelDead, - /// When process is unable to read stdout/stderr from the server - /// fast enough, resulting in dropped data - Overloaded, - /// When the communication over the wire has issues TransportError(TransportError), @@ -97,9 +92,9 @@ impl RemoteProcess { // Now we spawn a task to handle future responses that are async // such as ProcStdout, ProcStderr, and ProcDone - let stream = session.to_response_broadcast_stream(); + let broadcast = session.broadcast.take().unwrap(); let res_task = tokio::spawn(async move { - process_incoming_responses(id, stream, stdout_tx, stderr_tx).await + process_incoming_responses(id, broadcast, stdout_tx, stderr_tx).await }); // Spawn a task that takes stdin from our channel and forwards it to the remote process @@ -234,53 +229,45 @@ where /// Helper function that loops, processing incoming stdout & stderr requests from a remote process async fn process_incoming_responses( proc_id: usize, - mut stream: BroadcastStream, + mut broadcast: mpsc::Receiver, stdout_tx: mpsc::Sender, stderr_tx: mpsc::Sender, ) -> Result<(bool, Option), RemoteProcessError> { let mut result = Err(RemoteProcessError::UnexpectedEof); - while let Some(res) = stream.next().await { - match res { - Ok(res) => { - // Check if any of the payload data is the termination - let exit_status = res.payload.iter().find_map(|data| match data { - ResponseData::ProcDone { id, success, code } if *id == proc_id => { - Some((*success, *code)) - } - _ => None, - }); - - // Next, check for stdout/stderr and send them along our channels - // TODO: What should we do about unexpected data? For now, just ignore - for data in res.payload { - match data { - ResponseData::ProcStdout { id, data } if id == proc_id => { - if let Err(_) = stdout_tx.send(data).await { - result = Err(RemoteProcessError::ChannelDead); - break; - } - } - ResponseData::ProcStderr { id, data } if id == proc_id => { - if let Err(_) = stderr_tx.send(data).await { - result = Err(RemoteProcessError::ChannelDead); - break; - } - } - _ => {} + while let Some(res) = broadcast.recv().await { + // Check if any of the payload data is the termination + let exit_status = res.payload.iter().find_map(|data| match data { + ResponseData::ProcDone { id, success, code } if *id == proc_id => { + Some((*success, *code)) + } + _ => None, + }); + + // Next, check for stdout/stderr and send them along our channels + // TODO: What should we do about unexpected data? For now, just ignore + for data in res.payload { + match data { + ResponseData::ProcStdout { id, data } if id == proc_id => { + if let Err(_) = stdout_tx.send(data).await { + result = Err(RemoteProcessError::ChannelDead); + break; } } - - // If we got a termination, then exit accordingly - if let Some((success, code)) = exit_status { - result = Ok((success, code)); - break; + ResponseData::ProcStderr { id, data } if id == proc_id => { + if let Err(_) = stderr_tx.send(data).await { + result = Err(RemoteProcessError::ChannelDead); + break; + } } + _ => {} } - Err(_) => { - result = Err(RemoteProcessError::Overloaded); - break; - } + } + + // If we got a termination, then exit accordingly + if let Some((success, code)) = exit_status { + result = Ok((success, code)); + break; } } diff --git a/src/core/client/session/mod.rs b/src/core/client/session/mod.rs index d43a4d5..7596541 100644 --- a/src/core/client/session/mod.rs +++ b/src/core/client/session/mod.rs @@ -13,11 +13,10 @@ use std::{ use tokio::{ io, net::TcpStream, - sync::{broadcast, oneshot}, + sync::{mpsc, oneshot}, task::{JoinError, JoinHandle}, time::Duration, }; -use tokio_stream::wrappers::BroadcastStream; mod info; pub use info::{SessionInfo, SessionInfoFile, SessionInfoParseError}; @@ -35,16 +34,11 @@ where /// Collection of callbacks to be invoked upon receiving a response to a request callbacks: Callbacks, - /// Callback to trigger when a response is received without an origin or with an origin - /// not found in the list of callbacks - broadcast: broadcast::Sender, - - /// Represents an initial receiver for broadcasted responses that can capture responses - /// prior to a stream being established and consumed - init_broadcast_receiver: Option>, - /// Contains the task that is running to receive responses from a server response_task: JoinHandle<()>, + + /// Represents the receiver for broadcasted responses (ones with no callback) + pub broadcast: Option>, } impl Session { @@ -116,12 +110,10 @@ where pub async fn initialize(transport: Transport) -> io::Result { let (mut t_read, t_write) = transport.into_split(); let callbacks: Callbacks = Arc::new(Mutex::new(HashMap::new())); - let (broadcast, init_broadcast_receiver) = - broadcast::channel(CLIENT_BROADCAST_CHANNEL_CAPACITY); + let (broadcast_tx, broadcast_rx) = mpsc::channel(CLIENT_BROADCAST_CHANNEL_CAPACITY); // Start a task that continually checks for responses and triggers callbacks let callbacks_2 = Arc::clone(&callbacks); - let broadcast_2 = broadcast.clone(); let response_task = tokio::spawn(async move { loop { match t_read.receive::().await { @@ -142,7 +134,7 @@ where // Otherwise, this goes into the junk draw of response handlers } else { trace!("Callback missing for response! Broadcasting!"); - if let Err(x) = broadcast_2.send(res) { + if let Err(x) = broadcast_tx.send(res).await { error!("Failed to trigger broadcast: {}", x); } } @@ -159,8 +151,7 @@ where Ok(Self { t_write, callbacks, - broadcast, - init_broadcast_receiver: Some(init_broadcast_receiver), + broadcast: Some(broadcast_rx), response_task, }) } @@ -220,21 +211,6 @@ where .map_err(TransportError::from) .and_then(convert::identity) } - - /// Clones a new instance of the broadcaster used by the session - pub fn to_response_broadcaster(&self) -> broadcast::Sender { - self.broadcast.clone() - } - - /// Creates and returns a new stream of responses that are received that do not match the - /// response to a `send` request - pub fn to_response_broadcast_stream(&mut self) -> BroadcastStream { - BroadcastStream::new( - self.init_broadcast_receiver - .take() - .unwrap_or_else(|| self.broadcast.subscribe()), - ) - } } #[cfg(test)] diff --git a/src/core/net/listener.rs b/src/core/net/listener.rs index 5c1795b..997d87d 100644 --- a/src/core/net/listener.rs +++ b/src/core/net/listener.rs @@ -46,3 +46,26 @@ impl Listener for tokio::net::UnixListener { Box::pin(accept(self)) } } + +#[cfg(test)] +impl Listener for tokio::sync::Mutex> { + type Conn = T; + + fn accept<'a>(&'a self) -> Pin> + Send + 'a>> + where + Self: Sync + 'a, + { + async fn accept( + _self: &tokio::sync::Mutex>, + ) -> io::Result { + _self + .lock() + .await + .recv() + .await + .ok_or_else(|| io::Error::from(io::ErrorKind::BrokenPipe)) + } + + Box::pin(accept(self)) + } +} diff --git a/src/core/server/distant/handler.rs b/src/core/server/distant/handler.rs index 375d574..8941e74 100644 --- a/src/core/server/distant/handler.rs +++ b/src/core/server/distant/handler.rs @@ -457,7 +457,7 @@ async fn proc_run( // Spawn a task that sends stdin to the process let mut stdin = child.stdin.take().unwrap(); let (stdin_tx, mut stdin_rx) = mpsc::channel::(1); - tokio::spawn(async move { + let stdin_task = tokio::spawn(async move { while let Some(line) = stdin_rx.recv().await { if let Err(x) = stdin.write_all(line.as_bytes()).await { error!("Failed to send stdin to process {}: {}", id, x); @@ -469,9 +469,13 @@ async fn proc_run( // Spawn a task that waits on the process to exit but can also // kill the process when triggered let (kill_tx, kill_rx) = oneshot::channel(); - tokio::spawn(async move { + let wait_task = tokio::spawn(async move { tokio::select! { status = child.wait() => { + if let Err(x) = stdin_task.await { + error!("Join on stdin task failed: {}", x); + } + if let Err(x) = stderr_task.await { error!("Join on stderr task failed: {}", x); } @@ -554,6 +558,7 @@ async fn proc_run( id, stdin_tx, kill_tx, + wait_task, }; state.lock().await.push_process(conn_id, process); @@ -1280,4 +1285,155 @@ mod tests { // Also verify that the directory was actually created assert!(path.exists(), "Directory not created"); } + + #[tokio::test] + async fn remove_should_send_error_on_failure() { + todo!(); + } + + #[tokio::test] + async fn remove_should_support_deleting_a_directory() { + todo!(); + } + + #[tokio::test] + async fn remove_should_delete_nonempty_directory_if_force_is_true() { + todo!(); + } + + #[tokio::test] + async fn remove_should_support_deleting_a_single_file() { + todo!(); + } + + #[tokio::test] + async fn copy_should_send_error_on_failure() { + todo!(); + } + + #[tokio::test] + async fn copy_should_support_copying_an_entire_directory() { + todo!(); + } + + #[tokio::test] + async fn copy_should_support_copying_a_single_file() { + todo!(); + } + + #[tokio::test] + async fn rename_should_send_error_on_failure() { + todo!(); + } + + #[tokio::test] + async fn rename_should_support_renaming_an_entire_directory() { + todo!(); + } + + #[tokio::test] + async fn rename_should_support_renaming_a_single_file() { + todo!(); + } + + #[tokio::test] + async fn exists_should_send_error_on_failure() { + todo!(); + } + + #[tokio::test] + async fn exists_should_send_true_if_path_exists() { + todo!(); + } + + #[tokio::test] + async fn exists_should_send_false_if_path_does_not_exist() { + todo!(); + } + + #[tokio::test] + async fn metadata_should_send_error_on_failure() { + todo!(); + } + + #[tokio::test] + async fn metadata_should_send_back_metadata_on_file_if_exists() { + todo!(); + } + + #[tokio::test] + async fn metadata_should_send_back_metadata_on_dir_if_exists() { + todo!(); + } + + #[tokio::test] + async fn metadata_should_include_canonicalized_path_if_flag_specified() { + todo!(); + } + + #[tokio::test] + async fn proc_run_should_send_error_on_failure() { + todo!(); + } + + #[tokio::test] + async fn proc_run_should_send_back_proc_start_on_success() { + todo!(); + } + + #[tokio::test] + async fn proc_run_should_send_back_stdout_periodically_when_available() { + todo!(); + } + + #[tokio::test] + async fn proc_run_should_send_back_stderr_periodically_when_available() { + todo!(); + } + + #[tokio::test] + async fn proc_run_should_send_back_done_when_proc_finishes() { + // Make sure to verify that process also removed from state + todo!(); + } + + #[tokio::test] + async fn proc_run_should_send_back_done_when_killed() { + // Make sure to verify that process also removed from state + todo!(); + } + + #[tokio::test] + async fn proc_kill_should_send_error_on_failure() { + // Can verify that if the process is not running, will fail + todo!(); + } + + #[tokio::test] + async fn proc_kill_should_send_ok_on_success() { + // Verify that we trigger sending done + todo!(); + } + + #[tokio::test] + async fn proc_stdin_should_send_error_on_failure() { + // Can verify that if the process is not running, will fail + todo!(); + } + + #[tokio::test] + async fn proc_stdin_should_send_ok_on_success() { + // Verify that we trigger sending stdin to process + todo!(); + } + + #[tokio::test] + async fn proc_list_should_send_proc_entry_list() { + todo!(); + } + + #[tokio::test] + async fn system_info_should_send_system_info_based_on_binary() { + todo!(); + } } diff --git a/src/core/server/distant/state.rs b/src/core/server/distant/state.rs index 43ba9ec..f8801ab 100644 --- a/src/core/server/distant/state.rs +++ b/src/core/server/distant/state.rs @@ -1,6 +1,14 @@ use log::*; -use std::collections::HashMap; -use tokio::sync::{mpsc, oneshot}; +use std::{ + collections::HashMap, + future::Future, + pin::Pin, + task::{Context, Poll}, +}; +use tokio::{ + sync::{mpsc, oneshot}, + task::{JoinError, JoinHandle}, +}; /// Holds state related to multiple clients managed by a server #[derive(Default)] @@ -62,4 +70,22 @@ pub struct Process { /// Transport channel to report that the process should be killed pub kill_tx: oneshot::Sender<()>, + + /// Task used to wait on the process to complete or be killed + pub wait_task: JoinHandle<()>, +} + +impl Process { + pub async fn kill_and_wait(self) -> Result<(), JoinError> { + let _ = self.kill_tx.send(()); + self.wait_task.await + } +} + +impl Future for Process { + type Output = Result<(), JoinError>; + + fn poll(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll { + Pin::new(&mut self.wait_task).poll(cx) + } } diff --git a/src/core/server/relay.rs b/src/core/server/relay.rs index cdaeb0d..ac5bd36 100644 --- a/src/core/server/relay.rs +++ b/src/core/server/relay.rs @@ -9,7 +9,7 @@ use log::*; use std::{collections::HashMap, marker::Unpin, sync::Arc}; use tokio::{ io::{self, AsyncRead, AsyncWrite}, - sync::{broadcast, mpsc, oneshot, Mutex}, + sync::{mpsc, oneshot, Mutex}, task::{JoinError, JoinHandle}, time::Duration, }; @@ -17,13 +17,14 @@ use tokio::{ /// Represents a server that relays requests & responses between connections and the /// actual server pub struct RelayServer { - forward_task: JoinHandle<()>, accept_task: JoinHandle<()>, + broadcast_task: JoinHandle<()>, + forward_task: JoinHandle<()>, conns: Arc>>, } impl RelayServer { - pub async fn initialize( + pub fn initialize( mut session: Session, listener: L, shutdown_after: Option, @@ -33,10 +34,34 @@ impl RelayServer { T2: DataStream + Send + 'static, L: Listener + 'static, { - // Get a copy of our session's broadcaster so we can have each connection - // subscribe to it for new messages filtered by tenant - debug!("Acquiring session broadcaster"); - let broadcaster = session.to_response_broadcaster(); + let conns: Arc>> = Arc::new(Mutex::new(HashMap::new())); + + // Spawn task to send server responses to the appropriate connections + let conns_2 = Arc::clone(&conns); + debug!("Spawning response broadcast task"); + let mut broadcast = session.broadcast.take().unwrap(); + let broadcast_task = tokio::spawn(async move { + while let Some(res) = broadcast.recv().await { + // Search for all connections with a tenant that matches the response's tenant + for conn in conns_2.lock().await.values_mut() { + if conn.state.lock().await.tenant.as_deref() == Some(res.tenant.as_str()) { + debug!( + "Forwarding response of type{} {} to connection {}", + if res.payload.len() > 1 { "s" } else { "" }, + res.to_payload_type_string(), + conn.id + ); + if let Err(x) = conn.forward_response(res).await { + error!("Failed to pass forwarding message: {}", x); + } + + // NOTE: We assume that tenant is unique, so we can break after + // forwarding the message to the first matching tenant + break; + } + } + } + }); // Spawn task to send to the server requests from connections debug!("Spawning request forwarding task"); @@ -56,7 +81,6 @@ impl RelayServer { }); let (shutdown, tracker) = ShutdownTask::maybe_initialize(shutdown_after); - let conns = Arc::new(Mutex::new(HashMap::new())); let conns_2 = Arc::clone(&conns); let accept_task = tokio::spawn(async move { let inner = async move { @@ -66,7 +90,6 @@ impl RelayServer { let result = Conn::initialize( stream, req_tx.clone(), - broadcaster.clone(), tracker.as_ref().map(Arc::clone), ) .await; @@ -96,22 +119,24 @@ impl RelayServer { }); Ok(Self { - forward_task, accept_task, + broadcast_task, + forward_task, conns, }) } pub async fn wait(self) -> Result<(), JoinError> { - match tokio::try_join!(self.forward_task, self.accept_task) { + match tokio::try_join!(self.accept_task, self.broadcast_task, self.forward_task) { Ok(_) => Ok(()), Err(x) => Err(x), } } pub async fn abort(&self) { - self.forward_task.abort(); self.accept_task.abort(); + self.broadcast_task.abort(); + self.forward_task.abort(); self.conns .lock() .await @@ -124,11 +149,14 @@ struct Conn { id: usize, req_task: JoinHandle<()>, res_task: JoinHandle<()>, + res_tx: mpsc::Sender, + state: Arc>, } /// Represents state associated with a connection #[derive(Default)] struct ConnState { + tenant: Option, processes: Vec, } @@ -136,7 +164,6 @@ impl Conn { pub async fn initialize( stream: T, req_tx: mpsc::Sender, - res_broadcaster: broadcast::Sender, ct: Option>>, ) -> io::Result where @@ -164,7 +191,7 @@ impl Conn { // Spawn task to continually receive responses from the session that // may or may not be relevant to the connection, which will filter // by tenant and then along any response that matches - let res_rx = res_broadcaster.subscribe(); + let (res_tx, res_rx) = mpsc::channel::(CLIENT_BROADCAST_CHANNEL_CAPACITY); let state_2 = Arc::clone(&state); let res_task = tokio::spawn(async move { handle_conn_outgoing(id, state_2, t_write, tenant_rx, res_rx).await; @@ -173,11 +200,12 @@ impl Conn { // Spawn task to continually read requests from connection and forward // them along to be sent via the session let req_tx = req_tx.clone(); + let state_2 = Arc::clone(&state); let req_task = tokio::spawn(async move { if let Some(ct) = ct.as_ref() { ct.lock().await.increment(); } - handle_conn_incoming(id, state, t_read, tenant_tx, req_tx).await; + handle_conn_incoming(id, state_2, t_read, tenant_tx, req_tx).await; if let Some(ct) = ct.as_ref() { ct.lock().await.decrement(); } @@ -188,6 +216,8 @@ impl Conn { id, req_task, res_task, + res_tx, + state, }) } @@ -201,6 +231,14 @@ impl Conn { self.req_task.abort(); self.res_task.abort(); } + + /// Forwards a response back through this connection + pub async fn forward_response( + &mut self, + res: Response, + ) -> Result<(), mpsc::error::SendError> { + self.res_tx.send(res).await + } } /// Conn::Request -> Session::Fire @@ -284,7 +322,7 @@ async fn handle_conn_outgoing( state: Arc>, mut writer: TransportWriteHalf, tenant_rx: oneshot::Receiver, - mut res_rx: broadcast::Receiver, + mut res_rx: mpsc::Receiver, ) where T: AsyncWrite + Unpin, { @@ -294,43 +332,81 @@ async fn handle_conn_outgoing( // all responses before we know the tenant if let Ok(tenant) = tenant_rx.await { debug!("Associated tenant {} with conn {}", tenant, conn_id); - loop { - match res_rx.recv().await { - // Forward along responses that are for our connection - Ok(res) if res.tenant == tenant => { - debug!( - "Conn {} being sent response of type{} {}", - conn_id, - if res.payload.len() > 1 { "s" } else { "" }, - res.to_payload_type_string(), - ); - - // If a new process was started, we want to capture the id and - // associate it with the connection - let ids = res.payload.iter().filter_map(|x| match x { - ResponseData::ProcStart { id } => Some(*id), - _ => None, - }); - for id in ids { - debug!("Tracking proc {} for conn {}", id, conn_id); - state.lock().await.processes.push(id); - } + state.lock().await.tenant = Some(tenant.clone()); - if let Err(x) = writer.send(res).await { - error!("Failed to send response through unix connection: {}", x); - break; - } - } - // Skip responses that are not for our connection - Ok(_) => {} - Err(x) => { - error!( - "Conn {} failed to receive broadcast response: {}", - conn_id, x - ); - break; - } + while let Some(res) = res_rx.recv().await { + debug!( + "Conn {} being sent response of type{} {}", + conn_id, + if res.payload.len() > 1 { "s" } else { "" }, + res.to_payload_type_string(), + ); + + // If a new process was started, we want to capture the id and + // associate it with the connection + let ids = res.payload.iter().filter_map(|x| match x { + ResponseData::ProcStart { id } => Some(*id), + _ => None, + }); + for id in ids { + debug!("Tracking proc {} for conn {}", id, conn_id); + state.lock().await.processes.push(id); + } + + if let Err(x) = writer.send(res).await { + error!("Failed to send response through unix connection: {}", x); + break; } } } } + +#[cfg(test)] +mod tests { + use super::*; + + #[test] + fn wait_should_return_ok_when_all_inner_tasks_complete() { + todo!(); + } + + #[test] + fn wait_should_return_error_when_server_aborted() { + todo!(); + } + + #[test] + fn abort_should_abort_inner_tasks_and_all_connections() { + todo!(); + } + + #[test] + fn server_should_shutdown_if_no_connections_after_shutdown_duration() { + todo!(); + } + + #[test] + fn server_shutdown_should_abort_all_connections() { + todo!(); + } + + #[test] + fn server_should_forward_connection_requests_to_session() { + todo!(); + } + + #[test] + fn server_should_forward_session_responses_to_connection_with_matching_tenant() { + todo!(); + } + + #[test] + fn connection_abort_should_abort_inner_tasks() { + todo!(); + } + + #[test] + fn connection_abort_should_send_process_kill_requests_through_session() { + todo!(); + } +}